Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoModel | |
from init_model import get_tokenizer | |
from model import EvoTransformerArabic | |
tokenizer = get_tokenizer() | |
bert = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2") | |
class ArabicTaskDataset(Dataset): | |
def __init__(self, file_path): | |
self.data = [] | |
with open(file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
self.data.append(json.loads(line.strip())) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
text1 = item['question'] + " " + item['option1'] | |
text2 = item['question'] + " " + item['option2'] | |
encoded1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
encoded2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
return encoded1, encoded2, torch.tensor(item['label']) | |
def train(): | |
dataset = ArabicTaskDataset('arabic_tasks.jsonl') | |
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | |
model = EvoTransformerArabic() | |
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) | |
criterion = torch.nn.CrossEntropyLoss() | |
model.train() | |
for epoch in range(3): | |
total_loss = 0 | |
for batch in dataloader: | |
encoded1, encoded2, label = batch | |
with torch.no_grad(): | |
out1 = bert(**{k: v.squeeze(0) for k, v in encoded1.items()}).last_hidden_state[:, 0, :] | |
out2 = bert(**{k: v.squeeze(0) for k, v in encoded2.items()}).last_hidden_state[:, 0, :] | |
input_repr = torch.stack([out1, out2], dim=1) # [batch, 2, dim] | |
pred_scores = model(input_repr) # Custom logic may be needed here | |
# Simpler version: score both, pick argmax | |
logits1 = model(out1) | |
logits2 = model(out2) | |
preds = torch.stack([logits1[:,1], logits2[:,1]], dim=1) # Confidence scores | |
loss = criterion(preds, label) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}") | |
torch.save(model.state_dict(), "trained_model.pt") | |
print("✅ Model saved as trained_model.pt") | |
if __name__ == "__main__": | |
train() | |