File size: 1,942 Bytes
fd1679b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from unsloth import FastLanguageModel
import torch

# Load the fine-tuned model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="axondendriteplus/context-relevance-classifier",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# Enable inference mode
FastLanguageModel.for_inference(model)

def classify_answer(question, answer, context):
    """Classify if answer is generated from context"""
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.

<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}

Answer: {answer}

Context: {context}

Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
    
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=5,
            use_cache=True,
            do_sample=False,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Response: {response}")
    prediction = response.split("assistant")[-1].strip()
    print(f"Prediction: {prediction}")
    
    return "YES" in prediction.upper()

# Test the model
question = "What is the legal definition of contract?"
answer = "A contract is a legally binding agreement between two parties."
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations."

result = classify_answer(question, answer, context)
print(f"result : {result}")