arabic-gov-copilot / inference.py
HemanM's picture
Update inference.py
67209e5 verified
raw
history blame contribute delete
912 Bytes
import torch
from init_model import get_tokenizer, get_base_model
from model import EvoTransformerArabic
# Load tokenizer and base encoder (Arabic BERT)
tokenizer = get_tokenizer()
bert = get_base_model()
# Load Evo model and weights
model = EvoTransformerArabic()
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
model.eval()
def evo_suggest(question, option1, option2):
inputs = [question + " " + option1, question + " " + option2]
scores = []
for text in inputs:
encoded = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
with torch.no_grad():
outputs = bert(**encoded).last_hidden_state[:, 0, :] # Get [CLS] token
logits = model(outputs)
scores.append(logits[0][1].item()) # Confidence for class 1
return option1 if scores[0] > scores[1] else option2