Spaces:
Sleeping
Sleeping
File size: 13,627 Bytes
50e583f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
# """
# Retriever module for TrialGPT
# """
# import json
# import os
# import numpy as np
# import torch
# from nltk import word_tokenize
# from rank_bm25 import BM25Okapi
# import faiss
# from transformers import AutoTokenizer, AutoModel
# def get_bm25_corpus_index(corpus="sigir"):
# """Get BM25 corpus index for the specified corpus."""
# corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")
# if os.path.exists(corpus_path):
# corpus_data = json.load(open(corpus_path))
# tokenized_corpus = corpus_data["tokenized_corpus"]
# corpus_nctids = corpus_data["corpus_nctids"]
# else:
# # If the pre-built index doesn't exist, we'll need to build it
# # For now, return None to indicate the index needs to be built
# return None, None
# bm25 = BM25Okapi(tokenized_corpus)
# return bm25, corpus_nctids
# def get_medcpt_corpus_index(corpus="sigir"):
# """Get MedCPT corpus index for the specified corpus."""
# corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy"
# nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"
# if os.path.exists(corpus_path):
# embeds = np.load(corpus_path)
# corpus_nctids = json.load(open(nctids_path))
# else:
# # If the pre-built index doesn't exist, return None
# return None, None
# index = faiss.IndexFlatIP(768)
# index.add(embeds)
# return index, corpus_nctids
# def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1):
# """
# Retrieve clinical trials based on conditions using hybrid BM25 + MedCPT retrieval.
# Args:
# conditions (list): List of condition strings to search for
# corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir")
# top_k (int): Number of top trials to return
# bm25_weight (int): Weight for BM25 scores
# medcpt_weight (int): Weight for MedCPT scores
# Returns:
# list: List of NCT IDs for the top matching trials
# """
# # Get the retrieval indices
# bm25, bm25_nctids = get_bm25_corpus_index(corpus)
# medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus)
# if bm25 is None or medcpt is None:
# print(f"Warning: Pre-built indices for corpus '{corpus}' not found.")
# print("Please run the hybrid_fusion_retrieval.py script first to build the indices.")
# return []
# if len(conditions) == 0:
# return []
# # BM25 retrieval
# bm25_condition_top_nctids = []
# if bm25_weight > 0:
# for condition in conditions:
# tokens = word_tokenize(condition.lower())
# top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) # Get more for fusion
# bm25_condition_top_nctids.append(top_nctids)
# # MedCPT retrieval
# medcpt_condition_top_nctids = []
# if medcpt_weight > 0:
# try:
# model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
# tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
# with torch.no_grad():
# encoded = tokenizer(
# conditions,
# truncation=True,
# padding=True,
# return_tensors='pt',
# max_length=256,
# )
# # encode the queries
# embeds = model(**encoded).last_hidden_state[:, 0, :].numpy()
# # search the Faiss index
# scores, inds = medcpt.search(embeds, k=top_k*2) # Get more for fusion
# for ind_list in inds:
# top_nctids = [medcpt_nctids[ind] for ind in ind_list]
# medcpt_condition_top_nctids.append(top_nctids)
# except Exception as e:
# print(f"Warning: MedCPT retrieval failed: {e}")
# medcpt_weight = 0
# # Fusion of results
# nctid2score = {}
# for condition_idx, condition in enumerate(conditions):
# # BM25 scoring
# if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids):
# for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]):
# if nctid not in nctid2score:
# nctid2score[nctid] = 0
# nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# # MedCPT scoring
# if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids):
# for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]):
# if nctid not in nctid2score:
# nctid2score[nctid] = 0
# nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# # Sort by score and return top_k
# nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
# top_nctids = [nctid for nctid, _ in nctid2score[:top_k]]
# return top_nctids
"""
GPU-accelerated Retriever module for TrialGPT
"""
import json
import os
import numpy as np
import torch
from nltk import word_tokenize
from rank_bm25 import BM25Okapi
import faiss
from transformers import AutoTokenizer, AutoModel
class GPUTrialRetriever:
"""GPU-accelerated trial retriever with model caching."""
def __init__(self, device=None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.tokenizer = None
self._model_loaded = False
def _load_medcpt_model(self):
"""Load MedCPT model once and cache it."""
if not self._model_loaded:
self.tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
self.model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
# Move model to GPU if available
if self.device != 'cpu':
self.model = self.model.to(self.device)
# Set to evaluation mode for inference
self.model.eval()
self._model_loaded = True
print(f"MedCPT model loaded on {self.device}")
def get_bm25_corpus_index(corpus="sigir"):
"""Get BM25 corpus index for the specified corpus."""
corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")
if os.path.exists(corpus_path):
corpus_data = json.load(open(corpus_path))
tokenized_corpus = corpus_data["tokenized_corpus"]
corpus_nctids = corpus_data["corpus_nctids"]
else:
return None, None
bm25 = BM25Okapi(tokenized_corpus)
return bm25, corpus_nctids
def get_medcpt_corpus_index_gpu(corpus="sigir", use_gpu=True):
"""Get GPU-accelerated MedCPT corpus index."""
corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy"
nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"
if os.path.exists(corpus_path):
embeds = np.load(corpus_path).astype(np.float32) # Ensure float32 for GPU
corpus_nctids = json.load(open(nctids_path))
else:
return None, None
# Use GPU FAISS index if available and requested
if use_gpu and torch.cuda.is_available():
try:
# Create GPU resources
res = faiss.StandardGpuResources()
# Create CPU index first
cpu_index = faiss.IndexFlatIP(768)
cpu_index.add(embeds)
# Move to GPU
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
print(f"FAISS index moved to GPU with {len(corpus_nctids)} embeddings")
return gpu_index, corpus_nctids
except Exception as e:
print(f"GPU FAISS failed, falling back to CPU: {e}")
# Fall back to CPU index
index = faiss.IndexFlatIP(768)
index.add(embeds)
return index, corpus_nctids
else:
# CPU index
index = faiss.IndexFlatIP(768)
index.add(embeds)
return index, corpus_nctids
def retrieve_trials_gpu(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1,
use_gpu=True, batch_size=32, retriever=None):
"""
GPU-accelerated clinical trial retrieval with optimized batching.
Args:
conditions (list): List of condition strings to search for
corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir")
top_k (int): Number of top trials to return
bm25_weight (int): Weight for BM25 scores
medcpt_weight (int): Weight for MedCPT scores
use_gpu (bool): Whether to use GPU acceleration
batch_size (int): Batch size for MedCPT encoding
retriever (GPUTrialRetriever): Cached retriever instance
Returns:
list: List of NCT IDs for the top matching trials
"""
if len(conditions) == 0:
return []
# Get the retrieval indices
bm25, bm25_nctids = get_bm25_corpus_index(corpus)
medcpt, medcpt_nctids = get_medcpt_corpus_index_gpu(corpus, use_gpu)
if bm25 is None or medcpt is None:
print(f"Warning: Pre-built indices for corpus '{corpus}' not found.")
print("Please run the hybrid_fusion_retrieval.py script first to build the indices.")
return []
# Initialize retriever if not provided
if retriever is None:
device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
retriever = GPUTrialRetriever(device)
# BM25 retrieval (CPU-bound, hard to optimize further)
bm25_condition_top_nctids = []
if bm25_weight > 0:
for condition in conditions:
tokens = word_tokenize(condition.lower())
top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2)
bm25_condition_top_nctids.append(top_nctids)
# MedCPT retrieval with GPU acceleration
medcpt_condition_top_nctids = []
if medcpt_weight > 0:
try:
retriever._load_medcpt_model()
# Process conditions in batches for better GPU utilization
all_embeds = []
for i in range(0, len(conditions), batch_size):
batch_conditions = conditions[i:i + batch_size]
with torch.no_grad():
# Tokenize batch
encoded = retriever.tokenizer(
batch_conditions,
truncation=True,
padding=True,
return_tensors='pt',
max_length=256,
)
# Move tensors to GPU
if retriever.device != 'cpu':
encoded = {k: v.to(retriever.device) for k, v in encoded.items()}
# Get embeddings
batch_embeds = retriever.model(**encoded).last_hidden_state[:, 0, :]
# Always move to CPU before appending (fixes the CUDA tensor issue)
batch_embeds = batch_embeds.cpu()
all_embeds.append(batch_embeds)
# Concatenate all embeddings and convert to numpy
embeds = torch.cat(all_embeds, dim=0).numpy()
# Search the FAISS index
scores, inds = medcpt.search(embeds, k=top_k*2)
# Convert indices to NCT IDs
for ind_list in inds:
top_nctids = [medcpt_nctids[ind] for ind in ind_list]
medcpt_condition_top_nctids.append(top_nctids)
except Exception as e:
print(f"Warning: MedCPT retrieval failed: {e}")
medcpt_weight = 0
# Fusion of results (same as original)
nctid2score = {}
for condition_idx, condition in enumerate(conditions):
# BM25 scoring
if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids):
for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]):
if nctid not in nctid2score:
nctid2score[nctid] = 0
nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# MedCPT scoring
if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids):
for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]):
if nctid not in nctid2score:
nctid2score[nctid] = 0
nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# Sort by score and return top_k
nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
top_nctids = [nctid for nctid, _ in nctid2score[:top_k]]
return top_nctids
# Convenience function for backward compatibility
def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1):
"""Original interface with GPU acceleration."""
return retrieve_trials_gpu(conditions, corpus, top_k, bm25_weight, medcpt_weight)
# Example usage with persistent retriever for multiple calls
def create_retriever_session(use_gpu=True):
"""Create a retriever session for multiple queries to avoid model reloading."""
device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
return GPUTrialRetriever(device) |