darpanaswal commited on
Commit
b4bffe5
·
verified ·
1 Parent(s): 29de29c

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +52 -140
cross_encoder_reranking_train.py CHANGED
@@ -13,8 +13,8 @@ from sklearn.metrics.pairwise import cosine_similarity
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
- embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to(device)
17
-
18
 
19
  def embed_text_list(texts):
20
  return embedder.encode(texts, convert_to_tensor=False, device=device)
@@ -61,28 +61,6 @@ def process_single_patent(patent_dict):
61
  "features": rank_by_centrality(top_features),
62
  }
63
 
64
- def refined_process_single_patent(patent_dict, top_n=10):
65
- abstract = patent_dict.get("pa01", "")
66
- title = patent_dict.get("title", "")
67
- context = f"{title} {abstract}"
68
- context_emb = embed_text_list([context])[0]
69
-
70
- claims = [v for k, v in patent_dict.items() if k.startswith("c-en")]
71
- paragraphs = [v for k, v in patent_dict.items() if k.startswith("p")]
72
- features = [v for k, v in patent_dict.get("features", {}).items()]
73
-
74
- def semantic_rank(items, context_emb):
75
- embeddings = embed_text_list(items)
76
- scores = cosine_similarity([context_emb], embeddings)[0]
77
- ranked_items = [item for item, _ in sorted(zip(items, scores), key=lambda x: x[1], reverse=True)]
78
- return ranked_items
79
-
80
- return {
81
- "claims": semantic_rank(claims, context_emb)[:top_n],
82
- "paragraphs": semantic_rank(paragraphs, context_emb)[:top_n],
83
- "features": semantic_rank(features, context_emb)[:top_n],
84
- }
85
-
86
  def load_json_file(file_path):
87
  """Load JSON data from a file"""
88
  with open(file_path, 'r') as f:
@@ -174,22 +152,6 @@ def extract_text(content_dict, text_type="full"):
174
 
175
  return " ".join(all_text)
176
 
177
- elif text_type == "smart2":
178
- filtered_dict = refined_process_single_patent(content_dict)
179
- all_text = []
180
- # Context with title and abstract
181
- if "title" in content_dict:
182
- all_text.append(content_dict["title"])
183
- if "pa01" in content_dict:
184
- all_text.append(content_dict["pa01"])
185
-
186
- # Add claims, paragraphs, and features
187
- all_text.extend(filtered_dict["claims"])
188
- all_text.extend(filtered_dict["paragraphs"])
189
- all_text.extend(filtered_dict["features"])
190
-
191
- return " ".join(all_text)
192
-
193
 
194
  return ""
195
 
@@ -203,118 +165,67 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
203
  batch_size = last_hidden_states.shape[0]
204
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
205
 
206
- # def get_detailed_instruct(task_description: str, query: str) -> str:
207
- # """Create an instruction-formatted query"""
208
- # return f'Instruct: {task_description}\nQuery: {query}'
209
-
210
  def get_detailed_instruct(task_description: str, query: str) -> str:
211
- return (
212
- f"Instruct: Evaluate the semantic and technical similarity between two patent documents."
213
- f" Prioritize highly similar claims, technical implementations, and shared functionalities."
214
- f"\nQuery: {query}"
215
- )
216
-
217
- def hybrid_score(cross_encoder_score, semantic_score, weight_cross=0.8, weight_semantic=0.2):
218
- return (weight_cross * cross_encoder_score) + (weight_semantic * semantic_score)
219
-
220
-
221
- # def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
222
- # """
223
- # Rerank document texts based on query text using cross-encoder model
224
-
225
- # Parameters:
226
- # query_text (str): The query text
227
- # doc_texts (list): List of document texts
228
- # model: The cross-encoder model
229
- # tokenizer: The tokenizer for the model
230
- # batch_size (int): Batch size for processing
231
- # max_length (int): Maximum sequence length
232
-
233
- # Returns:
234
- # list: Indices of documents sorted by relevance score (descending)
235
- # """
236
- # device = next(model.parameters()).device
237
- # scores = []
238
 
239
- # # Format query with instruction
240
- # task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
241
 
242
- # instructed_query = get_detailed_instruct(task_description, query_text)
243
 
244
- # # Process in batches to avoid OOM
245
- # for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
246
- # batch_docs = doc_texts[i:i+batch_size]
247
 
248
- # # Prepare input pairs for the batch
249
- # input_texts = [instructed_query] + batch_docs
250
 
251
- # # Tokenize
252
- # with torch.no_grad():
253
- # batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
254
- # truncation=True, return_tensors='pt').to(device)
255
 
256
- # # Get embeddings
257
- # outputs = model(**batch_dict)
258
- # embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
259
 
260
- # # Normalize embeddings
261
- # embeddings = F.normalize(embeddings, p=2, dim=1)
262
 
263
- # # Calculate similarity scores between query and documents
264
- # batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
265
- # scores.extend(batch_scores.cpu().tolist())
266
-
267
- # # Create list of (index, score) tuples for sorting
268
- # indexed_scores = list(enumerate(scores))
269
 
270
- # # Sort by score in descending order
271
- # indexed_scores.sort(key=lambda x: x[1], reverse=True)
272
 
273
- # # Return sorted indices
274
- # return [idx for idx, _ in indexed_scores]
275
-
276
- def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
277
- device = next(model.parameters()).device
278
- cross_scores = []
279
- query_emb = embed_text_list([query_text])[0] # Move embedder to CPU
280
-
281
- instructed_query = get_detailed_instruct("", query_text)
282
-
283
- # Pre-create all input pairs (concatenation-based cross-encoder setup)
284
- input_texts = [f"{instructed_query} {doc}" for doc in doc_texts]
285
-
286
- for i in tqdm(range(0, len(input_texts), batch_size), desc="Scoring documents", leave=False):
287
- batch_input_texts = input_texts[i:i+batch_size]
288
-
289
- with torch.no_grad():
290
- batch_dict = tokenizer(batch_input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
291
-
292
- # Mixed precision for faster inference and lower memory
293
- with torch.cuda.amp.autocast():
294
- outputs = model(**batch_dict)
295
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
296
- embeddings = F.normalize(embeddings, p=2, dim=1)
297
-
298
- # Since queries are repeated in each pair, compare to instructed query embedding (first one)
299
- query_vector = embeddings[0].unsqueeze(0) # Use first as query
300
- batch_cross_scores = (query_vector @ embeddings.T).squeeze(0).cpu().numpy()[1:] # Exclude self-comparison
301
- cross_scores.extend(batch_cross_scores)
302
-
303
- # Semantic scores
304
- doc_embeddings = embed_text_list(doc_texts)
305
- semantic_scores = cosine_similarity([query_emb], doc_embeddings)[0]
306
-
307
- # Hybrid scores
308
- hybrid_scores = [hybrid_score(c, s) for c, s in zip(cross_scores, semantic_scores)]
309
-
310
- indexed_scores = list(enumerate(hybrid_scores))
311
  indexed_scores.sort(key=lambda x: x[1], reverse=True)
312
-
 
313
  return [idx for idx, _ in indexed_scores]
314
 
315
  def main():
316
  base_directory = os.getcwd()
317
- base_directory += "/Patent_Retrieval"
318
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
319
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
320
  help='Path to pre-ranking JSON file')
@@ -326,11 +237,12 @@ def main():
326
  parser.add_argument('--documents_content', type=str,
327
  default='./documents_content_with_features.json',
328
  help='Path to documents content JSON file')
329
- # Change here from train to test
 
330
  parser.add_argument('--queries_list', type=str, default='test_queries.json',
331
  help='Path to training queries JSON file')
332
  parser.add_argument('--text_type', type=str, default='TA',
333
- choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart', 'smart2'],
334
  help='Type of text to use for scoring')
335
  parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
336
  help='Name of the cross-encoder model')
@@ -341,7 +253,7 @@ def main():
341
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
342
  help='Device to use (cuda/cpu)')
343
  parser.add_argument('--base_dir', type=str,
344
- default=f'{base_directory}/datasets',
345
  help='Base directory for data files')
346
 
347
  args = parser.parse_args()
@@ -460,4 +372,4 @@ def main():
460
  print(f"Information about missing FANs saved to {missing_info_path}")
461
 
462
  if __name__ == "__main__":
463
- main()
 
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
17
+ embedder = embedder.to(device)
18
 
19
  def embed_text_list(texts):
20
  return embedder.encode(texts, convert_to_tensor=False, device=device)
 
61
  "features": rank_by_centrality(top_features),
62
  }
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def load_json_file(file_path):
65
  """Load JSON data from a file"""
66
  with open(file_path, 'r') as f:
 
152
 
153
  return " ".join(all_text)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  return ""
157
 
 
165
  batch_size = last_hidden_states.shape[0]
166
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
167
 
 
 
 
 
168
  def get_detailed_instruct(task_description: str, query: str) -> str:
169
+ """Create an instruction-formatted query"""
170
+ return f'Instruct: {task_description}\nQuery: {query}'
171
+
172
+ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
173
+ """
174
+ Rerank document texts based on query text using cross-encoder model
175
+
176
+ Parameters:
177
+ query_text (str): The query text
178
+ doc_texts (list): List of document texts
179
+ model: The cross-encoder model
180
+ tokenizer: The tokenizer for the model
181
+ batch_size (int): Batch size for processing
182
+ max_length (int): Maximum sequence length
183
+
184
+ Returns:
185
+ list: Indices of documents sorted by relevance score (descending)
186
+ """
187
+ device = next(model.parameters()).device
188
+ scores = []
 
 
 
 
 
 
 
189
 
190
+ # Format query with instruction
191
+ task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
192
 
193
+ instructed_query = get_detailed_instruct(task_description, query_text)
194
 
195
+ # Process in batches to avoid OOM
196
+ for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
197
+ batch_docs = doc_texts[i:i+batch_size]
198
 
199
+ # Prepare input pairs for the batch
200
+ input_texts = [instructed_query] + batch_docs
201
 
202
+ # Tokenize
203
+ with torch.no_grad():
204
+ batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
205
+ truncation=True, return_tensors='pt').to(device)
206
 
207
+ # Get embeddings
208
+ outputs = model(**batch_dict)
209
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
210
 
211
+ # Normalize embeddings
212
+ embeddings = F.normalize(embeddings, p=2, dim=1)
213
 
214
+ # Calculate similarity scores between query and documents
215
+ batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
216
+ scores.extend(batch_scores.cpu().tolist())
 
 
 
217
 
218
+ # Create list of (index, score) tuples for sorting
219
+ indexed_scores = list(enumerate(scores))
220
 
221
+ # Sort by score in descending order
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  indexed_scores.sort(key=lambda x: x[1], reverse=True)
223
+
224
+ # Return sorted indices
225
  return [idx for idx, _ in indexed_scores]
226
 
227
  def main():
228
  base_directory = os.getcwd()
 
229
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
230
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
231
  help='Path to pre-ranking JSON file')
 
237
  parser.add_argument('--documents_content', type=str,
238
  default='./documents_content_with_features.json',
239
  help='Path to documents content JSON file')
240
+
241
+ # Change here for test or train
242
  parser.add_argument('--queries_list', type=str, default='test_queries.json',
243
  help='Path to training queries JSON file')
244
  parser.add_argument('--text_type', type=str, default='TA',
245
+ choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart'],
246
  help='Type of text to use for scoring')
247
  parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
248
  help='Name of the cross-encoder model')
 
253
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
254
  help='Device to use (cuda/cpu)')
255
  parser.add_argument('--base_dir', type=str,
256
+ default=f'{base_directory}/Patent_Retrieval/datasets',
257
  help='Base directory for data files')
258
 
259
  args = parser.parse_args()
 
372
  print(f"Information about missing FANs saved to {missing_info_path}")
373
 
374
  if __name__ == "__main__":
375
+ main()