ClinicalTrial / trialgpt_retrieval /keyword_generation.py
Salma Hassan
files
50e583f
__author__ = "qiao"
"""
generate the search keywords for each patient
"""
import json
import os
import time
from openai import AzureOpenAI
import sys
client = AzureOpenAI(
api_version="2023-09-01-preview",
azure_endpoint=os.getenv("OPENAI_ENDPOINT"),
api_key=os.getenv("OPENAI_API_KEY"),
)
def get_keyword_generation_messages(note):
system = 'You are a helpful assistant and your task is to help search relevant clinical trials for a given patient description. Please first summarize the main medical problems of the patient. Then generate up to 32 key conditions for searching relevant clinical trials for this patient. The key condition list should be ranked by priority. Please output only a JSON dict formatted as Dict{{"summary": Str(summary), "conditions": List[Str(condition)]}}.'
prompt = f"Here is the patient description: \n{note}\n\nJSON output:"
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt}
]
return messages
def is_valid_json(json_str):
"""Check if a string is valid JSON"""
try:
json.loads(json_str)
return True
except (json.JSONDecodeError, TypeError):
return False
def fix_truncated_json(json_str):
"""Attempt to fix common JSON truncation issues"""
if not json_str or not isinstance(json_str, str):
return None
# Remove any trailing commas before closing brackets/braces
json_str = json_str.rstrip()
# If it ends with a comma, remove it
if json_str.endswith(','):
json_str = json_str[:-1]
# If it ends with an incomplete string, try to close it
if json_str.count('"') % 2 != 0:
# Find the last quote and see if we can close the string
last_quote_pos = json_str.rfind('"')
if last_quote_pos != -1:
# Check if this is an opening quote (no closing quote after it)
remaining = json_str[last_quote_pos+1:]
if '"' not in remaining:
# Try to close the string and the JSON
json_str += '"]}'
# If it ends with an incomplete array, try to close it
if json_str.endswith('['):
json_str += ']}'
elif json_str.endswith('["'):
json_str += ']}'
# If it ends with an incomplete object, try to close it
if json_str.endswith('{'):
json_str += '}'
return json_str
if __name__ == "__main__":
# the corpus: trec_2021, trec_2022, or sigir
corpus = sys.argv[1]
# the model index to use
model = sys.argv[2]
outputs = {}
with open(f"dataset/{corpus}/queries.jsonl", "r") as f:
for line in f.readlines():
entry = json.loads(line)
messages = get_keyword_generation_messages(entry["text"])
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
max_tokens=2000, # Increase max tokens to avoid truncation
)
output = response.choices[0].message.content
# Check if output is None or empty
if output is None or output.strip() == "":
print(f"Warning: Empty response for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}")
retry_count += 1
if retry_count < max_retries:
time.sleep(2) # Wait before retrying
continue
else:
print(f"Error: Failed to get response for entry {entry['_id']} after {max_retries} retries, skipping...")
break
output = output.strip("`").strip("json")
# Try to fix truncated JSON
if not is_valid_json(output):
fixed_output = fix_truncated_json(output)
if fixed_output and is_valid_json(fixed_output):
output = fixed_output
print(f"Info: Fixed truncated JSON for entry {entry['_id']}")
else:
print(f"Warning: Invalid JSON for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}")
print(f"Raw output: {output[:200]}...") # Show first 200 chars
retry_count += 1
if retry_count < max_retries:
time.sleep(2) # Wait before retrying
continue
else:
print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...")
break
try:
parsed_output = json.loads(output)
outputs[entry["_id"]] = parsed_output
print(f"Success: Processed entry {entry['_id']}")
break # Success, exit retry loop
except json.JSONDecodeError as e:
print(f"Warning: Failed to parse JSON for entry {entry['_id']}: {e}")
print(f"Raw output: {output[:200]}...") # Show first 200 chars
retry_count += 1
if retry_count < max_retries:
time.sleep(2) # Wait before retrying
continue
else:
print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...")
break
except Exception as e:
print(f"Error processing entry {entry['_id']}: {e}")
retry_count += 1
if retry_count < max_retries:
time.sleep(2) # Wait before retrying
continue
else:
print(f"Error: Failed to process entry {entry['_id']} after {max_retries} retries, skipping...")
break
# Save progress after each entry
with open(f"results/retrieval_keywords_{model}_{corpus}.json", "w") as f:
json.dump(outputs, f, indent=4)