from typing import List, Dict from tqdm.auto import tqdm from transformers import AutoModelForMaskedLM, AutoTokenizer from torch.utils.data import DataLoader from torch.nn import functional as F import torch class SpladeEncoder: batch_size = 4 def __init__(self): model_id = "naver/splade-v3" self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForMaskedLM.from_pretrained(model_id) self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()} @torch.no_grad() def forward(self, texts: List[str]): vectors = [] for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"): tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True) output = self.model(**tokens) vec = torch.max( torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1), dim=1 )[0].squeeze() vectors.append(vec) return torch.vstack(vectors) def query_reranking(self, query: str, documents: List[str]): vec = self.forward([query, *documents]) xQ = F.normalize(vec[:1], dim=-1, p=2.) xD = F.normalize(vec[1:], dim=-1, p=2.) return (xQ * xD).sum(dim=-1).cpu().tolist() def token_expand(self, query: str) -> Dict[str, float]: vec = self.forward([query]).squeeze() cols = vec.nonzero().squeeze().cpu().tolist() weights = vec[cols].cpu().tolist() sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0} return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))