import os import torch from transformers import BertTokenizer from evo_model import EvoTransformerForClassification from openai import OpenAI # === Load EvoTransformer === model = EvoTransformerForClassification.from_pretrained("trained_model") model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # === Tokenizer === tokenizer = BertTokenizer.from_pretrained("trained_model") # === GPT-3.5 Client === client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) def query_gpt35(prompt): try: response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], max_tokens=60, temperature=0.3, ) return response.choices[0].message.content.strip() except Exception as e: return f"[GPT-3.5 Error] {str(e)}" def get_logits(output): if isinstance(output, tuple): return output[1] # (loss, logits) elif hasattr(output, 'logits'): return output.logits else: return output # raw logits def generate_response(goal, option1, option2): try: # Format input text1 = f"{goal} [SEP] {option1}" text2 = f"{goal} [SEP] {option2}" enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128) enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128) # Remove token_type_ids to avoid crash in EvoTransformer enc1.pop("token_type_ids", None) enc2.pop("token_type_ids", None) # Move tensors to device for k in enc1: enc1[k] = enc1[k].to(device) enc2[k] = enc2[k].to(device) with torch.no_grad(): out1 = model(**enc1) out2 = model(**enc2) logits1 = get_logits(out1) logits2 = get_logits(out2) if logits1.shape[-1] < 2 or logits2.shape[-1] < 2: raise ValueError("Model did not return 2-class logits.") score1 = logits1[0][1].item() score2 = logits2[0][1].item() evo_result = option1 if score1 > score2 else option2 except Exception as e: evo_result = f"[Evo Error] {str(e)}" # GPT-3.5 comparison prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?" gpt_result = query_gpt35(prompt) return { "evo_suggestion": evo_result, "gpt_suggestion": gpt_result }