gpaasch commited on
Commit
8a3861b
·
1 Parent(s): f663376

no need for gradio live

Browse files
Files changed (2) hide show
  1. app.py +5 -1
  2. src/app.py +51 -82
app.py CHANGED
@@ -2,4 +2,8 @@
2
  from src.app import demo
3
 
4
  if __name__ == "__main__":
5
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
2
  from src.app import demo
3
 
4
  if __name__ == "__main__":
5
+ demo.launch(
6
+ server_name="0.0.0.0",
7
+ server_port=7860,
8
+ show_api=True # Shows the API documentation
9
+ )
src/app.py CHANGED
@@ -66,43 +66,22 @@ def process_audio(audio_array, sample_rate):
66
  if audio_array.ndim > 1:
67
  audio_array = audio_array.mean(axis=1)
68
 
69
- # Normalize audio
70
- audio_array = audio_array.astype(np.float32)
71
- audio_array /= np.max(np.abs(audio_array))
72
 
73
  # Resample to 16kHz if needed
74
  if sample_rate != 16000:
75
- resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
76
- audio_tensor = torch.FloatTensor(audio_array)
77
  audio_tensor = resampler(audio_tensor)
78
- audio_array = audio_tensor.numpy()
79
 
80
- # Process with correct input format
81
- inputs = processor(
82
- audio_array,
83
- sampling_rate=16000,
84
- return_tensors="pt"
85
- )
86
 
 
87
  return {
88
- "input_features": inputs.input_features,
89
- "attention_mask": inputs.attention_mask
90
- }
91
-
92
- # Update transcriber configuration
93
- transcriber = pipeline(
94
- "automatic-speech-recognition",
95
- model="openai/whisper-base.en",
96
- chunk_length_s=30,
97
- stride_length_s=5,
98
- device="cpu",
99
- torch_dtype=torch.float32,
100
- feature_extractor=feature_extractor,
101
- generate_kwargs={
102
- "use_cache": True,
103
- "return_timestamps": True
104
  }
105
- )
106
 
107
  def get_system_specs() -> Dict[str, float]:
108
  """Get system specifications."""
@@ -312,14 +291,6 @@ def process_speech(audio_data, history):
312
  print(f"Processing error: {str(e)}")
313
  return []
314
 
315
- def update_transcription(audio_path):
316
- """Update transcription box with speech recognition results."""
317
- if not audio_path:
318
- return ""
319
- # Extract transcription from audio result
320
- transcript = audio_path[1] if isinstance(audio_path, tuple) else audio_path
321
- return transcript
322
-
323
  # Build enhanced Gradio interface
324
  with gr.Blocks(
325
  theme="default",
@@ -332,7 +303,9 @@ with gr.Blocks(
332
  font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
333
  'Liberation Mono', 'Courier New', monospace;
334
  }
335
- """
 
 
336
  ) as demo:
337
  gr.Markdown("""
338
  # 🏥 Medical Symptom to ICD-10 Code Assistant
@@ -509,16 +482,10 @@ with gr.Blocks(
509
  # Normalize
510
  audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
511
 
512
- # Use feature extractor with correct sampling rate
513
- features = feature_extractor(
514
- audio_tensor.numpy(),
515
- sampling_rate=16000, # Always use 16kHz
516
- return_tensors="pt"
517
- )
518
-
519
  return {
520
- "input_features": features.input_features,
521
- "sampling_rate": 16000 # Return resampled rate
522
  }
523
 
524
  # Update transcription handler
@@ -527,17 +494,16 @@ with gr.Blocks(
527
  if not audio or not isinstance(audio, tuple):
528
  return ""
529
 
 
530
  sample_rate, audio_array = audio
531
  features = process_audio(audio_array, sample_rate)
532
 
533
- # Get pipeline and transcribe
534
  asr = get_asr_pipeline()
535
  result = asr(features)
536
 
537
- if isinstance(result, dict):
538
- return result.get("text", "").strip()
539
- elif isinstance(result, str):
540
- return result.strip()
541
  return ""
542
 
543
  microphone.stream(
@@ -566,30 +532,30 @@ with gr.Blocks(
566
  if not text:
567
  return history
568
 
569
- # Limit input length
570
- if len(text) > 500:
571
- text = text[:500] + "..."
572
-
573
- # Process the symptoms
574
- diagnosis_query = f"""
575
- Based on these symptoms: '{text}'
576
- Provide relevant ICD-10 codes and diagnostic questions.
577
- Focus on clinical implications.
578
- Limit response to 1000 characters.
579
- """
580
- response = symptom_index.as_query_engine().query(diagnosis_query)
581
 
582
- # Clean up memory
583
- cleanup_memory()
584
-
585
- return history + [
586
- {"role": "user", "content": text},
587
- {"role": "assistant", "content": format_response_for_user({
588
- "diagnoses": [],
589
- "confidences": [],
590
- "follow_up": str(response)[:1000] # Limit response length
591
- })}
592
- ]
 
 
 
 
 
 
 
 
 
593
 
594
  submit_btn.click(
595
  fn=process_text_input,
@@ -617,10 +583,13 @@ with gr.Blocks(
617
  - Sharing this tool with others in healthcare tech
618
  """)
619
 
620
- if __name__ == "__main__":
621
- demo.launch(
622
- server_name="0.0.0.0",
623
- server_port=7860,
624
- mcp_server=True,
625
- allowed_paths=["*"]
626
- )
 
 
 
 
66
  if audio_array.ndim > 1:
67
  audio_array = audio_array.mean(axis=1)
68
 
69
+ # Convert to tensor for resampling
70
+ audio_tensor = torch.FloatTensor(audio_array)
 
71
 
72
  # Resample to 16kHz if needed
73
  if sample_rate != 16000:
74
+ resampler = T.Resample(sample_rate, 16000)
 
75
  audio_tensor = resampler(audio_tensor)
 
76
 
77
+ # Normalize
78
+ audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
 
 
 
 
79
 
80
+ # Convert back to numpy array and return in correct format
81
  return {
82
+ "raw": audio_tensor.numpy(), # Key must be "raw"
83
+ "sampling_rate": 16000 # Key must be "sampling_rate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  }
 
85
 
86
  def get_system_specs() -> Dict[str, float]:
87
  """Get system specifications."""
 
291
  print(f"Processing error: {str(e)}")
292
  return []
293
 
 
 
 
 
 
 
 
 
294
  # Build enhanced Gradio interface
295
  with gr.Blocks(
296
  theme="default",
 
303
  font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
304
  'Liberation Mono', 'Courier New', monospace;
305
  }
306
+ """,
307
+ analytics_enabled=True,
308
+ title="MedCode MCP",
309
  ) as demo:
310
  gr.Markdown("""
311
  # 🏥 Medical Symptom to ICD-10 Code Assistant
 
482
  # Normalize
483
  audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
484
 
485
+ # Convert back to numpy array and return in correct format
 
 
 
 
 
 
486
  return {
487
+ "raw": audio_tensor.numpy(), # Key must be "raw"
488
+ "sampling_rate": 16000 # Key must be "sampling_rate"
489
  }
490
 
491
  # Update transcription handler
 
494
  if not audio or not isinstance(audio, tuple):
495
  return ""
496
 
497
+ try:
498
  sample_rate, audio_array = audio
499
  features = process_audio(audio_array, sample_rate)
500
 
 
501
  asr = get_asr_pipeline()
502
  result = asr(features)
503
 
504
+ return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
505
+ except Exception as e:
506
+ print(f"Transcription error: {str(e)}")
 
507
  return ""
508
 
509
  microphone.stream(
 
532
  if not text:
533
  return history
534
 
535
+ # Limit input length
536
+ if len(text) > 500:
537
+ text = text[:500] + "..."
 
 
 
 
 
 
 
 
 
538
 
539
+ # Process the symptoms
540
+ diagnosis_query = f"""
541
+ Based on these symptoms: '{text}'
542
+ Provide relevant ICD-10 codes and diagnostic questions.
543
+ Focus on clinical implications.
544
+ Limit response to 1000 characters.
545
+ """
546
+ response = symptom_index.as_query_engine().query(diagnosis_query)
547
+
548
+ # Clean up memory
549
+ cleanup_memory()
550
+
551
+ return history + [
552
+ {"role": "user", "content": text},
553
+ {"role": "assistant", "content": format_response_for_user({
554
+ "diagnoses": [],
555
+ "confidences": [],
556
+ "follow_up": str(response)[:1000] # Limit response length
557
+ })}
558
+ ]
559
 
560
  submit_btn.click(
561
  fn=process_text_input,
 
583
  - Sharing this tool with others in healthcare tech
584
  """)
585
 
586
+ def process_symptoms(symptoms: str):
587
+ """Convert symptoms to ICD codes using the configured LLM"""
588
+ try:
589
+ # Use the configured LLM to process symptoms
590
+ response = llm.complete(
591
+ f"Convert these symptoms to ICD-10 codes: {symptoms}"
592
+ )
593
+ return {"icd_codes": response.text, "status": "success"}
594
+ except Exception as e:
595
+ return {"error": str(e), "status": "error"}