MedCodeMCP / src /app.py
gpaasch's picture
t-shooting poor performance
d364129
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
import gradio as gr
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.llama_cpp import LlamaCPP
from .parse_tabular import create_symptom_index # Use relative import
import json
import psutil
from typing import Tuple, Dict
import torch
from gtts import gTTS
import io
import base64
import numpy as np
from transformers.pipelines import pipeline # Changed from transformers import pipeline
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
import torchaudio
import torchaudio.transforms as T
# Model options mapped to their requirements
MODEL_OPTIONS = {
"tiny": {
"name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf",
"repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
"vram_req": 2, # GB
"ram_req": 4 # GB
},
"small": {
"name": "phi-2.Q4_K_M.gguf",
"repo": "TheBloke/phi-2-GGUF",
"vram_req": 4,
"ram_req": 8
},
"medium": {
"name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf",
"repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"vram_req": 6,
"ram_req": 16
}
}
# Initialize Whisper components globally (these are lightweight)
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
processor = WhisperProcessor(feature_extractor, tokenizer)
def get_asr_pipeline():
"""Lazy load ASR pipeline with proper configuration."""
global transcriber
if "transcriber" not in globals():
transcriber = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base.en",
chunk_length_s=30,
stride_length_s=5,
device="cpu",
torch_dtype=torch.float32
)
return transcriber
# Audio preprocessing function
def process_audio(audio_array, sample_rate):
"""Pre-process audio for Whisper."""
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Convert to tensor for resampling
audio_tensor = torch.FloatTensor(audio_array)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
# Normalize
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
# Convert back to numpy array and return in correct format
return {
"raw": audio_tensor.numpy(), # Key must be "raw"
"sampling_rate": 16000 # Key must be "sampling_rate"
}
def get_system_specs() -> Dict[str, float]:
"""Get system specifications."""
# Get RAM
ram_gb = psutil.virtual_memory().total / (1024**3)
# Get GPU info if available
gpu_vram_gb = 0
if torch.cuda.is_available():
try:
# Query GPU memory in bytes and convert to GB
gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except Exception as e:
print(f"Warning: Could not get GPU memory: {e}")
return {
"ram_gb": ram_gb,
"gpu_vram_gb": gpu_vram_gb
}
def select_best_model() -> Tuple[str, str]:
"""Select the best model based on system specifications."""
specs = get_system_specs()
print(f"\nSystem specifications:")
print(f"RAM: {specs['ram_gb']:.1f} GB")
print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB")
# Prioritize GPU if available
if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work
model_tier = "small" # phi-2 should work well on RTX 2060
elif specs['ram_gb'] >= 8:
model_tier = "small"
else:
model_tier = "tiny"
selected = MODEL_OPTIONS[model_tier]
print(f"\nSelected model tier: {model_tier}")
print(f"Model: {selected['name']}")
return selected['name'], selected['repo']
# Set up model paths
MODEL_NAME, REPO_ID = select_best_model()
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "models")
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
from typing import Optional
def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str:
"""Ensures model is available, downloading only if needed."""
# Determine environment and set cache directory
if os.path.exists("/home/user"):
# HF Space environment
cache_dir = "/home/user/.cache/models"
else:
# Local development environment
cache_dir = os.path.join(BASE_DIR, "models")
# Create cache directory if it doesn't exist
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as e:
print(f"Warning: Could not create cache directory {cache_dir}: {e}")
# Fall back to temporary directory if needed
cache_dir = os.path.join("/tmp", "models")
os.makedirs(cache_dir, exist_ok=True)
# Get model details
if not model_name or not repo_id:
model_option = MODEL_OPTIONS["small"] # default to small model
model_name = model_option["name"]
repo_id = model_option["repo"]
# Ensure model_name and repo_id are not None
if model_name is None:
raise ValueError("model_name cannot be None")
if repo_id is None:
raise ValueError("repo_id cannot be None")
# Check if model already exists in cache
model_path = os.path.join(cache_dir, model_name)
if os.path.exists(model_path):
print(f"\nUsing cached model: {model_path}")
return model_path
print(f"\nDownloading model {model_name} from {repo_id}...")
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=model_name,
cache_dir=cache_dir,
local_dir=cache_dir
)
print(f"Model downloaded successfully to {model_path}")
return model_path
except Exception as e:
print(f"Error downloading model: {str(e)}")
raise
# Ensure model is downloaded
model_path = ensure_model()
# Configure local LLM with LlamaCPP
print("\nInitializing LLM...")
llm = LlamaCPP(
model_path=model_path,
temperature=0.7,
max_new_tokens=256,
context_window=2048,
verbose=False # Reduce logging
# n_batch and n_threads are not valid parameters for LlamaCPP and should not be used.
# If you encounter segmentation faults, try reducing context_window or check your system resources.
)
print("LLM initialized successfully")
# Configure global settings
print("\nConfiguring settings...")
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
print("Settings configured")
# Create the index at startup
print("\nCreating symptom index...")
symptom_index = create_symptom_index()
print("Index created successfully")
print("Loaded symptom_index:", type(symptom_index))
# --- System prompt ---
SYSTEM_PROMPT = """
You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
or, if you have enough info, output a final JSON with fields:
{"diagnoses":[…], "confidences":[…]}.
"""
def process_speech(audio_data, history):
"""Process speech input and convert to text."""
try:
if not audio_data:
return []
if isinstance(audio_data, tuple) and len(audio_data) == 2:
sample_rate, audio_array = audio_data
# Audio preprocessing
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
audio_array = audio_array.astype(np.float32)
audio_array /= np.max(np.abs(audio_array))
# Ensure correct sampling rate
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = torch.FloatTensor(audio_array)
audio_tensor = resampler(audio_tensor)
audio_array = audio_tensor.numpy()
sample_rate = 16000
# Transcribe with error handling
# Format dictionary correctly with required keys
input_features = {
"raw": audio_array,
"sampling_rate": sample_rate
}
result = transcriber(input_features)
# Handle different result types
if isinstance(result, dict) and "text" in result:
transcript = result["text"].strip()
elif isinstance(result, str):
transcript = result.strip()
else:
print(f"Unexpected transcriber result type: {type(result)}")
return []
if not transcript:
print("No transcription generated")
return []
# Query symptoms with transcribed text
diagnosis_query = f"""
Given these symptoms: '{transcript}'
Identify the most likely ICD-10 diagnoses and key questions.
Focus on clinical implications.
"""
response = symptom_index.as_query_engine().query(diagnosis_query)
return [
{"role": "user", "content": transcript},
{"role": "assistant", "content": json.dumps({
"diagnoses": [],
"confidences": [],
"follow_up": str(response)
})}
]
else:
print(f"Invalid audio format: {type(audio_data)}")
return []
except Exception as e:
print(f"Processing error: {str(e)}")
return []
# Build enhanced Gradio interface
with gr.Blocks(theme="default") as demo:
gr.Markdown("""
# 🏥 Medical Symptom to ICD-10 Code Assistant
## About
This application is part of the Agents+MCP Hackathon. It helps medical professionals
and patients understand potential diagnoses based on described symptoms.
### How it works:
1. Either click the record button and describe your symptoms or type them into the textbox
2. The AI will analyze your description and suggest possible diagnoses
3. Answer follow-up questions to refine the diagnosis
""")
with gr.Row():
with gr.Column(scale=2):
# Add text input above microphone
with gr.Row():
text_input = gr.Textbox(
label="Type your symptoms",
placeholder="Or type your symptoms here...",
lines=3
)
submit_btn = gr.Button("Submit", variant="primary")
# Existing microphone row
with gr.Row():
microphone = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
label="Describe your symptoms"
)
transcript_box = gr.Textbox(
label="Transcribed Text",
interactive=False,
show_label=True
)
clear_btn = gr.Button("Clear Chat", variant="secondary")
chatbot = gr.Chatbot(
label="Medical Consultation",
height=500,
container=True,
type="messages" # This is now properly supported by our message format
)
with gr.Column(scale=1):
with gr.Accordion("Advanced Settings", open=False):
api_key = gr.Textbox(
label="OpenAI API Key (optional)",
type="password",
placeholder="sk-..."
)
with gr.Row():
with gr.Column():
modal_key = gr.Textbox(
label="Modal Labs API Key",
type="password",
placeholder="mk-..."
)
anthropic_key = gr.Textbox(
label="Anthropic API Key",
type="password",
placeholder="sk-ant-..."
)
mistral_key = gr.Textbox(
label="MistralAI API Key",
type="password",
placeholder="..."
)
with gr.Column():
nebius_key = gr.Textbox(
label="Nebius API Key",
type="password",
placeholder="..."
)
hyperbolic_key = gr.Textbox(
label="Hyperbolic Labs API Key",
type="password",
placeholder="hyp-..."
)
sambanova_key = gr.Textbox(
label="SambaNova API Key",
type="password",
placeholder="..."
)
with gr.Row():
model_selector = gr.Dropdown(
choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"],
value="OpenAI",
label="Model Provider"
)
temperature = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
label="Temperature"
)
# self promotion at bottom of page
gr.Markdown("""
---
### 👋 About the Creator
Hi! I'm Graham Paasch, an experienced technology professional!
🎥 **Check out my YouTube channel** for more tech content:
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
💼 **Looking for a skilled developer?**
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
⭐ If you found this tool helpful, please consider:
- Subscribing to my YouTube channel
- Connecting on LinkedIn
- Sharing this tool with others in healthcare tech
""")
# Event handlers
clear_btn.click(lambda: None, None, chatbot, queue=False)
def format_response_for_user(response_dict):
"""Format the assistant's response dictionary into a user-friendly string."""
diagnoses = response_dict.get("diagnoses", [])
confidences = response_dict.get("confidences", [])
follow_up = response_dict.get("follow_up", "")
result = ""
if diagnoses:
result += "Possible Diagnoses:\n"
for i, diag in enumerate(diagnoses):
conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else ""
result += f"- {diag}{conf}\n"
if follow_up:
result += f"\nFollow-up: {follow_up}"
return result.strip()
def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7):
"""Handle streaming speech processing and chat updates."""
transcriber = get_asr_pipeline()
if not audio_path:
return history
try:
if isinstance(audio_path, tuple) and len(audio_path) == 2:
sample_rate, audio_array = audio_path
# Audio preprocessing
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
audio_array = audio_array.astype(np.float32)
audio_array /= np.max(np.abs(audio_array))
# Ensure correct sampling rate
if sample_rate != 16000:
resampler = T.Resample(
orig_freq=sample_rate,
new_freq=16000
)
audio_tensor = torch.FloatTensor(audio_array)
audio_tensor = resampler(audio_tensor)
audio_array = audio_tensor.numpy()
sample_rate = 16000
# Format input dictionary exactly as required
transcriber_input = {
"raw": audio_array,
"sampling_rate": sample_rate
}
# Get transcription from Whisper
result = transcriber(transcriber_input)
# Extract text from result
transcript = ""
if isinstance(result, dict):
transcript = result.get("text", "").strip()
elif isinstance(result, str):
transcript = result.strip()
if not transcript:
return history
# Process the symptoms
diagnosis_query = f"""
Based on these symptoms: '{transcript}'
Provide relevant ICD-10 codes and diagnostic questions.
"""
response = symptom_index.as_query_engine().query(diagnosis_query)
# Format and return chat messages
return history + [
{"role": "user", "content": transcript},
{"role": "assistant", "content": format_response_for_user({
"diagnoses": [],
"confidences": [],
"follow_up": str(response)
})}
]
except Exception as e:
print(f"Streaming error: {str(e)}")
return history
microphone.stream(
fn=enhanced_process_speech,
inputs=[microphone, chatbot, api_key, model_selector, temperature],
outputs=chatbot,
show_progress="hidden",
api_name=False,
queue=True # Enable queuing for better stream handling
)
def process_audio(audio_array, sample_rate):
"""Pre-process audio for Whisper."""
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Convert to tensor for resampling
audio_tensor = torch.FloatTensor(audio_array)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
# Normalize
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
# Convert back to numpy array and return in correct format
return {
"raw": audio_tensor.numpy(), # Key must be "raw"
"sampling_rate": 16000 # Key must be "sampling_rate"
}
# Update transcription handler
def update_live_transcription(audio):
"""Real-time transcription updates."""
if not audio or not isinstance(audio, tuple):
return ""
try:
sample_rate, audio_array = audio
features = process_audio(audio_array, sample_rate)
asr = get_asr_pipeline()
result = asr(features)
return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
except Exception as e:
print(f"Transcription error: {str(e)}")
return ""
microphone.stream(
fn=update_live_transcription,
inputs=[microphone],
outputs=transcript_box,
show_progress="hidden",
queue=True
)
clear_btn.click(
fn=lambda: (None, "", ""),
outputs=[chatbot, transcript_box, text_input],
queue=False
)
def cleanup_memory():
"""Release unused memory (placeholder for future memory management)."""
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_text_input(text, history):
"""Process text input with memory management."""
print("process_text_input received:", text)
if not text:
return history, "" # Return tuple to clear input
try:
# Process the symptoms using the configured LLM
prompt = f"""Given these symptoms: '{text}'
Please provide:
1. Most likely ICD-10 codes
2. Confidence levels for each diagnosis
3. Key follow-up questions
Format as JSON with diagnoses, confidences, and follow_up fields."""
response = llm.complete(prompt)
try:
# Try to parse as JSON first
result = json.loads(response.text)
except json.JSONDecodeError:
# If not JSON, wrap in our format
result = {
"diagnoses": [],
"confidences": [],
"follow_up": str(response.text)[:1000] # Limit response length
}
new_history = history + [
{"role": "user", "content": text},
{"role": "assistant", "content": format_response_for_user(result)}
]
return new_history, "" # Return empty string to clear input
except Exception as e:
print(f"Error processing text: {str(e)}")
return history, text # Keep text on error
# Update the submit button handler
submit_btn.click(
fn=process_text_input,
inputs=[text_input, chatbot],
outputs=[chatbot, text_input],
queue=True
).success( # Changed from .then to .success for better error handling
fn=cleanup_memory,
inputs=None,
outputs=None,
queue=False
)