__author__ = "qiao" """ generate the search keywords for each patient """ import json import os import time from openai import AzureOpenAI import sys client = AzureOpenAI( api_version="2023-09-01-preview", azure_endpoint=os.getenv("OPENAI_ENDPOINT"), api_key=os.getenv("OPENAI_API_KEY"), ) def get_keyword_generation_messages(note): system = 'You are a helpful assistant and your task is to help search relevant clinical trials for a given patient description. Please first summarize the main medical problems of the patient. Then generate up to 32 key conditions for searching relevant clinical trials for this patient. The key condition list should be ranked by priority. Please output only a JSON dict formatted as Dict{{"summary": Str(summary), "conditions": List[Str(condition)]}}.' prompt = f"Here is the patient description: \n{note}\n\nJSON output:" messages = [ {"role": "system", "content": system}, {"role": "user", "content": prompt} ] return messages def is_valid_json(json_str): """Check if a string is valid JSON""" try: json.loads(json_str) return True except (json.JSONDecodeError, TypeError): return False def fix_truncated_json(json_str): """Attempt to fix common JSON truncation issues""" if not json_str or not isinstance(json_str, str): return None # Remove any trailing commas before closing brackets/braces json_str = json_str.rstrip() # If it ends with a comma, remove it if json_str.endswith(','): json_str = json_str[:-1] # If it ends with an incomplete string, try to close it if json_str.count('"') % 2 != 0: # Find the last quote and see if we can close the string last_quote_pos = json_str.rfind('"') if last_quote_pos != -1: # Check if this is an opening quote (no closing quote after it) remaining = json_str[last_quote_pos+1:] if '"' not in remaining: # Try to close the string and the JSON json_str += '"]}' # If it ends with an incomplete array, try to close it if json_str.endswith('['): json_str += ']}' elif json_str.endswith('["'): json_str += ']}' # If it ends with an incomplete object, try to close it if json_str.endswith('{'): json_str += '}' return json_str if __name__ == "__main__": # the corpus: trec_2021, trec_2022, or sigir corpus = sys.argv[1] # the model index to use model = sys.argv[2] outputs = {} with open(f"dataset/{corpus}/queries.jsonl", "r") as f: for line in f.readlines(): entry = json.loads(line) messages = get_keyword_generation_messages(entry["text"]) max_retries = 3 retry_count = 0 while retry_count < max_retries: try: response = client.chat.completions.create( model=model, messages=messages, temperature=0, max_tokens=2000, # Increase max tokens to avoid truncation ) output = response.choices[0].message.content # Check if output is None or empty if output is None or output.strip() == "": print(f"Warning: Empty response for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}") retry_count += 1 if retry_count < max_retries: time.sleep(2) # Wait before retrying continue else: print(f"Error: Failed to get response for entry {entry['_id']} after {max_retries} retries, skipping...") break output = output.strip("`").strip("json") # Try to fix truncated JSON if not is_valid_json(output): fixed_output = fix_truncated_json(output) if fixed_output and is_valid_json(fixed_output): output = fixed_output print(f"Info: Fixed truncated JSON for entry {entry['_id']}") else: print(f"Warning: Invalid JSON for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}") print(f"Raw output: {output[:200]}...") # Show first 200 chars retry_count += 1 if retry_count < max_retries: time.sleep(2) # Wait before retrying continue else: print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...") break try: parsed_output = json.loads(output) outputs[entry["_id"]] = parsed_output print(f"Success: Processed entry {entry['_id']}") break # Success, exit retry loop except json.JSONDecodeError as e: print(f"Warning: Failed to parse JSON for entry {entry['_id']}: {e}") print(f"Raw output: {output[:200]}...") # Show first 200 chars retry_count += 1 if retry_count < max_retries: time.sleep(2) # Wait before retrying continue else: print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...") break except Exception as e: print(f"Error processing entry {entry['_id']}: {e}") retry_count += 1 if retry_count < max_retries: time.sleep(2) # Wait before retrying continue else: print(f"Error: Failed to process entry {entry['_id']} after {max_retries} retries, skipping...") break # Save progress after each entry with open(f"results/retrieval_keywords_{model}_{corpus}.json", "w") as f: json.dump(outputs, f, indent=4)