Spaces:
Sleeping
Sleeping
__author__ = "qiao" | |
""" | |
Conduct the first stage retrieval by the hybrid retriever | |
""" | |
from beir.datasets.data_loader import GenericDataLoader | |
import faiss | |
import json | |
from nltk import word_tokenize | |
import numpy as np | |
import os | |
from rank_bm25 import BM25Okapi | |
import sys | |
import tqdm | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from beir import util, LoggingHandler | |
# Device detection - use CUDA if available, otherwise CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
def get_bm25_corpus_index(corpus): | |
corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") | |
# if already cached then load, otherwise build | |
if os.path.exists(corpus_path): | |
corpus_data = json.load(open(corpus_path)) | |
tokenized_corpus = corpus_data["tokenized_corpus"] | |
corpus_nctids = corpus_data["corpus_nctids"] | |
else: | |
tokenized_corpus = [] | |
corpus_nctids = [] | |
with open(f"dataset/{corpus}/corpus.jsonl", "r") as f: | |
for line in f.readlines(): | |
entry = json.loads(line) | |
corpus_nctids.append(entry["_id"]) | |
# weighting: 3 * title, 2 * condition, 1 * text | |
tokens = word_tokenize(entry["title"].lower()) * 3 | |
for disease in entry["metadata"]["diseases_list"]: | |
tokens += word_tokenize(disease.lower()) * 2 | |
tokens += word_tokenize(entry["text"].lower()) | |
tokenized_corpus.append(tokens) | |
corpus_data = { | |
"tokenized_corpus": tokenized_corpus, | |
"corpus_nctids": corpus_nctids, | |
} | |
with open(corpus_path, "w") as f: | |
json.dump(corpus_data, f, indent=4) | |
bm25 = BM25Okapi(tokenized_corpus) | |
return bm25, corpus_nctids | |
def get_medcpt_corpus_index(corpus): | |
corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" | |
nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" | |
# if already cached then load, otherwise build | |
if os.path.exists(corpus_path): | |
embeds = np.load(corpus_path) | |
corpus_nctids = json.load(open(nctids_path)) | |
else: | |
embeds = [] | |
corpus_nctids = [] | |
model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder").to(device) | |
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder") | |
with open(f"dataset/{corpus}/corpus.jsonl", "r") as f: | |
print("Encoding the corpus") | |
for line in tqdm.tqdm(f.readlines()): | |
entry = json.loads(line) | |
corpus_nctids.append(entry["_id"]) | |
title = entry["title"] | |
text = entry["text"] | |
with torch.no_grad(): | |
# tokenize the articles | |
encoded = tokenizer( | |
[[title, text]], | |
truncation=True, | |
padding=True, | |
return_tensors='pt', | |
max_length=512, | |
).to(device) | |
embed = model(**encoded).last_hidden_state[:, 0, :] | |
embeds.append(embed[0].cpu().numpy()) | |
embeds = np.array(embeds) | |
np.save(corpus_path, embeds) | |
with open(nctids_path, "w") as f: | |
json.dump(corpus_nctids, f, indent=4) | |
index = faiss.IndexFlatIP(768) | |
index.add(embeds) | |
return index, corpus_nctids | |
if __name__ == "__main__": | |
# different corpora, "trec_2021", "trec_2022", "sigir" | |
corpus = sys.argv[1] | |
# query type | |
q_type = sys.argv[2] | |
# different k for fusion | |
k = int(sys.argv[3]) | |
# bm25 weight | |
bm25_wt = int(sys.argv[4]) | |
# medcpt weight | |
medcpt_wt = int(sys.argv[5]) | |
# how many to rank | |
N = 2000 | |
# loading the qrels | |
_, _, qrels = GenericDataLoader(data_folder=f"dataset/{corpus}/").load(split="test") | |
# loading all types of queries | |
id2queries = json.load(open(f"dataset/{corpus}/id2queries.json")) | |
# loading the indices | |
bm25, bm25_nctids = get_bm25_corpus_index(corpus) | |
medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus) | |
# loading the query encoder for MedCPT | |
model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(device) | |
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
# then conduct the searches, saving top 1k | |
output_path = f"results/qid2nctids_results_{q_type}_{corpus}_k{k}_bm25wt{bm25_wt}_medcptwt{medcpt_wt}_N{N}.json" | |
qid2nctids = {} | |
recalls = [] | |
with open(f"dataset/{corpus}/queries.jsonl", "r") as f: | |
for line in tqdm.tqdm(f.readlines()): | |
entry = json.loads(line) | |
query = entry["text"] | |
qid = entry["_id"] | |
if qid not in qrels: | |
continue | |
truth_sum = sum(qrels[qid].values()) | |
# get the keyword list | |
if q_type in ["raw", "human_summary"]: | |
conditions = [id2queries[qid][q_type]] | |
elif "turbo" in q_type: | |
conditions = id2queries[qid][q_type]["conditions"] | |
elif "Clinician" in q_type: | |
conditions = id2queries[qid].get(q_type, []) | |
if len(conditions) == 0: | |
nctid2score = {} | |
else: | |
# a list of nctid lists for the bm25 retriever | |
bm25_condition_top_nctids = [] | |
for condition in conditions: | |
tokens = word_tokenize(condition.lower()) | |
top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=N) | |
bm25_condition_top_nctids.append(top_nctids) | |
# doing MedCPT retrieval | |
with torch.no_grad(): | |
encoded = tokenizer( | |
conditions, | |
truncation=True, | |
padding=True, | |
return_tensors='pt', | |
max_length=256, | |
).to(device) | |
# encode the queries (use the [CLS] last hidden states as the representations) | |
embeds = model(**encoded).last_hidden_state[:, 0, :].cpu().numpy() | |
# search the Faiss index | |
scores, inds = medcpt.search(embeds, k=N) | |
medcpt_condition_top_nctids = [] | |
for ind_list in inds: | |
top_nctids = [medcpt_nctids[ind] for ind in ind_list] | |
medcpt_condition_top_nctids.append(top_nctids) | |
nctid2score = {} | |
for condition_idx, (bm25_top_nctids, medcpt_top_nctids) in enumerate(zip(bm25_condition_top_nctids, medcpt_condition_top_nctids)): | |
if bm25_wt > 0: | |
for rank, nctid in enumerate(bm25_top_nctids): | |
if nctid not in nctid2score: | |
nctid2score[nctid] = 0 | |
nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1)) | |
if medcpt_wt > 0: | |
for rank, nctid in enumerate(medcpt_top_nctids): | |
if nctid not in nctid2score: | |
nctid2score[nctid] = 0 | |
nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1)) | |
nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) | |
top_nctids = [nctid for nctid, _ in nctid2score[:N]] | |
qid2nctids[qid] = top_nctids | |
actual_sum = sum([qrels[qid].get(nctid, 0) for nctid in top_nctids]) | |
recalls.append(actual_sum / truth_sum) | |
with open(output_path, "w") as f: | |
json.dump(qid2nctids, f, indent=4) | |