Spaces:
Running
Running
from tqdm.auto import tqdm | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
from transformers.tokenization_utils_base import BatchEncoding | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
from torch import Tensor | |
import torch | |
class SpladeEncoder: | |
batch_size = 8 | |
model_id = "naver/splade-v3" | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
self.model = AutoModelForMaskedLM.from_pretrained(self.model_id) | |
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()} | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda") | |
elif torch.mps.is_available(): | |
self.device = torch.device("mps") | |
else: | |
self.device = torch.device("cpu") | |
self.model.to(self.device) | |
def forward(self, inputs: BatchEncoding) -> Tensor: | |
output = self.model(**inputs.to(self.device)) | |
logits: Tensor = output.logits | |
mask: Tensor = inputs.attention_mask | |
vec = (logits.relu() + 1).log() * mask.unsqueeze(dim=-1) | |
return vec.max(dim=1)[0].squeeze() | |
def encode(self, texts: list[str]) -> Tensor: | |
"""Forward pass to get dense vectors | |
Parameters | |
---------- | |
texts : list[str] | |
Returns | |
------- | |
torch.Tensor | |
Dense vectors | |
""" | |
vectors = [] | |
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Encoding"): # type: ignore | |
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True) | |
vec = self.forward(inputs=tokens) | |
vectors.append(vec) | |
return torch.vstack(vectors) | |
def query_reranking(self, query: str, documents: list[str]) -> list[float]: | |
"""Cosine similarity re-ranking. | |
Parameters | |
---------- | |
query : str | |
Retrieval query | |
documents : list[str] | |
Retrieved documents | |
Returns | |
------- | |
list[float] | |
Cosine values | |
""" | |
vec = self.encode([query, *documents]) | |
xQ = F.normalize(vec[:1], dim=-1, p=2.) | |
xD = F.normalize(vec[1:], dim=-1, p=2.) | |
return (xQ * xD).sum(dim=-1).cpu().tolist() | |
def token_expand(self, query: str) -> dict[str, float]: | |
"""Sparse lexical token expansion. | |
Parameters | |
---------- | |
query : str | |
Retrieval query | |
Returns | |
------- | |
dict[str, float] | |
""" | |
vec = self.encode([query]).squeeze() | |
cols = vec.nonzero().squeeze().cpu().tolist() | |
weights = vec[cols].cpu().tolist() | |
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0} | |
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True)) | |