File size: 13,627 Bytes
50e583f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# """
# Retriever module for TrialGPT
# """

# import json
# import os
# import numpy as np
# import torch
# from nltk import word_tokenize
# from rank_bm25 import BM25Okapi
# import faiss
# from transformers import AutoTokenizer, AutoModel

# def get_bm25_corpus_index(corpus="sigir"):
#     """Get BM25 corpus index for the specified corpus."""
#     corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")

#     if os.path.exists(corpus_path):
#         corpus_data = json.load(open(corpus_path))
#         tokenized_corpus = corpus_data["tokenized_corpus"]
#         corpus_nctids = corpus_data["corpus_nctids"]
#     else:
#         # If the pre-built index doesn't exist, we'll need to build it
#         # For now, return None to indicate the index needs to be built
#         return None, None
    
#     bm25 = BM25Okapi(tokenized_corpus)
#     return bm25, corpus_nctids

# def get_medcpt_corpus_index(corpus="sigir"):
#     """Get MedCPT corpus index for the specified corpus."""
#     corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" 
#     nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"

#     if os.path.exists(corpus_path):
#         embeds = np.load(corpus_path)
#         corpus_nctids = json.load(open(nctids_path))
#     else:
#         # If the pre-built index doesn't exist, return None
#         return None, None

#     index = faiss.IndexFlatIP(768)
#     index.add(embeds)
    
#     return index, corpus_nctids

# def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1):
#     """
#     Retrieve clinical trials based on conditions using hybrid BM25 + MedCPT retrieval.
    
#     Args:
#         conditions (list): List of condition strings to search for
#         corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir")
#         top_k (int): Number of top trials to return
#         bm25_weight (int): Weight for BM25 scores
#         medcpt_weight (int): Weight for MedCPT scores
    
#     Returns:
#         list: List of NCT IDs for the top matching trials
#     """
#     # Get the retrieval indices
#     bm25, bm25_nctids = get_bm25_corpus_index(corpus)
#     medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus)
    
#     if bm25 is None or medcpt is None:
#         print(f"Warning: Pre-built indices for corpus '{corpus}' not found.")
#         print("Please run the hybrid_fusion_retrieval.py script first to build the indices.")
#         return []
    
#     if len(conditions) == 0:
#         return []
    
#     # BM25 retrieval
#     bm25_condition_top_nctids = []
#     if bm25_weight > 0:
#         for condition in conditions:
#             tokens = word_tokenize(condition.lower())
#             top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2)  # Get more for fusion
#             bm25_condition_top_nctids.append(top_nctids)
    
#     # MedCPT retrieval
#     medcpt_condition_top_nctids = []
#     if medcpt_weight > 0:
#         try:
#             model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
#             tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
            
#             with torch.no_grad():
#                 encoded = tokenizer(
#                     conditions, 
#                     truncation=True, 
#                     padding=True, 
#                     return_tensors='pt', 
#                     max_length=256,
#                 )

#                 # encode the queries
#                 embeds = model(**encoded).last_hidden_state[:, 0, :].numpy()

#                 # search the Faiss index
#                 scores, inds = medcpt.search(embeds, k=top_k*2)  # Get more for fusion

#             for ind_list in inds:
#                 top_nctids = [medcpt_nctids[ind] for ind in ind_list]
#                 medcpt_condition_top_nctids.append(top_nctids)
                
#         except Exception as e:
#             print(f"Warning: MedCPT retrieval failed: {e}")
#             medcpt_weight = 0
    
#     # Fusion of results
#     nctid2score = {}
    
#     for condition_idx, condition in enumerate(conditions):
#         # BM25 scoring
#         if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids):
#             for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]):
#                 if nctid not in nctid2score:
#                     nctid2score[nctid] = 0
#                 nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
        
#         # MedCPT scoring
#         if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids):
#             for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]):
#                 if nctid not in nctid2score:
#                     nctid2score[nctid] = 0
#                 nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
    
#     # Sort by score and return top_k
#     nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
#     top_nctids = [nctid for nctid, _ in nctid2score[:top_k]]
    
#     return top_nctids 

"""
GPU-accelerated Retriever module for TrialGPT
"""

import json
import os
import numpy as np
import torch
from nltk import word_tokenize
from rank_bm25 import BM25Okapi
import faiss
from transformers import AutoTokenizer, AutoModel

class GPUTrialRetriever:
    """GPU-accelerated trial retriever with model caching."""
    
    def __init__(self, device=None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.tokenizer = None
        self._model_loaded = False
        
    def _load_medcpt_model(self):
        """Load MedCPT model once and cache it."""
        if not self._model_loaded:
            self.tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
            self.model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
            
            # Move model to GPU if available
            if self.device != 'cpu':
                self.model = self.model.to(self.device)
            
            # Set to evaluation mode for inference
            self.model.eval()
            self._model_loaded = True
            print(f"MedCPT model loaded on {self.device}")

def get_bm25_corpus_index(corpus="sigir"):
    """Get BM25 corpus index for the specified corpus."""
    corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")

    if os.path.exists(corpus_path):
        corpus_data = json.load(open(corpus_path))
        tokenized_corpus = corpus_data["tokenized_corpus"]
        corpus_nctids = corpus_data["corpus_nctids"]
    else:
        return None, None
    
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25, corpus_nctids

def get_medcpt_corpus_index_gpu(corpus="sigir", use_gpu=True):
    """Get GPU-accelerated MedCPT corpus index."""
    corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" 
    nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"

    if os.path.exists(corpus_path):
        embeds = np.load(corpus_path).astype(np.float32)  # Ensure float32 for GPU
        corpus_nctids = json.load(open(nctids_path))
    else:
        return None, None

    # Use GPU FAISS index if available and requested
    if use_gpu and torch.cuda.is_available():
        try:
            # Create GPU resources
            res = faiss.StandardGpuResources()
            
            # Create CPU index first
            cpu_index = faiss.IndexFlatIP(768)
            cpu_index.add(embeds)
            
            # Move to GPU
            gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
            print(f"FAISS index moved to GPU with {len(corpus_nctids)} embeddings")
            return gpu_index, corpus_nctids
            
        except Exception as e:
            print(f"GPU FAISS failed, falling back to CPU: {e}")
            # Fall back to CPU index
            index = faiss.IndexFlatIP(768)
            index.add(embeds)
            return index, corpus_nctids
    else:
        # CPU index
        index = faiss.IndexFlatIP(768)
        index.add(embeds)
        return index, corpus_nctids

def retrieve_trials_gpu(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1, 
                       use_gpu=True, batch_size=32, retriever=None):
    """
    GPU-accelerated clinical trial retrieval with optimized batching.
    
    Args:
        conditions (list): List of condition strings to search for
        corpus (str): Corpus to search in ("trec_2021", "trec_2022", "sigir")
        top_k (int): Number of top trials to return
        bm25_weight (int): Weight for BM25 scores
        medcpt_weight (int): Weight for MedCPT scores
        use_gpu (bool): Whether to use GPU acceleration
        batch_size (int): Batch size for MedCPT encoding
        retriever (GPUTrialRetriever): Cached retriever instance
    
    Returns:
        list: List of NCT IDs for the top matching trials
    """
    if len(conditions) == 0:
        return []
    
    # Get the retrieval indices
    bm25, bm25_nctids = get_bm25_corpus_index(corpus)
    medcpt, medcpt_nctids = get_medcpt_corpus_index_gpu(corpus, use_gpu)
    
    if bm25 is None or medcpt is None:
        print(f"Warning: Pre-built indices for corpus '{corpus}' not found.")
        print("Please run the hybrid_fusion_retrieval.py script first to build the indices.")
        return []
    
    # Initialize retriever if not provided
    if retriever is None:
        device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
        retriever = GPUTrialRetriever(device)
    
    # BM25 retrieval (CPU-bound, hard to optimize further)
    bm25_condition_top_nctids = []
    if bm25_weight > 0:
        for condition in conditions:
            tokens = word_tokenize(condition.lower())
            top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=top_k*2)
            bm25_condition_top_nctids.append(top_nctids)
    
    # MedCPT retrieval with GPU acceleration
    medcpt_condition_top_nctids = []
    if medcpt_weight > 0:
        try:
            retriever._load_medcpt_model()
            
            # Process conditions in batches for better GPU utilization
            all_embeds = []
            
            for i in range(0, len(conditions), batch_size):
                batch_conditions = conditions[i:i + batch_size]
                
                with torch.no_grad():
                    # Tokenize batch
                    encoded = retriever.tokenizer(
                        batch_conditions,
                        truncation=True,
                        padding=True,
                        return_tensors='pt',
                        max_length=256,
                    )
                    
                    # Move tensors to GPU
                    if retriever.device != 'cpu':
                        encoded = {k: v.to(retriever.device) for k, v in encoded.items()}
                    
                    # Get embeddings
                    batch_embeds = retriever.model(**encoded).last_hidden_state[:, 0, :]
                    
                    # Always move to CPU before appending (fixes the CUDA tensor issue)
                    batch_embeds = batch_embeds.cpu()
                    all_embeds.append(batch_embeds)
            
            # Concatenate all embeddings and convert to numpy
            embeds = torch.cat(all_embeds, dim=0).numpy()
            
            # Search the FAISS index
            scores, inds = medcpt.search(embeds, k=top_k*2)
            
            # Convert indices to NCT IDs
            for ind_list in inds:
                top_nctids = [medcpt_nctids[ind] for ind in ind_list]
                medcpt_condition_top_nctids.append(top_nctids)
                
        except Exception as e:
            print(f"Warning: MedCPT retrieval failed: {e}")
            medcpt_weight = 0
    
    # Fusion of results (same as original)
    nctid2score = {}
    
    for condition_idx, condition in enumerate(conditions):
        # BM25 scoring
        if bm25_weight > 0 and condition_idx < len(bm25_condition_top_nctids):
            for rank, nctid in enumerate(bm25_condition_top_nctids[condition_idx]):
                if nctid not in nctid2score:
                    nctid2score[nctid] = 0
                nctid2score[nctid] += bm25_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
        
        # MedCPT scoring
        if medcpt_weight > 0 and condition_idx < len(medcpt_condition_top_nctids):
            for rank, nctid in enumerate(medcpt_condition_top_nctids[condition_idx]):
                if nctid not in nctid2score:
                    nctid2score[nctid] = 0
                nctid2score[nctid] += medcpt_weight * (1 / (rank + 1)) * (1 / (condition_idx + 1))
    
    # Sort by score and return top_k
    nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
    top_nctids = [nctid for nctid, _ in nctid2score[:top_k]]
    
    return top_nctids

# Convenience function for backward compatibility
def retrieve_trials(conditions, corpus="sigir", top_k=5, bm25_weight=1, medcpt_weight=1):
    """Original interface with GPU acceleration."""
    return retrieve_trials_gpu(conditions, corpus, top_k, bm25_weight, medcpt_weight)

# Example usage with persistent retriever for multiple calls
def create_retriever_session(use_gpu=True):
    """Create a retriever session for multiple queries to avoid model reloading."""
    device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
    return GPUTrialRetriever(device)