gpaasch commited on
Commit
82d84c7
·
1 Parent(s): fd7fa97

Add multi‐backend LLM support and audio‐driven medical agent pipeline

Browse files

- Integrate Whisper ASR for speech‐to‐text symptom input
- Unify agent logic in `transcribe_and_respond()` using `get_llm_predictor()` (OpenAI, Mistral, or local pipeline)
- Enable environment flags `USE_LOCAL_GPU` and `USE_MISTRAL` to switch models dynamically
- Update Gradio `app.py` to launch audio/chat interface with MCP support

Files changed (5) hide show
  1. .gitignore +3 -1
  2. app.py +1 -1
  3. requirements.txt +4 -1
  4. src/app.py +69 -35
  5. utils/llama_index_utils.py +7 -1
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  venv
2
- .venv
 
 
 
1
  venv
2
+ .venv
3
+ __pycache__
4
+ gpt2-medium
app.py CHANGED
@@ -2,4 +2,4 @@
2
  from src.app import demo
3
 
4
  if __name__ == "__main__":
5
- demo.launch()
 
2
  from src.app import demo
3
 
4
  if __name__ == "__main__":
5
+ demo.launch(server_name="0.0.0.0", server_port=7860, mcp_server=True)
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  gradio[full]
 
2
  llama-index==0.6.9
3
  openai==0.27.0
4
- transformers
 
 
 
1
  gradio[full]
2
+ gradio[mcp]
3
  llama-index==0.6.9
4
  openai==0.27.0
5
+ transformers
6
+ torch
7
+ accelerate
src/app.py CHANGED
@@ -1,40 +1,74 @@
1
- # app.py
2
- import json
3
  import gradio as gr
 
 
 
4
 
5
- # Load the merged knowledge base
6
- with open("data/knowledge_base.json", encoding="utf-8") as f:
7
- kb = json.load(f)
8
-
9
- symptom_to_icd = kb["symptom_to_icd"]
10
- icd_to_description = kb["icd_to_description"]
11
-
12
- def map_symptoms(raw_input):
13
- terms = [t.strip().lower() for t in raw_input.split(",") if t.strip()]
14
- icd_counts = {}
15
- for term in terms:
16
- for code in symptom_to_icd.get(term, []):
17
- icd_counts[code] = icd_counts.get(code, 0) + 1
18
- if not icd_counts:
19
- return {"diagnoses": [], "confidences": []}
20
- total = sum(icd_counts.values())
21
- # sort codes by frequency descending
22
- sorted_items = sorted(icd_counts.items(), key=lambda x: x[1], reverse=True)
23
- diagnoses = []
24
- confidences = []
25
- for code, count in sorted_items:
26
- desc = icd_to_description.get(code, "Unknown")
27
- diagnoses.append(f"{code}: {desc}")
28
- confidences.append(round(count / total, 2))
29
- return {"diagnoses": diagnoses, "confidences": confidences}
30
-
31
- # Use Blocks so that mcp_server=True is accepted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  with gr.Blocks() as demo:
33
- gr.Markdown("## Symptom to ICD10 Code Lookup")
34
- inp = gr.Textbox(label="Enter symptoms (comma‐separated)")
35
- out = gr.JSON(label="Result")
36
- # Wire the submit event
37
- inp.submit(fn=map_symptoms, inputs=inp, outputs=out)
 
 
 
 
 
38
 
39
  if __name__ == "__main__":
40
- demo.launch(server_name="0.0.0.0", server_port=7860, mcp_server=True)
 
 
 
 
 
1
+ import os
 
2
  import gradio as gr
3
+ from transformers import pipeline
4
+ from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
5
+ from llama_index.llm_predictor import HuggingFaceLLMPredictor, LLMPredictor
6
 
7
+ # Optional OpenAI import remains for default predictor
8
+ import openai
9
+
10
+ # --- Whisper ASR setup ---
11
+ asr = pipeline(
12
+ "automatic-speech-recognition",
13
+ model="openai/whisper-small",
14
+ device=0,
15
+ chunk_length_s=30,
16
+ )
17
+
18
+ # --- LlamaIndex utils import ---
19
+ from utils.llama_index_utils import get_llm_predictor, build_index, query_symptoms
20
+
21
+ # --- System prompt ---
22
+ SYSTEM_PROMPT = """
23
+ You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
24
+ At each turn, EITHER ask one focused clarifying question (e.g. “Is your cough dry or productive?”)
25
+ or, if you have enough info, output a final JSON with fields:
26
+ {"diagnoses":[…], "confidences":[]}.
27
+ """
28
+
29
+
30
+ def transcribe_and_respond(audio, history):
31
+ # 1) Transcribe audio → text
32
+ user_text = asr(audio)["text"]
33
+ history = history or []
34
+ history.append(("user", user_text))
35
+
36
+ # 2) Build unified prompt for LLM
37
+ messages = [("system", SYSTEM_PROMPT)] + history
38
+ prompt = "\n".join(f"{role.capitalize()}: {text}" for role, text in messages)
39
+ prompt += "\nAssistant:"
40
+
41
+ # 3) Select predictor (OpenAI or Mistral/local)
42
+ predictor = get_llm_predictor()
43
+ resp = predictor.predict(prompt)
44
+
45
+ # 4) If JSON-style output, treat as final
46
+ if resp.strip().startswith("{"):
47
+ result = query_symptoms(resp)
48
+ history.append(("assistant", f"Here is your diagnosis: {result}"))
49
+ return "", history
50
+
51
+ # 5) Otherwise, it's a follow-up question
52
+ history.append(("assistant", resp))
53
+ return "", history
54
+
55
+
56
+ # --- Build Gradio app ---
57
  with gr.Blocks() as demo:
58
+ gr.Markdown("## Symptom to ICD-10 Diagnoser (audio & chat)")
59
+ chatbot = gr.Chatbot(label="Conversation")
60
+ mic = gr.Microphone(label="Describe your symptoms")
61
+ state = gr.State([])
62
+
63
+ mic.submit(
64
+ fn=transcribe_and_respond,
65
+ inputs=[mic, state],
66
+ outputs=[mic, chatbot, state]
67
+ )
68
 
69
  if __name__ == "__main__":
70
+ demo.launch(
71
+ server_name="0.0.0.0",
72
+ server_port=7860,
73
+ mcp_server=True
74
+ )
utils/llama_index_utils.py CHANGED
@@ -1,10 +1,15 @@
1
  import os
2
-
3
  from transformers import pipeline
4
  from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex, LLMPredictor, OpenAI
5
 
6
  _index = None
7
 
 
 
 
 
 
8
  def get_llm_predictor():
9
  """
10
  Return an LLMPredictor configured for local GPU (transformers) if USE_LOCAL_GPU=1,
@@ -41,3 +46,4 @@ def query_symptoms(prompt: str, top_k: int = 5):
41
  predictor = get_llm_predictor()
42
  query_engine = idx.as_query_engine(similarity_top_k=top_k, llm_predictor=predictor)
43
  return query_engine.query(prompt)
 
 
1
  import os
2
+ import json
3
  from transformers import pipeline
4
  from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex, LLMPredictor, OpenAI
5
 
6
  _index = None
7
 
8
+ def query_symptoms_tool(prompt_json: str):
9
+ # parse “prompt_json” into Python dict and call your existing query_symptoms()
10
+ data = json.loads(prompt_json)
11
+ return query_symptoms(data["raw_input"])
12
+
13
  def get_llm_predictor():
14
  """
15
  Return an LLMPredictor configured for local GPU (transformers) if USE_LOCAL_GPU=1,
 
46
  predictor = get_llm_predictor()
47
  query_engine = idx.as_query_engine(similarity_top_k=top_k, llm_predictor=predictor)
48
  return query_engine.query(prompt)
49
+