# modeling_i3.py from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from configuration_i3 import I3Config from i3_architecture import i3Model # your actual i3 implementation class I3ForCausalLM(PreTrainedModel): config_class = I3Config def __init__(self, config): super().__init__(config) self.model = i3Model( vocab_size=config.vocab_size, d_model=config.d_model, n_layers=config.n_layers, n_heads=config.n_heads, max_seq_len=config.max_seq_len, rank=config.rank, d_state=config.d_state, ) self.lm_head = torch.nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def forward(self, input_ids, labels=None, attention_mask=None, **kwargs): outputs = self.model(input_ids) logits = self.lm_head(outputs) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithPast( loss=loss, logits=logits, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}