File size: 5,927 Bytes
61f0110
 
 
834fbf3
 
 
 
 
 
 
 
61f0110
 
834fbf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985707d
834fbf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
---
base_model: unsloth/llama-3.2-1b-instruct-bnb-4bit
library_name: peft
license: mit
datasets:
- axondendriteplus/context-relevance-classifier-dataset
language:
- en
pipeline_tag: text-classification
tags:
- legal
---

# llama-3.2-1b-context-relevance-classifier

This is a supervised fine-tuned version of `unsloth/llama-3.2-1b-instruct-bnb-4bit`, trained to classify whether an answer is derived from a given context based on a legal question.

- **Task**: Binary classification
- **Input**: question, answer, context  
- **Output**: YES (answer is from context) or NO (answer is not from context)
- **Domain**: Legal Q&A context relevance detection

## Dataset

Fine-tuned on: [axondendriteplus/context-relevance-classifier-dataset](https://huggingface.co/datasets/axondendriteplus/context-relevance-classifier-dataset)

Each sample contains:
```json
{
 "question": "What is the legal definition of contract?",
 "answer": "A contract is a legally binding agreement between two parties.",
 "context": "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations.",
 "label": 1
}
```

## Training Configuration

- **Base model**: `unsloth/llama-3.2-1b-instruct-bnb-4bit`  
- **Training method**: **LoRA fine-tuning using Unsloth**  
- **Sequence length**: **2048**  
- **Epochs**: **3**  
- **Batch size**: **2** (gradient accumulation: **4**)  
- **Optimizer**: **AdamW (8-bit)**  
- **Learning rate**: **2e-4**  
- **Weight decay**: **0.01**  
- **Warmup steps**: **50**  
- **LoRA config**: **r=16**, **alpha=16**, **dropout=0**  
- **Target modules**: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`

## Prompt Format
``` json
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.

<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}

Answer: {answer}

Context: {context}

Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{YES/NO}
```

### Installation
```json
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes
```

### Usage
Basic Inference - You can also use inference_SFT.py in files!

```python
from unsloth import FastLanguageModel
import torch

# Load the fine-tuned model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="axondendriteplus/llama-3.2-1b-context-relevance-classifier",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# Enable inference mode
FastLanguageModel.for_inference(model)

def classify_answer(question, answer, context):
    """Classify if answer is generated from context"""
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.

<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}

Answer: {answer}

Context: {context}

Was this answer generated from the given context? Respond with YES or NO only.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
    
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=5,
            use_cache=True,
            do_sample=False,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    prediction = response.split("assistant")[-1].strip()
    
    return "YES" in prediction.upper()

# Example usage
question = "What is the legal definition of contract?"
answer = "A contract is a legally binding agreement between two parties."
context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations."

result = classify_answer(question, answer, context)
print(f"Classification result: {'Relevant' if result else 'Not Relevant'}")
```

### Usage Examples
Example 1: Relevant Answer
```python
question = "What are the elements of a valid contract?"
answer = "A valid contract requires offer, acceptance, and consideration."
context = "For a contract to be legally binding, it must contain three essential elements: an offer, acceptance of that offer, and consideration."
# Expected output: YES (Relevant)
```
Example 2: Irrelevant Answer
```python
question = "What is negligence in tort law?"
answer = "A contract is a legally binding agreement."
context = "Negligence is the failure to exercise reasonable care that results in harm to another person."
# Expected output: NO (Not Relevant)
```
Example 3: Partially Relevant
```python
question = "What is the statute of limitations?"
answer = "It's a time limit for filing lawsuits, typically 2-3 years."
context = "The statute of limitations is a law that sets the maximum time after an event within which legal proceedings may be initiated."
# Expected output: YES (Relevant)
```

The model is trained to distinguish between:

Positive examples: Answers that are derived from or supported by the given context
Negative examples: Answers that are not supported by the context (generated using GPT-4o-mini or context mismatching)

### Training dataset composition:

- **Total examples**: ~3.2K
- **Positive examples**: ~1.6K (answers generated from context)
- **Negative examples**: ~1.6K (wrong answers or mismatched contexts)