# """ # Retriever module for TrialGPT # """ # import json # import os # import numpy as np # import torch # from nltk import word_tokenize # from rank_bm25 import BM25Okapi # import faiss # from transformers import AutoTokenizer, AutoModel # def get_bm25_corpus_index(corpus="sigir"): # """Get BM25 corpus index for the specified corpus.""" # corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") # if os.path.exists(corpus_path): # corpus_data = json.load(open(corpus_path)) # tokenized_corpus = corpus_data["tokenized_corpus"] # corpus_nctids = corpus_data["corpus_nctids"] # else: # # If the pre-built index doesn't exist, we'll need to build it # # For now, return None to indicate the index needs to be built # return None, None # bm25 = BM25Okapi(tokenized_corpus) # return bm25, corpus_nctids # def get_medcpt_corpus_index(corpus="sigir"): # """Get MedCPT corpus index for the specified corpus.""" # corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" # nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" # if os.path.exists(corpus_path): # embeds = np.load(corpus_path) # corpus_nctids = json.load(open(nctids_path)) # else: # # If the pre-built index doesn't exist, return None # return None, None # index = faiss.IndexFlatIP(768) # index.add(embeds) # return index, corpus_nctids # def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1): # """ # Retrieve clinical trials based on conditions using hybrid BM25 + MedCPT retrieval. # Args: # conditions (list): List of condition strings to search for # corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir") # top_k (int): Number of top trials to return # bm25_weight (int): Weight for BM25 scores # medcpt_weight (int): Weight for MedCPT scores # Returns: # list: List of NCT IDs for the top matching trials # """ # # Get the retrieval indices # bm25, bm25_nctids = get_bm25_corpus_index(corpus) # medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus) # if bm25 is None or medcpt is None: # print(f"Warning: Pre-built indices for corpus '{corpus}' not found.") # print("Please run the hybrid_fusion_retrieval.py script first to build the indices.") # return [] # if len(conditions) == 0: # return [] # # BM25 retrieval # bm25_condition_top_nctids = [] # if bm25_weight > 0: # for condition in conditions: # tokens = word_tokenize(condition.lower()) # top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) # Get more for fusion # bm25_condition_top_nctids.append(top_nctids) # # MedCPT retrieval # medcpt_condition_top_nctids = [] # if medcpt_weight > 0: # try: # model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") # tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") # with torch.no_grad(): # encoded = tokenizer( # conditions, # truncation=True, # padding=True, # return_tensors='pt', # max_length=256, # ) # # encode the queries # embeds = model(**encoded).last_hidden_state[:, 0, :].numpy() # # search the Faiss index # scores, inds = medcpt.search(embeds, k=top_k*2) # Get more for fusion # for ind_list in inds: # top_nctids = [medcpt_nctids[ind] for ind in ind_list] # medcpt_condition_top_nctids.append(top_nctids) # except Exception as e: # print(f"Warning: MedCPT retrieval failed: {e}") # medcpt_weight = 0 # # Fusion of results # nctid2score = {} # for condition_idx, condition in enumerate(conditions): # # BM25 scoring # if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids): # for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]): # if nctid not in nctid2score: # nctid2score[nctid] = 0 # nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) # # MedCPT scoring # if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids): # for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]): # if nctid not in nctid2score: # nctid2score[nctid] = 0 # nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) # # Sort by score and return top_k # nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) # top_nctids = [nctid for nctid, _ in nctid2score[:top_k]] # return top_nctids """ GPU-accelerated Retriever module for TrialGPT """ import json import os import numpy as np import torch from nltk import word_tokenize from rank_bm25 import BM25Okapi import faiss from transformers import AutoTokenizer, AutoModel class GPUTrialRetriever: """GPU-accelerated trial retriever with model caching.""" def __init__(self, device=None): self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.model = None self.tokenizer = None self._model_loaded = False def _load_medcpt_model(self): """Load MedCPT model once and cache it.""" if not self._model_loaded: self.tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") self.model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") # Move model to GPU if available if self.device != 'cpu': self.model = self.model.to(self.device) # Set to evaluation mode for inference self.model.eval() self._model_loaded = True print(f"MedCPT model loaded on {self.device}") def get_bm25_corpus_index(corpus="sigir"): """Get BM25 corpus index for the specified corpus.""" corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") if os.path.exists(corpus_path): corpus_data = json.load(open(corpus_path)) tokenized_corpus = corpus_data["tokenized_corpus"] corpus_nctids = corpus_data["corpus_nctids"] else: return None, None bm25 = BM25Okapi(tokenized_corpus) return bm25, corpus_nctids def get_medcpt_corpus_index_gpu(corpus="sigir", use_gpu=True): """Get GPU-accelerated MedCPT corpus index.""" corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" if os.path.exists(corpus_path): embeds = np.load(corpus_path).astype(np.float32) # Ensure float32 for GPU corpus_nctids = json.load(open(nctids_path)) else: return None, None # Use GPU FAISS index if available and requested if use_gpu and torch.cuda.is_available(): try: # Create GPU resources res = faiss.StandardGpuResources() # Create CPU index first cpu_index = faiss.IndexFlatIP(768) cpu_index.add(embeds) # Move to GPU gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index) print(f"FAISS index moved to GPU with {len(corpus_nctids)} embeddings") return gpu_index, corpus_nctids except Exception as e: print(f"GPU FAISS failed, falling back to CPU: {e}") # Fall back to CPU index index = faiss.IndexFlatIP(768) index.add(embeds) return index, corpus_nctids else: # CPU index index = faiss.IndexFlatIP(768) index.add(embeds) return index, corpus_nctids def retrieve_trials_gpu(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1, use_gpu=True, batch_size=32, retriever=None): """ GPU-accelerated clinical trial retrieval with optimized batching. Args: conditions (list): List of condition strings to search for corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir") top_k (int): Number of top trials to return bm25_weight (int): Weight for BM25 scores medcpt_weight (int): Weight for MedCPT scores use_gpu (bool): Whether to use GPU acceleration batch_size (int): Batch size for MedCPT encoding retriever (GPUTrialRetriever): Cached retriever instance Returns: list: List of NCT IDs for the top matching trials """ if len(conditions) == 0: return [] # Get the retrieval indices bm25, bm25_nctids = get_bm25_corpus_index(corpus) medcpt, medcpt_nctids = get_medcpt_corpus_index_gpu(corpus, use_gpu) if bm25 is None or medcpt is None: print(f"Warning: Pre-built indices for corpus '{corpus}' not found.") print("Please run the hybrid_fusion_retrieval.py script first to build the indices.") return [] # Initialize retriever if not provided if retriever is None: device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' retriever = GPUTrialRetriever(device) # BM25 retrieval (CPU-bound, hard to optimize further) bm25_condition_top_nctids = [] if bm25_weight > 0: for condition in conditions: tokens = word_tokenize(condition.lower()) top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) bm25_condition_top_nctids.append(top_nctids) # MedCPT retrieval with GPU acceleration medcpt_condition_top_nctids = [] if medcpt_weight > 0: try: retriever._load_medcpt_model() # Process conditions in batches for better GPU utilization all_embeds = [] for i in range(0, len(conditions), batch_size): batch_conditions = conditions[i:i + batch_size] with torch.no_grad(): # Tokenize batch encoded = retriever.tokenizer( batch_conditions, truncation=True, padding=True, return_tensors='pt', max_length=256, ) # Move tensors to GPU if retriever.device != 'cpu': encoded = {k: v.to(retriever.device) for k, v in encoded.items()} # Get embeddings batch_embeds = retriever.model(**encoded).last_hidden_state[:, 0, :] # Always move to CPU before appending (fixes the CUDA tensor issue) batch_embeds = batch_embeds.cpu() all_embeds.append(batch_embeds) # Concatenate all embeddings and convert to numpy embeds = torch.cat(all_embeds, dim=0).numpy() # Search the FAISS index scores, inds = medcpt.search(embeds, k=top_k*2) # Convert indices to NCT IDs for ind_list in inds: top_nctids = [medcpt_nctids[ind] for ind in ind_list] medcpt_condition_top_nctids.append(top_nctids) except Exception as e: print(f"Warning: MedCPT retrieval failed: {e}") medcpt_weight = 0 # Fusion of results (same as original) nctid2score = {} for condition_idx, condition in enumerate(conditions): # BM25 scoring if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids): for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]): if nctid not in nctid2score: nctid2score[nctid] = 0 nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) # MedCPT scoring if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids): for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]): if nctid not in nctid2score: nctid2score[nctid] = 0 nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) # Sort by score and return top_k nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) top_nctids = [nctid for nctid, _ in nctid2score[:top_k]] return top_nctids # Convenience function for backward compatibility def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1): """Original interface with GPU acceleration.""" return retrieve_trials_gpu(conditions, corpus, top_k, bm25_weight, medcpt_weight) # Example usage with persistent retriever for multiple calls def create_retriever_session(use_gpu=True): """Create a retriever session for multiple queries to avoid model reloading.""" device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' return GPUTrialRetriever(device)