Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from transformers import BertTokenizer | |
from evo_model import EvoTransformerForClassification | |
from openai import OpenAI | |
# === Load EvoTransformer === | |
model = EvoTransformerForClassification.from_pretrained("trained_model") | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# === Tokenizer === | |
tokenizer = BertTokenizer.from_pretrained("trained_model") | |
# === GPT-3.5 Client === | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
def query_gpt35(prompt): | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=60, | |
temperature=0.3, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"[GPT-3.5 Error] {str(e)}" | |
def get_logits(output): | |
if isinstance(output, tuple): | |
return output[1] # (loss, logits) | |
elif hasattr(output, 'logits'): | |
return output.logits | |
else: | |
return output # raw logits | |
def generate_response(goal, option1, option2): | |
try: | |
# Format input | |
text1 = f"{goal} [SEP] {option1}" | |
text2 = f"{goal} [SEP] {option2}" | |
enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128) | |
enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128) | |
# Remove token_type_ids to avoid crash in EvoTransformer | |
enc1.pop("token_type_ids", None) | |
enc2.pop("token_type_ids", None) | |
# Move tensors to device | |
for k in enc1: | |
enc1[k] = enc1[k].to(device) | |
enc2[k] = enc2[k].to(device) | |
with torch.no_grad(): | |
out1 = model(**enc1) | |
out2 = model(**enc2) | |
logits1 = get_logits(out1) | |
logits2 = get_logits(out2) | |
if logits1.shape[-1] < 2 or logits2.shape[-1] < 2: | |
raise ValueError("Model did not return 2-class logits.") | |
score1 = logits1[0][1].item() | |
score2 = logits2[0][1].item() | |
evo_result = option1 if score1 > score2 else option2 | |
except Exception as e: | |
evo_result = f"[Evo Error] {str(e)}" | |
# GPT-3.5 comparison | |
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?" | |
gpt_result = query_gpt35(prompt) | |
return { | |
"evo_suggestion": evo_result, | |
"gpt_suggestion": gpt_result | |
} | |