ClinicalTrial / trialgpt_ranking /run_aggregation.py
Salma Hassan
files
50e583f
__author__ = "qiao"
"""
Using GPT to aggregate the scores by itself.
"""
from beir.datasets.data_loader import GenericDataLoader
import json
from nltk.tokenize import sent_tokenize
import os
import sys
import time
from TrialGPT import trialgpt_aggregation
if __name__ == "__main__":
corpus = sys.argv[1]
model = sys.argv[2]
# the path of the matching results
matching_results_path = sys.argv[3]
results = json.load(open(matching_results_path))
# loading the trial2info dict
trial2info = json.load(open("dataset/trial_info.json"))
# loading the patient info
_, queries, _ = GenericDataLoader(data_folder=f"dataset/{corpus}/").load(split="test")
# output file path
output_path = f"results/aggregation_results_{corpus}_{model}.json"
if os.path.exists(output_path):
output = json.load(open(output_path))
else:
output = {}
# patient-level
for patient_id, info in results.items():
# get the patient note
patient = queries[patient_id]
sents = sent_tokenize(patient)
sents.append("The patient will provide informed consent, and will comply with the trial protocol without any practical issues.")
sents = [f"{idx}. {sent}" for idx, sent in enumerate(sents)]
patient = "\n".join(sents)
if patient_id not in output:
output[patient_id] = {}
# label-level, 3 label / patient
for label, trials in info.items():
# trial-level
for trial_id, trial_results in trials.items():
# already cached results
if trial_id in output[patient_id]:
continue
if type(trial_results) is not dict:
output[patient_id][trial_id] = "matching result error"
with open(output_path, "w") as f:
json.dump(output, f, indent=4)
continue
# specific trial information
trial_info = trial2info[trial_id]
try:
result = trialgpt_aggregation(patient, trial_results, trial_info, model)
output[patient_id][trial_id] = result
with open(output_path, "w") as f:
json.dump(output, f, indent=4)
except:
continue