File size: 6,127 Bytes
8810a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from mai_dx import MaiDxOrchestrator
from loguru import logger

if __name__ == "__main__":
    # Example case inspired by the paper's Figure 1
    initial_info = (
        "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
        "and bleeding. Symptoms did not abate with antimicrobial therapy."
    )

    full_case = """
    Patient: 29-year-old female.
    History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
    No fevers, headaches, or gastrointestinal symptoms. Past medical history is unremarkable. No history of smoking or significant alcohol use.
    Physical Exam: Right peritonsillar mass, displacing the uvula. No other significant findings.
    Initial Labs: FBC, clotting studies normal.
    MRI Neck: Showed a large, enhancing mass in the right peritonsillar space.
    Biopsy (H&E): Infiltrative round-cell neoplasm with high nuclear-to-cytoplasmic ratio and frequent mitotic figures.
    Biopsy (Immunohistochemistry for Carcinoma): CD31, D2-40, CD34, ERG, GLUT-1, pan-cytokeratin, CD45, CD20, CD3 all negative. Ki-67: 60% nuclear positivity.
    Biopsy (Immunohistochemistry for Rhabdomyosarcoma): Desmin and MyoD1 diffusely positive. Myogenin multifocally positive.
    Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
    Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
    """

    ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"

    # --- Demonstrate Different MAI-DxO Variants ---
    try:
        print("\n" + "=" * 80)
        print(
            "    MAI DIAGNOSTIC ORCHESTRATOR (MAI-DxO) - SEQUENTIAL DIAGNOSIS BENCHMARK"
        )
        print(
            "                    Implementation based on the NEJM Research Paper"
        )
        print("=" * 80)

        # Test different variants as described in the paper
        variants_to_test = [
            (
                "no_budget",
                "Standard MAI-DxO with no budget constraints",
            ),
            ("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
            (
                "question_only",
                "Question-only variant (no diagnostic tests)",
            ),
        ]

        results = {}

        for variant_name, description in variants_to_test:
            print(f"\n{'='*60}")
            print(f"Testing Variant: {variant_name.upper()}")
            print(f"Description: {description}")
            print("=" * 60)

            # Create the variant
            if variant_name == "budgeted":
                orchestrator = MaiDxOrchestrator.create_variant(
                    variant_name,
                    budget=3000,
                    model_name="gemini/gemini-2.5-flash",
                    max_iterations=5,
                )
            else:
                orchestrator = MaiDxOrchestrator.create_variant(
                    variant_name,
                    model_name="gemini/gemini-2.5-flash",
                    max_iterations=5,
                )

            # Run the diagnostic process
            result = orchestrator.run(
                initial_case_info=initial_info,
                full_case_details=full_case,
                ground_truth_diagnosis=ground_truth,
            )

            results[variant_name] = result

            # Display results
            print(f"\nπŸš€ Final Diagnosis: {result.final_diagnosis}")
            print(f"🎯 Ground Truth: {result.ground_truth}")
            print(f"⭐ Accuracy Score: {result.accuracy_score}/5.0")
            print(f"   Reasoning: {result.accuracy_reasoning}")
            print(f"πŸ’° Total Cost: ${result.total_cost:,}")
            print(f"πŸ”„ Iterations: {result.iterations}")
            print(f"⏱️  Mode: {orchestrator.mode}")

        # Demonstrate ensemble approach
        print(f"\n{'='*60}")
        print("Testing Variant: ENSEMBLE")
        print(
            "Description: Multiple independent runs with consensus aggregation"
        )
        print("=" * 60)

        ensemble_orchestrator = MaiDxOrchestrator.create_variant(
            "ensemble",
            model_name="gemini/gemini-2.5-flash",
            max_iterations=3,  # Shorter iterations for ensemble
        )

        ensemble_result = ensemble_orchestrator.run_ensemble(
            initial_case_info=initial_info,
            full_case_details=full_case,
            ground_truth_diagnosis=ground_truth,
            num_runs=2,  # Reduced for demo
        )

        results["ensemble"] = ensemble_result

        print(
            f"\nπŸš€ Ensemble Diagnosis: {ensemble_result.final_diagnosis}"
        )
        print(f"🎯 Ground Truth: {ensemble_result.ground_truth}")
        print(
            f"⭐ Ensemble Score: {ensemble_result.accuracy_score}/5.0"
        )
        print(
            f"πŸ’° Total Ensemble Cost: ${ensemble_result.total_cost:,}"
        )

        # --- Summary Comparison ---
        print(f"\n{'='*80}")
        print("                           RESULTS SUMMARY")
        print("=" * 80)
        print(
            f"{'Variant':<15} {'Diagnosis Match':<15} {'Score':<8} {'Cost':<12} {'Iterations':<12}"
        )
        print("-" * 80)

        for variant_name, result in results.items():
            match_status = (
                "βœ“ Match"
                if result.accuracy_score >= 4.0
                else "βœ— No Match"
            )
            print(
                f"{variant_name:<15} {match_status:<15} {result.accuracy_score:<8.1f} ${result.total_cost:<11,} {result.iterations:<12}"
            )

        print(f"\n{'='*80}")
        print(
            "Implementation successfully demonstrates the MAI-DxO framework"
        )
        print(
            "as described in 'Sequential Diagnosis with Language Models' paper"
        )
        print("=" * 80)

    except Exception as e:
        logger.exception(
            f"An error occurred during the diagnostic session: {e}"
        )
        print(f"\n❌ Error occurred: {e}")
        print("Please check your model configuration and API keys.")