Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files- cross_encoder_reranking_train.py +52 -140
cross_encoder_reranking_train.py
CHANGED
@@ -13,8 +13,8 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
# Load embedder once
|
16 |
-
embedder = SentenceTransformer("
|
17 |
-
|
18 |
|
19 |
def embed_text_list(texts):
|
20 |
return embedder.encode(texts, convert_to_tensor=False, device=device)
|
@@ -61,28 +61,6 @@ def process_single_patent(patent_dict):
|
|
61 |
"features": rank_by_centrality(top_features),
|
62 |
}
|
63 |
|
64 |
-
def refined_process_single_patent(patent_dict, top_n=10):
|
65 |
-
abstract = patent_dict.get("pa01", "")
|
66 |
-
title = patent_dict.get("title", "")
|
67 |
-
context = f"{title} {abstract}"
|
68 |
-
context_emb = embed_text_list([context])[0]
|
69 |
-
|
70 |
-
claims = [v for k, v in patent_dict.items() if k.startswith("c-en")]
|
71 |
-
paragraphs = [v for k, v in patent_dict.items() if k.startswith("p")]
|
72 |
-
features = [v for k, v in patent_dict.get("features", {}).items()]
|
73 |
-
|
74 |
-
def semantic_rank(items, context_emb):
|
75 |
-
embeddings = embed_text_list(items)
|
76 |
-
scores = cosine_similarity([context_emb], embeddings)[0]
|
77 |
-
ranked_items = [item for item, _ in sorted(zip(items, scores), key=lambda x: x[1], reverse=True)]
|
78 |
-
return ranked_items
|
79 |
-
|
80 |
-
return {
|
81 |
-
"claims": semantic_rank(claims, context_emb)[:top_n],
|
82 |
-
"paragraphs": semantic_rank(paragraphs, context_emb)[:top_n],
|
83 |
-
"features": semantic_rank(features, context_emb)[:top_n],
|
84 |
-
}
|
85 |
-
|
86 |
def load_json_file(file_path):
|
87 |
"""Load JSON data from a file"""
|
88 |
with open(file_path, 'r') as f:
|
@@ -174,22 +152,6 @@ def extract_text(content_dict, text_type="full"):
|
|
174 |
|
175 |
return " ".join(all_text)
|
176 |
|
177 |
-
elif text_type == "smart2":
|
178 |
-
filtered_dict = refined_process_single_patent(content_dict)
|
179 |
-
all_text = []
|
180 |
-
# Context with title and abstract
|
181 |
-
if "title" in content_dict:
|
182 |
-
all_text.append(content_dict["title"])
|
183 |
-
if "pa01" in content_dict:
|
184 |
-
all_text.append(content_dict["pa01"])
|
185 |
-
|
186 |
-
# Add claims, paragraphs, and features
|
187 |
-
all_text.extend(filtered_dict["claims"])
|
188 |
-
all_text.extend(filtered_dict["paragraphs"])
|
189 |
-
all_text.extend(filtered_dict["features"])
|
190 |
-
|
191 |
-
return " ".join(all_text)
|
192 |
-
|
193 |
|
194 |
return ""
|
195 |
|
@@ -203,118 +165,67 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
|
|
203 |
batch_size = last_hidden_states.shape[0]
|
204 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
205 |
|
206 |
-
# def get_detailed_instruct(task_description: str, query: str) -> str:
|
207 |
-
# """Create an instruction-formatted query"""
|
208 |
-
# return f'Instruct: {task_description}\nQuery: {query}'
|
209 |
-
|
210 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
# max_length (int): Maximum sequence length
|
232 |
-
|
233 |
-
# Returns:
|
234 |
-
# list: Indices of documents sorted by relevance score (descending)
|
235 |
-
# """
|
236 |
-
# device = next(model.parameters()).device
|
237 |
-
# scores = []
|
238 |
|
239 |
-
#
|
240 |
-
|
241 |
|
242 |
-
|
243 |
|
244 |
-
#
|
245 |
-
|
246 |
-
|
247 |
|
248 |
-
#
|
249 |
-
|
250 |
|
251 |
-
#
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
|
256 |
-
#
|
257 |
-
|
258 |
-
|
259 |
|
260 |
-
#
|
261 |
-
|
262 |
|
263 |
-
#
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
# # Create list of (index, score) tuples for sorting
|
268 |
-
# indexed_scores = list(enumerate(scores))
|
269 |
|
270 |
-
#
|
271 |
-
|
272 |
|
273 |
-
#
|
274 |
-
# return [idx for idx, _ in indexed_scores]
|
275 |
-
|
276 |
-
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
|
277 |
-
device = next(model.parameters()).device
|
278 |
-
cross_scores = []
|
279 |
-
query_emb = embed_text_list([query_text])[0] # Move embedder to CPU
|
280 |
-
|
281 |
-
instructed_query = get_detailed_instruct("", query_text)
|
282 |
-
|
283 |
-
# Pre-create all input pairs (concatenation-based cross-encoder setup)
|
284 |
-
input_texts = [f"{instructed_query} {doc}" for doc in doc_texts]
|
285 |
-
|
286 |
-
for i in tqdm(range(0, len(input_texts), batch_size), desc="Scoring documents", leave=False):
|
287 |
-
batch_input_texts = input_texts[i:i+batch_size]
|
288 |
-
|
289 |
-
with torch.no_grad():
|
290 |
-
batch_dict = tokenizer(batch_input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
|
291 |
-
|
292 |
-
# Mixed precision for faster inference and lower memory
|
293 |
-
with torch.cuda.amp.autocast():
|
294 |
-
outputs = model(**batch_dict)
|
295 |
-
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
296 |
-
embeddings = F.normalize(embeddings, p=2, dim=1)
|
297 |
-
|
298 |
-
# Since queries are repeated in each pair, compare to instructed query embedding (first one)
|
299 |
-
query_vector = embeddings[0].unsqueeze(0) # Use first as query
|
300 |
-
batch_cross_scores = (query_vector @ embeddings.T).squeeze(0).cpu().numpy()[1:] # Exclude self-comparison
|
301 |
-
cross_scores.extend(batch_cross_scores)
|
302 |
-
|
303 |
-
# Semantic scores
|
304 |
-
doc_embeddings = embed_text_list(doc_texts)
|
305 |
-
semantic_scores = cosine_similarity([query_emb], doc_embeddings)[0]
|
306 |
-
|
307 |
-
# Hybrid scores
|
308 |
-
hybrid_scores = [hybrid_score(c, s) for c, s in zip(cross_scores, semantic_scores)]
|
309 |
-
|
310 |
-
indexed_scores = list(enumerate(hybrid_scores))
|
311 |
indexed_scores.sort(key=lambda x: x[1], reverse=True)
|
312 |
-
|
|
|
313 |
return [idx for idx, _ in indexed_scores]
|
314 |
|
315 |
def main():
|
316 |
base_directory = os.getcwd()
|
317 |
-
base_directory += "/Patent_Retrieval"
|
318 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
319 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
320 |
help='Path to pre-ranking JSON file')
|
@@ -326,11 +237,12 @@ def main():
|
|
326 |
parser.add_argument('--documents_content', type=str,
|
327 |
default='./documents_content_with_features.json',
|
328 |
help='Path to documents content JSON file')
|
329 |
-
|
|
|
330 |
parser.add_argument('--queries_list', type=str, default='test_queries.json',
|
331 |
help='Path to training queries JSON file')
|
332 |
parser.add_argument('--text_type', type=str, default='TA',
|
333 |
-
choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart'
|
334 |
help='Type of text to use for scoring')
|
335 |
parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
|
336 |
help='Name of the cross-encoder model')
|
@@ -341,7 +253,7 @@ def main():
|
|
341 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
342 |
help='Device to use (cuda/cpu)')
|
343 |
parser.add_argument('--base_dir', type=str,
|
344 |
-
default=f'{base_directory}/datasets',
|
345 |
help='Base directory for data files')
|
346 |
|
347 |
args = parser.parse_args()
|
@@ -460,4 +372,4 @@ def main():
|
|
460 |
print(f"Information about missing FANs saved to {missing_info_path}")
|
461 |
|
462 |
if __name__ == "__main__":
|
463 |
-
main()
|
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
# Load embedder once
|
16 |
+
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
17 |
+
embedder = embedder.to(device)
|
18 |
|
19 |
def embed_text_list(texts):
|
20 |
return embedder.encode(texts, convert_to_tensor=False, device=device)
|
|
|
61 |
"features": rank_by_centrality(top_features),
|
62 |
}
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def load_json_file(file_path):
|
65 |
"""Load JSON data from a file"""
|
66 |
with open(file_path, 'r') as f:
|
|
|
152 |
|
153 |
return " ".join(all_text)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
return ""
|
157 |
|
|
|
165 |
batch_size = last_hidden_states.shape[0]
|
166 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
167 |
|
|
|
|
|
|
|
|
|
168 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
169 |
+
"""Create an instruction-formatted query"""
|
170 |
+
return f'Instruct: {task_description}\nQuery: {query}'
|
171 |
+
|
172 |
+
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
|
173 |
+
"""
|
174 |
+
Rerank document texts based on query text using cross-encoder model
|
175 |
+
|
176 |
+
Parameters:
|
177 |
+
query_text (str): The query text
|
178 |
+
doc_texts (list): List of document texts
|
179 |
+
model: The cross-encoder model
|
180 |
+
tokenizer: The tokenizer for the model
|
181 |
+
batch_size (int): Batch size for processing
|
182 |
+
max_length (int): Maximum sequence length
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
list: Indices of documents sorted by relevance score (descending)
|
186 |
+
"""
|
187 |
+
device = next(model.parameters()).device
|
188 |
+
scores = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# Format query with instruction
|
191 |
+
task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
|
192 |
|
193 |
+
instructed_query = get_detailed_instruct(task_description, query_text)
|
194 |
|
195 |
+
# Process in batches to avoid OOM
|
196 |
+
for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
|
197 |
+
batch_docs = doc_texts[i:i+batch_size]
|
198 |
|
199 |
+
# Prepare input pairs for the batch
|
200 |
+
input_texts = [instructed_query] + batch_docs
|
201 |
|
202 |
+
# Tokenize
|
203 |
+
with torch.no_grad():
|
204 |
+
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
|
205 |
+
truncation=True, return_tensors='pt').to(device)
|
206 |
|
207 |
+
# Get embeddings
|
208 |
+
outputs = model(**batch_dict)
|
209 |
+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
210 |
|
211 |
+
# Normalize embeddings
|
212 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
213 |
|
214 |
+
# Calculate similarity scores between query and documents
|
215 |
+
batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
|
216 |
+
scores.extend(batch_scores.cpu().tolist())
|
|
|
|
|
|
|
217 |
|
218 |
+
# Create list of (index, score) tuples for sorting
|
219 |
+
indexed_scores = list(enumerate(scores))
|
220 |
|
221 |
+
# Sort by score in descending order
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
indexed_scores.sort(key=lambda x: x[1], reverse=True)
|
223 |
+
|
224 |
+
# Return sorted indices
|
225 |
return [idx for idx, _ in indexed_scores]
|
226 |
|
227 |
def main():
|
228 |
base_directory = os.getcwd()
|
|
|
229 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
230 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
231 |
help='Path to pre-ranking JSON file')
|
|
|
237 |
parser.add_argument('--documents_content', type=str,
|
238 |
default='./documents_content_with_features.json',
|
239 |
help='Path to documents content JSON file')
|
240 |
+
|
241 |
+
# Change here for test or train
|
242 |
parser.add_argument('--queries_list', type=str, default='test_queries.json',
|
243 |
help='Path to training queries JSON file')
|
244 |
parser.add_argument('--text_type', type=str, default='TA',
|
245 |
+
choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart'],
|
246 |
help='Type of text to use for scoring')
|
247 |
parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
|
248 |
help='Name of the cross-encoder model')
|
|
|
253 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
254 |
help='Device to use (cuda/cpu)')
|
255 |
parser.add_argument('--base_dir', type=str,
|
256 |
+
default=f'{base_directory}/Patent_Retrieval/datasets',
|
257 |
help='Base directory for data files')
|
258 |
|
259 |
args = parser.parse_args()
|
|
|
372 |
print(f"Information about missing FANs saved to {missing_info_path}")
|
373 |
|
374 |
if __name__ == "__main__":
|
375 |
+
main()
|