File size: 5,927 Bytes
61f0110 834fbf3 61f0110 834fbf3 985707d 834fbf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
---
base_model: unsloth/llama-3.2-1b-instruct-bnb-4bit
library_name: peft
license: mit
datasets:
- axondendriteplus/context-relevance-classifier-dataset
language:
- en
pipeline_tag: text-classification
tags:
- legal
---
# llama-3.2-1b-context-relevance-classifier
This is a supervised fine-tuned version of `unsloth/llama-3.2-1b-instruct-bnb-4bit`, trained to classify whether an answer is derived from a given context based on a legal question.
- **Task**: Binary classification
- **Input**: question, answer, context
- **Output**: YES (answer is from context) or NO (answer is not from context)
- **Domain**: Legal Q&A context relevance detection
## Dataset
Fine-tuned on: [axondendriteplus/context-relevance-classifier-dataset](https://huggingface.co/datasets/axondendriteplus/context-relevance-classifier-dataset)
Each sample contains:
```json
{
"question": "What is the legal definition of contract?",
"answer": "A contract is a legally binding agreement between two parties.",
"context": "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations.",
"label": 1
}
```
## Training Configuration
- **Base model**: `unsloth/llama-3.2-1b-instruct-bnb-4bit`
- **Training method**: **LoRA fine-tuning using Unsloth**
- **Sequence length**: **2048**
- **Epochs**: **3**
- **Batch size**: **2** (gradient accumulation: **4**)
- **Optimizer**: **AdamW (8-bit)**
- **Learning rate**: **2e-4**
- **Weight decay**: **0.01**
- **Warmup steps**: **50**
- **LoRA config**: **r=16**, **alpha=16**, **dropout=0**
- **Target modules**: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`
## Prompt Format
``` json
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Answer: {answer}
Context: {context}
Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{YES/NO}
```
### Installation
```json
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes
```
### Usage
Basic Inference - You can also use inference_SFT.py in files!
```python
from unsloth import FastLanguageModel
import torch
# Load the fine-tuned model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="axondendriteplus/llama-3.2-1b-context-relevance-classifier",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
# Enable inference mode
FastLanguageModel.for_inference(model)
def classify_answer(question, answer, context):
"""Classify if answer is generated from context"""
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Answer: {answer}
Context: {context}
Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=5,
use_cache=True,
do_sample=False,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
prediction = response.split("assistant")[-1].strip()
return "YES" in prediction.upper()
# Example usage
question = "What is the legal definition of contract?"
answer = "A contract is a legally binding agreement between two parties."
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations."
result = classify_answer(question, answer, context)
print(f"Classification result: {'Relevant' if result else 'Not Relevant'}")
```
### Usage Examples
Example 1: Relevant Answer
```python
question = "What are the elements of a valid contract?"
answer = "A valid contract requires offer, acceptance, and consideration."
context = "For a contract to be legally binding, it must contain three essential elements: an offer, acceptance of that offer, and consideration."
# Expected output: YES (Relevant)
```
Example 2: Irrelevant Answer
```python
question = "What is negligence in tort law?"
answer = "A contract is a legally binding agreement."
context = "Negligence is the failure to exercise reasonable care that results in harm to another person."
# Expected output: NO (Not Relevant)
```
Example 3: Partially Relevant
```python
question = "What is the statute of limitations?"
answer = "It's a time limit for filing lawsuits, typically 2-3 years."
context = "The statute of limitations is a law that sets the maximum time after an event within which legal proceedings may be initiated."
# Expected output: YES (Relevant)
```
The model is trained to distinguish between:
Positive examples: Answers that are derived from or supported by the given context
Negative examples: Answers that are not supported by the context (generated using GPT-4o-mini or context mismatching)
### Training dataset composition:
- **Total examples**: ~3.2K
- **Positive examples**: ~1.6K (answers generated from context)
- **Negative examples**: ~1.6K (wrong answers or mismatched contexts) |