Spaces:
Sleeping
Sleeping
__author__ = "qiao" | |
""" | |
generate the search keywords for each patient | |
""" | |
import json | |
import os | |
import time | |
from openai import AzureOpenAI | |
import sys | |
client = AzureOpenAI( | |
api_version="2023-09-01-preview", | |
azure_endpoint=os.getenv("OPENAI_ENDPOINT"), | |
api_key=os.getenv("OPENAI_API_KEY"), | |
) | |
def get_keyword_generation_messages(note): | |
system = 'You are a helpful assistant and your task is to help search relevant clinical trials for a given patient description. Please first summarize the main medical problems of the patient. Then generate up to 32 key conditions for searching relevant clinical trials for this patient. The key condition list should be ranked by priority. Please output only a JSON dict formatted as Dict{{"summary": Str(summary), "conditions": List[Str(condition)]}}.' | |
prompt = f"Here is the patient description: \n{note}\n\nJSON output:" | |
messages = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": prompt} | |
] | |
return messages | |
def is_valid_json(json_str): | |
"""Check if a string is valid JSON""" | |
try: | |
json.loads(json_str) | |
return True | |
except (json.JSONDecodeError, TypeError): | |
return False | |
def fix_truncated_json(json_str): | |
"""Attempt to fix common JSON truncation issues""" | |
if not json_str or not isinstance(json_str, str): | |
return None | |
# Remove any trailing commas before closing brackets/braces | |
json_str = json_str.rstrip() | |
# If it ends with a comma, remove it | |
if json_str.endswith(','): | |
json_str = json_str[:-1] | |
# If it ends with an incomplete string, try to close it | |
if json_str.count('"') % 2 != 0: | |
# Find the last quote and see if we can close the string | |
last_quote_pos = json_str.rfind('"') | |
if last_quote_pos != -1: | |
# Check if this is an opening quote (no closing quote after it) | |
remaining = json_str[last_quote_pos+1:] | |
if '"' not in remaining: | |
# Try to close the string and the JSON | |
json_str += '"]}' | |
# If it ends with an incomplete array, try to close it | |
if json_str.endswith('['): | |
json_str += ']}' | |
elif json_str.endswith('["'): | |
json_str += ']}' | |
# If it ends with an incomplete object, try to close it | |
if json_str.endswith('{'): | |
json_str += '}' | |
return json_str | |
if __name__ == "__main__": | |
# the corpus: trec_2021, trec_2022, or sigir | |
corpus = sys.argv[1] | |
# the model index to use | |
model = sys.argv[2] | |
outputs = {} | |
with open(f"dataset/{corpus}/queries.jsonl", "r") as f: | |
for line in f.readlines(): | |
entry = json.loads(line) | |
messages = get_keyword_generation_messages(entry["text"]) | |
max_retries = 3 | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
response = client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=0, | |
max_tokens=2000, # Increase max tokens to avoid truncation | |
) | |
output = response.choices[0].message.content | |
# Check if output is None or empty | |
if output is None or output.strip() == "": | |
print(f"Warning: Empty response for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}") | |
retry_count += 1 | |
if retry_count < max_retries: | |
time.sleep(2) # Wait before retrying | |
continue | |
else: | |
print(f"Error: Failed to get response for entry {entry['_id']} after {max_retries} retries, skipping...") | |
break | |
output = output.strip("`").strip("json") | |
# Try to fix truncated JSON | |
if not is_valid_json(output): | |
fixed_output = fix_truncated_json(output) | |
if fixed_output and is_valid_json(fixed_output): | |
output = fixed_output | |
print(f"Info: Fixed truncated JSON for entry {entry['_id']}") | |
else: | |
print(f"Warning: Invalid JSON for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}") | |
print(f"Raw output: {output[:200]}...") # Show first 200 chars | |
retry_count += 1 | |
if retry_count < max_retries: | |
time.sleep(2) # Wait before retrying | |
continue | |
else: | |
print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...") | |
break | |
try: | |
parsed_output = json.loads(output) | |
outputs[entry["_id"]] = parsed_output | |
print(f"Success: Processed entry {entry['_id']}") | |
break # Success, exit retry loop | |
except json.JSONDecodeError as e: | |
print(f"Warning: Failed to parse JSON for entry {entry['_id']}: {e}") | |
print(f"Raw output: {output[:200]}...") # Show first 200 chars | |
retry_count += 1 | |
if retry_count < max_retries: | |
time.sleep(2) # Wait before retrying | |
continue | |
else: | |
print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...") | |
break | |
except Exception as e: | |
print(f"Error processing entry {entry['_id']}: {e}") | |
retry_count += 1 | |
if retry_count < max_retries: | |
time.sleep(2) # Wait before retrying | |
continue | |
else: | |
print(f"Error: Failed to process entry {entry['_id']} after {max_retries} retries, skipping...") | |
break | |
# Save progress after each entry | |
with open(f"results/retrieval_keywords_{model}_{corpus}.json", "w") as f: | |
json.dump(outputs, f, indent=4) | |