File size: 2,983 Bytes
9433533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from tqdm.auto import tqdm

from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import Tensor
import torch


class SpladeEncoder:
    batch_size = 8
    model_id = "naver/splade-v3"

    def __init__(self):

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.model = AutoModelForMaskedLM.from_pretrained(self.model_id)
        self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")
        self.model.to(self.device)

    @torch.no_grad()
    def forward(self, inputs: BatchEncoding) -> Tensor:
        output = self.model(**inputs.to(self.device))

        logits: Tensor = output.logits
        mask: Tensor = inputs.attention_mask

        vec = (logits.relu() + 1).log() * mask.unsqueeze(dim=-1)
        return vec.max(dim=1)[0].squeeze()

    def encode(self, texts: list[str]) -> Tensor:
        """Forward pass to get dense vectors

        Parameters
        ----------
        texts : list[str]

        Returns
        -------
        torch.Tensor
            Dense vectors
        """

        vectors = []
        for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Encoding"): # type: ignore
            tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
            vec = self.forward(inputs=tokens)
            vectors.append(vec)
        return torch.vstack(vectors)

    def query_reranking(self, query: str, documents: list[str]) -> list[float]:
        """Cosine similarity re-ranking.

        Parameters
        ----------
        query : str
            Retrieval query
        documents : list[str]
            Retrieved documents

        Returns
        -------
        list[float]
            Cosine values
        """

        vec = self.encode([query, *documents])
        xQ = F.normalize(vec[:1], dim=-1, p=2.)
        xD = F.normalize(vec[1:], dim=-1, p=2.)
        return (xQ * xD).sum(dim=-1).cpu().tolist()

    def token_expand(self, query: str) -> dict[str, float]:
        """Sparse lexical token expansion.

        Parameters
        ----------
        query : str
            Retrieval query

        Returns
        -------
        dict[str, float]
        """

        vec = self.encode([query]).squeeze()
        cols = vec.nonzero().squeeze().cpu().tolist()
        weights = vec[cols].cpu().tolist()

        sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
        return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))