gpaasch commited on
Commit
00bcf43
·
1 Parent(s): 95321db

1. Removed unsafe `.get()` calls

Browse files

2. Added proper type checking
3. Better handling of different result formats
4. More robust error handling
5. Cleaner string handling

Files changed (2) hide show
  1. requirements.txt +5 -1
  2. src/app.py +174 -55
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
  # Core dependencies
2
  gradio[full]>=5.33.0
3
- gradio[mcp]>=5.33.0
 
 
 
4
 
5
  # LLM and embeddings
6
  llama-index>=0.9.0
@@ -10,6 +13,7 @@ sentence-transformers>=2.2.0
10
 
11
  # Audio processing
12
  ffmpeg-python
 
13
 
14
  # System utilities
15
  psutil
 
1
  # Core dependencies
2
  gradio[full]>=5.33.0
3
+ transformers>=4.37.0
4
+ torch>=2.2.0
5
+ torchaudio>=2.2.0
6
+ numpy>=1.24.0
7
 
8
  # LLM and embeddings
9
  llama-index>=0.9.0
 
13
 
14
  # Audio processing
15
  ffmpeg-python
16
+ librosa>=0.10.1
17
 
18
  # System utilities
19
  psutil
src/app.py CHANGED
@@ -13,6 +13,9 @@ import torch
13
  from gtts import gTTS
14
  import io
15
  import base64
 
 
 
16
 
17
  # Model options mapped to their requirements
18
  MODEL_OPTIONS = {
@@ -36,6 +39,26 @@ MODEL_OPTIONS = {
36
  }
37
  }
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def get_system_specs() -> Dict[str, float]:
40
  """Get system specifications."""
41
  # Get RAM
@@ -169,35 +192,68 @@ or, if you have enough info, output a final JSON with fields:
169
  {"diagnoses":[…], "confidences":[…]}.
170
  """
171
 
172
- def process_speech(audio_path, history):
173
  """Process speech input and convert to text."""
174
  try:
175
- if not audio_path:
176
  return []
177
 
178
- # Extract just the transcribed text if it's a tuple
179
- transcript = audio_path[1] if isinstance(audio_path, tuple) else audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Query the symptom index
182
- diagnosis_query = f"""
183
- Given these symptoms: '{transcript}'
184
- Identify the most likely ICD-10 diagnoses and key questions.
185
- Focus on clinical implications.
186
- """
187
-
188
- response = symptom_index.as_query_engine().query(diagnosis_query)
189
-
190
- return [
191
- {"role": "user", "content": transcript},
192
- {"role": "assistant", "content": json.dumps({
193
- "diagnoses": [],
194
- "confidences": [],
195
- "follow_up": str(response)
196
- })}
197
- ]
198
-
199
  except Exception as e:
200
- print(f"Error processing speech: {e}")
201
  return []
202
 
203
  def update_transcription(audio_path):
@@ -240,8 +296,10 @@ with gr.Blocks(
240
  # Moved microphone row above chatbot
241
  with gr.Row():
242
  microphone = gr.Audio(
243
- label="Describe your symptoms",
244
- streaming=True
 
 
245
  )
246
  transcript_box = gr.Textbox(
247
  label="Transcribed Text",
@@ -296,49 +354,110 @@ with gr.Blocks(
296
  return result.strip()
297
 
298
  def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7):
299
- """Handle speech processing and chat formatting."""
300
  if not audio_path:
301
  return history
302
-
303
- # Process the new audio input
304
- new_messages = process_speech(audio_path, history)
305
- if not new_messages:
306
- return history
307
-
308
  try:
309
- # Format last assistant response
310
- assistant_response = new_messages[-1]["content"]
311
- response_dict = json.loads(assistant_response)
312
- formatted_text = format_response_for_user(response_dict)
313
-
314
- # Add to history with proper message format
315
- return history + [
316
- {"role": "user", "content": new_messages[0]["content"]},
317
- {"role": "assistant", "content": formatted_text}
318
- ]
319
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  except Exception as e:
321
- print(f"Error formatting response: {e}")
322
  return history
323
 
324
  microphone.stream(
325
  fn=enhanced_process_speech,
326
- inputs=[
327
- microphone,
328
- chatbot,
329
- api_key,
330
- model_selector,
331
- temperature
332
- ],
333
  outputs=chatbot,
334
- show_progress="hidden"
 
 
335
  )
336
 
337
- microphone.stream( # Add real-time transcription updates
338
- fn=update_transcription,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  inputs=[microphone],
340
  outputs=transcript_box,
341
- show_progress="hidden"
 
342
  )
343
 
344
  clear_btn.click(
 
13
  from gtts import gTTS
14
  import io
15
  import base64
16
+ import numpy as np
17
+ from transformers.pipelines import pipeline # Changed from transformers import pipeline
18
+ from transformers import WhisperProcessor
19
 
20
  # Model options mapped to their requirements
21
  MODEL_OPTIONS = {
 
39
  }
40
  }
41
 
42
+ # Initialize Whisper with proper configuration
43
+ transcriber = pipeline(
44
+ "automatic-speech-recognition",
45
+ model="openai/whisper-base.en",
46
+ chunk_length_s=30,
47
+ stride_length_s=5,
48
+ return_timestamps=True,
49
+ device="cpu", # Explicitly set to CPU since we're seeing GPU warnings
50
+ torch_dtype=torch.float32,
51
+ generate_kwargs={
52
+ "task": "transcribe",
53
+ "language": "en",
54
+ "use_cache": True,
55
+ "return_timestamps": True
56
+ }
57
+ )
58
+
59
+ # Create processor for proper attention mask
60
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base.en")
61
+
62
  def get_system_specs() -> Dict[str, float]:
63
  """Get system specifications."""
64
  # Get RAM
 
192
  {"diagnoses":[…], "confidences":[…]}.
193
  """
194
 
195
+ def process_speech(audio_data, history):
196
  """Process speech input and convert to text."""
197
  try:
198
+ if not audio_data:
199
  return []
200
 
201
+ if isinstance(audio_data, tuple) and len(audio_data) == 2:
202
+ sample_rate, audio_array = audio_data
203
+
204
+ # Audio preprocessing
205
+ if audio_array.ndim > 1:
206
+ audio_array = audio_array.mean(axis=1)
207
+ audio_array = audio_array.astype(np.float32)
208
+ audio_array /= np.max(np.abs(audio_array))
209
+
210
+ # Transcribe with error handling
211
+ try:
212
+ result = transcriber(
213
+ {"sampling_rate": sample_rate, "raw": audio_array},
214
+ batch_size=8
215
+ )
216
+
217
+ # Handle different result types
218
+ if isinstance(result, dict) and "text" in result:
219
+ transcript = result["text"].strip()
220
+ elif isinstance(result, str):
221
+ transcript = result.strip()
222
+ else:
223
+ print(f"Unexpected transcriber result type: {type(result)}")
224
+ return []
225
+
226
+ if not transcript:
227
+ print("No transcription generated")
228
+ return []
229
+
230
+ # Query symptoms with transcribed text
231
+ diagnosis_query = f"""
232
+ Given these symptoms: '{transcript}'
233
+ Identify the most likely ICD-10 diagnoses and key questions.
234
+ Focus on clinical implications.
235
+ """
236
+
237
+ response = symptom_index.as_query_engine().query(diagnosis_query)
238
+
239
+ return [
240
+ {"role": "user", "content": transcript},
241
+ {"role": "assistant", "content": json.dumps({
242
+ "diagnoses": [],
243
+ "confidences": [],
244
+ "follow_up": str(response)
245
+ })}
246
+ ]
247
+
248
+ except Exception as e:
249
+ print(f"Transcription error: {str(e)}")
250
+ return []
251
+ else:
252
+ print(f"Invalid audio format: {type(audio_data)}")
253
+ return []
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  except Exception as e:
256
+ print(f"Processing error: {str(e)}")
257
  return []
258
 
259
  def update_transcription(audio_path):
 
296
  # Moved microphone row above chatbot
297
  with gr.Row():
298
  microphone = gr.Audio(
299
+ sources=["microphone"],
300
+ streaming=True,
301
+ type="numpy",
302
+ label="Describe your symptoms"
303
  )
304
  transcript_box = gr.Textbox(
305
  label="Transcribed Text",
 
354
  return result.strip()
355
 
356
  def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7):
357
+ """Handle streaming speech processing and chat updates."""
358
  if not audio_path:
359
  return history
360
+
 
 
 
 
 
361
  try:
362
+ # Process audio stream
363
+ if isinstance(audio_path, tuple) and len(audio_path) == 2:
364
+ sample_rate, audio_array = audio_path
365
+
366
+ # Audio preprocessing
367
+ if audio_array.ndim > 1:
368
+ audio_array = audio_array.mean(axis=1)
369
+ audio_array = audio_array.astype(np.float32)
370
+ audio_array /= np.max(np.abs(audio_array))
371
+
372
+ # Get transcription from Whisper
373
+ result = transcriber(
374
+ {"sampling_rate": sample_rate, "raw": audio_array},
375
+ batch_size=8,
376
+ return_timestamps=True
377
+ )
378
+
379
+ # Handle different result types
380
+ transcript = ""
381
+ if isinstance(result, dict):
382
+ transcript = result.get("text", "")
383
+ elif isinstance(result, str):
384
+ transcript = result
385
+ elif isinstance(result, (list, tuple)) and len(result) > 0:
386
+ transcript = str(result[0])
387
+ else:
388
+ print(f"Unexpected transcriber result type: {type(result)}")
389
+ return history
390
+
391
+ transcript = transcript.strip()
392
+ if not transcript:
393
+ return history
394
+
395
+ # Process the symptoms
396
+ diagnosis_query = f"""
397
+ Based on these symptoms: '{transcript}'
398
+ Provide relevant ICD-10 codes and diagnostic questions.
399
+ """
400
+ response = symptom_index.as_query_engine().query(diagnosis_query)
401
+
402
+ # Format and return chat messages
403
+ return history + [
404
+ {"role": "user", "content": transcript},
405
+ {"role": "assistant", "content": format_response_for_user({
406
+ "diagnoses": [],
407
+ "confidences": [],
408
+ "follow_up": str(response)
409
+ })}
410
+ ]
411
+
412
  except Exception as e:
413
+ print(f"Streaming error: {str(e)}")
414
  return history
415
 
416
  microphone.stream(
417
  fn=enhanced_process_speech,
418
+ inputs=[microphone, chatbot, api_key, model_selector, temperature],
 
 
 
 
 
 
419
  outputs=chatbot,
420
+ show_progress="hidden",
421
+ api_name=False,
422
+ queue=True # Enable queuing for better stream handling
423
  )
424
 
425
+ # Update transcription handler
426
+ def update_live_transcription(audio):
427
+ """Real-time transcription updates."""
428
+ if not audio or not isinstance(audio, tuple):
429
+ return ""
430
+
431
+ try:
432
+ sample_rate, audio_array = audio
433
+ if audio_array.ndim > 1:
434
+ audio_array = audio_array.mean(axis=1)
435
+ audio_array = audio_array.astype(np.float32)
436
+ audio_array /= np.max(np.abs(audio_array))
437
+
438
+ result = transcriber(
439
+ {"sampling_rate": sample_rate, "raw": audio_array}
440
+ )
441
+
442
+ # Handle different result types
443
+ if isinstance(result, dict):
444
+ return result.get("text", "").strip()
445
+ elif isinstance(result, str):
446
+ return result.strip()
447
+ elif isinstance(result, (list, tuple)) and len(result) > 0:
448
+ return str(result[0]).strip()
449
+ return ""
450
+
451
+ except Exception as e:
452
+ print(f"Transcription error: {str(e)}")
453
+ return ""
454
+
455
+ microphone.stream(
456
+ fn=update_live_transcription,
457
  inputs=[microphone],
458
  outputs=transcript_box,
459
+ show_progress="hidden",
460
+ queue=True
461
  )
462
 
463
  clear_btn.click(