Spaces:
Running
Running
import os | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
from llama_index.core import Settings | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.llama_cpp import LlamaCPP | |
from .parse_tabular import create_symptom_index # Use relative import | |
import json | |
import psutil | |
from typing import Tuple, Dict | |
import torch | |
from gtts import gTTS | |
import io | |
import base64 | |
import numpy as np | |
from transformers.pipelines import pipeline # Changed from transformers import pipeline | |
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor | |
import torchaudio | |
import torchaudio.transforms as T | |
# Model options mapped to their requirements | |
MODEL_OPTIONS = { | |
"tiny": { | |
"name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf", | |
"repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", | |
"vram_req": 2, # GB | |
"ram_req": 4 # GB | |
}, | |
"small": { | |
"name": "phi-2.Q4_K_M.gguf", | |
"repo": "TheBloke/phi-2-GGUF", | |
"vram_req": 4, | |
"ram_req": 8 | |
}, | |
"medium": { | |
"name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf", | |
"repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", | |
"vram_req": 6, | |
"ram_req": 16 | |
} | |
} | |
# Initialize Whisper components globally (these are lightweight) | |
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en") | |
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en") | |
processor = WhisperProcessor(feature_extractor, tokenizer) | |
def get_asr_pipeline(): | |
"""Lazy load ASR pipeline with proper configuration.""" | |
global transcriber | |
if "transcriber" not in globals(): | |
transcriber = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-base.en", | |
chunk_length_s=30, | |
stride_length_s=5, | |
device="cpu", | |
torch_dtype=torch.float32 | |
) | |
return transcriber | |
# Audio preprocessing function | |
def process_audio(audio_array, sample_rate): | |
"""Pre-process audio for Whisper.""" | |
if audio_array.ndim > 1: | |
audio_array = audio_array.mean(axis=1) | |
# Convert to tensor for resampling | |
audio_tensor = torch.FloatTensor(audio_array) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
resampler = T.Resample(sample_rate, 16000) | |
audio_tensor = resampler(audio_tensor) | |
# Normalize | |
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) | |
# Convert back to numpy array and return in correct format | |
return { | |
"raw": audio_tensor.numpy(), # Key must be "raw" | |
"sampling_rate": 16000 # Key must be "sampling_rate" | |
} | |
def get_system_specs() -> Dict[str, float]: | |
"""Get system specifications.""" | |
# Get RAM | |
ram_gb = psutil.virtual_memory().total / (1024**3) | |
# Get GPU info if available | |
gpu_vram_gb = 0 | |
if torch.cuda.is_available(): | |
try: | |
# Query GPU memory in bytes and convert to GB | |
gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
except Exception as e: | |
print(f"Warning: Could not get GPU memory: {e}") | |
return { | |
"ram_gb": ram_gb, | |
"gpu_vram_gb": gpu_vram_gb | |
} | |
def select_best_model() -> Tuple[str, str]: | |
"""Select the best model based on system specifications.""" | |
specs = get_system_specs() | |
print(f"\nSystem specifications:") | |
print(f"RAM: {specs['ram_gb']:.1f} GB") | |
print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB") | |
# Prioritize GPU if available | |
if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work | |
model_tier = "small" # phi-2 should work well on RTX 2060 | |
elif specs['ram_gb'] >= 8: | |
model_tier = "small" | |
else: | |
model_tier = "tiny" | |
selected = MODEL_OPTIONS[model_tier] | |
print(f"\nSelected model tier: {model_tier}") | |
print(f"Model: {selected['name']}") | |
return selected['name'], selected['repo'] | |
# Set up model paths | |
MODEL_NAME, REPO_ID = select_best_model() | |
BASE_DIR = os.path.dirname(os.path.dirname(__file__)) | |
MODEL_DIR = os.path.join(BASE_DIR, "models") | |
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) | |
from typing import Optional | |
def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str: | |
"""Ensures model is available, downloading only if needed.""" | |
# Determine environment and set cache directory | |
if os.path.exists("/home/user"): | |
# HF Space environment | |
cache_dir = "/home/user/.cache/models" | |
else: | |
# Local development environment | |
cache_dir = os.path.join(BASE_DIR, "models") | |
# Create cache directory if it doesn't exist | |
try: | |
os.makedirs(cache_dir, exist_ok=True) | |
except Exception as e: | |
print(f"Warning: Could not create cache directory {cache_dir}: {e}") | |
# Fall back to temporary directory if needed | |
cache_dir = os.path.join("/tmp", "models") | |
os.makedirs(cache_dir, exist_ok=True) | |
# Get model details | |
if not model_name or not repo_id: | |
model_option = MODEL_OPTIONS["small"] # default to small model | |
model_name = model_option["name"] | |
repo_id = model_option["repo"] | |
# Ensure model_name and repo_id are not None | |
if model_name is None: | |
raise ValueError("model_name cannot be None") | |
if repo_id is None: | |
raise ValueError("repo_id cannot be None") | |
# Check if model already exists in cache | |
model_path = os.path.join(cache_dir, model_name) | |
if os.path.exists(model_path): | |
print(f"\nUsing cached model: {model_path}") | |
return model_path | |
print(f"\nDownloading model {model_name} from {repo_id}...") | |
try: | |
model_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=model_name, | |
cache_dir=cache_dir, | |
local_dir=cache_dir | |
) | |
print(f"Model downloaded successfully to {model_path}") | |
return model_path | |
except Exception as e: | |
print(f"Error downloading model: {str(e)}") | |
raise | |
# Ensure model is downloaded | |
model_path = ensure_model() | |
# Configure local LLM with LlamaCPP | |
print("\nInitializing LLM...") | |
llm = LlamaCPP( | |
model_path=model_path, | |
temperature=0.7, | |
max_new_tokens=256, | |
context_window=2048, | |
verbose=False # Reduce logging | |
# n_batch and n_threads are not valid parameters for LlamaCPP and should not be used. | |
# If you encounter segmentation faults, try reducing context_window or check your system resources. | |
) | |
print("LLM initialized successfully") | |
# Configure global settings | |
print("\nConfiguring settings...") | |
Settings.llm = llm | |
Settings.embed_model = HuggingFaceEmbedding( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
) | |
print("Settings configured") | |
# Create the index at startup | |
print("\nCreating symptom index...") | |
symptom_index = create_symptom_index() | |
print("Index created successfully") | |
print("Loaded symptom_index:", type(symptom_index)) | |
# --- System prompt --- | |
SYSTEM_PROMPT = """ | |
You are a medical assistant helping a user narrow down to the most likely ICD-10 code. | |
At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?") | |
or, if you have enough info, output a final JSON with fields: | |
{"diagnoses":[…], "confidences":[…]}. | |
""" | |
def process_speech(audio_data, history): | |
"""Process speech input and convert to text.""" | |
try: | |
if not audio_data: | |
return [] | |
if isinstance(audio_data, tuple) and len(audio_data) == 2: | |
sample_rate, audio_array = audio_data | |
# Audio preprocessing | |
if audio_array.ndim > 1: | |
audio_array = audio_array.mean(axis=1) | |
audio_array = audio_array.astype(np.float32) | |
audio_array /= np.max(np.abs(audio_array)) | |
# Ensure correct sampling rate | |
if sample_rate != 16000: | |
resampler = T.Resample(sample_rate, 16000) | |
audio_tensor = torch.FloatTensor(audio_array) | |
audio_tensor = resampler(audio_tensor) | |
audio_array = audio_tensor.numpy() | |
sample_rate = 16000 | |
# Transcribe with error handling | |
# Format dictionary correctly with required keys | |
input_features = { | |
"raw": audio_array, | |
"sampling_rate": sample_rate | |
} | |
result = transcriber(input_features) | |
# Handle different result types | |
if isinstance(result, dict) and "text" in result: | |
transcript = result["text"].strip() | |
elif isinstance(result, str): | |
transcript = result.strip() | |
else: | |
print(f"Unexpected transcriber result type: {type(result)}") | |
return [] | |
if not transcript: | |
print("No transcription generated") | |
return [] | |
# Query symptoms with transcribed text | |
diagnosis_query = f""" | |
Given these symptoms: '{transcript}' | |
Identify the most likely ICD-10 diagnoses and key questions. | |
Focus on clinical implications. | |
""" | |
response = symptom_index.as_query_engine().query(diagnosis_query) | |
return [ | |
{"role": "user", "content": transcript}, | |
{"role": "assistant", "content": json.dumps({ | |
"diagnoses": [], | |
"confidences": [], | |
"follow_up": str(response) | |
})} | |
] | |
else: | |
print(f"Invalid audio format: {type(audio_data)}") | |
return [] | |
except Exception as e: | |
print(f"Processing error: {str(e)}") | |
return [] | |
# Build enhanced Gradio interface | |
with gr.Blocks(theme="default") as demo: | |
gr.Markdown(""" | |
# 🏥 Medical Symptom to ICD-10 Code Assistant | |
## About | |
This application is part of the Agents+MCP Hackathon. It helps medical professionals | |
and patients understand potential diagnoses based on described symptoms. | |
### How it works: | |
1. Either click the record button and describe your symptoms or type them into the textbox | |
2. The AI will analyze your description and suggest possible diagnoses | |
3. Answer follow-up questions to refine the diagnosis | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Add text input above microphone | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Type your symptoms", | |
placeholder="Or type your symptoms here...", | |
lines=3 | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
# Existing microphone row | |
with gr.Row(): | |
microphone = gr.Audio( | |
sources=["microphone"], | |
streaming=True, | |
type="numpy", | |
label="Describe your symptoms" | |
) | |
transcript_box = gr.Textbox( | |
label="Transcribed Text", | |
interactive=False, | |
show_label=True | |
) | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
chatbot = gr.Chatbot( | |
label="Medical Consultation", | |
height=500, | |
container=True, | |
type="messages" # This is now properly supported by our message format | |
) | |
with gr.Column(scale=1): | |
with gr.Accordion("Advanced Settings", open=False): | |
api_key = gr.Textbox( | |
label="OpenAI API Key (optional)", | |
type="password", | |
placeholder="sk-..." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
modal_key = gr.Textbox( | |
label="Modal Labs API Key", | |
type="password", | |
placeholder="mk-..." | |
) | |
anthropic_key = gr.Textbox( | |
label="Anthropic API Key", | |
type="password", | |
placeholder="sk-ant-..." | |
) | |
mistral_key = gr.Textbox( | |
label="MistralAI API Key", | |
type="password", | |
placeholder="..." | |
) | |
with gr.Column(): | |
nebius_key = gr.Textbox( | |
label="Nebius API Key", | |
type="password", | |
placeholder="..." | |
) | |
hyperbolic_key = gr.Textbox( | |
label="Hyperbolic Labs API Key", | |
type="password", | |
placeholder="hyp-..." | |
) | |
sambanova_key = gr.Textbox( | |
label="SambaNova API Key", | |
type="password", | |
placeholder="..." | |
) | |
with gr.Row(): | |
model_selector = gr.Dropdown( | |
choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"], | |
value="OpenAI", | |
label="Model Provider" | |
) | |
temperature = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.7, | |
label="Temperature" | |
) | |
# self promotion at bottom of page | |
gr.Markdown(""" | |
--- | |
### 👋 About the Creator | |
Hi! I'm Graham Paasch, an experienced technology professional! | |
🎥 **Check out my YouTube channel** for more tech content: | |
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) | |
💼 **Looking for a skilled developer?** | |
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) | |
⭐ If you found this tool helpful, please consider: | |
- Subscribing to my YouTube channel | |
- Connecting on LinkedIn | |
- Sharing this tool with others in healthcare tech | |
""") | |
# Event handlers | |
clear_btn.click(lambda: None, None, chatbot, queue=False) | |
def format_response_for_user(response_dict): | |
"""Format the assistant's response dictionary into a user-friendly string.""" | |
diagnoses = response_dict.get("diagnoses", []) | |
confidences = response_dict.get("confidences", []) | |
follow_up = response_dict.get("follow_up", "") | |
result = "" | |
if diagnoses: | |
result += "Possible Diagnoses:\n" | |
for i, diag in enumerate(diagnoses): | |
conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else "" | |
result += f"- {diag}{conf}\n" | |
if follow_up: | |
result += f"\nFollow-up: {follow_up}" | |
return result.strip() | |
def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7): | |
"""Handle streaming speech processing and chat updates.""" | |
transcriber = get_asr_pipeline() | |
if not audio_path: | |
return history | |
try: | |
if isinstance(audio_path, tuple) and len(audio_path) == 2: | |
sample_rate, audio_array = audio_path | |
# Audio preprocessing | |
if audio_array.ndim > 1: | |
audio_array = audio_array.mean(axis=1) | |
audio_array = audio_array.astype(np.float32) | |
audio_array /= np.max(np.abs(audio_array)) | |
# Ensure correct sampling rate | |
if sample_rate != 16000: | |
resampler = T.Resample( | |
orig_freq=sample_rate, | |
new_freq=16000 | |
) | |
audio_tensor = torch.FloatTensor(audio_array) | |
audio_tensor = resampler(audio_tensor) | |
audio_array = audio_tensor.numpy() | |
sample_rate = 16000 | |
# Format input dictionary exactly as required | |
transcriber_input = { | |
"raw": audio_array, | |
"sampling_rate": sample_rate | |
} | |
# Get transcription from Whisper | |
result = transcriber(transcriber_input) | |
# Extract text from result | |
transcript = "" | |
if isinstance(result, dict): | |
transcript = result.get("text", "").strip() | |
elif isinstance(result, str): | |
transcript = result.strip() | |
if not transcript: | |
return history | |
# Process the symptoms | |
diagnosis_query = f""" | |
Based on these symptoms: '{transcript}' | |
Provide relevant ICD-10 codes and diagnostic questions. | |
""" | |
response = symptom_index.as_query_engine().query(diagnosis_query) | |
# Format and return chat messages | |
return history + [ | |
{"role": "user", "content": transcript}, | |
{"role": "assistant", "content": format_response_for_user({ | |
"diagnoses": [], | |
"confidences": [], | |
"follow_up": str(response) | |
})} | |
] | |
except Exception as e: | |
print(f"Streaming error: {str(e)}") | |
return history | |
microphone.stream( | |
fn=enhanced_process_speech, | |
inputs=[microphone, chatbot, api_key, model_selector, temperature], | |
outputs=chatbot, | |
show_progress="hidden", | |
api_name=False, | |
queue=True # Enable queuing for better stream handling | |
) | |
def process_audio(audio_array, sample_rate): | |
"""Pre-process audio for Whisper.""" | |
if audio_array.ndim > 1: | |
audio_array = audio_array.mean(axis=1) | |
# Convert to tensor for resampling | |
audio_tensor = torch.FloatTensor(audio_array) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
resampler = T.Resample(sample_rate, 16000) | |
audio_tensor = resampler(audio_tensor) | |
# Normalize | |
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) | |
# Convert back to numpy array and return in correct format | |
return { | |
"raw": audio_tensor.numpy(), # Key must be "raw" | |
"sampling_rate": 16000 # Key must be "sampling_rate" | |
} | |
# Update transcription handler | |
def update_live_transcription(audio): | |
"""Real-time transcription updates.""" | |
if not audio or not isinstance(audio, tuple): | |
return "" | |
try: | |
sample_rate, audio_array = audio | |
features = process_audio(audio_array, sample_rate) | |
asr = get_asr_pipeline() | |
result = asr(features) | |
return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip() | |
except Exception as e: | |
print(f"Transcription error: {str(e)}") | |
return "" | |
microphone.stream( | |
fn=update_live_transcription, | |
inputs=[microphone], | |
outputs=transcript_box, | |
show_progress="hidden", | |
queue=True | |
) | |
clear_btn.click( | |
fn=lambda: (None, "", ""), | |
outputs=[chatbot, transcript_box, text_input], | |
queue=False | |
) | |
def cleanup_memory(): | |
"""Release unused memory (placeholder for future memory management).""" | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def process_text_input(text, history): | |
"""Process text input with memory management.""" | |
print("process_text_input received:", text) | |
if not text: | |
return history, "" # Return tuple to clear input | |
try: | |
# Process the symptoms using the configured LLM | |
prompt = f"""Given these symptoms: '{text}' | |
Please provide: | |
1. Most likely ICD-10 codes | |
2. Confidence levels for each diagnosis | |
3. Key follow-up questions | |
Format as JSON with diagnoses, confidences, and follow_up fields.""" | |
response = llm.complete(prompt) | |
try: | |
# Try to parse as JSON first | |
result = json.loads(response.text) | |
except json.JSONDecodeError: | |
# If not JSON, wrap in our format | |
result = { | |
"diagnoses": [], | |
"confidences": [], | |
"follow_up": str(response.text)[:1000] # Limit response length | |
} | |
new_history = history + [ | |
{"role": "user", "content": text}, | |
{"role": "assistant", "content": format_response_for_user(result)} | |
] | |
return new_history, "" # Return empty string to clear input | |
except Exception as e: | |
print(f"Error processing text: {str(e)}") | |
return history, text # Keep text on error | |
# Update the submit button handler | |
submit_btn.click( | |
fn=process_text_input, | |
inputs=[text_input, chatbot], | |
outputs=[chatbot, text_input], | |
queue=True | |
).success( # Changed from .then to .success for better error handling | |
fn=cleanup_memory, | |
inputs=None, | |
outputs=None, | |
queue=False | |
) | |