zachzzc's picture
Upload tts playground and serving engine
07f1f64
import contextlib
from contextlib import contextmanager
from functools import wraps
import torch
from transformers.integrations import is_deepspeed_available
if is_deepspeed_available():
from deepspeed.utils import groups as deepspeed_groups
from deepspeed.sequence.layer import _SeqAllToAll
else:
deepspeed_groups = None
_SeqAllToAll = None
def _ceil_to_nearest(n, round_to):
return (n + round_to - 1) // round_to * round_to
def count_parameters(model, trainable_only=True):
if trainable_only:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in model.parameters())
# TODO(sxjscience) Consider to move the function to audio_processing/utils.py
def build_delay_pattern_mask(
input_ids: torch.LongTensor,
bos_token_id: int,
pad_token_id: int,
):
"""Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
In the delay pattern, each codebook is offset by the previous codebook by
one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
- [ *, *, *, *, *, P, P, P]
- [ B, *, *, *, *, *, P, P]
- [ B, B, *, *, *, *, *, P]
- [ B, B, B, *, *, *, *, *]
where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
Now let's consider the case where we have a sequence of audio tokens to condition on.
The audio tokens were originally in the following non-delayed form:
- [a, b]
- [c, d]
- [e, f]
- [g, h]
After conversion, we get the following delayed form:
- [a, b, -1, -1, -1]
- [B, c, d, -1, -1]
- [B, B, e, f, -1]
- [B, B, B, g, h]
Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
In that case, we should override the `-1` tokens in auto-regressive generation.
Args:
input_ids (:obj:`torch.LongTensor`):
The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
bos_token_id (:obj:`int`):
The id of the special delay token
pad_token_id (:obj:`int`):
The id of the padding token. Should be the same as eos_token_id.
Returns:
input_ids (:obj:`torch.LongTensor`):
The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
input_ids_with_gen_mask (:obj:`torch.LongTensor`):
The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
"""
bsz, num_codebooks, seq_len = input_ids.shape
new_seq_len = seq_len + num_codebooks - 1
input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
input_ids_with_gen_mask[bos_mask] = bos_token_id
input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
input_ids = input_ids_with_gen_mask.clone()
input_ids[eos_mask] = pad_token_id
input_ids_with_gen_mask[eos_mask] = -1
return input_ids, input_ids_with_gen_mask
def revert_delay_pattern(data):
"""Convert samples encoded with delay pattern back to the original form.
Args:
data (:obj:`torch.Tensor`):
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
Returns:
ret (:obj:`torch.Tensor`):
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
"""
assert len(data.shape) == 2
out_l = []
num_codebooks = data.shape[0]
for i in range(num_codebooks):
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
return torch.cat(out_l, dim=0)
def merge_input_ids_with_audio_features(
audio_features_embed,
audio_features_length,
audio_in_embed,
audio_in_ids_start,
audio_out_embed,
audio_out_ids_start,
audio_in_token_idx,
audio_out_token_idx,
inputs_embeds,
input_ids,
attention_mask,
label_ids,
pad_token_id,
ignore_index=-100,
round_to=8,
left_padding=True,
):
"""
Merge input_ids with audio features into final embeddings.
Args:
audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
Encoded vectors of all audios in the batch (obtained from the semantic encoder)
audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
The length of audio embeddings of each audio as stacked in `audio_features_embed`
audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
The embeddings of audio-in tokens
audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
The start index of the audio-in tokens for each audio
audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
The embeddings of audio-out tokens
audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
The start index of the audio-out tokens for each audio
audio_in_token_idx
The index of the audio-in token in the vocabulary
audio_out_token_idx
The index of the audio-out token in the vocabulary
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with audio embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with audio token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
labels need to be recalculated to support training (if provided)
pad_token_id (`int`):
The index of the pad token in the vocabulary
ignore_index
The index to ignore in the loss calculation
round_to
The number to round to for padding
left_padding
Whether to apply left padding
Returns:
final_embedding
The final embeddings after merging audio embeddings with text embeddings.
final_attention_mask
The final attention mask after merging audio embeddings with text embeddings.
final_labels
The labels for the text stream
position_ids
Positional ids for the merged data
final_input_ids
The final input_ids after merging audio embeddings with text embeddings.
final_audio_in_mask
Mask for audio-in embeddings
final_audio_in_discrete_codes_mask
Mask for audio-in discrete tokens
final_audio_out_mask
Mask for audio-out embeddings
Explanation:
each audio has variable length embeddings, with length specified by
- audio_features_length
- audio_in_ids_start
- audio_out_ids_start
Task:
- fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
- fill each <|AUDIO_OUT|> with the audio-out embeddings
Example:
<|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
<|AUDIO|>: Z (8 tokens)
X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
"""
if label_ids is None:
skip_labels = True
else:
skip_labels = False
if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
audio_features_embed = None
if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
audio_in_embed = None
if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
audio_out_embed = None
batch_size, sequence_length, embed_dim = inputs_embeds.shape
target_device = inputs_embeds.device
if left_padding is None:
left_padding = torch.any(attention_mask[:, 0] == 0)
audio_in_token_mask = input_ids == audio_in_token_idx
audio_out_token_mask = input_ids == audio_out_token_idx
text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
# 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
token_placeholder_num = torch.ones_like(input_ids)
if audio_features_embed is not None:
num_audios, max_audio_tokens, _ = audio_features_embed.shape
audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
audio_features_length.device
) < audio_features_length.unsqueeze(1)
masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
if audio_in_embed is not None:
audio_in_codes_length = torch.concat(
[
audio_in_ids_start[1:] - audio_in_ids_start[:-1],
torch.tensor(
[audio_in_embed.shape[0] - audio_in_ids_start[-1]],
device=audio_in_ids_start.device,
dtype=torch.long,
),
],
dim=0,
)
if audio_features_embed is not None:
token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
else:
token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
if audio_out_embed is not None:
audio_out_codes_length = torch.concat(
[
audio_out_ids_start[1:] - audio_out_ids_start[:-1],
torch.tensor(
[audio_out_embed.shape[0] - audio_out_ids_start[-1]],
device=audio_out_ids_start.device,
dtype=torch.long,
),
],
dim=0,
)
token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_audio_pad[:, None] # offset for left padding
# 2. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
(batch_size, max_token_num, embed_dim),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
(batch_size, max_token_num),
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
final_input_ids = torch.full(
(batch_size, max_token_num),
pad_token_id,
dtype=input_ids.dtype,
device=inputs_embeds.device,
)
if skip_labels:
final_labels = None
else:
final_labels = torch.full(
(batch_size, max_token_num),
ignore_index,
dtype=label_ids.dtype,
device=inputs_embeds.device,
)
final_audio_in_mask = torch.full(
(batch_size, max_token_num),
False,
dtype=torch.bool,
device=inputs_embeds.device,
)
final_audio_in_discrete_codes_mask = torch.full(
(batch_size, max_token_num),
False,
dtype=torch.bool,
device=inputs_embeds.device,
)
final_audio_out_mask = torch.full(
(batch_size, max_token_num),
False,
dtype=torch.bool,
device=inputs_embeds.device,
)
# 3. Get the audio-in token positions and audio-out token positions
batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
if audio_in_embed is not None:
# Fill in the audio-in embeddings
seq_indices = (
torch.arange(max_token_num, device=target_device)
.unsqueeze(0)
.expand(audio_in_ids_start.shape[0], max_token_num)
)
audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
batch_indices, col_indices = torch.where(
(seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
)
batch_indices = audio_in_batch_id[batch_indices]
final_embedding[batch_indices, col_indices] = audio_in_embed
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
if not skip_labels:
final_labels[batch_indices, col_indices] = ignore_index
final_audio_in_mask[batch_indices, col_indices] = True
final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
if audio_features_embed is not None:
# Fill in the audio features
seq_indices = (
torch.arange(max_token_num, device=target_device)
.unsqueeze(0)
.expand(audio_features_embed.shape[0], max_token_num)
)
audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
batch_indices, col_indices = torch.where(
(seq_indices >= audio_features_token_starts.unsqueeze(1))
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
)
batch_indices = audio_in_batch_id[batch_indices]
final_embedding[batch_indices, col_indices] = masked_audio_in_features
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
if not skip_labels:
final_labels[batch_indices, col_indices] = ignore_index
final_audio_in_mask[batch_indices, col_indices] = True
if audio_out_embed is not None:
# Fill in the audio-out embeddings
seq_indices = (
torch.arange(max_token_num, device=target_device)
.unsqueeze(0)
.expand(audio_out_ids_start.shape[0], max_token_num)
)
audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
batch_indices, col_indices = torch.where(
(seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
& (seq_indices <= audio_out_embed_ends.unsqueeze(1))
)
batch_indices = audio_out_batch_id[batch_indices]
final_embedding[batch_indices, col_indices] = audio_out_embed
final_input_ids[batch_indices, col_indices] = audio_out_token_idx
if not skip_labels:
final_labels[batch_indices, col_indices] = ignore_index
final_audio_out_mask[batch_indices, col_indices] = True
# Fill in the original text embeddings and labels
batch_indices, non_audio_indices = torch.where(text_token_mask)
text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
if not skip_labels:
final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
# Trim the tensor if there are redundant padding tokens
if left_padding:
first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
if first_non_zero_loc > 0:
final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
final_embedding = final_embedding[:, first_non_zero_loc:]
if not skip_labels:
final_labels = final_labels[:, first_non_zero_loc:]
final_input_ids = final_input_ids[:, first_non_zero_loc:]
final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
else:
# We have done right padding, so we need to trim the mask
last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
if last_non_zero_loc < max_token_num:
final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
final_embedding = final_embedding[:, :last_non_zero_loc]
if not skip_labels:
final_labels = final_labels[:, :last_non_zero_loc]
final_input_ids = final_input_ids[:, :last_non_zero_loc]
final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return (
final_embedding,
final_attention_mask,
final_labels,
position_ids,
final_input_ids,
final_audio_in_mask,
final_audio_in_discrete_codes_mask,
final_audio_out_mask,
)
def is_deepspeed_ulysses_enabled():
if deepspeed_groups is None:
return False
"""Check if sequence parallelism is enabled."""
return deepspeed_groups._get_sequence_parallel_world_size() > 1
def support_deepspeed_ulysses(module):
"""A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
module._sp_size = None
module._sp_rank = None
module._sp_group = None
@property
def sp_size(self):
if self._sp_size is None:
self._sp_size = 1
if is_deepspeed_ulysses_enabled():
self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
return self._sp_size
@property
def sp_rank(self):
if self._sp_rank is None:
self._sp_rank = 0
if is_deepspeed_ulysses_enabled():
self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
return self._sp_rank
@property
def sp_group(self):
if self._sp_group is None and is_deepspeed_ulysses_enabled():
self._sp_group = deepspeed_groups._get_sequence_parallel_group()
return self._sp_group
module.sp_size = sp_size
module.sp_rank = sp_rank
module.sp_group = sp_group
return module
def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
"""Perform all-to-all before and after the attention function."""
def attention_decorator(attn_func=None):
def wrapped(*args, **kwargs):
if is_deepspeed_ulysses_enabled():
sp_group = deepspeed_groups._get_sequence_parallel_group()
scatter_idx = head_dim # Scatter on num_heads dimension
gather_idx = seq_dim # Gather on seq_len dimension
batch_dim_idx = 0
args = list(args)
args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
args = tuple(args)
attn_output = attn_func(*args, **kwargs)
if is_deepspeed_ulysses_enabled():
scatter_idx = seq_dim # Scatter back on seq_len dimension
gather_idx = head_dim # Gather on num_heads dimension
batch_dim_idx = 0
attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
return attn_output
return wrapped
return attention_decorator
def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
"""Slice the corresponding cos and sin chunks for rope."""
def rope_decorator(rope_func=None):
def wrapped(*args, **kwargs):
if is_deepspeed_ulysses_enabled():
sp_rank = deepspeed_groups._get_sequence_parallel_rank()
args = list(args)
seq_chunk_size = args[0].size(state_seq_dim)
args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
args = tuple(args)
return rope_func(*args, **kwargs)
return wrapped
return rope_decorator
def _gather_tensors(input_, group=None):
"""Gather tensors and concatenate them along a dimension."""
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
tensor_shapes = [
torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
]
input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
torch.distributed.all_gather(tensor_shapes, input_size, group=group)
gathered_buffers = [
torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
]
torch.distributed.all_gather(gathered_buffers, input_, group=group)
return gathered_buffers
def _scatter_tensors(input_, group=None):
"""Scatter tensors."""
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
rank = torch.distributed.get_rank(group)
return input_[rank]
class _GatherTensors(torch.autograd.Function):
"""All gather tensors among the ranks."""
@staticmethod
def symbolic(graph, input_, group):
return _gather_tensors(input_, group)
@staticmethod
def forward(ctx, input_, group):
ctx.group = group
return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
@staticmethod
def backward(ctx, grad_output):
return _scatter_tensors(grad_output, ctx.group), None
def all_gather_tensors(input_, size=None, dim=0, group=None):
if torch.distributed.get_world_size(group) == 1:
# no sequence parallelism
return input_
gathered_tensors = _GatherTensors.apply(input_, group)
if size:
split_gathered_tensors = []
for s, gathered_tensor in zip(size, gathered_tensors):
split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
split_gathered_tensors.append(split_gathered_tensor)
gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
return torch.cat(gathered_tensors, dim).contiguous()
def get_sequence_data_parallel_world_size():
return torch.distributed.get_world_size()
def get_sequence_data_parallel_rank():
return torch.distributed.get_rank()
def get_sequence_data_parallel_group():
return torch.distributed.group.WORLD
if is_deepspeed_available():
deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
def _gather_tokens(input_, dim=0, group=None):
"""Gather tensors and concatenate them along a dimension"""
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
if dim == 0:
shape = list(input_.size())
shape[0] = shape[0] * world_size
output = gather_buffer.view(shape)
else:
tensor_list = [
gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
]
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
def _drop_tokens(input_, dim=0, group=None):
"""Divide a tensor among the sequence parallel ranks"""
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
this_rank = torch.distributed.get_rank(group)
assert input_.shape[dim] % world_size == 0, (
f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
)
chunk_size = input_.shape[dim] // world_size
return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
class _DropTokens(torch.autograd.Function):
"Divide tokens equally among the sequence parallel ranks"
@staticmethod
def symbolic(graph, input_, dim, group, grad_scale):
return _drop_tokens(input_, dim, group)
@staticmethod
def forward(ctx, input_, dim, group, grad_scale):
ctx.dim = dim
ctx.group = group
ctx.grad_scale = grad_scale
return _drop_tokens(input_, dim, group)
@staticmethod
def backward(ctx, grad_output):
grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
if ctx.grad_scale != 1:
grad_input /= ctx.grad_scale
return grad_input, None, None, None
class _GatherTokens(torch.autograd.Function):
"Gather tokens among the sequence parallel ranks"
@staticmethod
def symbolic(graph, input_, dim, group, grad_scale):
return _gather_tokens(input_, dim, group)
@staticmethod
def forward(ctx, input_, dim, group, grad_scale):
ctx.dim = dim
ctx.group = group
ctx.grad_scale = grad_scale
return _gather_tokens(input_, dim, group)
@staticmethod
def backward(ctx, grad_output):
grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
if ctx.grad_scale != 1:
grad_input *= ctx.grad_scale
return grad_input, None, None, None
def drop_tokens(input_, dim=0, group=None, grad_scale=1):
if torch.distributed.get_world_size(group) == 1:
# no sequence parallelism
return input_
return _DropTokens.apply(input_, dim, group, grad_scale)
def gather_tokens(input_, dim=0, group=None, grad_scale=1):
if torch.distributed.get_world_size(group) == 1:
# no sequence parallelism
return input_
return _GatherTokens.apply(input_, dim, group, grad_scale)
def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
"""
Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
Args:
sp_size (`int`):
Sequence parallel size.
sp_rank (`int`):
Sequence parallel rank for the current process.
dim (`int`):
The dimension to slice
"""
if sp_size == 1:
return args[0] if len(args) == 1 else args
seq_length = args[0].size(dim)
for arg in args[1:]:
assert arg.size(dim) == seq_length, (
f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
)
assert seq_length % sp_size == 0, (
f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
)
sub_seq_length = seq_length // sp_size
sub_seq_start = sp_rank * sub_seq_length
output = []
for ind in args:
ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
output.append(ind)
return tuple(output) if len(output) > 1 else output[0]
@contextmanager
def disable_deepspeed_ulysses():
"""Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
if is_deepspeed_ulysses_enabled():
_old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
def _get_sequence_parallel_world_size():
return 1
deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
try:
yield
finally:
deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
else:
context = contextlib.nullcontext
with context():
yield