EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
e0e3bb1 verified
import os
import torch
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
from openai import OpenAI
# === Load EvoTransformer ===
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# === Tokenizer ===
tokenizer = BertTokenizer.from_pretrained("trained_model")
# === GPT-3.5 Client ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def query_gpt35(prompt):
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=60,
temperature=0.3,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[GPT-3.5 Error] {str(e)}"
def get_logits(output):
if isinstance(output, tuple):
return output[1] # (loss, logits)
elif hasattr(output, 'logits'):
return output.logits
else:
return output # raw logits
def generate_response(goal, option1, option2):
try:
# Format input
text1 = f"{goal} [SEP] {option1}"
text2 = f"{goal} [SEP] {option2}"
enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
# Remove token_type_ids to avoid crash in EvoTransformer
enc1.pop("token_type_ids", None)
enc2.pop("token_type_ids", None)
# Move tensors to device
for k in enc1:
enc1[k] = enc1[k].to(device)
enc2[k] = enc2[k].to(device)
with torch.no_grad():
out1 = model(**enc1)
out2 = model(**enc2)
logits1 = get_logits(out1)
logits2 = get_logits(out2)
if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
raise ValueError("Model did not return 2-class logits.")
score1 = logits1[0][1].item()
score2 = logits2[0][1].item()
evo_result = option1 if score1 > score2 else option2
except Exception as e:
evo_result = f"[Evo Error] {str(e)}"
# GPT-3.5 comparison
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?"
gpt_result = query_gpt35(prompt)
return {
"evo_suggestion": evo_result,
"gpt_suggestion": gpt_result
}