File size: 2,458 Bytes
1a4d7a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel
from init_model import get_tokenizer
from model import EvoTransformerArabic

tokenizer = get_tokenizer()
bert = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")

class ArabicTaskDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line.strip()))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text1 = item['question'] + " " + item['option1']
        text2 = item['question'] + " " + item['option2']

        encoded1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        encoded2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')

        return encoded1, encoded2, torch.tensor(item['label'])

def train():
    dataset = ArabicTaskDataset('arabic_tasks.jsonl')
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    model = EvoTransformerArabic()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for epoch in range(3):
        total_loss = 0
        for batch in dataloader:
            encoded1, encoded2, label = batch
            with torch.no_grad():
                out1 = bert(**{k: v.squeeze(0) for k, v in encoded1.items()}).last_hidden_state[:, 0, :]
                out2 = bert(**{k: v.squeeze(0) for k, v in encoded2.items()}).last_hidden_state[:, 0, :]

            input_repr = torch.stack([out1, out2], dim=1)  # [batch, 2, dim]
            pred_scores = model(input_repr)  # Custom logic may be needed here

            # Simpler version: score both, pick argmax
            logits1 = model(out1)
            logits2 = model(out2)
            preds = torch.stack([logits1[:,1], logits2[:,1]], dim=1)  # Confidence scores
            loss = criterion(preds, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")

    torch.save(model.state_dict(), "trained_model.pt")
    print("✅ Model saved as trained_model.pt")

if __name__ == "__main__":
    train()