Changes in modelling_RW.py to be able to handle past_key_values for faster model generations
#85
by
puru22
- opened
- modelling_RW.py +69 -30
modelling_RW.py
CHANGED
|
@@ -87,10 +87,18 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 87 |
|
| 88 |
return self.cos_cached, self.sin_cached
|
| 89 |
|
| 90 |
-
def forward(self, q, k):
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def _make_causal_mask(
|
|
@@ -100,10 +108,10 @@ def _make_causal_mask(
|
|
| 100 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
| 101 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
| 102 |
seq_ids = torch.arange(target_length, device=device)
|
| 103 |
-
mask[:, past_key_values_length:] = seq_ids[:, None]
|
| 104 |
|
| 105 |
if past_key_values_length > 0:
|
| 106 |
-
mask[:, :past_key_values_length] =
|
| 107 |
|
| 108 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
| 109 |
return expanded_mask
|
|
@@ -248,6 +256,7 @@ class Attention(nn.Module):
|
|
| 248 |
head_mask: Optional[torch.Tensor] = None,
|
| 249 |
use_cache: bool = False,
|
| 250 |
output_attentions: bool = False,
|
|
|
|
| 251 |
):
|
| 252 |
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
| 253 |
|
|
@@ -264,31 +273,43 @@ class Attention(nn.Module):
|
|
| 264 |
)
|
| 265 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 266 |
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
if layer_past is not None:
|
| 270 |
past_key, past_value = layer_past
|
| 271 |
# concatenate along seq_length dimension:
|
| 272 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
| 273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
|
|
| 274 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
| 275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 276 |
|
| 277 |
_, kv_length, _ = key_layer.shape
|
| 278 |
|
| 279 |
if use_cache is True:
|
| 280 |
-
|
|
|
|
| 281 |
else:
|
| 282 |
present = None
|
| 283 |
-
|
| 284 |
if alibi is None:
|
| 285 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 286 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 287 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 294 |
x = x.permute(0, 2, 1, 3)
|
|
@@ -385,6 +406,7 @@ class DecoderLayer(nn.Module):
|
|
| 385 |
head_mask: Optional[torch.Tensor] = None,
|
| 386 |
use_cache: bool = False,
|
| 387 |
output_attentions: bool = False,
|
|
|
|
| 388 |
):
|
| 389 |
|
| 390 |
ln_attn = self.ln_attn(hidden_states)
|
|
@@ -401,6 +423,7 @@ class DecoderLayer(nn.Module):
|
|
| 401 |
head_mask=head_mask,
|
| 402 |
use_cache=use_cache,
|
| 403 |
output_attentions=output_attentions,
|
|
|
|
| 404 |
)
|
| 405 |
|
| 406 |
attention_output = attn_outputs[0]
|
|
@@ -528,10 +551,10 @@ class RWModel(RWPreTrainedModel):
|
|
| 528 |
device = attention_mask.device
|
| 529 |
_, src_length = input_shape
|
| 530 |
|
| 531 |
-
if src_length > 1:
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
|
| 536 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
| 537 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
@@ -651,15 +674,28 @@ class RWModel(RWPreTrainedModel):
|
|
| 651 |
head_mask[i],
|
| 652 |
)
|
| 653 |
else:
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
hidden_states = outputs[0]
|
| 665 |
if use_cache is True:
|
|
@@ -710,16 +746,19 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
| 710 |
**kwargs,
|
| 711 |
) -> dict:
|
| 712 |
# only last token for input_ids if past is not None
|
| 713 |
-
if
|
| 714 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 715 |
-
|
| 716 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
| 717 |
-
if
|
| 718 |
-
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
return {
|
| 721 |
"input_ids": input_ids,
|
| 722 |
-
"past_key_values":
|
| 723 |
"use_cache": kwargs.get("use_cache"),
|
| 724 |
"attention_mask": attention_mask,
|
| 725 |
}
|
|
|
|
| 87 |
|
| 88 |
return self.cos_cached, self.sin_cached
|
| 89 |
|
| 90 |
+
def forward(self, q, k, past_seq_length=None):
|
| 91 |
+
if past_seq_length == None :
|
| 92 |
+
batch, seq_len, head_dim = q.shape
|
| 93 |
+
else :
|
| 94 |
+
# print("past_seq_length", past_seq_length)
|
| 95 |
+
batch, input_seq_len, head_dim = q.shape
|
| 96 |
+
seq_len = past_seq_length + input_seq_len
|
| 97 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 98 |
+
if past_seq_length != None :
|
| 99 |
+
return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
|
| 100 |
+
else :
|
| 101 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 102 |
|
| 103 |
|
| 104 |
def _make_causal_mask(
|
|
|
|
| 108 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
| 109 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
| 110 |
seq_ids = torch.arange(target_length, device=device)
|
| 111 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
|
| 112 |
|
| 113 |
if past_key_values_length > 0:
|
| 114 |
+
mask[:, :past_key_values_length] = True
|
| 115 |
|
| 116 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
| 117 |
return expanded_mask
|
|
|
|
| 256 |
head_mask: Optional[torch.Tensor] = None,
|
| 257 |
use_cache: bool = False,
|
| 258 |
output_attentions: bool = False,
|
| 259 |
+
layer_number = None
|
| 260 |
):
|
| 261 |
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
| 262 |
|
|
|
|
| 273 |
)
|
| 274 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 275 |
|
| 276 |
+
if layer_past is not None :
|
| 277 |
+
past_key, past_value = layer_past
|
| 278 |
+
past_kv_length = past_key.shape[2]
|
| 279 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
| 280 |
+
else :
|
| 281 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 282 |
|
| 283 |
if layer_past is not None:
|
| 284 |
past_key, past_value = layer_past
|
| 285 |
# concatenate along seq_length dimension:
|
| 286 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
| 287 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 288 |
+
past_key = past_key.permute(0, 2, 1)
|
| 289 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
| 290 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 291 |
|
| 292 |
_, kv_length, _ = key_layer.shape
|
| 293 |
|
| 294 |
if use_cache is True:
|
| 295 |
+
key_layer_permute = key_layer.permute(0, 2, 1)
|
| 296 |
+
present = (key_layer_permute, value_layer)
|
| 297 |
else:
|
| 298 |
present = None
|
| 299 |
+
|
| 300 |
if alibi is None:
|
| 301 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 302 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 303 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 304 |
+
|
| 305 |
+
if attention_mask is not None :
|
| 306 |
+
attn_output = F.scaled_dot_product_attention(
|
| 307 |
+
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
| 308 |
+
)
|
| 309 |
+
else :
|
| 310 |
+
attn_output = F.scaled_dot_product_attention(
|
| 311 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
| 312 |
+
)
|
| 313 |
|
| 314 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 315 |
x = x.permute(0, 2, 1, 3)
|
|
|
|
| 406 |
head_mask: Optional[torch.Tensor] = None,
|
| 407 |
use_cache: bool = False,
|
| 408 |
output_attentions: bool = False,
|
| 409 |
+
layer_number = None
|
| 410 |
):
|
| 411 |
|
| 412 |
ln_attn = self.ln_attn(hidden_states)
|
|
|
|
| 423 |
head_mask=head_mask,
|
| 424 |
use_cache=use_cache,
|
| 425 |
output_attentions=output_attentions,
|
| 426 |
+
layer_number=layer_number
|
| 427 |
)
|
| 428 |
|
| 429 |
attention_output = attn_outputs[0]
|
|
|
|
| 551 |
device = attention_mask.device
|
| 552 |
_, src_length = input_shape
|
| 553 |
|
| 554 |
+
# if src_length > 1:
|
| 555 |
+
combined_attention_mask = _make_causal_mask(
|
| 556 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
| 557 |
+
)
|
| 558 |
|
| 559 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
| 560 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
|
|
| 674 |
head_mask[i],
|
| 675 |
)
|
| 676 |
else:
|
| 677 |
+
if i==0 :
|
| 678 |
+
outputs = block(
|
| 679 |
+
hidden_states,
|
| 680 |
+
layer_past=layer_past,
|
| 681 |
+
attention_mask=causal_mask,
|
| 682 |
+
head_mask=head_mask[i],
|
| 683 |
+
use_cache=use_cache,
|
| 684 |
+
output_attentions=output_attentions,
|
| 685 |
+
alibi=alibi,
|
| 686 |
+
layer_number=0
|
| 687 |
+
)
|
| 688 |
+
else :
|
| 689 |
+
outputs = block(
|
| 690 |
+
hidden_states,
|
| 691 |
+
layer_past=layer_past,
|
| 692 |
+
attention_mask=causal_mask,
|
| 693 |
+
head_mask=head_mask[i],
|
| 694 |
+
use_cache=use_cache,
|
| 695 |
+
output_attentions=output_attentions,
|
| 696 |
+
alibi=alibi,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
|
| 700 |
hidden_states = outputs[0]
|
| 701 |
if use_cache is True:
|
|
|
|
| 746 |
**kwargs,
|
| 747 |
) -> dict:
|
| 748 |
# only last token for input_ids if past is not None
|
| 749 |
+
if kwargs.get("past_key_values", None) :
|
| 750 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 751 |
+
past_key_values = kwargs["past_key_values"]
|
| 752 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
| 753 |
+
# if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
|
| 754 |
+
# past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
|
| 755 |
+
# past_key_values = kwargs["past_key_values"]
|
| 756 |
+
else :
|
| 757 |
+
past_key_values = None
|
| 758 |
|
| 759 |
return {
|
| 760 |
"input_ids": input_ids,
|
| 761 |
+
"past_key_values": past_key_values,
|
| 762 |
"use_cache": kwargs.get("use_cache"),
|
| 763 |
"attention_mask": attention_mask,
|
| 764 |
}
|