Salma Hassan
files
50e583f
# """
# Retriever module for TrialGPT
# """
# import json
# import os
# import numpy as np
# import torch
# from nltk import word_tokenize
# from rank_bm25 import BM25Okapi
# import faiss
# from transformers import AutoTokenizer, AutoModel
# def get_bm25_corpus_index(corpus="sigir"):
# """Get BM25 corpus index for the specified corpus."""
# corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")
# if os.path.exists(corpus_path):
# corpus_data = json.load(open(corpus_path))
# tokenized_corpus = corpus_data["tokenized_corpus"]
# corpus_nctids = corpus_data["corpus_nctids"]
# else:
# # If the pre-built index doesn't exist, we'll need to build it
# # For now, return None to indicate the index needs to be built
# return None, None
# bm25 = BM25Okapi(tokenized_corpus)
# return bm25, corpus_nctids
# def get_medcpt_corpus_index(corpus="sigir"):
# """Get MedCPT corpus index for the specified corpus."""
# corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy"
# nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"
# if os.path.exists(corpus_path):
# embeds = np.load(corpus_path)
# corpus_nctids = json.load(open(nctids_path))
# else:
# # If the pre-built index doesn't exist, return None
# return None, None
# index = faiss.IndexFlatIP(768)
# index.add(embeds)
# return index, corpus_nctids
# def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1):
# """
# Retrieve clinical trials based on conditions using hybrid BM25 + MedCPT retrieval.
# Args:
# conditions (list): List of condition strings to search for
# corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir")
# top_k (int): Number of top trials to return
# bm25_weight (int): Weight for BM25 scores
# medcpt_weight (int): Weight for MedCPT scores
# Returns:
# list: List of NCT IDs for the top matching trials
# """
# # Get the retrieval indices
# bm25, bm25_nctids = get_bm25_corpus_index(corpus)
# medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus)
# if bm25 is None or medcpt is None:
# print(f"Warning: Pre-built indices for corpus '{corpus}' not found.")
# print("Please run the hybrid_fusion_retrieval.py script first to build the indices.")
# return []
# if len(conditions) == 0:
# return []
# # BM25 retrieval
# bm25_condition_top_nctids = []
# if bm25_weight > 0:
# for condition in conditions:
# tokens = word_tokenize(condition.lower())
# top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2) # Get more for fusion
# bm25_condition_top_nctids.append(top_nctids)
# # MedCPT retrieval
# medcpt_condition_top_nctids = []
# if medcpt_weight > 0:
# try:
# model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
# tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
# with torch.no_grad():
# encoded = tokenizer(
# conditions,
# truncation=True,
# padding=True,
# return_tensors='pt',
# max_length=256,
# )
# # encode the queries
# embeds = model(**encoded).last_hidden_state[:, 0, :].numpy()
# # search the Faiss index
# scores, inds = medcpt.search(embeds, k=top_k*2) # Get more for fusion
# for ind_list in inds:
# top_nctids = [medcpt_nctids[ind] for ind in ind_list]
# medcpt_condition_top_nctids.append(top_nctids)
# except Exception as e:
# print(f"Warning: MedCPT retrieval failed: {e}")
# medcpt_weight = 0
# # Fusion of results
# nctid2score = {}
# for condition_idx, condition in enumerate(conditions):
# # BM25 scoring
# if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids):
# for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]):
# if nctid not in nctid2score:
# nctid2score[nctid] = 0
# nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# # MedCPT scoring
# if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids):
# for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]):
# if nctid not in nctid2score:
# nctid2score[nctid] = 0
# nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# # Sort by score and return top_k
# nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
# top_nctids = [nctid for nctid, _ in nctid2score[:top_k]]
# return top_nctids
"""
GPU-accelerated Retriever module for TrialGPT
"""
import json
import os
import numpy as np
import torch
from nltk import word_tokenize
from rank_bm25 import BM25Okapi
import faiss
from transformers import AutoTokenizer, AutoModel
class GPUTrialRetriever:
"""GPU-accelerated trial retriever with model caching."""
def __init__(self, device=None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.tokenizer = None
self._model_loaded = False
def _load_medcpt_model(self):
"""Load MedCPT model once and cache it."""
if not self._model_loaded:
self.tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
self.model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
# Move model to GPU if available
if self.device != 'cpu':
self.model = self.model.to(self.device)
# Set to evaluation mode for inference
self.model.eval()
self._model_loaded = True
print(f"MedCPT model loaded on {self.device}")
def get_bm25_corpus_index(corpus="sigir"):
"""Get BM25 corpus index for the specified corpus."""
corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")
if os.path.exists(corpus_path):
corpus_data = json.load(open(corpus_path))
tokenized_corpus = corpus_data["tokenized_corpus"]
corpus_nctids = corpus_data["corpus_nctids"]
else:
return None, None
bm25 = BM25Okapi(tokenized_corpus)
return bm25, corpus_nctids
def get_medcpt_corpus_index_gpu(corpus="sigir", use_gpu=True):
"""Get GPU-accelerated MedCPT corpus index."""
corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy"
nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"
if os.path.exists(corpus_path):
embeds = np.load(corpus_path).astype(np.float32) # Ensure float32 for GPU
corpus_nctids = json.load(open(nctids_path))
else:
return None, None
# Use GPU FAISS index if available and requested
if use_gpu and torch.cuda.is_available():
try:
# Create GPU resources
res = faiss.StandardGpuResources()
# Create CPU index first
cpu_index = faiss.IndexFlatIP(768)
cpu_index.add(embeds)
# Move to GPU
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
print(f"FAISS index moved to GPU with {len(corpus_nctids)} embeddings")
return gpu_index, corpus_nctids
except Exception as e:
print(f"GPU FAISS failed, falling back to CPU: {e}")
# Fall back to CPU index
index = faiss.IndexFlatIP(768)
index.add(embeds)
return index, corpus_nctids
else:
# CPU index
index = faiss.IndexFlatIP(768)
index.add(embeds)
return index, corpus_nctids
def retrieve_trials_gpu(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1,
use_gpu=True, batch_size=32, retriever=None):
"""
GPU-accelerated clinical trial retrieval with optimized batching.
Args:
conditions (list): List of condition strings to search for
corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir")
top_k (int): Number of top trials to return
bm25_weight (int): Weight for BM25 scores
medcpt_weight (int): Weight for MedCPT scores
use_gpu (bool): Whether to use GPU acceleration
batch_size (int): Batch size for MedCPT encoding
retriever (GPUTrialRetriever): Cached retriever instance
Returns:
list: List of NCT IDs for the top matching trials
"""
if len(conditions) == 0:
return []
# Get the retrieval indices
bm25, bm25_nctids = get_bm25_corpus_index(corpus)
medcpt, medcpt_nctids = get_medcpt_corpus_index_gpu(corpus, use_gpu)
if bm25 is None or medcpt is None:
print(f"Warning: Pre-built indices for corpus '{corpus}' not found.")
print("Please run the hybrid_fusion_retrieval.py script first to build the indices.")
return []
# Initialize retriever if not provided
if retriever is None:
device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
retriever = GPUTrialRetriever(device)
# BM25 retrieval (CPU-bound, hard to optimize further)
bm25_condition_top_nctids = []
if bm25_weight > 0:
for condition in conditions:
tokens = word_tokenize(condition.lower())
top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2)
bm25_condition_top_nctids.append(top_nctids)
# MedCPT retrieval with GPU acceleration
medcpt_condition_top_nctids = []
if medcpt_weight > 0:
try:
retriever._load_medcpt_model()
# Process conditions in batches for better GPU utilization
all_embeds = []
for i in range(0, len(conditions), batch_size):
batch_conditions = conditions[i:i + batch_size]
with torch.no_grad():
# Tokenize batch
encoded = retriever.tokenizer(
batch_conditions,
truncation=True,
padding=True,
return_tensors='pt',
max_length=256,
)
# Move tensors to GPU
if retriever.device != 'cpu':
encoded = {k: v.to(retriever.device) for k, v in encoded.items()}
# Get embeddings
batch_embeds = retriever.model(**encoded).last_hidden_state[:, 0, :]
# Always move to CPU before appending (fixes the CUDA tensor issue)
batch_embeds = batch_embeds.cpu()
all_embeds.append(batch_embeds)
# Concatenate all embeddings and convert to numpy
embeds = torch.cat(all_embeds, dim=0).numpy()
# Search the FAISS index
scores, inds = medcpt.search(embeds, k=top_k*2)
# Convert indices to NCT IDs
for ind_list in inds:
top_nctids = [medcpt_nctids[ind] for ind in ind_list]
medcpt_condition_top_nctids.append(top_nctids)
except Exception as e:
print(f"Warning: MedCPT retrieval failed: {e}")
medcpt_weight = 0
# Fusion of results (same as original)
nctid2score = {}
for condition_idx, condition in enumerate(conditions):
# BM25 scoring
if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids):
for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]):
if nctid not in nctid2score:
nctid2score[nctid] = 0
nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# MedCPT scoring
if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids):
for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]):
if nctid not in nctid2score:
nctid2score[nctid] = 0
nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
# Sort by score and return top_k
nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
top_nctids = [nctid for nctid, _ in nctid2score[:top_k]]
return top_nctids
# Convenience function for backward compatibility
def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1):
"""Original interface with GPU acceleration."""
return retrieve_trials_gpu(conditions, corpus, top_k, bm25_weight, medcpt_weight)
# Example usage with persistent retriever for multiple calls
def create_retriever_session(use_gpu=True):
"""Create a retriever session for multiple queries to avoid model reloading."""
device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
return GPUTrialRetriever(device)