Spaces:
Running
Running
File size: 2,983 Bytes
9433533 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from tqdm.auto import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import Tensor
import torch
class SpladeEncoder:
batch_size = 8
model_id = "naver/splade-v3"
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForMaskedLM.from_pretrained(self.model_id)
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
self.model.to(self.device)
@torch.no_grad()
def forward(self, inputs: BatchEncoding) -> Tensor:
output = self.model(**inputs.to(self.device))
logits: Tensor = output.logits
mask: Tensor = inputs.attention_mask
vec = (logits.relu() + 1).log() * mask.unsqueeze(dim=-1)
return vec.max(dim=1)[0].squeeze()
def encode(self, texts: list[str]) -> Tensor:
"""Forward pass to get dense vectors
Parameters
----------
texts : list[str]
Returns
-------
torch.Tensor
Dense vectors
"""
vectors = []
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Encoding"): # type: ignore
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
vec = self.forward(inputs=tokens)
vectors.append(vec)
return torch.vstack(vectors)
def query_reranking(self, query: str, documents: list[str]) -> list[float]:
"""Cosine similarity re-ranking.
Parameters
----------
query : str
Retrieval query
documents : list[str]
Retrieved documents
Returns
-------
list[float]
Cosine values
"""
vec = self.encode([query, *documents])
xQ = F.normalize(vec[:1], dim=-1, p=2.)
xD = F.normalize(vec[1:], dim=-1, p=2.)
return (xQ * xD).sum(dim=-1).cpu().tolist()
def token_expand(self, query: str) -> dict[str, float]:
"""Sparse lexical token expansion.
Parameters
----------
query : str
Retrieval query
Returns
-------
dict[str, float]
"""
vec = self.encode([query]).squeeze()
cols = vec.nonzero().squeeze().cpu().tolist()
weights = vec[cols].cpu().tolist()
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))
|