VisoLearn commited on
Commit
196c072
·
verified ·
1 Parent(s): 0504af3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -46
app.py CHANGED
@@ -3,7 +3,6 @@ import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
6
- import re
7
 
8
  phi4_model_path = "Intelligent-Internet/II-Medical-8B"
9
 
@@ -16,30 +15,23 @@ phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
16
  def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
17
  if not user_message.strip():
18
  return history_state, history_state
19
-
20
  model = phi4_model
21
  tokenizer = phi4_tokenizer
22
  start_tag = "<|im_start|>"
23
  sep_tag = "<|im_sep|>"
24
  end_tag = "<|im_end|>"
25
 
26
- system_message = """You are a highly knowledgeable and thoughtful AI medical assistant. Your primary role is to assist with diagnostic reasoning by evaluating patient symptoms, medical history, and relevant clinical context.
27
 
28
- Structure your response into two main sections using the following format: <think> {Thought section} </think> {Solution section}.
29
 
30
- In the <think> section, use structured clinical reasoning to:
31
- - Identify possible differential diagnoses based on the given symptoms.
32
- - Consider risk factors, medical history, duration, and severity of symptoms.
33
- - Use step-by-step logic to rule in or rule out conditions.
34
- - Reflect on diagnostic uncertainty and suggest further assessments if needed.
35
 
36
- In the <solution> section, provide your most likely diagnosis or clinical assessment along with the rationale. Include brief suggestions for potential next steps like labs, imaging, or referrals if appropriate.
37
 
38
- IMPORTANT: When referencing lab values or pathophysiological mechanisms, use LaTeX formatting for clarity. Use $...$ for inline and $$...$$ for block-level expressions.
39
 
40
- Now, please analyze and respond to the following case:
41
- """
42
-
43
  prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
44
  for message in history_state:
45
  if message["role"] == "user":
@@ -51,7 +43,6 @@ Now, please analyze and respond to the following case:
51
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
52
 
53
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
54
-
55
  generation_kwargs = {
56
  "input_ids": inputs["input_ids"],
57
  "attention_mask": inputs["attention_mask"],
@@ -81,15 +72,14 @@ Now, please analyze and respond to the following case:
81
 
82
  yield new_history, new_history
83
 
84
- # Updated example cases for medical diagnostics
85
  example_messages = {
86
- "Chest Pain": "A 58-year-old man presents with chest pain that started 20 minutes ago while climbing stairs. He describes it as a heavy pressure in the center of his chest, radiating to his left arm. He has a history of hypertension and smoking. What is the likely diagnosis?",
87
- "Shortness of Breath": "A 34-year-old woman presents with 3 days of worsening shortness of breath, low-grade fever, and a dry cough. She denies chest pain or recent travel. Pulse oximetry is 91% on room air.",
88
- "Abdominal Pain": "A 22-year-old female presents with lower right quadrant abdominal pain, nausea, and fever. The pain started around the umbilicus and migrated to the right lower quadrant over the past 12 hours.",
89
- "Pediatric Fever": "A 2-year-old child has a fever of 39.5°C, irritability, and a rash on the trunk and arms. The child received all standard vaccinations and has no sick contacts. What should be considered in the differential diagnosis?"
90
  }
91
 
92
- # Custom CSS
93
  css = """
94
  .markdown-body .katex {
95
  font-size: 1.2em;
@@ -102,12 +92,7 @@ css = """
102
  """
103
 
104
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
105
- gr.Markdown(
106
- """
107
- # Medical Diagnosis Assistant
108
- This AI assistant uses structured reasoning to evaluate clinical cases and assist with diagnostic decision-making. Includes LaTeX support for medical calculations.
109
- """
110
- )
111
 
112
  gr.HTML("""
113
  <script>
@@ -126,7 +111,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
126
  messageStyle: 'none'
127
  };
128
  }
129
-
130
  function rerender() {
131
  if (window.MathJax && window.MathJax.Hub) {
132
  window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
@@ -147,21 +131,21 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
147
  top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
148
  top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
149
  repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
150
-
151
  with gr.Column(scale=4):
152
- chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", elem_id="chatbot", show_copy_button=True)
153
  with gr.Row():
154
- user_input = gr.Textbox(label="Describe patient symptoms...", placeholder="Type a clinical case here...", scale=3)
155
  submit_button = gr.Button("Send", variant="primary", scale=1)
156
  clear_button = gr.Button("Clear", scale=1)
157
- gr.Markdown("**Try these example cases:**")
158
  with gr.Row():
159
- example1_button = gr.Button("Chest Pain")
160
- example2_button = gr.Button("Shortness of Breath")
161
- example3_button = gr.Button("Abdominal Pain")
162
- example4_button = gr.Button("Pediatric Fever")
163
 
164
- submit_button.click(
165
  fn=generate_response,
166
  inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
167
  outputs=[chatbot, history_state]
@@ -171,15 +155,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
171
  outputs=user_input
172
  )
173
 
174
- clear_button.click(
175
- fn=lambda: ([], []),
176
- inputs=None,
177
- outputs=[chatbot, history_state]
178
- )
179
 
180
- example1_button.click(fn=lambda: gr.update(value=example_messages["Chest Pain"]), inputs=None, outputs=user_input)
181
- example2_button.click(fn=lambda: gr.update(value=example_messages["Shortness of Breath"]), inputs=None, outputs=user_input)
182
- example3_button.click(fn=lambda: gr.update(value=example_messages["Abdominal Pain"]), inputs=None, outputs=user_input)
183
- example4_button.click(fn=lambda: gr.update(value=example_messages["Pediatric Fever"]), inputs=None, outputs=user_input)
184
 
185
  demo.launch(ssr_mode=False)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
 
6
 
7
  phi4_model_path = "Intelligent-Internet/II-Medical-8B"
8
 
 
15
  def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
16
  if not user_message.strip():
17
  return history_state, history_state
18
+
19
  model = phi4_model
20
  tokenizer = phi4_tokenizer
21
  start_tag = "<|im_start|>"
22
  sep_tag = "<|im_sep|>"
23
  end_tag = "<|im_end|>"
24
 
25
+ system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections:
26
 
27
+ <think> In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. </think>
28
 
29
+ In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional.
 
 
 
 
30
 
31
+ Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$).
32
 
33
+ Now, analyze the following case:"""
34
 
 
 
 
35
  prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
36
  for message in history_state:
37
  if message["role"] == "user":
 
43
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
44
 
45
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
46
  generation_kwargs = {
47
  "input_ids": inputs["input_ids"],
48
  "attention_mask": inputs["attention_mask"],
 
72
 
73
  yield new_history, new_history
74
 
75
+
76
  example_messages = {
77
+ "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.",
78
+ "Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.",
79
+ "Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.",
80
+ "BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese."
81
  }
82
 
 
83
  css = """
84
  .markdown-body .katex {
85
  font-size: 1.2em;
 
92
  """
93
 
94
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
95
+ gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.")
 
 
 
 
 
96
 
97
  gr.HTML("""
98
  <script>
 
111
  messageStyle: 'none'
112
  };
113
  }
 
114
  function rerender() {
115
  if (window.MathJax && window.MathJax.Hub) {
116
  window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]);
 
131
  top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
132
  top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
133
  repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
134
+
135
  with gr.Column(scale=4):
136
+ chatbot = gr.Chatbot(label="Chat", render_markdown=True, type="messages", show_copy_button=True)
137
  with gr.Row():
138
+ user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3)
139
  submit_button = gr.Button("Send", variant="primary", scale=1)
140
  clear_button = gr.Button("Clear", scale=1)
141
+ gr.Markdown("**Try these examples:**")
142
  with gr.Row():
143
+ example1 = gr.Button("Headache case")
144
+ example2 = gr.Button("Chest pain")
145
+ example3 = gr.Button("Abdominal pain")
146
+ example4 = gr.Button("BMI calculation")
147
 
148
+ submit_button.stream(
149
  fn=generate_response,
150
  inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
151
  outputs=[chatbot, history_state]
 
155
  outputs=user_input
156
  )
157
 
158
+ clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history_state])
 
 
 
 
159
 
160
+ example1.click(lambda: gr.update(value=example_messages["Headache case"]), None, user_input)
161
+ example2.click(lambda: gr.update(value=example_messages["Chest pain"]), None, user_input)
162
+ example3.click(lambda: gr.update(value=example_messages["Abdominal pain"]), None, user_input)
163
+ example4.click(lambda: gr.update(value=example_messages["BMI calculation"]), None, user_input)
164
 
165
  demo.launch(ssr_mode=False)