Spaces:
Running
on
Zero
Running
on
Zero
File size: 30,819 Bytes
07f1f64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 |
import contextlib
from contextlib import contextmanager
from functools import wraps
import torch
from transformers.integrations import is_deepspeed_available
if is_deepspeed_available():
from deepspeed.utils import groups as deepspeed_groups
from deepspeed.sequence.layer import _SeqAllToAll
else:
deepspeed_groups = None
_SeqAllToAll = None
def _ceil_to_nearest(n, round_to):
return (n + round_to - 1) // round_to * round_to
def count_parameters(model, trainable_only=True):
if trainable_only:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in model.parameters())
# TODO(sxjscience) Consider to move the function to audio_processing/utils.py
def build_delay_pattern_mask(
input_ids: torch.LongTensor,
bos_token_id: int,
pad_token_id: int,
):
"""Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
In the delay pattern, each codebook is offset by the previous codebook by
one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
- [ *, *, *, *, *, P, P, P]
- [ B, *, *, *, *, *, P, P]
- [ B, B, *, *, *, *, *, P]
- [ B, B, B, *, *, *, *, *]
where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
Now let's consider the case where we have a sequence of audio tokens to condition on.
The audio tokens were originally in the following non-delayed form:
- [a, b]
- [c, d]
- [e, f]
- [g, h]
After conversion, we get the following delayed form:
- [a, b, -1, -1, -1]
- [B, c, d, -1, -1]
- [B, B, e, f, -1]
- [B, B, B, g, h]
Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
In that case, we should override the `-1` tokens in auto-regressive generation.
Args:
input_ids (:obj:`torch.LongTensor`):
The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
bos_token_id (:obj:`int`):
The id of the special delay token
pad_token_id (:obj:`int`):
The id of the padding token. Should be the same as eos_token_id.
Returns:
input_ids (:obj:`torch.LongTensor`):
The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
input_ids_with_gen_mask (:obj:`torch.LongTensor`):
The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
"""
bsz, num_codebooks, seq_len = input_ids.shape
new_seq_len = seq_len + num_codebooks - 1
input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
input_ids_with_gen_mask[bos_mask] = bos_token_id
input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
input_ids = input_ids_with_gen_mask.clone()
input_ids[eos_mask] = pad_token_id
input_ids_with_gen_mask[eos_mask] = -1
return input_ids, input_ids_with_gen_mask
def revert_delay_pattern(data):
"""Convert samples encoded with delay pattern back to the original form.
Args:
data (:obj:`torch.Tensor`):
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
Returns:
ret (:obj:`torch.Tensor`):
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
"""
assert len(data.shape) == 2
out_l = []
num_codebooks = data.shape[0]
for i in range(num_codebooks):
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
return torch.cat(out_l, dim=0)
def merge_input_ids_with_audio_features(
audio_features_embed,
audio_features_length,
audio_in_embed,
audio_in_ids_start,
audio_out_embed,
audio_out_ids_start,
audio_in_token_idx,
audio_out_token_idx,
inputs_embeds,
input_ids,
attention_mask,
label_ids,
pad_token_id,
ignore_index=-100,
round_to=8,
left_padding=True,
):
"""
Merge input_ids with audio features into final embeddings.
Args:
audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
Encoded vectors of all audios in the batch (obtained from the semantic encoder)
audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
The length of audio embeddings of each audio as stacked in `audio_features_embed`
audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
The embeddings of audio-in tokens
audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
The start index of the audio-in tokens for each audio
audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
The embeddings of audio-out tokens
audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
The start index of the audio-out tokens for each audio
audio_in_token_idx
The index of the audio-in token in the vocabulary
audio_out_token_idx
The index of the audio-out token in the vocabulary
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with audio embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with audio token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
labels need to be recalculated to support training (if provided)
pad_token_id (`int`):
The index of the pad token in the vocabulary
ignore_index
The index to ignore in the loss calculation
round_to
The number to round to for padding
left_padding
Whether to apply left padding
Returns:
final_embedding
The final embeddings after merging audio embeddings with text embeddings.
final_attention_mask
The final attention mask after merging audio embeddings with text embeddings.
final_labels
The labels for the text stream
position_ids
Positional ids for the merged data
final_input_ids
The final input_ids after merging audio embeddings with text embeddings.
final_audio_in_mask
Mask for audio-in embeddings
final_audio_in_discrete_codes_mask
Mask for audio-in discrete tokens
final_audio_out_mask
Mask for audio-out embeddings
Explanation:
each audio has variable length embeddings, with length specified by
- audio_features_length
- audio_in_ids_start
- audio_out_ids_start
Task:
- fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
- fill each <|AUDIO_OUT|> with the audio-out embeddings
Example:
<|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
<|AUDIO|>: Z (8 tokens)
X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
"""
if label_ids is None:
skip_labels = True
else:
skip_labels = False
if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
audio_features_embed = None
if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
audio_in_embed = None
if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
audio_out_embed = None
batch_size, sequence_length, embed_dim = inputs_embeds.shape
target_device = inputs_embeds.device
if left_padding is None:
left_padding = torch.any(attention_mask[:, 0] == 0)
audio_in_token_mask = input_ids == audio_in_token_idx
audio_out_token_mask = input_ids == audio_out_token_idx
text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
# 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
token_placeholder_num = torch.ones_like(input_ids)
if audio_features_embed is not None:
num_audios, max_audio_tokens, _ = audio_features_embed.shape
audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
audio_features_length.device
) < audio_features_length.unsqueeze(1)
masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
if audio_in_embed is not None:
audio_in_codes_length = torch.concat(
[
audio_in_ids_start[1:] - audio_in_ids_start[:-1],
torch.tensor(
[audio_in_embed.shape[0] - audio_in_ids_start[-1]],
device=audio_in_ids_start.device,
dtype=torch.long,
),
],
dim=0,
)
if audio_features_embed is not None:
token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
else:
token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
if audio_out_embed is not None:
audio_out_codes_length = torch.concat(
[
audio_out_ids_start[1:] - audio_out_ids_start[:-1],
torch.tensor(
[audio_out_embed.shape[0] - audio_out_ids_start[-1]],
device=audio_out_ids_start.device,
dtype=torch.long,
),
],
dim=0,
)
token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_audio_pad[:, None] # offset for left padding
# 2. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
(batch_size, max_token_num, embed_dim),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
(batch_size, max_token_num),
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
final_input_ids = torch.full(
(batch_size, max_token_num),
pad_token_id,
dtype=input_ids.dtype,
device=inputs_embeds.device,
)
if skip_labels:
final_labels = None
else:
final_labels = torch.full(
(batch_size, max_token_num),
ignore_index,
dtype=label_ids.dtype,
device=inputs_embeds.device,
)
final_audio_in_mask = torch.full(
(batch_size, max_token_num),
False,
dtype=torch.bool,
device=inputs_embeds.device,
)
final_audio_in_discrete_codes_mask = torch.full(
(batch_size, max_token_num),
False,
dtype=torch.bool,
device=inputs_embeds.device,
)
final_audio_out_mask = torch.full(
(batch_size, max_token_num),
False,
dtype=torch.bool,
device=inputs_embeds.device,
)
# 3. Get the audio-in token positions and audio-out token positions
batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
if audio_in_embed is not None:
# Fill in the audio-in embeddings
seq_indices = (
torch.arange(max_token_num, device=target_device)
.unsqueeze(0)
.expand(audio_in_ids_start.shape[0], max_token_num)
)
audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
batch_indices, col_indices = torch.where(
(seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
)
batch_indices = audio_in_batch_id[batch_indices]
final_embedding[batch_indices, col_indices] = audio_in_embed
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
if not skip_labels:
final_labels[batch_indices, col_indices] = ignore_index
final_audio_in_mask[batch_indices, col_indices] = True
final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
if audio_features_embed is not None:
# Fill in the audio features
seq_indices = (
torch.arange(max_token_num, device=target_device)
.unsqueeze(0)
.expand(audio_features_embed.shape[0], max_token_num)
)
audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
batch_indices, col_indices = torch.where(
(seq_indices >= audio_features_token_starts.unsqueeze(1))
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
)
batch_indices = audio_in_batch_id[batch_indices]
final_embedding[batch_indices, col_indices] = masked_audio_in_features
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
if not skip_labels:
final_labels[batch_indices, col_indices] = ignore_index
final_audio_in_mask[batch_indices, col_indices] = True
if audio_out_embed is not None:
# Fill in the audio-out embeddings
seq_indices = (
torch.arange(max_token_num, device=target_device)
.unsqueeze(0)
.expand(audio_out_ids_start.shape[0], max_token_num)
)
audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
batch_indices, col_indices = torch.where(
(seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
& (seq_indices <= audio_out_embed_ends.unsqueeze(1))
)
batch_indices = audio_out_batch_id[batch_indices]
final_embedding[batch_indices, col_indices] = audio_out_embed
final_input_ids[batch_indices, col_indices] = audio_out_token_idx
if not skip_labels:
final_labels[batch_indices, col_indices] = ignore_index
final_audio_out_mask[batch_indices, col_indices] = True
# Fill in the original text embeddings and labels
batch_indices, non_audio_indices = torch.where(text_token_mask)
text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
if not skip_labels:
final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
# Trim the tensor if there are redundant padding tokens
if left_padding:
first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
if first_non_zero_loc > 0:
final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
final_embedding = final_embedding[:, first_non_zero_loc:]
if not skip_labels:
final_labels = final_labels[:, first_non_zero_loc:]
final_input_ids = final_input_ids[:, first_non_zero_loc:]
final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
else:
# We have done right padding, so we need to trim the mask
last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
if last_non_zero_loc < max_token_num:
final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
final_embedding = final_embedding[:, :last_non_zero_loc]
if not skip_labels:
final_labels = final_labels[:, :last_non_zero_loc]
final_input_ids = final_input_ids[:, :last_non_zero_loc]
final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return (
final_embedding,
final_attention_mask,
final_labels,
position_ids,
final_input_ids,
final_audio_in_mask,
final_audio_in_discrete_codes_mask,
final_audio_out_mask,
)
def is_deepspeed_ulysses_enabled():
if deepspeed_groups is None:
return False
"""Check if sequence parallelism is enabled."""
return deepspeed_groups._get_sequence_parallel_world_size() > 1
def support_deepspeed_ulysses(module):
"""A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
module._sp_size = None
module._sp_rank = None
module._sp_group = None
@property
def sp_size(self):
if self._sp_size is None:
self._sp_size = 1
if is_deepspeed_ulysses_enabled():
self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
return self._sp_size
@property
def sp_rank(self):
if self._sp_rank is None:
self._sp_rank = 0
if is_deepspeed_ulysses_enabled():
self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
return self._sp_rank
@property
def sp_group(self):
if self._sp_group is None and is_deepspeed_ulysses_enabled():
self._sp_group = deepspeed_groups._get_sequence_parallel_group()
return self._sp_group
module.sp_size = sp_size
module.sp_rank = sp_rank
module.sp_group = sp_group
return module
def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
"""Perform all-to-all before and after the attention function."""
def attention_decorator(attn_func=None):
def wrapped(*args, **kwargs):
if is_deepspeed_ulysses_enabled():
sp_group = deepspeed_groups._get_sequence_parallel_group()
scatter_idx = head_dim # Scatter on num_heads dimension
gather_idx = seq_dim # Gather on seq_len dimension
batch_dim_idx = 0
args = list(args)
args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
args = tuple(args)
attn_output = attn_func(*args, **kwargs)
if is_deepspeed_ulysses_enabled():
scatter_idx = seq_dim # Scatter back on seq_len dimension
gather_idx = head_dim # Gather on num_heads dimension
batch_dim_idx = 0
attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
return attn_output
return wrapped
return attention_decorator
def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
"""Slice the corresponding cos and sin chunks for rope."""
def rope_decorator(rope_func=None):
def wrapped(*args, **kwargs):
if is_deepspeed_ulysses_enabled():
sp_rank = deepspeed_groups._get_sequence_parallel_rank()
args = list(args)
seq_chunk_size = args[0].size(state_seq_dim)
args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
args = tuple(args)
return rope_func(*args, **kwargs)
return wrapped
return rope_decorator
def _gather_tensors(input_, group=None):
"""Gather tensors and concatenate them along a dimension."""
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
tensor_shapes = [
torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
]
input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
torch.distributed.all_gather(tensor_shapes, input_size, group=group)
gathered_buffers = [
torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
]
torch.distributed.all_gather(gathered_buffers, input_, group=group)
return gathered_buffers
def _scatter_tensors(input_, group=None):
"""Scatter tensors."""
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
rank = torch.distributed.get_rank(group)
return input_[rank]
class _GatherTensors(torch.autograd.Function):
"""All gather tensors among the ranks."""
@staticmethod
def symbolic(graph, input_, group):
return _gather_tensors(input_, group)
@staticmethod
def forward(ctx, input_, group):
ctx.group = group
return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
@staticmethod
def backward(ctx, grad_output):
return _scatter_tensors(grad_output, ctx.group), None
def all_gather_tensors(input_, size=None, dim=0, group=None):
if torch.distributed.get_world_size(group) == 1:
# no sequence parallelism
return input_
gathered_tensors = _GatherTensors.apply(input_, group)
if size:
split_gathered_tensors = []
for s, gathered_tensor in zip(size, gathered_tensors):
split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
split_gathered_tensors.append(split_gathered_tensor)
gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
return torch.cat(gathered_tensors, dim).contiguous()
def get_sequence_data_parallel_world_size():
return torch.distributed.get_world_size()
def get_sequence_data_parallel_rank():
return torch.distributed.get_rank()
def get_sequence_data_parallel_group():
return torch.distributed.group.WORLD
if is_deepspeed_available():
deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
def _gather_tokens(input_, dim=0, group=None):
"""Gather tensors and concatenate them along a dimension"""
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
if dim == 0:
shape = list(input_.size())
shape[0] = shape[0] * world_size
output = gather_buffer.view(shape)
else:
tensor_list = [
gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
]
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
def _drop_tokens(input_, dim=0, group=None):
"""Divide a tensor among the sequence parallel ranks"""
world_size = torch.distributed.get_world_size(group)
if world_size == 1:
return input_
this_rank = torch.distributed.get_rank(group)
assert input_.shape[dim] % world_size == 0, (
f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
)
chunk_size = input_.shape[dim] // world_size
return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
class _DropTokens(torch.autograd.Function):
"Divide tokens equally among the sequence parallel ranks"
@staticmethod
def symbolic(graph, input_, dim, group, grad_scale):
return _drop_tokens(input_, dim, group)
@staticmethod
def forward(ctx, input_, dim, group, grad_scale):
ctx.dim = dim
ctx.group = group
ctx.grad_scale = grad_scale
return _drop_tokens(input_, dim, group)
@staticmethod
def backward(ctx, grad_output):
grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
if ctx.grad_scale != 1:
grad_input /= ctx.grad_scale
return grad_input, None, None, None
class _GatherTokens(torch.autograd.Function):
"Gather tokens among the sequence parallel ranks"
@staticmethod
def symbolic(graph, input_, dim, group, grad_scale):
return _gather_tokens(input_, dim, group)
@staticmethod
def forward(ctx, input_, dim, group, grad_scale):
ctx.dim = dim
ctx.group = group
ctx.grad_scale = grad_scale
return _gather_tokens(input_, dim, group)
@staticmethod
def backward(ctx, grad_output):
grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
if ctx.grad_scale != 1:
grad_input *= ctx.grad_scale
return grad_input, None, None, None
def drop_tokens(input_, dim=0, group=None, grad_scale=1):
if torch.distributed.get_world_size(group) == 1:
# no sequence parallelism
return input_
return _DropTokens.apply(input_, dim, group, grad_scale)
def gather_tokens(input_, dim=0, group=None, grad_scale=1):
if torch.distributed.get_world_size(group) == 1:
# no sequence parallelism
return input_
return _GatherTokens.apply(input_, dim, group, grad_scale)
def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
"""
Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
Args:
sp_size (`int`):
Sequence parallel size.
sp_rank (`int`):
Sequence parallel rank for the current process.
dim (`int`):
The dimension to slice
"""
if sp_size == 1:
return args[0] if len(args) == 1 else args
seq_length = args[0].size(dim)
for arg in args[1:]:
assert arg.size(dim) == seq_length, (
f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
)
assert seq_length % sp_size == 0, (
f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
)
sub_seq_length = seq_length // sp_size
sub_seq_start = sp_rank * sub_seq_length
output = []
for ind in args:
ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
output.append(ind)
return tuple(output) if len(output) > 1 else output[0]
@contextmanager
def disable_deepspeed_ulysses():
"""Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
if is_deepspeed_ulysses_enabled():
_old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
def _get_sequence_parallel_world_size():
return 1
deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
try:
yield
finally:
deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
else:
context = contextlib.nullcontext
with context():
yield
|