sushil3125 commited on
Commit
60cacbb
Β·
1 Parent(s): b3e42d1

duplicate issue

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -7,6 +7,8 @@ import torch
7
  from functools import lru_cache
8
  import logging
9
  from datetime import datetime
 
 
10
 
11
  # πŸ”§ Configure logging
12
  logging.basicConfig(level=logging.INFO)
@@ -70,12 +72,23 @@ def encode_splade_cached(text: str) -> SparseVector:
70
  vocab_indices = nonzero[:, 1]
71
  values = relu_log[nonzero[:, 0], nonzero[:, 1]]
72
 
73
- logger.info(f"SPLADE encoding complete with {len(vocab_indices)} dimensions")
 
 
 
 
 
 
 
 
 
 
74
  return SparseVector(
75
- indices=vocab_indices.cpu().numpy().tolist(),
76
- values=values.cpu().numpy().tolist()
77
  )
78
 
 
79
  # πŸš€ Main endpoint
80
  @app.post("/get-embedding/")
81
  async def get_embedding(input: TextInput):
 
7
  from functools import lru_cache
8
  import logging
9
  from datetime import datetime
10
+ from collections import defaultdict
11
+
12
 
13
  # πŸ”§ Configure logging
14
  logging.basicConfig(level=logging.INFO)
 
72
  vocab_indices = nonzero[:, 1]
73
  values = relu_log[nonzero[:, 0], nonzero[:, 1]]
74
 
75
+ vocab_indices_list = vocab_indices.cpu().numpy().tolist()
76
+ values_list = values.cpu().numpy().tolist()
77
+
78
+ index_to_value = defaultdict(float)
79
+ for idx, val in zip(vocab_indices_list, values_list):
80
+ index_to_value[idx] += val
81
+
82
+ deduped_indices = list(index_to_value.keys())
83
+ deduped_values = list(index_to_value.values())
84
+
85
+ logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions")
86
  return SparseVector(
87
+ indices=deduped_indices,
88
+ values=deduped_values
89
  )
90
 
91
+
92
  # πŸš€ Main endpoint
93
  @app.post("/get-embedding/")
94
  async def get_embedding(input: TextInput):