FlameF0X commited on
Commit
bb3c416
·
verified ·
1 Parent(s): 0b018af

Update modeling_i3.py

Browse files
Files changed (1) hide show
  1. modeling_i3.py +20 -12
modeling_i3.py CHANGED
@@ -1,6 +1,6 @@
1
  # modeling_i3.py
2
- import torch
3
  from transformers import PreTrainedModel
 
4
  from configuration_i3 import I3Config
5
  from i3_architecture import i3Model # your actual i3 implementation
6
 
@@ -18,18 +18,26 @@ class I3ForCausalLM(PreTrainedModel):
18
  rank=config.rank,
19
  d_state=config.d_state,
20
  )
 
21
  self.post_init()
22
 
23
- def forward(self, input_ids, labels=None):
24
- logits, loss = self.model(input_ids, labels)
25
- return {"loss": loss, "logits": logits}
 
 
 
 
 
 
 
 
 
26
 
27
- @torch.no_grad()
28
- def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
29
- return self.model.generate(input_ids, max_new_tokens, temperature, top_k)
30
-
31
- # AutoClass registration (optional but recommended)
32
- from transformers import AutoConfig, AutoModelForCausalLM
33
 
34
- AutoConfig.register("i3", I3Config)
35
- AutoModelForCausalLM.register(I3Config, I3ForCausalLM)
 
1
  # modeling_i3.py
 
2
  from transformers import PreTrainedModel
3
+ from transformers.modeling_outputs import CausalLMOutputWithPast
4
  from configuration_i3 import I3Config
5
  from i3_architecture import i3Model # your actual i3 implementation
6
 
 
18
  rank=config.rank,
19
  d_state=config.d_state,
20
  )
21
+ self.lm_head = torch.nn.Linear(config.d_model, config.vocab_size, bias=False)
22
  self.post_init()
23
 
24
+ def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
25
+ outputs = self.model(input_ids)
26
+ logits = self.lm_head(outputs)
27
+
28
+ loss = None
29
+ if labels is not None:
30
+ # Shift so that tokens < n predict n
31
+ shift_logits = logits[..., :-1, :].contiguous()
32
+ shift_labels = labels[..., 1:].contiguous()
33
+ loss_fct = torch.nn.CrossEntropyLoss()
34
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
35
+ shift_labels.view(-1))
36
 
37
+ return CausalLMOutputWithPast(
38
+ loss=loss,
39
+ logits=logits,
40
+ )
 
 
41
 
42
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
43
+ return {"input_ids": input_ids}