Spaces:
Sleeping
Sleeping
# """ | |
# Retriever module for TrialGPT | |
# """ | |
# import json | |
# import os | |
# import numpy as np | |
# import torch | |
# from nltk import word_tokenize | |
# from rank_bm25 import BM25Okapi | |
# import faiss | |
# from transformers import AutoTokenizer, AutoModel | |
# def get_bm25_corpus_index(corpus="sigir"): | |
# """Get BM25 corpus index for the specified corpus.""" | |
# corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") | |
# if os.path.exists(corpus_path): | |
# corpus_data = json.load(open(corpus_path)) | |
# tokenized_corpus = corpus_data["tokenized_corpus"] | |
# corpus_nctids = corpus_data["corpus_nctids"] | |
# else: | |
# # If the pre-built index doesn't exist, we'll need to build it | |
# # For now, return None to indicate the index needs to be built | |
# return None, None | |
# bm25 = BM25Okapi(tokenized_corpus) | |
# return bm25, corpus_nctids | |
# def get_medcpt_corpus_index(corpus="sigir"): | |
# """Get MedCPT corpus index for the specified corpus.""" | |
# corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" | |
# nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" | |
# if os.path.exists(corpus_path): | |
# embeds = np.load(corpus_path) | |
# corpus_nctids = json.load(open(nctids_path)) | |
# else: | |
# # If the pre-built index doesn't exist, return None | |
# return None, None | |
# index = faiss.IndexFlatIP(768) | |
# index.add(embeds) | |
# return index, corpus_nctids | |
# def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1): | |
# """ | |
# Retrieve clinical trials based on conditions using hybrid BM25 + MedCPT retrieval. | |
# Args: | |
# conditions (list): List of condition strings to search for | |
# corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir") | |
# top_k (int): Number of top trials to return | |
# bm25_weight (int): Weight for BM25 scores | |
# medcpt_weight (int): Weight for MedCPT scores | |
# Returns: | |
# list: List of NCT IDs for the top matching trials | |
# """ | |
# # Get the retrieval indices | |
# bm25, bm25_nctids = get_bm25_corpus_index(corpus) | |
# medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus) | |
# if bm25 is None or medcpt is None: | |
# print(f"Warning: Pre-built indices for corpus '{corpus}' not found.") | |
# print("Please run the hybrid_fusion_retrieval.py script first to build the indices.") | |
# return [] | |
# if len(conditions) == 0: | |
# return [] | |
# # BM25 retrieval | |
# bm25_condition_top_nctids = [] | |
# if bm25_weight > 0: | |
# for condition in conditions: | |
# tokens = word_tokenize(condition.lower()) | |
# top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) # Get more for fusion | |
# bm25_condition_top_nctids.append(top_nctids) | |
# # MedCPT retrieval | |
# medcpt_condition_top_nctids = [] | |
# if medcpt_weight > 0: | |
# try: | |
# model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
# tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
# with torch.no_grad(): | |
# encoded = tokenizer( | |
# conditions, | |
# truncation=True, | |
# padding=True, | |
# return_tensors='pt', | |
# max_length=256, | |
# ) | |
# # encode the queries | |
# embeds = model(**encoded).last_hidden_state[:, 0, :].numpy() | |
# # search the Faiss index | |
# scores, inds = medcpt.search(embeds, k=top_k*2) # Get more for fusion | |
# for ind_list in inds: | |
# top_nctids = [medcpt_nctids[ind] for ind in ind_list] | |
# medcpt_condition_top_nctids.append(top_nctids) | |
# except Exception as e: | |
# print(f"Warning: MedCPT retrieval failed: {e}") | |
# medcpt_weight = 0 | |
# # Fusion of results | |
# nctid2score = {} | |
# for condition_idx, condition in enumerate(conditions): | |
# # BM25 scoring | |
# if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids): | |
# for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]): | |
# if nctid not in nctid2score: | |
# nctid2score[nctid] = 0 | |
# nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
# # MedCPT scoring | |
# if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids): | |
# for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]): | |
# if nctid not in nctid2score: | |
# nctid2score[nctid] = 0 | |
# nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
# # Sort by score and return top_k | |
# nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) | |
# top_nctids = [nctid for nctid, _ in nctid2score[:top_k]] | |
# return top_nctids | |
""" | |
GPU-accelerated Retriever module for TrialGPT | |
""" | |
import json | |
import os | |
import numpy as np | |
import torch | |
from nltk import word_tokenize | |
from rank_bm25 import BM25Okapi | |
import faiss | |
from transformers import AutoTokenizer, AutoModel | |
class GPUTrialRetriever: | |
"""GPU-accelerated trial retriever with model caching.""" | |
def __init__(self, device=None): | |
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.model = None | |
self.tokenizer = None | |
self._model_loaded = False | |
def _load_medcpt_model(self): | |
"""Load MedCPT model once and cache it.""" | |
if not self._model_loaded: | |
self.tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
self.model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
# Move model to GPU if available | |
if self.device != 'cpu': | |
self.model = self.model.to(self.device) | |
# Set to evaluation mode for inference | |
self.model.eval() | |
self._model_loaded = True | |
print(f"MedCPT model loaded on {self.device}") | |
def get_bm25_corpus_index(corpus="sigir"): | |
"""Get BM25 corpus index for the specified corpus.""" | |
corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") | |
if os.path.exists(corpus_path): | |
corpus_data = json.load(open(corpus_path)) | |
tokenized_corpus = corpus_data["tokenized_corpus"] | |
corpus_nctids = corpus_data["corpus_nctids"] | |
else: | |
return None, None | |
bm25 = BM25Okapi(tokenized_corpus) | |
return bm25, corpus_nctids | |
def get_medcpt_corpus_index_gpu(corpus="sigir", use_gpu=True): | |
"""Get GPU-accelerated MedCPT corpus index.""" | |
corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" | |
nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" | |
if os.path.exists(corpus_path): | |
embeds = np.load(corpus_path).astype(np.float32) # Ensure float32 for GPU | |
corpus_nctids = json.load(open(nctids_path)) | |
else: | |
return None, None | |
# Use GPU FAISS index if available and requested | |
if use_gpu and torch.cuda.is_available(): | |
try: | |
# Create GPU resources | |
res = faiss.StandardGpuResources() | |
# Create CPU index first | |
cpu_index = faiss.IndexFlatIP(768) | |
cpu_index.add(embeds) | |
# Move to GPU | |
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index) | |
print(f"FAISS index moved to GPU with {len(corpus_nctids)} embeddings") | |
return gpu_index, corpus_nctids | |
except Exception as e: | |
print(f"GPU FAISS failed, falling back to CPU: {e}") | |
# Fall back to CPU index | |
index = faiss.IndexFlatIP(768) | |
index.add(embeds) | |
return index, corpus_nctids | |
else: | |
# CPU index | |
index = faiss.IndexFlatIP(768) | |
index.add(embeds) | |
return index, corpus_nctids | |
def retrieve_trials_gpu(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1, | |
use_gpu=True, batch_size=32, retriever=None): | |
""" | |
GPU-accelerated clinical trial retrieval with optimized batching. | |
Args: | |
conditions (list): List of condition strings to search for | |
corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir") | |
top_k (int): Number of top trials to return | |
bm25_weight (int): Weight for BM25 scores | |
medcpt_weight (int): Weight for MedCPT scores | |
use_gpu (bool): Whether to use GPU acceleration | |
batch_size (int): Batch size for MedCPT encoding | |
retriever (GPUTrialRetriever): Cached retriever instance | |
Returns: | |
list: List of NCT IDs for the top matching trials | |
""" | |
if len(conditions) == 0: | |
return [] | |
# Get the retrieval indices | |
bm25, bm25_nctids = get_bm25_corpus_index(corpus) | |
medcpt, medcpt_nctids = get_medcpt_corpus_index_gpu(corpus, use_gpu) | |
if bm25 is None or medcpt is None: | |
print(f"Warning: Pre-built indices for corpus '{corpus}' not found.") | |
print("Please run the hybrid_fusion_retrieval.py script first to build the indices.") | |
return [] | |
# Initialize retriever if not provided | |
if retriever is None: | |
device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' | |
retriever = GPUTrialRetriever(device) | |
# BM25 retrieval (CPU-bound, hard to optimize further) | |
bm25_condition_top_nctids = [] | |
if bm25_weight > 0: | |
for condition in conditions: | |
tokens = word_tokenize(condition.lower()) | |
top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) | |
bm25_condition_top_nctids.append(top_nctids) | |
# MedCPT retrieval with GPU acceleration | |
medcpt_condition_top_nctids = [] | |
if medcpt_weight > 0: | |
try: | |
retriever._load_medcpt_model() | |
# Process conditions in batches for better GPU utilization | |
all_embeds = [] | |
for i in range(0, len(conditions), batch_size): | |
batch_conditions = conditions[i:i + batch_size] | |
with torch.no_grad(): | |
# Tokenize batch | |
encoded = retriever.tokenizer( | |
batch_conditions, | |
truncation=True, | |
padding=True, | |
return_tensors='pt', | |
max_length=256, | |
) | |
# Move tensors to GPU | |
if retriever.device != 'cpu': | |
encoded = {k: v.to(retriever.device) for k, v in encoded.items()} | |
# Get embeddings | |
batch_embeds = retriever.model(**encoded).last_hidden_state[:, 0, :] | |
# Always move to CPU before appending (fixes the CUDA tensor issue) | |
batch_embeds = batch_embeds.cpu() | |
all_embeds.append(batch_embeds) | |
# Concatenate all embeddings and convert to numpy | |
embeds = torch.cat(all_embeds, dim=0).numpy() | |
# Search the FAISS index | |
scores, inds = medcpt.search(embeds, k=top_k*2) | |
# Convert indices to NCT IDs | |
for ind_list in inds: | |
top_nctids = [medcpt_nctids[ind] for ind in ind_list] | |
medcpt_condition_top_nctids.append(top_nctids) | |
except Exception as e: | |
print(f"Warning: MedCPT retrieval failed: {e}") | |
medcpt_weight = 0 | |
# Fusion of results (same as original) | |
nctid2score = {} | |
for condition_idx, condition in enumerate(conditions): | |
# BM25 scoring | |
if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids): | |
for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]): | |
if nctid not in nctid2score: | |
nctid2score[nctid] = 0 | |
nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
# MedCPT scoring | |
if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids): | |
for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]): | |
if nctid not in nctid2score: | |
nctid2score[nctid] = 0 | |
nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1)) | |
# Sort by score and return top_k | |
nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) | |
top_nctids = [nctid for nctid, _ in nctid2score[:top_k]] | |
return top_nctids | |
# Convenience function for backward compatibility | |
def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1): | |
"""Original interface with GPU acceleration.""" | |
return retrieve_trials_gpu(conditions, corpus, top_k, bm25_weight, medcpt_weight) | |
# Example usage with persistent retriever for multiple calls | |
def create_retriever_session(use_gpu=True): | |
"""Create a retriever session for multiple queries to avoid model reloading.""" | |
device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' | |
return GPUTrialRetriever(device) |