|
""" |
|
classifier.py |
|
------------- |
|
|
|
This module defines utilities for classifying the relationship between a |
|
claim and candidate sentences. It tries to use a transformers NLI |
|
cross-encoder; if that fails, it falls back to a lightweight heuristic. |
|
|
|
Labels: |
|
- "support" (entailment) |
|
- "contradict" (contradiction) |
|
- "neutral" |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
from typing import Iterable, List |
|
|
|
import numpy as np |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
_nli_model = None |
|
_nli_tokenizer = None |
|
_use_transformers = False |
|
|
|
|
|
def _load_nli_model(model_name: str = "cross-encoder/nli-roberta-base"): |
|
"""Lazy-load the NLI model and tokenizer; set fallback flag on failure.""" |
|
global _nli_model, _nli_tokenizer, _use_transformers |
|
if _nli_model is not None and _nli_tokenizer is not None and _use_transformers: |
|
return |
|
try: |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
_nli_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
_nli_model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
_nli_model.eval() |
|
_use_transformers = True |
|
except Exception as exc: |
|
logger.warning( |
|
"Failed to load NLI model '%s'. Falling back to heuristic: %s", |
|
model_name, |
|
exc, |
|
) |
|
_nli_model = None |
|
_nli_tokenizer = None |
|
_use_transformers = False |
|
|
|
|
|
def _classify_with_nli(claim: str, sentences: List[str], batch_size: int = 16) -> List[str]: |
|
"""Classify sentence relations using the pretrained NLI cross-encoder.""" |
|
assert _nli_model is not None and _nli_tokenizer is not None |
|
import torch |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
_nli_model.to(device) |
|
|
|
|
|
id2label = {0: "contradict", 1: "support", 2: "neutral"} |
|
|
|
labels_out: List[str] = [] |
|
for start in range(0, len(sentences), batch_size): |
|
batch = sentences[start : start + batch_size] |
|
enc = _nli_tokenizer( |
|
[claim] * len(batch), |
|
batch, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
).to(device) |
|
with torch.no_grad(): |
|
logits = _nli_model(**enc).logits.cpu().numpy() |
|
preds = logits.argmax(axis=1) |
|
labels_out.extend([id2label.get(int(p), "neutral") for p in preds]) |
|
return labels_out |
|
|
|
|
|
def _heuristic_classify(claim: str, sentences: List[str]) -> List[str]: |
|
"""Very simple heuristic fallback (lexical overlap + negation).""" |
|
import re |
|
|
|
claim_tokens = set(re.findall(r"\b\w+\b", claim.lower())) |
|
neg = {"not", "no", "never", "none", "cannot", "n't"} |
|
out: List[str] = [] |
|
for s in sentences: |
|
s_tokens = set(re.findall(r"\b\w+\b", s.lower())) |
|
overlap = bool(claim_tokens & s_tokens) |
|
has_neg = any(tok in s_tokens for tok in neg) |
|
if overlap and not has_neg: |
|
out.append("support") |
|
elif overlap and has_neg: |
|
out.append("contradict") |
|
else: |
|
out.append("neutral") |
|
return out |
|
|
|
|
|
def classify(claim: str, sentences: Iterable[str], batch_size: int = 16) -> List[str]: |
|
"""Return a label for each sentence relative to the claim.""" |
|
|
|
global _nli_model, _nli_tokenizer, _use_transformers |
|
|
|
sents = list(sentences) |
|
if not sents: |
|
return [] |
|
|
|
|
|
if _nli_model is None or _nli_tokenizer is None: |
|
_load_nli_model() |
|
|
|
if _use_transformers and _nli_model is not None and _nli_tokenizer is not None: |
|
try: |
|
return _classify_with_nli(claim, sents, batch_size=batch_size) |
|
except Exception as exc: |
|
logger.warning( |
|
"NLI classification failed; switching to heuristic. Error: %s", |
|
exc, |
|
) |
|
|
|
_use_transformers = False |
|
_nli_model = None |
|
_nli_tokenizer = None |
|
|
|
|
|
return _heuristic_classify(claim, sents) |
|
|