File size: 5,222 Bytes
50e583f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
__author__ = "qiao"

"""
generate the search keywords for each patient
"""

import json
import os
import time
from openai import AzureOpenAI

import sys

client = AzureOpenAI(
	api_version="2023-09-01-preview",
	azure_endpoint=os.getenv("OPENAI_ENDPOINT"),
	api_key=os.getenv("OPENAI_API_KEY"),
)


def get_keyword_generation_messages(note):
	system = 'You are a helpful assistant and your task is to help search relevant clinical trials for a given patient description. Please first summarize the main medical problems of the patient. Then generate up to 32 key conditions for searching relevant clinical trials for this patient. The key condition list should be ranked by priority. Please output only a JSON dict formatted as Dict{{"summary": Str(summary), "conditions": List[Str(condition)]}}.'

	prompt =  f"Here is the patient description: \n{note}\n\nJSON output:"

	messages = [
		{"role": "system", "content": system},
		{"role": "user", "content": prompt}
	]
	
	return messages


def is_valid_json(json_str):
	"""Check if a string is valid JSON"""
	try:
		json.loads(json_str)
		return True
	except (json.JSONDecodeError, TypeError):
		return False


def fix_truncated_json(json_str):
	"""Attempt to fix common JSON truncation issues"""
	if not json_str or not isinstance(json_str, str):
		return None
	
	# Remove any trailing commas before closing brackets/braces
	json_str = json_str.rstrip()
	
	# If it ends with a comma, remove it
	if json_str.endswith(','):
		json_str = json_str[:-1]
	
	# If it ends with an incomplete string, try to close it
	if json_str.count('"') % 2 != 0:
		# Find the last quote and see if we can close the string
		last_quote_pos = json_str.rfind('"')
		if last_quote_pos != -1:
			# Check if this is an opening quote (no closing quote after it)
			remaining = json_str[last_quote_pos+1:]
			if '"' not in remaining:
				# Try to close the string and the JSON
				json_str += '"]}'
	
	# If it ends with an incomplete array, try to close it
	if json_str.endswith('['):
		json_str += ']}'
	elif json_str.endswith('["'):
		json_str += ']}'
	
	# If it ends with an incomplete object, try to close it
	if json_str.endswith('{'):
		json_str += '}'
	
	return json_str


if __name__ == "__main__":
	# the corpus: trec_2021, trec_2022, or sigir
	corpus = sys.argv[1]

	# the model index to use
	model = sys.argv[2]

	outputs = {}
	
	with open(f"dataset/{corpus}/queries.jsonl", "r") as f:
		for line in f.readlines():
			entry = json.loads(line)
			messages = get_keyword_generation_messages(entry["text"])

			max_retries = 3
			retry_count = 0
			
			while retry_count < max_retries:
				try:
					response = client.chat.completions.create(
						model=model,
						messages=messages,
						temperature=0,
						max_tokens=2000,  # Increase max tokens to avoid truncation
					)

					output = response.choices[0].message.content
					
					# Check if output is None or empty
					if output is None or output.strip() == "":
						print(f"Warning: Empty response for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}")
						retry_count += 1
						if retry_count < max_retries:
							time.sleep(2)  # Wait before retrying
							continue
						else:
							print(f"Error: Failed to get response for entry {entry['_id']} after {max_retries} retries, skipping...")
							break
					
					output = output.strip("`").strip("json")
					
					# Try to fix truncated JSON
					if not is_valid_json(output):
						fixed_output = fix_truncated_json(output)
						if fixed_output and is_valid_json(fixed_output):
							output = fixed_output
							print(f"Info: Fixed truncated JSON for entry {entry['_id']}")
						else:
							print(f"Warning: Invalid JSON for entry {entry['_id']}, retry {retry_count + 1}/{max_retries}")
							print(f"Raw output: {output[:200]}...")  # Show first 200 chars
							retry_count += 1
							if retry_count < max_retries:
								time.sleep(2)  # Wait before retrying
								continue
							else:
								print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...")
								break
					
					try:
						parsed_output = json.loads(output)
						outputs[entry["_id"]] = parsed_output
						print(f"Success: Processed entry {entry['_id']}")
						break  # Success, exit retry loop
						
					except json.JSONDecodeError as e:
						print(f"Warning: Failed to parse JSON for entry {entry['_id']}: {e}")
						print(f"Raw output: {output[:200]}...")  # Show first 200 chars
						retry_count += 1
						if retry_count < max_retries:
							time.sleep(2)  # Wait before retrying
							continue
						else:
							print(f"Error: Failed to parse JSON for entry {entry['_id']} after {max_retries} retries, skipping...")
							break
						
				except Exception as e:
					print(f"Error processing entry {entry['_id']}: {e}")
					retry_count += 1
					if retry_count < max_retries:
						time.sleep(2)  # Wait before retrying
						continue
					else:
						print(f"Error: Failed to process entry {entry['_id']} after {max_retries} retries, skipping...")
						break

			# Save progress after each entry
			with open(f"results/retrieval_keywords_{model}_{corpus}.json", "w") as f:
				json.dump(outputs, f, indent=4)