arabic-gov-copilot / train_arabic.py
HemanM's picture
Update train_arabic.py
1a4d7a4 verified
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel
from init_model import get_tokenizer
from model import EvoTransformerArabic
tokenizer = get_tokenizer()
bert = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
class ArabicTaskDataset(Dataset):
def __init__(self, file_path):
self.data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
self.data.append(json.loads(line.strip()))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text1 = item['question'] + " " + item['option1']
text2 = item['question'] + " " + item['option2']
encoded1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
encoded2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
return encoded1, encoded2, torch.tensor(item['label'])
def train():
dataset = ArabicTaskDataset('arabic_tasks.jsonl')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
model = EvoTransformerArabic()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(3):
total_loss = 0
for batch in dataloader:
encoded1, encoded2, label = batch
with torch.no_grad():
out1 = bert(**{k: v.squeeze(0) for k, v in encoded1.items()}).last_hidden_state[:, 0, :]
out2 = bert(**{k: v.squeeze(0) for k, v in encoded2.items()}).last_hidden_state[:, 0, :]
input_repr = torch.stack([out1, out2], dim=1) # [batch, 2, dim]
pred_scores = model(input_repr) # Custom logic may be needed here
# Simpler version: score both, pick argmax
logits1 = model(out1)
logits2 = model(out2)
preds = torch.stack([logits1[:,1], logits2[:,1]], dim=1) # Confidence scores
loss = criterion(preds, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")
torch.save(model.state_dict(), "trained_model.pt")
print("✅ Model saved as trained_model.pt")
if __name__ == "__main__":
train()