|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from configuration_i3 import I3Config |
|
|
from i3_architecture import i3Model |
|
|
|
|
|
class I3ForCausalLM(PreTrainedModel): |
|
|
config_class = I3Config |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = i3Model( |
|
|
vocab_size=config.vocab_size, |
|
|
d_model=config.d_model, |
|
|
n_layers=config.n_layers, |
|
|
n_heads=config.n_heads, |
|
|
max_seq_len=config.max_seq_len, |
|
|
rank=config.rank, |
|
|
d_state=config.d_state, |
|
|
) |
|
|
self.lm_head = torch.nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids, labels=None, attention_mask=None, **kwargs): |
|
|
outputs = self.model(input_ids) |
|
|
logits = self.lm_head(outputs) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1)) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids} |