|
from unsloth import FastLanguageModel |
|
import torch |
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name="axondendriteplus/context-relevance-classifier", |
|
max_seq_length=2048, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
|
|
|
|
FastLanguageModel.for_inference(model) |
|
|
|
def classify_answer(question, answer, context): |
|
"""Classify if answer is generated from context""" |
|
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not. |
|
|
|
<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
Question: {question} |
|
|
|
Answer: {answer} |
|
|
|
Context: {context} |
|
|
|
Was this answer generated from the given context? Respond with YES or NO only. |
|
<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
""" |
|
|
|
inputs = tokenizer([prompt], return_tensors="pt").to("cuda") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=5, |
|
use_cache=True, |
|
do_sample=False, |
|
repetition_penalty=1.1, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print(f"Response: {response}") |
|
prediction = response.split("assistant")[-1].strip() |
|
print(f"Prediction: {prediction}") |
|
|
|
return "YES" in prediction.upper() |
|
|
|
|
|
question = "What is the legal definition of contract?" |
|
answer = "A contract is a legally binding agreement between two parties." |
|
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations." |
|
|
|
result = classify_answer(question, answer, context) |
|
print(f"result : {result}") |