|
--- |
|
base_model: unsloth/llama-3.2-1b-instruct-bnb-4bit |
|
library_name: peft |
|
license: mit |
|
datasets: |
|
- axondendriteplus/context-relevance-classifier-dataset |
|
language: |
|
- en |
|
pipeline_tag: text-classification |
|
tags: |
|
- legal |
|
--- |
|
|
|
# llama-3.2-1b-context-relevance-classifier |
|
|
|
This is a supervised fine-tuned version of `unsloth/llama-3.2-1b-instruct-bnb-4bit`, trained to classify whether an answer is derived from a given context based on a legal question. |
|
|
|
- **Task**: Binary classification |
|
- **Input**: question, answer, context |
|
- **Output**: YES (answer is from context) or NO (answer is not from context) |
|
- **Domain**: Legal Q&A context relevance detection |
|
|
|
## Dataset |
|
|
|
Fine-tuned on: [axondendriteplus/context-relevance-classifier-dataset](https://huggingface.co/datasets/axondendriteplus/context-relevance-classifier-dataset) |
|
|
|
Each sample contains: |
|
```json |
|
{ |
|
"question": "What is the legal definition of contract?", |
|
"answer": "A contract is a legally binding agreement between two parties.", |
|
"context": "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations.", |
|
"label": 1 |
|
} |
|
``` |
|
|
|
## Training Configuration |
|
|
|
- **Base model**: `unsloth/llama-3.2-1b-instruct-bnb-4bit` |
|
- **Training method**: **LoRA fine-tuning using Unsloth** |
|
- **Sequence length**: **2048** |
|
- **Epochs**: **3** |
|
- **Batch size**: **2** (gradient accumulation: **4**) |
|
- **Optimizer**: **AdamW (8-bit)** |
|
- **Learning rate**: **2e-4** |
|
- **Weight decay**: **0.01** |
|
- **Warmup steps**: **50** |
|
- **LoRA config**: **r=16**, **alpha=16**, **dropout=0** |
|
- **Target modules**: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]` |
|
|
|
## Prompt Format |
|
``` json |
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not. |
|
|
|
<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
Question: {question} |
|
|
|
Answer: {answer} |
|
|
|
Context: {context} |
|
|
|
Was this answer generated from the given context? Respond with YES or NO only. |
|
<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
{YES/NO} |
|
``` |
|
|
|
### Installation |
|
```json |
|
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" |
|
pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes |
|
``` |
|
|
|
### Usage |
|
Basic Inference - You can also use inference_SFT.py in files! |
|
|
|
```python |
|
from unsloth import FastLanguageModel |
|
import torch |
|
|
|
# Load the fine-tuned model |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name="axondendriteplus/llama-3.2-1b-context-relevance-classifier", |
|
max_seq_length=2048, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
|
|
# Enable inference mode |
|
FastLanguageModel.for_inference(model) |
|
|
|
def classify_answer(question, answer, context): |
|
"""Classify if answer is generated from context""" |
|
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not. |
|
|
|
<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
Question: {question} |
|
|
|
Answer: {answer} |
|
|
|
Context: {context} |
|
|
|
Was this answer generated from the given context? Respond with YES or NO only. |
|
<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
""" |
|
|
|
inputs = tokenizer([prompt], return_tensors="pt").to("cuda") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=5, |
|
use_cache=True, |
|
do_sample=False, |
|
repetition_penalty=1.1, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
prediction = response.split("assistant")[-1].strip() |
|
|
|
return "YES" in prediction.upper() |
|
|
|
# Example usage |
|
question = "What is the legal definition of contract?" |
|
answer = "A contract is a legally binding agreement between two parties." |
|
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations." |
|
|
|
result = classify_answer(question, answer, context) |
|
print(f"Classification result: {'Relevant' if result else 'Not Relevant'}") |
|
``` |
|
|
|
### Usage Examples |
|
Example 1: Relevant Answer |
|
```python |
|
question = "What are the elements of a valid contract?" |
|
answer = "A valid contract requires offer, acceptance, and consideration." |
|
context = "For a contract to be legally binding, it must contain three essential elements: an offer, acceptance of that offer, and consideration." |
|
# Expected output: YES (Relevant) |
|
``` |
|
Example 2: Irrelevant Answer |
|
```python |
|
question = "What is negligence in tort law?" |
|
answer = "A contract is a legally binding agreement." |
|
context = "Negligence is the failure to exercise reasonable care that results in harm to another person." |
|
# Expected output: NO (Not Relevant) |
|
``` |
|
Example 3: Partially Relevant |
|
```python |
|
question = "What is the statute of limitations?" |
|
answer = "It's a time limit for filing lawsuits, typically 2-3 years." |
|
context = "The statute of limitations is a law that sets the maximum time after an event within which legal proceedings may be initiated." |
|
# Expected output: YES (Relevant) |
|
``` |
|
|
|
The model is trained to distinguish between: |
|
|
|
Positive examples: Answers that are derived from or supported by the given context |
|
Negative examples: Answers that are not supported by the context (generated using GPT-4o-mini or context mismatching) |
|
|
|
### Training dataset composition: |
|
|
|
- **Total examples**: ~3.2K |
|
- **Positive examples**: ~1.6K (answers generated from context) |
|
- **Negative examples**: ~1.6K (wrong answers or mismatched contexts) |