from tqdm.auto import tqdm from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers.tokenization_utils_base import BatchEncoding from torch.utils.data import DataLoader import torch.nn.functional as F from torch import Tensor import torch class SpladeEncoder: batch_size = 8 model_id = "naver/splade-v3" def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForMaskedLM.from_pretrained(self.model_id) self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()} if torch.cuda.is_available(): self.device = torch.device("cuda") elif torch.mps.is_available(): self.device = torch.device("mps") else: self.device = torch.device("cpu") self.model.to(self.device) @torch.no_grad() def forward(self, inputs: BatchEncoding) -> Tensor: output = self.model(**inputs.to(self.device)) logits: Tensor = output.logits mask: Tensor = inputs.attention_mask vec = (logits.relu() + 1).log() * mask.unsqueeze(dim=-1) return vec.max(dim=1)[0].squeeze() def encode(self, texts: list[str]) -> Tensor: """Forward pass to get dense vectors Parameters ---------- texts : list[str] Returns ------- torch.Tensor Dense vectors """ vectors = [] for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Encoding"): # type: ignore tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True) vec = self.forward(inputs=tokens) vectors.append(vec) return torch.vstack(vectors) def query_reranking(self, query: str, documents: list[str]) -> list[float]: """Cosine similarity re-ranking. Parameters ---------- query : str Retrieval query documents : list[str] Retrieved documents Returns ------- list[float] Cosine values """ vec = self.encode([query, *documents]) xQ = F.normalize(vec[:1], dim=-1, p=2.) xD = F.normalize(vec[1:], dim=-1, p=2.) return (xQ * xD).sum(dim=-1).cpu().tolist() def token_expand(self, query: str) -> dict[str, float]: """Sparse lexical token expansion. Parameters ---------- query : str Retrieval query Returns ------- dict[str, float] """ vec = self.encode([query]).squeeze() cols = vec.nonzero().squeeze().cpu().tolist() weights = vec[cols].cpu().tolist() sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0} return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))