brainsqueeze commited on
Commit
2744d22
·
verified ·
1 Parent(s): 2fcef4a

* small batch size processing for SPLADE re-ranking to work within CPU limitations
* reduce default number of context documents to 5

ask_candid/retrieval/elastic.py CHANGED
@@ -299,7 +299,7 @@ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
299
  def reranker(
300
  query_results: Iterable[ElasticHitsResult],
301
  search_text: Optional[str] = None,
302
- max_num_results: int = 10
303
  ) -> Iterator[ElasticHitsResult]:
304
  """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
305
  This will shuffle results
 
299
  def reranker(
300
  query_results: Iterable[ElasticHitsResult],
301
  search_text: Optional[str] = None,
302
+ max_num_results: int = 5
303
  ) -> Iterator[ElasticHitsResult]:
304
  """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
305
  This will shuffle results
ask_candid/retrieval/sparse_lexical.py CHANGED
@@ -1,11 +1,15 @@
1
  from typing import List, Dict
2
 
 
 
3
  from transformers import AutoModelForMaskedLM, AutoTokenizer
 
4
  from torch.nn import functional as F
5
  import torch
6
 
7
 
8
  class SpladeEncoder:
 
9
 
10
  def __init__(self):
11
  model_id = "naver/splade-v3"
@@ -16,13 +20,16 @@ class SpladeEncoder:
16
 
17
  @torch.no_grad()
18
  def forward(self, texts: List[str]):
19
- tokens = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True)
20
- output = self.model(**tokens)
21
- vec = torch.max(
22
- torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
23
- dim=1
24
- )[0].squeeze()
25
- return vec
 
 
 
26
 
27
  def query_reranking(self, query: str, documents: List[str]):
28
  vec = self.forward([query, *documents])
@@ -31,7 +38,7 @@ class SpladeEncoder:
31
  return (xQ * xD).sum(dim=-1).cpu().tolist()
32
 
33
  def token_expand(self, query: str) -> Dict[str, float]:
34
- vec = self.forward([query])
35
  cols = vec.nonzero().squeeze().cpu().tolist()
36
  weights = vec[cols].cpu().tolist()
37
 
 
1
  from typing import List, Dict
2
 
3
+ from tqdm.auto import tqdm
4
+
5
  from transformers import AutoModelForMaskedLM, AutoTokenizer
6
+ from torch.utils.data import DataLoader
7
  from torch.nn import functional as F
8
  import torch
9
 
10
 
11
  class SpladeEncoder:
12
+ batch_size = 4
13
 
14
  def __init__(self):
15
  model_id = "naver/splade-v3"
 
20
 
21
  @torch.no_grad()
22
  def forward(self, texts: List[str]):
23
+ vectors = []
24
+ for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"):
25
+ tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
26
+ output = self.model(**tokens)
27
+ vec = torch.max(
28
+ torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
29
+ dim=1
30
+ )[0].squeeze()
31
+ vectors.append(vec)
32
+ return torch.vstack(vectors)
33
 
34
  def query_reranking(self, query: str, documents: List[str]):
35
  vec = self.forward([query, *documents])
 
38
  return (xQ * xD).sum(dim=-1).cpu().tolist()
39
 
40
  def token_expand(self, query: str) -> Dict[str, float]:
41
+ vec = self.forward([query]).squeeze()
42
  cols = vec.nonzero().squeeze().cpu().tolist()
43
  weights = vec[cols].cpu().tolist()
44