import xml.etree.ElementTree as ET import json import sys import os from ..services.indexing import create_symptom_index import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Update path constants BASE_DIR = os.path.dirname(os.path.dirname(__file__)) DATA_DIR = os.path.join(BASE_DIR, "data") ICD_DIR = os.path.join(DATA_DIR, "icd10cm_tabular_2025") DEFAULT_XML_PATH = os.path.join(ICD_DIR, "icd10cm_tabular_2025.xml") PROCESSED_DIR = os.path.join(DATA_DIR, "processed") def main(xml_path=DEFAULT_XML_PATH): # Create processed directory if it doesn't exist os.makedirs(PROCESSED_DIR, exist_ok=True) if not os.path.isfile(xml_path): print(f"ERROR: cannot find tabular XML at '{xml_path}'") sys.exit(1) tree = ET.parse(xml_path) root = tree.getroot() icd_to_description = {} # Iterate over every in the entire file, recursively. # Each has: # • (the ICD-10 code) # • (the human-readable description) # • zero or more nested children (sub-codes). for diag in root.iter("diag"): name_elem = diag.find("name") desc_elem = diag.find("desc") if name_elem is None or desc_elem is None: continue # Some nodes might have or with no text; skip those. if name_elem.text is None or desc_elem.text is None: continue code = name_elem.text.strip() description = desc_elem.text.strip() # Only store non-empty strings: if code and description: icd_to_description[code] = description # Write out a flat JSON mapping code → description out_path = os.path.join(PROCESSED_DIR, "icd_to_description.json") with open(out_path, "w", encoding="utf-8") as fp: json.dump(icd_to_description, fp, indent=2, ensure_ascii=False) print(f"Wrote {len(icd_to_description)} code entries to {out_path}") # Move this outside the main() function symptom_index = None if __name__ == "__main__": if len(sys.argv) > 1: main(sys.argv[1]) else: main() # Use default path symptom_index = create_symptom_index() # Test multiple queries test_queries = [ "persistent cough with fever", "severe headache with nausea", "lower back pain", "difficulty breathing" ] print("\nTesting symptom matching:") print("-" * 50) for query in test_queries: response = symptom_index.as_query_engine().query(query) print(f"\nQuery: {query}") print(f"Relevant ICD-10 codes:") print(str(response)) print("-" * 50)