gpaasch commited on
Commit
7e29701
·
1 Parent(s): 00f30b1

switching from openaid whisper to gradio asr: https://www.gradio.app/guides/real-time-speech-recognition

Browse files
Files changed (1) hide show
  1. src/app.py +25 -45
src/app.py CHANGED
@@ -1,77 +1,57 @@
1
  import os
2
  import gradio as gr
3
- import openai
4
- from transformers import pipeline
5
  from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
6
  from llama_index import HuggingFaceLLMPredictor
7
  from src.parse_tabular import symptom_index
8
 
9
- # --- Whisper ASR setup ---
10
- asr = pipeline(
11
- "automatic-speech-recognition",
12
- model="openai/whisper-small",
13
- device=0,
14
- chunk_length_s=30,
15
- )
16
-
17
  # --- LlamaIndex utils import ---
18
  from utils.llama_index_utils import get_llm_predictor, build_index, query_symptoms
19
 
20
  # --- System prompt ---
21
  SYSTEM_PROMPT = """
22
  You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
23
- At each turn, EITHER ask one focused clarifying question (e.g. Is your cough dry or productive?”)
24
  or, if you have enough info, output a final JSON with fields:
25
  {"diagnoses":[…], "confidences":[…]}.
26
  """
27
 
28
- def transcribe_and_respond(audio_chunk, state):
29
- # Transcribe audio chunk
30
- result = asr(audio_chunk)
31
- text = result.get('text', '').strip()
32
- if not text:
33
- return state, []
34
-
35
- # Append user message
36
- state.append(("user", text))
37
-
38
- # Build LLM predictor (you can swap OpenAI / HuggingFace here)
39
  llm_predictor = HuggingFaceLLMPredictor(model_name_or_path=os.getenv("HF_MODEL", "gpt2-medium"))
40
-
41
  # Query index with conversation
42
- # (Assuming `symptom_index` is your GPTVectorStoreIndex)
43
- # Prepare combined prompt from state
44
- prompt = "\n".join([f"{role}: {msg}" for role, msg in state])
45
  response = symptom_index.as_query_engine(
46
  llm_predictor=llm_predictor
47
  ).query(prompt)
48
- reply = response.response
49
-
50
- # Append assistant message
51
- state.append(("assistant", reply))
52
-
53
- # Return updated state to chatbot
54
- return state, state
55
 
56
  # Build Gradio interface
57
  demo = gr.Blocks()
58
  with demo:
59
  gr.Markdown("# Symptom to ICD-10 Code Lookup (Audio Input)")
60
  chatbot = gr.Chatbot(label="Conversation")
61
- state = gr.State([])
62
- # Use streaming audio input for real-time transcription
63
- mic = gr.Audio(source="microphone", type="filepath", streaming=True, label="Describe your symptoms")
64
-
65
- mic.stream(
66
- fn=transcribe_and_respond,
67
- inputs=[mic, state],
68
- outputs=[chatbot, state],
69
- time_limit=60,
70
- stream_every=5,
71
- concurrency_limit=1
72
  )
73
 
74
  if __name__ == "__main__":
75
  demo.launch(
76
- server_name="0.0.0.0", server_port=7860, mcp_server=True
 
 
77
  )
 
1
  import os
2
  import gradio as gr
 
 
3
  from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
4
  from llama_index import HuggingFaceLLMPredictor
5
  from src.parse_tabular import symptom_index
6
 
 
 
 
 
 
 
 
 
7
  # --- LlamaIndex utils import ---
8
  from utils.llama_index_utils import get_llm_predictor, build_index, query_symptoms
9
 
10
  # --- System prompt ---
11
  SYSTEM_PROMPT = """
12
  You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
13
+ At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
14
  or, if you have enough info, output a final JSON with fields:
15
  {"diagnoses":[…], "confidences":[…]}.
16
  """
17
 
18
+ def process_speech(new_transcript, history):
19
+ # Skip if no new transcript
20
+ if not new_transcript:
21
+ return history
22
+
23
+ # Build LLM predictor
 
 
 
 
 
24
  llm_predictor = HuggingFaceLLMPredictor(model_name_or_path=os.getenv("HF_MODEL", "gpt2-medium"))
25
+
26
  # Query index with conversation
27
+ prompt = "\n".join([f"{role}: {msg}" for role, msg in history])
28
+ prompt += f"\nuser: {new_transcript}"
29
+
30
  response = symptom_index.as_query_engine(
31
  llm_predictor=llm_predictor
32
  ).query(prompt)
33
+
34
+ # Append the new exchange to history
35
+ history.append((new_transcript, response.response))
36
+ return history
 
 
 
37
 
38
  # Build Gradio interface
39
  demo = gr.Blocks()
40
  with demo:
41
  gr.Markdown("# Symptom to ICD-10 Code Lookup (Audio Input)")
42
  chatbot = gr.Chatbot(label="Conversation")
43
+ audio = gr.Audio(source="microphone", type="text", streaming=True)
44
+
45
+ audio.stream(
46
+ process_speech,
47
+ inputs=[audio, chatbot],
48
+ outputs=chatbot,
49
+ show_progress="hidden"
 
 
 
 
50
  )
51
 
52
  if __name__ == "__main__":
53
  demo.launch(
54
+ server_name="0.0.0.0",
55
+ server_port=7860,
56
+ mcp_server=True
57
  )