Spaces:
Sleeping
Sleeping
File size: 2,525 Bytes
c3ba93a a457e2e b530936 1c3ac0d 4e96bf5 896bcee b530936 4e96bf5 896bcee 4e96bf5 896bcee cb92224 a457e2e 896bcee 854864a a457e2e 1c3ac0d a457e2e cb92224 a457e2e 3bfaa31 a457e2e c3ba93a 854864a cb92224 a457e2e 896bcee cb92224 e33deda cb92224 e33deda e0e3bb1 896bcee 981d63b 896bcee c3ba93a cb92224 e33deda 896bcee e33deda 896bcee a457e2e 896bcee c3ba93a a457e2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import torch
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
from openai import OpenAI
# === Load EvoTransformer ===
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# === Tokenizer ===
tokenizer = BertTokenizer.from_pretrained("trained_model")
# === GPT-3.5 Client ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def query_gpt35(prompt):
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=60,
temperature=0.3,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[GPT-3.5 Error] {str(e)}"
def get_logits(output):
if isinstance(output, tuple):
return output[1] # (loss, logits)
elif hasattr(output, 'logits'):
return output.logits
else:
return output # raw logits
def generate_response(goal, option1, option2):
try:
# Format input
text1 = f"{goal} [SEP] {option1}"
text2 = f"{goal} [SEP] {option2}"
enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
# Remove token_type_ids to avoid crash in EvoTransformer
enc1.pop("token_type_ids", None)
enc2.pop("token_type_ids", None)
# Move tensors to device
for k in enc1:
enc1[k] = enc1[k].to(device)
enc2[k] = enc2[k].to(device)
with torch.no_grad():
out1 = model(**enc1)
out2 = model(**enc2)
logits1 = get_logits(out1)
logits2 = get_logits(out2)
if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
raise ValueError("Model did not return 2-class logits.")
score1 = logits1[0][1].item()
score2 = logits2[0][1].item()
evo_result = option1 if score1 > score2 else option2
except Exception as e:
evo_result = f"[Evo Error] {str(e)}"
# GPT-3.5 comparison
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?"
gpt_result = query_gpt35(prompt)
return {
"evo_suggestion": evo_result,
"gpt_suggestion": gpt_result
}
|