axondendriteplus's picture
Update README.md
985707d verified
---
base_model: unsloth/llama-3.2-1b-instruct-bnb-4bit
library_name: peft
license: mit
datasets:
- axondendriteplus/context-relevance-classifier-dataset
language:
- en
pipeline_tag: text-classification
tags:
- legal
---
# llama-3.2-1b-context-relevance-classifier
This is a supervised fine-tuned version of `unsloth/llama-3.2-1b-instruct-bnb-4bit`, trained to classify whether an answer is derived from a given context based on a legal question.
- **Task**: Binary classification
- **Input**: question, answer, context
- **Output**: YES (answer is from context) or NO (answer is not from context)
- **Domain**: Legal Q&A context relevance detection
## Dataset
Fine-tuned on: [axondendriteplus/context-relevance-classifier-dataset](https://huggingface.co/datasets/axondendriteplus/context-relevance-classifier-dataset)
Each sample contains:
```json
{
"question": "What is the legal definition of contract?",
"answer": "A contract is a legally binding agreement between two parties.",
"context": "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations.",
"label": 1
}
```
## Training Configuration
- **Base model**: `unsloth/llama-3.2-1b-instruct-bnb-4bit`
- **Training method**: **LoRA fine-tuning using Unsloth**
- **Sequence length**: **2048**
- **Epochs**: **3**
- **Batch size**: **2** (gradient accumulation: **4**)
- **Optimizer**: **AdamW (8-bit)**
- **Learning rate**: **2e-4**
- **Weight decay**: **0.01**
- **Warmup steps**: **50**
- **LoRA config**: **r=16**, **alpha=16**, **dropout=0**
- **Target modules**: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`
## Prompt Format
``` json
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Answer: {answer}
Context: {context}
Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{YES/NO}
```
### Installation
```json
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes
```
### Usage
Basic Inference - You can also use inference_SFT.py in files!
```python
from unsloth import FastLanguageModel
import torch
# Load the fine-tuned model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="axondendriteplus/llama-3.2-1b-context-relevance-classifier",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
# Enable inference mode
FastLanguageModel.for_inference(model)
def classify_answer(question, answer, context):
"""Classify if answer is generated from context"""
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Answer: {answer}
Context: {context}
Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=5,
use_cache=True,
do_sample=False,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
prediction = response.split("assistant")[-1].strip()
return "YES" in prediction.upper()
# Example usage
question = "What is the legal definition of contract?"
answer = "A contract is a legally binding agreement between two parties."
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations."
result = classify_answer(question, answer, context)
print(f"Classification result: {'Relevant' if result else 'Not Relevant'}")
```
### Usage Examples
Example 1: Relevant Answer
```python
question = "What are the elements of a valid contract?"
answer = "A valid contract requires offer, acceptance, and consideration."
context = "For a contract to be legally binding, it must contain three essential elements: an offer, acceptance of that offer, and consideration."
# Expected output: YES (Relevant)
```
Example 2: Irrelevant Answer
```python
question = "What is negligence in tort law?"
answer = "A contract is a legally binding agreement."
context = "Negligence is the failure to exercise reasonable care that results in harm to another person."
# Expected output: NO (Not Relevant)
```
Example 3: Partially Relevant
```python
question = "What is the statute of limitations?"
answer = "It's a time limit for filing lawsuits, typically 2-3 years."
context = "The statute of limitations is a law that sets the maximum time after an event within which legal proceedings may be initiated."
# Expected output: YES (Relevant)
```
The model is trained to distinguish between:
Positive examples: Answers that are derived from or supported by the given context
Negative examples: Answers that are not supported by the context (generated using GPT-4o-mini or context mismatching)
### Training dataset composition:
- **Total examples**: ~3.2K
- **Positive examples**: ~1.6K (answers generated from context)
- **Negative examples**: ~1.6K (wrong answers or mismatched contexts)