Spaces:
Running
Running
Batching
Browse files* small batch size processing for SPLADE re-ranking to work within CPU limitations
* reduce default number of context documents to 5
ask_candid/retrieval/elastic.py
CHANGED
@@ -299,7 +299,7 @@ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
|
299 |
def reranker(
|
300 |
query_results: Iterable[ElasticHitsResult],
|
301 |
search_text: Optional[str] = None,
|
302 |
-
max_num_results: int =
|
303 |
) -> Iterator[ElasticHitsResult]:
|
304 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
305 |
This will shuffle results
|
|
|
299 |
def reranker(
|
300 |
query_results: Iterable[ElasticHitsResult],
|
301 |
search_text: Optional[str] = None,
|
302 |
+
max_num_results: int = 5
|
303 |
) -> Iterator[ElasticHitsResult]:
|
304 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
305 |
This will shuffle results
|
ask_candid/retrieval/sparse_lexical.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
from typing import List, Dict
|
2 |
|
|
|
|
|
3 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
|
4 |
from torch.nn import functional as F
|
5 |
import torch
|
6 |
|
7 |
|
8 |
class SpladeEncoder:
|
|
|
9 |
|
10 |
def __init__(self):
|
11 |
model_id = "naver/splade-v3"
|
@@ -16,13 +20,16 @@ class SpladeEncoder:
|
|
16 |
|
17 |
@torch.no_grad()
|
18 |
def forward(self, texts: List[str]):
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
26 |
|
27 |
def query_reranking(self, query: str, documents: List[str]):
|
28 |
vec = self.forward([query, *documents])
|
@@ -31,7 +38,7 @@ class SpladeEncoder:
|
|
31 |
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
32 |
|
33 |
def token_expand(self, query: str) -> Dict[str, float]:
|
34 |
-
vec = self.forward([query])
|
35 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
36 |
weights = vec[cols].cpu().tolist()
|
37 |
|
|
|
1 |
from typing import List, Dict
|
2 |
|
3 |
+
from tqdm.auto import tqdm
|
4 |
+
|
5 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
from torch.nn import functional as F
|
8 |
import torch
|
9 |
|
10 |
|
11 |
class SpladeEncoder:
|
12 |
+
batch_size = 4
|
13 |
|
14 |
def __init__(self):
|
15 |
model_id = "naver/splade-v3"
|
|
|
20 |
|
21 |
@torch.no_grad()
|
22 |
def forward(self, texts: List[str]):
|
23 |
+
vectors = []
|
24 |
+
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"):
|
25 |
+
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
|
26 |
+
output = self.model(**tokens)
|
27 |
+
vec = torch.max(
|
28 |
+
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
29 |
+
dim=1
|
30 |
+
)[0].squeeze()
|
31 |
+
vectors.append(vec)
|
32 |
+
return torch.vstack(vectors)
|
33 |
|
34 |
def query_reranking(self, query: str, documents: List[str]):
|
35 |
vec = self.forward([query, *documents])
|
|
|
38 |
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
39 |
|
40 |
def token_expand(self, query: str) -> Dict[str, float]:
|
41 |
+
vec = self.forward([query]).squeeze()
|
42 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
43 |
weights = vec[cols].cpu().tolist()
|
44 |
|