Files changed (1) hide show
  1. modeling_mimo.py +6 -6
modeling_mimo.py CHANGED
@@ -27,10 +27,10 @@ class MiMoMTPLayers(nn.Module):
27
  hidden_states,
28
  attention_mask,
29
  position_ids,
30
- past_key_values: Optional[Cache]=None,
31
  output_attentions: Optional[bool]=False,
32
  use_cache: Optional[bool]=False,
33
- position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
34
  cache_position=None,
35
  **kwargs):
36
  input_embeds = self.token_layernorm(input_embeds)
@@ -38,15 +38,15 @@ class MiMoMTPLayers(nn.Module):
38
  hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
39
  residual = hidden_states
40
  hidden_states = self.input_layernorm(hidden_states)
41
- hidden_states, _ = self.self_attn(hidden_states,
42
  attention_mask=attention_mask,
43
  position_ids=position_ids,
44
- past_key_values=past_key_values,
45
  output_attentions=output_attentions,
46
  use_cache=use_cache,
47
  cache_position=cache_position,
48
- position_embedding=position_embedding,
49
- **kwargs)
50
  hidden_states = residual + hidden_states
51
  residual = hidden_states
52
  hidden_states = self.post_attention_layernorm(hidden_states)
 
27
  hidden_states,
28
  attention_mask,
29
  position_ids,
30
+ past_key_value: Optional[Cache]=None,
31
  output_attentions: Optional[bool]=False,
32
  use_cache: Optional[bool]=False,
33
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
34
  cache_position=None,
35
  **kwargs):
36
  input_embeds = self.token_layernorm(input_embeds)
 
38
  hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
39
  residual = hidden_states
40
  hidden_states = self.input_layernorm(hidden_states)
41
+ hidden_states = self.self_attn(hidden_states,
42
  attention_mask=attention_mask,
43
  position_ids=position_ids,
44
+ past_key_value=past_key_value,
45
  output_attentions=output_attentions,
46
  use_cache=use_cache,
47
  cache_position=cache_position,
48
+ position_embeddings=position_embeddings,
49
+ **kwargs)[0]
50
  hidden_states = residual + hidden_states
51
  residual = hidden_states
52
  hidden_states = self.post_attention_layernorm(hidden_states)