MedCodeMCP / src /app.py
gpaasch's picture
Add multi‐backend LLM support and audio‐driven medical agent pipeline
82d84c7
raw
history blame
2.25 kB
import os
import gradio as gr
from transformers import pipeline
from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
from llama_index.llm_predictor import HuggingFaceLLMPredictor, LLMPredictor
# Optional OpenAI import remains for default predictor
import openai
# --- Whisper ASR setup ---
asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=0,
chunk_length_s=30,
)
# --- LlamaIndex utils import ---
from utils.llama_index_utils import get_llm_predictor, build_index, query_symptoms
# --- System prompt ---
SYSTEM_PROMPT = """
You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
At each turn, EITHER ask one focused clarifying question (e.g. “Is your cough dry or productive?”)
or, if you have enough info, output a final JSON with fields:
{"diagnoses":[…], "confidences":[…]}.
"""
def transcribe_and_respond(audio, history):
# 1) Transcribe audio → text
user_text = asr(audio)["text"]
history = history or []
history.append(("user", user_text))
# 2) Build unified prompt for LLM
messages = [("system", SYSTEM_PROMPT)] + history
prompt = "\n".join(f"{role.capitalize()}: {text}" for role, text in messages)
prompt += "\nAssistant:"
# 3) Select predictor (OpenAI or Mistral/local)
predictor = get_llm_predictor()
resp = predictor.predict(prompt)
# 4) If JSON-style output, treat as final
if resp.strip().startswith("{"):
result = query_symptoms(resp)
history.append(("assistant", f"Here is your diagnosis: {result}"))
return "", history
# 5) Otherwise, it's a follow-up question
history.append(("assistant", resp))
return "", history
# --- Build Gradio app ---
with gr.Blocks() as demo:
gr.Markdown("## Symptom to ICD-10 Diagnoser (audio & chat)")
chatbot = gr.Chatbot(label="Conversation")
mic = gr.Microphone(label="Describe your symptoms")
state = gr.State([])
mic.submit(
fn=transcribe_and_respond,
inputs=[mic, state],
outputs=[mic, chatbot, state]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True
)