__author__ = "qiao" """ Conduct the first stage retrieval by the hybrid retriever """ from beir.datasets.data_loader import GenericDataLoader import faiss import json from nltk import word_tokenize import numpy as np import os from rank_bm25 import BM25Okapi import sys import tqdm import torch from transformers import AutoTokenizer, AutoModel from beir import util, LoggingHandler # Device detection - use CUDA if available, otherwise CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") def get_bm25_corpus_index(corpus): corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") # if already cached then load, otherwise build if os.path.exists(corpus_path): corpus_data = json.load(open(corpus_path)) tokenized_corpus = corpus_data["tokenized_corpus"] corpus_nctids = corpus_data["corpus_nctids"] else: tokenized_corpus = [] corpus_nctids = [] with open(f"dataset/{corpus}/corpus.jsonl", "r") as f: for line in f.readlines(): entry = json.loads(line) corpus_nctids.append(entry["_id"]) # weighting: 3 * title, 2 * condition, 1 * text tokens = word_tokenize(entry["title"].lower()) * 3 for disease in entry["metadata"]["diseases_list"]: tokens += word_tokenize(disease.lower()) * 2 tokens += word_tokenize(entry["text"].lower()) tokenized_corpus.append(tokens) corpus_data = { "tokenized_corpus": tokenized_corpus, "corpus_nctids": corpus_nctids, } with open(corpus_path, "w") as f: json.dump(corpus_data, f, indent=4) bm25 = BM25Okapi(tokenized_corpus) return bm25, corpus_nctids def get_medcpt_corpus_index(corpus): corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" # if already cached then load, otherwise build if os.path.exists(corpus_path): embeds = np.load(corpus_path) corpus_nctids = json.load(open(nctids_path)) else: embeds = [] corpus_nctids = [] model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder").to(device) tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder") with open(f"dataset/{corpus}/corpus.jsonl", "r") as f: print("Encoding the corpus") for line in tqdm.tqdm(f.readlines()): entry = json.loads(line) corpus_nctids.append(entry["_id"]) title = entry["title"] text = entry["text"] with torch.no_grad(): # tokenize the articles encoded = tokenizer( [[title, text]], truncation=True, padding=True, return_tensors='pt', max_length=512, ).to(device) embed = model(**encoded).last_hidden_state[:, 0, :] embeds.append(embed[0].cpu().numpy()) embeds = np.array(embeds) np.save(corpus_path, embeds) with open(nctids_path, "w") as f: json.dump(corpus_nctids, f, indent=4) index = faiss.IndexFlatIP(768) index.add(embeds) return index, corpus_nctids if __name__ == "__main__": # different corpora, "trec_2021", "trec_2022", "sigir" corpus = sys.argv[1] # query type q_type = sys.argv[2] # different k for fusion k = int(sys.argv[3]) # bm25 weight bm25_wt = int(sys.argv[4]) # medcpt weight medcpt_wt = int(sys.argv[5]) # how many to rank N = 2000 # loading the qrels _, _, qrels = GenericDataLoader(data_folder=f"dataset/{corpus}/").load(split="test") # loading all types of queries id2queries = json.load(open(f"dataset/{corpus}/id2queries.json")) # loading the indices bm25, bm25_nctids = get_bm25_corpus_index(corpus) medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus) # loading the query encoder for MedCPT model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(device) tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") # then conduct the searches, saving top 1k output_path = f"results/qid2nctids_results_{q_type}_{corpus}_k{k}_bm25wt{bm25_wt}_medcptwt{medcpt_wt}_N{N}.json" qid2nctids = {} recalls = [] with open(f"dataset/{corpus}/queries.jsonl", "r") as f: for line in tqdm.tqdm(f.readlines()): entry = json.loads(line) query = entry["text"] qid = entry["_id"] if qid not in qrels: continue truth_sum = sum(qrels[qid].values()) # get the keyword list if q_type in ["raw", "human_summary"]: conditions = [id2queries[qid][q_type]] elif "turbo" in q_type: conditions = id2queries[qid][q_type]["conditions"] elif "Clinician" in q_type: conditions = id2queries[qid].get(q_type, []) if len(conditions) == 0: nctid2score = {} else: # a list of nctid lists for the bm25 retriever bm25_condition_top_nctids = [] for condition in conditions: tokens = word_tokenize(condition.lower()) top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=N) bm25_condition_top_nctids.append(top_nctids) # doing MedCPT retrieval with torch.no_grad(): encoded = tokenizer( conditions, truncation=True, padding=True, return_tensors='pt', max_length=256, ).to(device) # encode the queries (use the [CLS] last hidden states as the representations) embeds = model(**encoded).last_hidden_state[:, 0, :].cpu().numpy() # search the Faiss index scores, inds = medcpt.search(embeds, k=N) medcpt_condition_top_nctids = [] for ind_list in inds: top_nctids = [medcpt_nctids[ind] for ind in ind_list] medcpt_condition_top_nctids.append(top_nctids) nctid2score = {} for condition_idx, (bm25_top_nctids, medcpt_top_nctids) in enumerate(zip(bm25_condition_top_nctids, medcpt_condition_top_nctids)): if bm25_wt > 0: for rank, nctid in enumerate(bm25_top_nctids): if nctid not in nctid2score: nctid2score[nctid] = 0 nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1)) if medcpt_wt > 0: for rank, nctid in enumerate(medcpt_top_nctids): if nctid not in nctid2score: nctid2score[nctid] = 0 nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1)) nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) top_nctids = [nctid for nctid, _ in nctid2score[:N]] qid2nctids[qid] = top_nctids actual_sum = sum([qrels[qid].get(nctid, 0) for nctid in top_nctids]) recalls.append(actual_sum / truth_sum) with open(output_path, "w") as f: json.dump(qid2nctids, f, indent=4)