TRUTHLENS / src /classifier.py
akaafridi's picture
Update src/classifier.py
9381a8f verified
"""
classifier.py
-------------
This module defines utilities for classifying the relationship between a
claim and candidate sentences. It tries to use a transformers NLI
cross-encoder; if that fails, it falls back to a lightweight heuristic.
Labels:
- "support" (entailment)
- "contradict" (contradiction)
- "neutral"
"""
from __future__ import annotations
import logging
from typing import Iterable, List
import numpy as np
logger = logging.getLogger(__name__)
_nli_model = None # type: ignore
_nli_tokenizer = None # type: ignore
_use_transformers = False # whether NLI model is successfully loaded
def _load_nli_model(model_name: str = "cross-encoder/nli-roberta-base"):
"""Lazy-load the NLI model and tokenizer; set fallback flag on failure."""
global _nli_model, _nli_tokenizer, _use_transformers
if _nli_model is not None and _nli_tokenizer is not None and _use_transformers:
return
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore
_nli_tokenizer = AutoTokenizer.from_pretrained(model_name)
_nli_model = AutoModelForSequenceClassification.from_pretrained(model_name)
_nli_model.eval()
_use_transformers = True
except Exception as exc:
logger.warning(
"Failed to load NLI model '%s'. Falling back to heuristic: %s",
model_name,
exc,
)
_nli_model = None
_nli_tokenizer = None
_use_transformers = False
def _classify_with_nli(claim: str, sentences: List[str], batch_size: int = 16) -> List[str]:
"""Classify sentence relations using the pretrained NLI cross-encoder."""
assert _nli_model is not None and _nli_tokenizer is not None
import torch # type: ignore
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_nli_model.to(device)
# Order for nli-roberta-base: [contradiction, entailment, neutral]
id2label = {0: "contradict", 1: "support", 2: "neutral"}
labels_out: List[str] = []
for start in range(0, len(sentences), batch_size):
batch = sentences[start : start + batch_size]
enc = _nli_tokenizer(
[claim] * len(batch),
batch,
return_tensors="pt",
truncation=True,
padding=True,
).to(device)
with torch.no_grad():
logits = _nli_model(**enc).logits.cpu().numpy()
preds = logits.argmax(axis=1)
labels_out.extend([id2label.get(int(p), "neutral") for p in preds])
return labels_out
def _heuristic_classify(claim: str, sentences: List[str]) -> List[str]:
"""Very simple heuristic fallback (lexical overlap + negation)."""
import re
claim_tokens = set(re.findall(r"\b\w+\b", claim.lower()))
neg = {"not", "no", "never", "none", "cannot", "n't"}
out: List[str] = []
for s in sentences:
s_tokens = set(re.findall(r"\b\w+\b", s.lower()))
overlap = bool(claim_tokens & s_tokens)
has_neg = any(tok in s_tokens for tok in neg)
if overlap and not has_neg:
out.append("support")
elif overlap and has_neg:
out.append("contradict")
else:
out.append("neutral")
return out
def classify(claim: str, sentences: Iterable[str], batch_size: int = 16) -> List[str]:
"""Return a label for each sentence relative to the claim."""
# IMPORTANT: declare globals first since we modify them on failure
global _nli_model, _nli_tokenizer, _use_transformers
sents = list(sentences)
if not sents:
return []
# Try to ensure model is loaded
if _nli_model is None or _nli_tokenizer is None:
_load_nli_model()
if _use_transformers and _nli_model is not None and _nli_tokenizer is not None:
try:
return _classify_with_nli(claim, sents, batch_size=batch_size)
except Exception as exc:
logger.warning(
"NLI classification failed; switching to heuristic. Error: %s",
exc,
)
# Mark as unusable so subsequent calls go straight to heuristic
_use_transformers = False
_nli_model = None
_nli_tokenizer = None
# Fallback
return _heuristic_classify(claim, sents)