File size: 3,456 Bytes
0d38280
 
 
 
f9e956a
 
a487eb3
 
 
 
3f1fdcf
357828f
 
 
 
 
 
0d38280
357828f
 
 
 
0d38280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357828f
0d38280
 
 
 
 
357828f
f9e956a
 
a487eb3
f9e956a
 
 
 
 
a487eb3
f9e956a
 
 
357828f
f9e956a
a487eb3
f9e956a
357828f
f9e956a
a487eb3
f9e956a
357828f
 
 
0d38280
 
357828f
 
 
 
 
0ef172c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import xml.etree.ElementTree as ET
import json
import sys
import os
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import Settings
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Update path constants
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data")
ICD_DIR = os.path.join(DATA_DIR, "icd10cm_tabular_2025")
DEFAULT_XML_PATH = os.path.join(ICD_DIR, "icd10cm_tabular_2025.xml")
PROCESSED_DIR = os.path.join(DATA_DIR, "processed")

def main(xml_path=DEFAULT_XML_PATH):
    # Create processed directory if it doesn't exist
    os.makedirs(PROCESSED_DIR, exist_ok=True)
    
    if not os.path.isfile(xml_path):
        print(f"ERROR: cannot find tabular XML at '{xml_path}'")
        sys.exit(1)

    tree = ET.parse(xml_path)
    root = tree.getroot()

    icd_to_description = {}

    # Iterate over every <diag> in the entire file, recursively.
    # Each <diag> has:
    #   • <name>  (the ICD-10 code)
    #   • <desc>  (the human-readable description)
    #   • zero or more nested <diag> children (sub-codes).
    for diag in root.iter("diag"):
        name_elem = diag.find("name")
        desc_elem = diag.find("desc")
        if name_elem is None or desc_elem is None:
            continue
        # Some <diag> nodes might have <name/> or <desc/> with no text; skip those.
        if name_elem.text is None or desc_elem.text is None:
            continue

        code = name_elem.text.strip()
        description = desc_elem.text.strip()
        # Only store non-empty strings:
        if code and description:
            icd_to_description[code] = description

    # Write out a flat JSON mapping code → description
    out_path = os.path.join(PROCESSED_DIR, "icd_to_description.json")
    with open(out_path, "w", encoding="utf-8") as fp:
        json.dump(icd_to_description, fp, indent=2, ensure_ascii=False)

    print(f"Wrote {len(icd_to_description)} code entries to {out_path}")

def create_symptom_index():
    """Create and return symptom index from ICD-10 data."""
    try:
        logger.info("Loading documents from data directory...")
        documents = SimpleDirectoryReader(
            input_dir="data",
            filename_as_id=True
        ).load_data()
        
        logger.info(f"Creating vector index from {len(documents)} documents...")
        index = VectorStoreIndex.from_documents(
            documents,
            show_progress=True
        )
        
        logger.info("Symptom index created successfully")
        return index
    
    except Exception as e:
        logger.error(f"Failed to create symptom index: {str(e)}")
        raise

# Move this outside the main() function
symptom_index = None

if __name__ == "__main__":
    if len(sys.argv) > 1:
        main(sys.argv[1])
    else:
        main()  # Use default path
    symptom_index = create_symptom_index()
    
    # Test multiple queries
    test_queries = [
        "persistent cough with fever",
        "severe headache with nausea",
        "lower back pain",
        "difficulty breathing"
    ]
    
    print("\nTesting symptom matching:")
    print("-" * 50)
    
    for query in test_queries:
        response = symptom_index.as_query_engine().query(query)
        print(f"\nQuery: {query}")
        print(f"Relevant ICD-10 codes:")
        print(str(response))
        print("-" * 50)