File size: 2,525 Bytes
c3ba93a
 
a457e2e
b530936
1c3ac0d
4e96bf5
896bcee
b530936
4e96bf5
896bcee
 
4e96bf5
896bcee
cb92224
a457e2e
896bcee
 
854864a
a457e2e
 
1c3ac0d
a457e2e
 
cb92224
a457e2e
 
3bfaa31
a457e2e
c3ba93a
854864a
cb92224
 
 
 
 
 
 
 
a457e2e
896bcee
cb92224
 
 
e33deda
cb92224
 
e33deda
e0e3bb1
 
 
 
 
896bcee
 
 
981d63b
896bcee
 
 
c3ba93a
cb92224
 
 
 
 
e33deda
896bcee
 
 
e33deda
896bcee
 
a457e2e
896bcee
c3ba93a
a457e2e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
from openai import OpenAI

# === Load EvoTransformer ===
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# === Tokenizer ===
tokenizer = BertTokenizer.from_pretrained("trained_model")

# === GPT-3.5 Client ===
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def query_gpt35(prompt):
    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=60,
            temperature=0.3,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"[GPT-3.5 Error] {str(e)}"

def get_logits(output):
    if isinstance(output, tuple):
        return output[1]  # (loss, logits)
    elif hasattr(output, 'logits'):
        return output.logits
    else:
        return output  # raw logits

def generate_response(goal, option1, option2):
    try:
        # Format input
        text1 = f"{goal} [SEP] {option1}"
        text2 = f"{goal} [SEP] {option2}"

        enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
        enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)

        # Remove token_type_ids to avoid crash in EvoTransformer
        enc1.pop("token_type_ids", None)
        enc2.pop("token_type_ids", None)

        # Move tensors to device
        for k in enc1:
            enc1[k] = enc1[k].to(device)
            enc2[k] = enc2[k].to(device)

        with torch.no_grad():
            out1 = model(**enc1)
            out2 = model(**enc2)

        logits1 = get_logits(out1)
        logits2 = get_logits(out2)

        if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
            raise ValueError("Model did not return 2-class logits.")

        score1 = logits1[0][1].item()
        score2 = logits2[0][1].item()
        evo_result = option1 if score1 > score2 else option2

    except Exception as e:
        evo_result = f"[Evo Error] {str(e)}"

    # GPT-3.5 comparison
    prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?"
    gpt_result = query_gpt35(prompt)

    return {
        "evo_suggestion": evo_result,
        "gpt_suggestion": gpt_result
    }