Update modeling_mimo.py
#7
by
chengfeng17
- opened
- modeling_mimo.py +6 -6
modeling_mimo.py
CHANGED
@@ -27,10 +27,10 @@ class MiMoMTPLayers(nn.Module):
|
|
27 |
hidden_states,
|
28 |
attention_mask,
|
29 |
position_ids,
|
30 |
-
|
31 |
output_attentions: Optional[bool]=False,
|
32 |
use_cache: Optional[bool]=False,
|
33 |
-
|
34 |
cache_position=None,
|
35 |
**kwargs):
|
36 |
input_embeds = self.token_layernorm(input_embeds)
|
@@ -38,15 +38,15 @@ class MiMoMTPLayers(nn.Module):
|
|
38 |
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
|
39 |
residual = hidden_states
|
40 |
hidden_states = self.input_layernorm(hidden_states)
|
41 |
-
hidden_states
|
42 |
attention_mask=attention_mask,
|
43 |
position_ids=position_ids,
|
44 |
-
|
45 |
output_attentions=output_attentions,
|
46 |
use_cache=use_cache,
|
47 |
cache_position=cache_position,
|
48 |
-
|
49 |
-
**kwargs)
|
50 |
hidden_states = residual + hidden_states
|
51 |
residual = hidden_states
|
52 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
27 |
hidden_states,
|
28 |
attention_mask,
|
29 |
position_ids,
|
30 |
+
past_key_value: Optional[Cache]=None,
|
31 |
output_attentions: Optional[bool]=False,
|
32 |
use_cache: Optional[bool]=False,
|
33 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
34 |
cache_position=None,
|
35 |
**kwargs):
|
36 |
input_embeds = self.token_layernorm(input_embeds)
|
|
|
38 |
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
|
39 |
residual = hidden_states
|
40 |
hidden_states = self.input_layernorm(hidden_states)
|
41 |
+
hidden_states = self.self_attn(hidden_states,
|
42 |
attention_mask=attention_mask,
|
43 |
position_ids=position_ids,
|
44 |
+
past_key_value=past_key_value,
|
45 |
output_attentions=output_attentions,
|
46 |
use_cache=use_cache,
|
47 |
cache_position=cache_position,
|
48 |
+
position_embeddings=position_embeddings,
|
49 |
+
**kwargs)[0]
|
50 |
hidden_states = residual + hidden_states
|
51 |
residual = hidden_states
|
52 |
hidden_states = self.post_attention_layernorm(hidden_states)
|