import json import torch from torch.utils.data import Dataset, DataLoader from transformers import AutoModel from init_model import get_tokenizer from model import EvoTransformerArabic tokenizer = get_tokenizer() bert = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2") class ArabicTaskDataset(Dataset): def __init__(self, file_path): self.data = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: self.data.append(json.loads(line.strip())) def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] text1 = item['question'] + " " + item['option1'] text2 = item['question'] + " " + item['option2'] encoded1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt') encoded2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt') return encoded1, encoded2, torch.tensor(item['label']) def train(): dataset = ArabicTaskDataset('arabic_tasks.jsonl') dataloader = DataLoader(dataset, batch_size=4, shuffle=True) model = EvoTransformerArabic() optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) criterion = torch.nn.CrossEntropyLoss() model.train() for epoch in range(3): total_loss = 0 for batch in dataloader: encoded1, encoded2, label = batch with torch.no_grad(): out1 = bert(**{k: v.squeeze(0) for k, v in encoded1.items()}).last_hidden_state[:, 0, :] out2 = bert(**{k: v.squeeze(0) for k, v in encoded2.items()}).last_hidden_state[:, 0, :] input_repr = torch.stack([out1, out2], dim=1) # [batch, 2, dim] pred_scores = model(input_repr) # Custom logic may be needed here # Simpler version: score both, pick argmax logits1 = model(out1) logits2 = model(out2) preds = torch.stack([logits1[:,1], logits2[:,1]], dim=1) # Confidence scores loss = criterion(preds, label) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}") torch.save(model.state_dict(), "trained_model.pt") print("✅ Model saved as trained_model.pt") if __name__ == "__main__": train()