Spaces:
Running
Running
import xml.etree.ElementTree as ET | |
import json | |
import sys | |
import os | |
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex | |
from llama_index.core import Settings | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Update path constants | |
BASE_DIR = os.path.dirname(os.path.dirname(__file__)) | |
DATA_DIR = os.path.join(BASE_DIR, "data") | |
ICD_DIR = os.path.join(DATA_DIR, "icd10cm_tabular_2025") | |
DEFAULT_XML_PATH = os.path.join(ICD_DIR, "icd10cm_tabular_2025.xml") | |
PROCESSED_DIR = os.path.join(DATA_DIR, "processed") | |
def main(xml_path=DEFAULT_XML_PATH): | |
# Create processed directory if it doesn't exist | |
os.makedirs(PROCESSED_DIR, exist_ok=True) | |
if not os.path.isfile(xml_path): | |
print(f"ERROR: cannot find tabular XML at '{xml_path}'") | |
sys.exit(1) | |
tree = ET.parse(xml_path) | |
root = tree.getroot() | |
icd_to_description = {} | |
# Iterate over every <diag> in the entire file, recursively. | |
# Each <diag> has: | |
# • <name> (the ICD-10 code) | |
# • <desc> (the human-readable description) | |
# • zero or more nested <diag> children (sub-codes). | |
for diag in root.iter("diag"): | |
name_elem = diag.find("name") | |
desc_elem = diag.find("desc") | |
if name_elem is None or desc_elem is None: | |
continue | |
# Some <diag> nodes might have <name/> or <desc/> with no text; skip those. | |
if name_elem.text is None or desc_elem.text is None: | |
continue | |
code = name_elem.text.strip() | |
description = desc_elem.text.strip() | |
# Only store non-empty strings: | |
if code and description: | |
icd_to_description[code] = description | |
# Write out a flat JSON mapping code → description | |
out_path = os.path.join(PROCESSED_DIR, "icd_to_description.json") | |
with open(out_path, "w", encoding="utf-8") as fp: | |
json.dump(icd_to_description, fp, indent=2, ensure_ascii=False) | |
print(f"Wrote {len(icd_to_description)} code entries to {out_path}") | |
def create_symptom_index(): | |
"""Create and return symptom index from ICD-10 data.""" | |
try: | |
logger.info("Loading documents from data directory...") | |
documents = SimpleDirectoryReader( | |
input_dir="data", | |
filename_as_id=True | |
).load_data() | |
logger.info(f"Creating vector index from {len(documents)} documents...") | |
index = VectorStoreIndex.from_documents( | |
documents, | |
show_progress=True | |
) | |
logger.info("Symptom index created successfully") | |
return index | |
except Exception as e: | |
logger.error(f"Failed to create symptom index: {str(e)}") | |
raise | |
# Move this outside the main() function | |
symptom_index = None | |
if __name__ == "__main__": | |
if len(sys.argv) > 1: | |
main(sys.argv[1]) | |
else: | |
main() # Use default path | |
symptom_index = create_symptom_index() | |
# Test multiple queries | |
test_queries = [ | |
"persistent cough with fever", | |
"severe headache with nausea", | |
"lower back pain", | |
"difficulty breathing" | |
] | |
print("\nTesting symptom matching:") | |
print("-" * 50) | |
for query in test_queries: | |
response = symptom_index.as_query_engine().query(query) | |
print(f"\nQuery: {query}") | |
print(f"Relevant ICD-10 codes:") | |
print(str(response)) | |
print("-" * 50) | |