|
--- |
|
license: apache-2.0 |
|
base_model: bert-base-uncased |
|
tags: |
|
- text-classification |
|
- multi-label-classification |
|
- bert |
|
- conversational-qa |
|
- educational-ai |
|
language: en |
|
metrics: |
|
- f1 |
|
- precision |
|
- recall |
|
--- |
|
|
|
# Enhanced BERT Multi-Label Classifier for Conversational QA |
|
|
|
## Performance |
|
- **Micro F1**: 0.7779 |
|
- **Macro F1**: 0.7098 |
|
- **Optimal Threshold**: 0.20 (CRITICAL - not 0.5!) |
|
|
|
## Class Performance |
|
| Label | F1 Score | Status | |
|
|-------|----------|---------| |
|
| questioning | 0.933 | Excellent | |
|
| responsive | 0.756 | Good | |
|
| interactive | 0.613 | Good (breakthrough!) | |
|
| collaborative | 0.537 | Acceptable | |
|
|
|
## Usage |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("IIchukissII/enhanced-bert-coqa-multilabel-classifier") |
|
tokenizer = AutoTokenizer.from_pretrained("IIchukissII/enhanced-bert-coqa-multilabel-classifier") |
|
|
|
text = "context [SEP] question [SEP] answer" |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probs = torch.sigmoid(outputs.logits) |
|
|
|
# IMPORTANT: Use threshold 0.20, not 0.5! |
|
LABELS = ["interactive", "responsive", "questioning", "collaborative"] |
|
predictions = {label: int(probs[0, i] >= 0.20) for i, label in enumerate(LABELS)} |
|
``` |
|
|
|
## Architecture |
|
- Enhanced BERT with 3-layer classifier head |
|
- Layer normalization and L2 regularization |
|
- Optimized for multi-label classification |
|
|
|
## Training |
|
- 7 epochs on expanded dataset |
|
- Breakthrough in interactive detection (+15.7% F1) |
|
- Threshold optimization discovery (0.20 optimal) |
|
|