axondendriteplus's picture
Upload inference_SFT.py with huggingface_hub
fd1679b verified
from unsloth import FastLanguageModel
import torch
# Load the fine-tuned model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="axondendriteplus/context-relevance-classifier",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
# Enable inference mode
FastLanguageModel.for_inference(model)
def classify_answer(question, answer, context):
"""Classify if answer is generated from context"""
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Answer: {answer}
Context: {context}
Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=5,
use_cache=True,
do_sample=False,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Response: {response}")
prediction = response.split("assistant")[-1].strip()
print(f"Prediction: {prediction}")
return "YES" in prediction.upper()
# Test the model
question = "What is the legal definition of contract?"
answer = "A contract is a legally binding agreement between two parties."
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations."
result = classify_answer(question, answer, context)
print(f"result : {result}")