gpaasch commited on
Commit
5e4e457
·
1 Parent(s): 72e4962

checkpoint 2 - we are in a very good position right now, audio transcription is working, gradio page looks good, local llms are working, returning the consultation, the json that will be sent to the mcp client, and a debugging panel

Browse files
app.py CHANGED
@@ -1,287 +1,218 @@
1
  import gradio as gr
2
  from utils.model_configuration_utils import select_best_model, ensure_model
3
  from services.llm import build_llm
 
4
  from services.embeddings import configure_embeddings
5
- from services.indexing import build_symptom_index
6
- from utils.voice_input_utils import enhanced_process_speech, format_response_for_user, get_asr_pipeline
7
  import torch
 
8
  import torchaudio.transforms as T
9
  import json
 
10
 
11
- # 1) Model selection & download
12
  MODEL_NAME, REPO_ID = select_best_model()
13
  model_path = ensure_model()
14
- print(f"Using model: {MODEL_NAME} from {REPO_ID}")
15
- print(f"Model path: {model_path}")
16
- print(f"Model requirements: {MODEL_NAME} requires at least 4GB VRAM and 8GB RAM.")
17
 
18
- # 2) LLM and embeddings config
 
19
  llm = build_llm(model_path)
20
- configure_embeddings()
21
- print(f"LLM configured with model: {model_path}")
22
- print("Embeddings configured successfully.")
23
-
24
- # 3) Index setup
25
- symptom_index = build_symptom_index()
26
- print("Symptom index built successfully.")
27
- print("Ready for queries.")
28
 
29
- # --- System prompt ---
30
- SYSTEM_PROMPT = """
31
- You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
32
- At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
33
- or, if you have enough info, output a final JSON with fields:
34
- {"diagnoses":[…], "confidences":[…]}.
35
- """
36
-
37
- # Build enhanced Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  with gr.Blocks(theme="default") as demo:
39
  gr.Markdown("""
40
  # 🏥 Medical Symptom to ICD-10 Code Assistant
41
-
42
- ## About
43
- This application is part of the Agents+MCP Hackathon. It helps medical professionals
44
- and patients understand potential diagnoses based on described symptoms.
45
-
46
- ### How it works:
47
- 1. Either click the record button and describe your symptoms or type them into the textbox
48
- 2. The AI will analyze your description and suggest possible diagnoses
49
- 3. Answer follow-up questions to refine the diagnosis
50
- """)
51
-
52
  with gr.Row():
53
  with gr.Column(scale=2):
54
- # Add text input above microphone
55
- with gr.Row():
56
- text_input = gr.Textbox(
57
- label="Type your symptoms",
58
- placeholder="Or type your symptoms here...",
59
- lines=3
60
- )
61
- submit_btn = gr.Button("Submit", variant="primary")
62
-
63
- # Existing microphone row
64
- with gr.Row():
65
- microphone = gr.Audio(
66
- sources=["microphone"],
67
- streaming=True,
68
- type="numpy",
69
- label="Describe your symptoms"
70
- )
71
- transcript_box = gr.Textbox(
72
- label="Transcribed Text",
73
- interactive=False,
74
- show_label=True
75
- )
76
- clear_btn = gr.Button("Clear Chat", variant="secondary")
77
-
78
  chatbot = gr.Chatbot(
79
  label="Medical Consultation",
80
  height=500,
81
- container=True,
82
- type="messages" # This is now properly supported by our message format
83
  )
84
-
 
85
  with gr.Column(scale=1):
86
- with gr.Accordion("Enter an API Key to give it more power!", open=False):
87
- api_key = gr.Textbox(
88
- label="OpenAI API Key (optional)",
89
- type="password",
90
- placeholder="sk-..."
 
91
  )
92
-
93
- with gr.Row():
94
- with gr.Column():
95
- modal_key = gr.Textbox(
96
- label="Modal Labs API Key",
97
- type="password",
98
- placeholder="mk-..."
99
- )
100
- anthropic_key = gr.Textbox(
101
- label="Anthropic API Key",
102
- type="password",
103
- placeholder="sk-ant-..."
104
- )
105
- mistral_key = gr.Textbox(
106
- label="MistralAI API Key",
107
- type="password",
108
- placeholder="..."
109
- )
110
-
111
- with gr.Column():
112
- nebius_key = gr.Textbox(
113
- label="Nebius API Key",
114
- type="password",
115
- placeholder="..."
116
- )
117
- hyperbolic_key = gr.Textbox(
118
- label="Hyperbolic Labs API Key",
119
- type="password",
120
- placeholder="hyp-..."
121
- )
122
- sambanova_key = gr.Textbox(
123
- label="SambaNova API Key",
124
- type="password",
125
- placeholder="..."
126
- )
127
-
128
- with gr.Row():
129
- model_selector = gr.Dropdown(
130
- choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"],
131
- value="OpenAI",
132
- label="Model Provider"
133
- )
134
- temperature = gr.Slider(
135
- minimum=0,
136
- maximum=1,
137
- value=0.7,
138
- label="Temperature"
139
- )
140
- # self promotion at bottom of page
141
- gr.Markdown("""
142
- ---
143
- ### 👋 About the Creator
144
-
145
- Hi! I'm Graham Paasch, an experienced technology professional!
146
-
147
- 🎥 **Check out my YouTube channel** for more tech content:
148
- [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
149
-
150
- 💼 **Looking for a skilled developer?**
151
- I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
152
-
153
- ⭐ If you found this tool helpful, please consider:
154
- - Subscribing to my YouTube channel
155
- - Connecting on LinkedIn
156
- - Sharing this tool with others in healthcare tech
157
- """)
158
 
159
- # Event handlers
160
- clear_btn.click(lambda: None, None, chatbot, queue=False)
161
-
162
- microphone.stream(
163
- fn=enhanced_process_speech,
164
- inputs=[microphone, chatbot, api_key, model_selector, temperature],
165
- outputs=chatbot,
166
- show_progress="hidden",
167
- api_name=False,
168
- queue=True # Enable queuing for better stream handling
 
 
169
  )
170
-
171
- def process_audio(audio_array, sample_rate):
172
- """Pre-process audio for Whisper."""
173
- if audio_array.ndim > 1:
174
- audio_array = audio_array.mean(axis=1)
175
-
176
- # Convert to tensor for resampling
177
- audio_tensor = torch.FloatTensor(audio_array)
178
-
179
- # Resample to 16kHz if needed
180
- if sample_rate != 16000:
181
- resampler = T.Resample(sample_rate, 16000)
182
- audio_tensor = resampler(audio_tensor)
183
-
184
- # Normalize
185
- audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
186
-
187
- # Convert back to numpy array and return in correct format
188
- return {
189
- "raw": audio_tensor.numpy(), # Key must be "raw"
190
- "sampling_rate": 16000 # Key must be "sampling_rate"
191
- }
192
-
193
- # Update transcription handler
194
- def update_live_transcription(audio):
195
- """Real-time transcription updates."""
196
- if not audio or not isinstance(audio, tuple):
197
- return ""
198
-
199
- try:
200
- sample_rate, audio_array = audio
201
- features = process_audio(audio_array, sample_rate)
202
-
203
- asr = get_asr_pipeline()
204
- result = asr(features)
205
-
206
- return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
207
- except Exception as e:
208
- print(f"Transcription error: {str(e)}")
209
- return ""
210
-
211
  microphone.stream(
212
  fn=update_live_transcription,
213
  inputs=[microphone],
214
- outputs=transcript_box,
215
- show_progress="hidden",
216
  queue=True
217
  )
218
-
219
- clear_btn.click(
220
- fn=lambda: (None, "", ""),
221
- outputs=[chatbot, transcript_box, text_input],
222
- queue=False
223
- )
224
-
225
- def cleanup_memory():
226
- """Release unused memory (placeholder for future memory management)."""
227
- import gc
228
- gc.collect()
229
- if torch.cuda.is_available():
230
- torch.cuda.empty_cache()
231
-
232
- def process_text_input(text, history):
233
- """Process text input with memory management."""
234
-
235
- print("process_text_input received:", text)
236
 
237
- if not text:
238
- return history, "" # Return tuple to clear input
 
 
239
 
240
- # Process the symptoms using the configured LLM
241
- prompt = f"""Given these symptoms: '{text}'
242
- Please provide:
243
- 1. Most likely ICD-10 codes
244
- 2. Confidence levels for each diagnosis
245
- 3. Key follow-up questions
246
 
247
- Format as JSON with diagnoses, confidences, and follow_up fields."""
248
-
249
- response = llm.complete(prompt)
250
-
251
- try:
252
- # Try to parse as JSON first
253
- result = json.loads(response.text)
254
- except json.JSONDecodeError:
255
- # If not JSON, wrap in our format
256
- result = {
257
- "diagnoses": [],
258
- "confidences": [],
259
- "follow_up": str(response.text)[:1000] # Limit response length
260
- }
261
 
262
- new_history = history + [
263
- {"role": "user", "content": text},
264
- {"role": "assistant", "content": format_response_for_user(result)}
265
- ]
266
- return new_history, "" # Return empty string to clear input
267
 
268
- # Update the submit button handler
269
- submit_btn.click(
270
- fn=process_text_input,
271
- inputs=[text_input, chatbot],
272
- outputs=[chatbot, text_input],
273
- queue=True
274
- ).success( # Changed from .then to .success for better error handling
275
- fn=cleanup_memory,
276
- inputs=None,
277
- outputs=None,
278
- queue=False
279
  )
280
 
281
  if __name__ == "__main__":
282
- demo.launch(
283
- server_name="0.0.0.0",
284
- server_port=7860,
285
- share=True, # Enable sharing via Gradio's temporary URLs
286
- show_api=True # Shows the API documentation
287
- )
 
1
  import gradio as gr
2
  from utils.model_configuration_utils import select_best_model, ensure_model
3
  from services.llm import build_llm
4
+ from utils.voice_input_utils import update_live_transcription, format_response_for_user
5
  from services.embeddings import configure_embeddings
6
+ from services.indexing import create_symptom_index
 
7
  import torch
8
+ import torchaudio
9
  import torchaudio.transforms as T
10
  import json
11
+ import re
12
 
13
+ # ========== Model setup ==========
14
  MODEL_NAME, REPO_ID = select_best_model()
15
  model_path = ensure_model()
16
+ print(f"Using model: {MODEL_NAME} from {REPO_ID}", flush=True)
17
+ print(f"Model path: {model_path}", flush=True)
 
18
 
19
+ # ========== LLM initialization ==========
20
+ print("\n<<< before build_llm: ", flush=True)
21
  llm = build_llm(model_path)
22
+ print(">>> after build_llm", flush=True)
 
 
 
 
 
 
 
23
 
24
+ # ========== Embeddings & index setup ==========
25
+ print("\n<<< before configure_embeddings: ", flush=True)
26
+ configure_embeddings()
27
+ print(">>> after configure_embeddings", flush=True)
28
+ print("Embeddings configured and ready", flush=True)
29
+
30
+ print("\n<<< before create_symptom_index: ", flush=True)
31
+ symptom_index = create_symptom_index()
32
+ print(">>> after create_symptom_index", flush=True)
33
+ print("Symptom index built successfully. Ready for queries.", flush=True)
34
+
35
+ # ========== Prompt template ==========
36
+ SYSTEM_PROMPT = (
37
+ "You are a medical assistant helping a user narrow down to the most likely ICD-10 code. "
38
+ "At each turn, either ask one focused clarifying question (e.g. 'Is your cough dry or productive?') "
39
+ "or if you have enough information, provide a final JSON with fields: {\"diagnoses\": [...], "
40
+ "\"confidences\": [...], \"follow_up\": [...]}. Output must be valid JSON with no trailing commas. Your output MUST be strictly valid JSON, starting with '{' and ending with '}', with no extra text outside the JSON."
41
+ )
42
+
43
+ # ========== Generator handler ==========
44
+ def on_submit(symptoms_text, history):
45
+ log = []
46
+ print("on_submit called", flush=True)
47
+
48
+ # Placeholder
49
+ msg = "🔍 Received input"
50
+ log.append(msg)
51
+ print(msg, flush=True)
52
+ history = history + [{"role": "assistant", "content": "Processing your request..."}]
53
+ yield history, None, "\n".join(log)
54
+
55
+ # Validate
56
+ if not symptoms_text.strip():
57
+ msg = "❌ No symptoms provided"
58
+ log.append(msg)
59
+ print(msg, flush=True)
60
+ result = {"error": "No input provided", "diagnoses": [], "confidences": [], "follow_up": []}
61
+ yield history, result, "\n".join(log)
62
+ return
63
+
64
+ # Clean input
65
+ cleaned = symptoms_text.strip()
66
+ msg = f"🔄 Cleaned text: {cleaned}"
67
+ log.append(msg)
68
+ print(msg, flush=True)
69
+ yield history, None, "\n".join(log)
70
+
71
+ # Semantic query
72
+ msg = "🔍 Running semantic query"
73
+ log.append(msg)
74
+ print(msg, flush=True)
75
+ yield history, None, "\n".join(log)
76
+
77
+ qe = symptom_index.as_query_engine(retriever_kwargs={"similarity_top_k": 5})
78
+ hits = qe.query(cleaned)
79
+ msg = f"🔍 Retrieved context entries"
80
+ log.append(msg)
81
+ print(msg, flush=True)
82
+ history = history + [{"role": "assistant", "content": msg}]
83
+ yield history, None, "\n".join(log)
84
+
85
+ # Build prompt with minimal context
86
+ context_list = []
87
+ for node in getattr(hits, 'source_nodes', [])[:3]:
88
+ md = getattr(node, 'metadata', {}) or {}
89
+ context_list.append(f"{md.get('code','')}: {md.get('description','')}")
90
+ context_text = "\n".join(context_list)
91
+ prompt = (
92
+ f"{SYSTEM_PROMPT}\n\n"
93
+ f"User symptoms: '{cleaned}'\n\n"
94
+ f"Relevant ICD-10 context:\n{context_text}\n\n"
95
+ "Respond with valid JSON."
96
+ )
97
+ msg = "✏️ Prompt built"
98
+ log.append(msg)
99
+ print(msg, flush=True)
100
+ yield history, None, "\n".join(log)
101
+
102
+ # Call LLM
103
+ # Use constrained decoding to enforce JSON-only output
104
+ response = llm.complete(prompt, stop=["}"]) # stop after closing brace
105
+ raw = getattr(response, 'text', str(response))
106
+ # Truncate extra content after the final JSON object
107
+ if not raw.strip().endswith('}'):
108
+ end_idx = raw.rfind('}')
109
+ if end_idx != -1:
110
+ raw = raw[:end_idx+1]
111
+ msg = "📡 Raw LLM response received"
112
+ log.append(msg)
113
+ print(msg, flush=True)
114
+ yield history, None, "\n".join(log)
115
+
116
+ # Parse JSON
117
+ cleaned_raw = re.sub(r",\s*([}\]])", r"\1", raw)
118
+ try:
119
+ parsed = json.loads(cleaned_raw)
120
+ msg = "✅ JSON parsed"
121
+ except Exception as e:
122
+ msg = f"❌ JSON parse error: {e}"
123
+ parsed = {"error": str(e), "raw": raw}
124
+ log.append(msg)
125
+ print(msg, flush=True)
126
+ yield history, parsed, "\n".join(log)
127
+
128
+ # Final assistant message
129
+ assistant_msg = format_response_for_user(parsed)
130
+ history = history + [{"role": "assistant", "content": assistant_msg}]
131
+ msg = "✅ Final response appended"
132
+ log.append(msg)
133
+ print(msg, flush=True)
134
+ yield history, parsed, "\n".join(log)
135
+
136
+ # ========== Gradio UI ==========
137
  with gr.Blocks(theme="default") as demo:
138
  gr.Markdown("""
139
  # 🏥 Medical Symptom to ICD-10 Code Assistant
140
+ ## Describe symptoms by typing or speaking.
141
+ Debug log updates live below.
142
+ """
143
+ )
 
 
 
 
 
 
 
144
  with gr.Row():
145
  with gr.Column(scale=2):
146
+ text_input = gr.Textbox(
147
+ label="Type your symptoms",
148
+ placeholder="I'm feeling under the weather...",
149
+ lines=3
150
+ )
151
+ microphone = gr.Audio(
152
+ sources=["microphone"],
153
+ streaming=True,
154
+ type="numpy",
155
+ label="Or speak your symptoms..."
156
+ )
157
+ submit_btn = gr.Button("Submit", variant="primary")
158
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
159
  chatbot = gr.Chatbot(
160
  label="Medical Consultation",
161
  height=500,
162
+ type="messages"
 
163
  )
164
+ json_output = gr.JSON(label="Diagnosis JSON")
165
+ debug_box = gr.Textbox(label="Debug log", lines=10)
166
  with gr.Column(scale=1):
167
+ with gr.Accordion("API Keys (optional)", open=False):
168
+ api_key = gr.Textbox(label="OpenAI Key", type="password")
169
+ model_selector = gr.Dropdown(
170
+ choices=["OpenAI","Modal","Anthropic","MistralAI","Nebius","Hyperbolic","SambaNova"],
171
+ value="OpenAI",
172
+ label="Model Provider"
173
  )
174
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ # Bindings
177
+ submit_btn.click(
178
+ fn=on_submit,
179
+ inputs=[text_input, chatbot],
180
+ outputs=[chatbot, json_output, debug_box],
181
+ queue=True
182
+ )
183
+ clear_btn.click(
184
+ lambda: (None, {}, ""),
185
+ None,
186
+ [chatbot, json_output, debug_box],
187
+ queue=False
188
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  microphone.stream(
190
  fn=update_live_transcription,
191
  inputs=[microphone],
192
+ outputs=[text_input],
193
+ show_progress=False,
194
  queue=True
195
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ # --- About the Creator ---
198
+ gr.Markdown("""
199
+ ---
200
+ ### 👋 About the Creator
201
 
202
+ Hi! I'm Graham Paasch, an experienced technology professional!
 
 
 
 
 
203
 
204
+ 🎥 **Check out my YouTube channel** for more tech content:
205
+ [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ 💼 **Looking for a skilled developer?**
208
+ I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
 
 
 
209
 
210
+ If you found this tool helpful, please consider:
211
+ - Subscribing to my YouTube channel
212
+ - Connecting on LinkedIn
213
+ - Sharing this tool with others in healthcare tech
214
+ """
 
 
 
 
 
 
215
  )
216
 
217
  if __name__ == "__main__":
218
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_api=True)
 
 
 
 
 
services/__init__.py ADDED
File without changes
services/embeddings.py CHANGED
@@ -3,4 +3,5 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
3
 
4
 
5
  def configure_embeddings(model_name="sentence-transformers/all-MiniLM-L6-v2"):
 
6
  Settings.embed_model = HuggingFaceEmbedding(model_name=model_name)
 
3
 
4
 
5
  def configure_embeddings(model_name="sentence-transformers/all-MiniLM-L6-v2"):
6
+ print("configure_embeddings: using ", model_name)
7
  Settings.embed_model = HuggingFaceEmbedding(model_name=model_name)
services/indexing.py CHANGED
@@ -1,4 +1,20 @@
1
- from src.parse_tabular import create_symptom_index
 
2
 
3
- def build_symptom_index():
4
- return create_symptom_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
2
+ from llama_index.core import Settings
3
 
4
+ def create_symptom_index():
5
+ """Create and return symptom index from ICD-10 data."""
6
+ print("build_symptom_index: Loading documents from data directory...")
7
+ documents = SimpleDirectoryReader(
8
+ input_dir="data",
9
+ filename_as_id=True
10
+ ).load_data()
11
+
12
+ print(f"build_symptom_index: Creating vector index from {len(documents)} documents...")
13
+ symptom_index = VectorStoreIndex.from_documents(
14
+ documents,
15
+ show_progress=True
16
+ )
17
+
18
+ print("build_symptom_index: Symptom index created successfully")
19
+
20
+ return symptom_index
services/llm.py CHANGED
@@ -2,6 +2,7 @@ from llama_index.core import Settings
2
  from llama_index.llms.llama_cpp import LlamaCPP
3
 
4
  def build_llm(model_path, temperature=0.7, max_tokens=256, context_window=2048):
 
5
  llm = LlamaCPP(
6
  model_path=model_path,
7
  temperature=temperature,
 
2
  from llama_index.llms.llama_cpp import LlamaCPP
3
 
4
  def build_llm(model_path, temperature=0.7, max_tokens=256, context_window=2048):
5
+ print("build_llm: loading model from", model_path)
6
  llm = LlamaCPP(
7
  model_path=model_path,
8
  temperature=temperature,
src/parse_tabular.py CHANGED
@@ -2,8 +2,7 @@ import xml.etree.ElementTree as ET
2
  import json
3
  import sys
4
  import os
5
- from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
6
- from llama_index.core import Settings
7
  import logging
8
 
9
  logging.basicConfig(level=logging.INFO)
@@ -56,28 +55,6 @@ def main(xml_path=DEFAULT_XML_PATH):
56
 
57
  print(f"Wrote {len(icd_to_description)} code entries to {out_path}")
58
 
59
- def create_symptom_index():
60
- """Create and return symptom index from ICD-10 data."""
61
- try:
62
- logger.info("Loading documents from data directory...")
63
- documents = SimpleDirectoryReader(
64
- input_dir="data",
65
- filename_as_id=True
66
- ).load_data()
67
-
68
- logger.info(f"Creating vector index from {len(documents)} documents...")
69
- index = VectorStoreIndex.from_documents(
70
- documents,
71
- show_progress=True
72
- )
73
-
74
- logger.info("Symptom index created successfully")
75
- return index
76
-
77
- except Exception as e:
78
- logger.error(f"Failed to create symptom index: {str(e)}")
79
- raise
80
-
81
  # Move this outside the main() function
82
  symptom_index = None
83
 
 
2
  import json
3
  import sys
4
  import os
5
+ from ..services.indexing import create_symptom_index
 
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
 
55
 
56
  print(f"Wrote {len(icd_to_description)} code entries to {out_path}")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Move this outside the main() function
59
  symptom_index = None
60
 
utils/voice_input_utils.py CHANGED
@@ -10,6 +10,53 @@ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base
10
  tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
11
  processor = WhisperProcessor(feature_extractor, tokenizer)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def get_asr_pipeline():
14
  """Lazy load ASR pipeline with proper configuration."""
15
  global transcriber
@@ -24,28 +71,6 @@ def get_asr_pipeline():
24
  )
25
  return transcriber
26
 
27
- def process_audio(audio_array, sample_rate):
28
- """Pre-process audio for Whisper."""
29
- if audio_array.ndim > 1:
30
- audio_array = audio_array.mean(axis=1)
31
-
32
- # Convert to tensor for resampling
33
- audio_tensor = torch.FloatTensor(audio_array)
34
-
35
- # Resample to 16kHz if needed
36
- if sample_rate != 16000:
37
- resampler = T.Resample(sample_rate, 16000)
38
- audio_tensor = resampler(audio_tensor)
39
-
40
- # Normalize
41
- audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
42
-
43
- # Convert back to numpy array and return in correct format
44
- return {
45
- "raw": audio_tensor.numpy(), # Key must be "raw"
46
- "sampling_rate": 16000 # Key must be "sampling_rate"
47
- }
48
-
49
  def process_speech(audio_data, symptom_index):
50
  """Process speech input and convert to text."""
51
  if not audio_data:
 
10
  tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
11
  processor = WhisperProcessor(feature_extractor, tokenizer)
12
 
13
+ # Update transcription handler
14
+ def update_live_transcription(audio):
15
+ """Real-time transcription updates."""
16
+
17
+ print("update_live_transcription called with:", type(audio))
18
+
19
+ if not audio or not isinstance(audio, tuple):
20
+ return ""
21
+
22
+ try:
23
+ sample_rate, audio_array = audio
24
+
25
+ print(f"got audio tuple – sample_rate={sample_rate}, shape={audio_array.shape}")
26
+
27
+ def process_audio(audio_array, sample_rate):
28
+ """Pre-process audio for Whisper."""
29
+ if audio_array.ndim > 1:
30
+ audio_array = audio_array.mean(axis=1)
31
+
32
+ # Convert to tensor for resampling
33
+ audio_tensor = torch.FloatTensor(audio_array)
34
+
35
+ # Resample to 16kHz if needed
36
+ if sample_rate != 16000:
37
+ resampler = T.Resample(sample_rate, 16000)
38
+ audio_tensor = resampler(audio_tensor)
39
+
40
+ # Normalize
41
+ audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
42
+
43
+ # Convert back to numpy array and return in correct format
44
+ return {
45
+ "raw": audio_tensor.numpy(), # Key must be "raw"
46
+ "sampling_rate": 16000 # Key must be "sampling_rate"
47
+ }
48
+
49
+ features = process_audio(audio_array, sample_rate)
50
+
51
+ asr = get_asr_pipeline()
52
+ result = asr(features)
53
+
54
+ return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
55
+
56
+ except Exception as e:
57
+ print(f"Transcription error: {str(e)}")
58
+ return ""
59
+
60
  def get_asr_pipeline():
61
  """Lazy load ASR pipeline with proper configuration."""
62
  global transcriber
 
71
  )
72
  return transcriber
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def process_speech(audio_data, symptom_index):
75
  """Process speech input and convert to text."""
76
  if not audio_data: