gpaasch commited on
Commit
3a7b207
·
1 Parent(s): 5c7ca93

Revert "1. Structure responses with both diagnoses and follow-up questions"

Browse files

This reverts commit f21d279f5305b3325d8c12144d6035d652b3f769.

Files changed (1) hide show
  1. src/app.py +55 -64
src/app.py CHANGED
@@ -102,17 +102,15 @@ def get_system_specs() -> Dict[str, float]:
102
  "gpu_vram_gb": gpu_vram_gb
103
  }
104
 
105
- def select_best_model():
106
  """Select the best model based on system specifications."""
107
  specs = get_system_specs()
 
 
 
108
 
109
- # Prioritize Mistral if we have API key or sufficient resources
110
- if any(k.startswith("mk-") for k in [api_key.value]): # Check for Mistral API key
111
- return "mistral-7b-instruct-v0.1.Q4_K_M.gguf", "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
112
- elif specs['gpu_vram_gb'] >= 6 or specs['ram_gb'] >= 16:
113
- return MODEL_OPTIONS["medium"]["name"], MODEL_OPTIONS["medium"]["repo"]
114
  # Prioritize GPU if available
115
- elif specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work
116
  model_tier = "small" # phi-2 should work well on RTX 2060
117
  elif specs['ram_gb'] >= 8:
118
  model_tier = "small"
@@ -214,15 +212,12 @@ symptom_index = create_symptom_index()
214
  print("Index created successfully")
215
 
216
  # --- System prompt ---
217
- SYSTEM_PROMPT = """You are a medical assistant using the Mistral model to analyze symptoms and determine ICD-10 codes.
218
- Your responses should ALWAYS be in this format:
219
- {
220
- "diagnoses": ["ICD10 code - description"],
221
- "confidences": [confidence score between 0-1],
222
- "follow_up": "ONE specific follow-up question to refine the diagnosis",
223
- "explanation": "Brief explanation of why you're asking this question"
224
- }
225
- Keep responses focused and clinical."""
226
 
227
  def process_speech(audio_data, history):
228
  """Process speech input and convert to text."""
@@ -526,84 +521,80 @@ with gr.Blocks(theme="default") as demo:
526
  features = process_audio(audio_array, sample_rate)
527
 
528
  asr = get_asr_pipeline()
 
 
529
  return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
530
  except Exception as e:
531
- print(f"Transcription error: {str(e)}")f isinstance(result, dict) else str(result).strip()
532
- return ""ion as e:
533
  print(f"Transcription error: {str(e)}")
 
 
534
  microphone.stream(
535
  fn=update_live_transcription,
536
  inputs=[microphone],
537
- outputs=transcript_box,ption,
538
  show_progress="hidden",
539
- queue=Trueanscript_box,
540
- ) show_progress="hidden",
541
  queue=True
 
 
542
  clear_btn.click(
543
  fn=lambda: (None, "", ""),
544
  outputs=[chatbot, transcript_box, text_input],
545
- queue=False(None, "", ""),
546
- ) outputs=[chatbot, transcript_box, text_input],
547
  queue=False
 
 
548
  def cleanup_memory():
549
  """Release unused memory (placeholder for future memory management)."""
550
- import gcemory():
551
- gc.collect()nused memory (placeholder for future memory management)."""
552
  if torch.cuda.is_available():
553
  torch.cuda.empty_cache()
554
- if torch.cuda.is_available():
555
  def process_text_input(text, history):
556
- """Process text input with interactive follow-up."""
557
- if not text:_input(text, history):
558
- return history, ""with interactive follow-up."""
559
  if not text:
560
- try:return history, ""
561
- # Add context from history
562
- context = "\n".join([m["content"] for m in history if m["role"] == "user"]) if history else ""
563
- # Add context from history
564
- prompt = f"""{SYSTEM_PROMPT}ent"] for m in history if m["role"] == "user"]) if history else ""
565
- Previous context: {context}
566
- Current symptoms: {text}"{SYSTEM_PROMPT}
567
- Analyze and respond with likely diagnoses and ONE key follow-up question."""
568
- toms: {text}
569
- response = llm.complete(prompt)nd ONE key follow-up question."""
 
570
 
571
- try:onse = llm.complete(prompt)
 
 
 
572
  result = json.loads(response.text)
573
  except json.JSONDecodeError:
574
- result = {son.loads(response.text)
575
- "diagnoses": ["R69 - Illness, unspecified"],
576
- "confidences": [0.5],
577
- "follow_up": str(response.text)[:200],ied"],
578
- "explanation": "Could not parse response"
579
- } "follow_up": str(response.text)[:200],
580
- "explanation": "Could not parse response"
581
- formatted_response = f"""Possible Diagnoses:
582
- {''.join(f'- {d} ({c*100:.0f}%)\n' for d, c in zip(result['diagnoses'], result['confidences']))}
583
- formatted_response = f"""Possible Diagnoses:
584
- Follow-up Question: {result['follow_up']} c in zip(result['diagnoses'], result['confidences']))}
585
- ({result['explanation']})"""
586
- Follow-up Question: {result['follow_up']}
587
  new_history = history + [
588
  {"role": "user", "content": text},
589
- {"role": "assistant", "content": formatted_response}
590
- ] {"role": "user", "content": text},
591
- return new_history, ""t", "content": formatted_response}
592
  except Exception as e:
593
  print(f"Error processing text: {str(e)}")
594
- return history, text
595
- print(f"Error processing text: {str(e)}")
596
  # Update the submit button handler
597
  submit_btn.click(
598
- fn=process_text_input, handler
599
  inputs=[text_input, chatbot],
600
  outputs=[chatbot, text_input],
601
- queue=Truext_input, chatbot],
602
  ).success( # Changed from .then to .success for better error handling
603
  fn=cleanup_memory,
604
- inputs=None,anged from .then to .success for better error handling
605
- outputs=None,mory,
606
- queue=False,
607
- ) outputs=None,
608
  queue=False
609
  )
 
102
  "gpu_vram_gb": gpu_vram_gb
103
  }
104
 
105
+ def select_best_model() -> Tuple[str, str]:
106
  """Select the best model based on system specifications."""
107
  specs = get_system_specs()
108
+ print(f"\nSystem specifications:")
109
+ print(f"RAM: {specs['ram_gb']:.1f} GB")
110
+ print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB")
111
 
 
 
 
 
 
112
  # Prioritize GPU if available
113
+ if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work
114
  model_tier = "small" # phi-2 should work well on RTX 2060
115
  elif specs['ram_gb'] >= 8:
116
  model_tier = "small"
 
212
  print("Index created successfully")
213
 
214
  # --- System prompt ---
215
+ SYSTEM_PROMPT = """
216
+ You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
217
+ At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
218
+ or, if you have enough info, output a final JSON with fields:
219
+ {"diagnoses":[…], "confidences":[]}.
220
+ """
 
 
 
221
 
222
  def process_speech(audio_data, history):
223
  """Process speech input and convert to text."""
 
521
  features = process_audio(audio_array, sample_rate)
522
 
523
  asr = get_asr_pipeline()
524
+ result = asr(features)
525
+
526
  return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
527
  except Exception as e:
 
 
528
  print(f"Transcription error: {str(e)}")
529
+ return ""
530
+
531
  microphone.stream(
532
  fn=update_live_transcription,
533
  inputs=[microphone],
534
+ outputs=transcript_box,
535
  show_progress="hidden",
 
 
536
  queue=True
537
+ )
538
+
539
  clear_btn.click(
540
  fn=lambda: (None, "", ""),
541
  outputs=[chatbot, transcript_box, text_input],
 
 
542
  queue=False
543
+ )
544
+
545
  def cleanup_memory():
546
  """Release unused memory (placeholder for future memory management)."""
547
+ import gc
548
+ gc.collect()
549
  if torch.cuda.is_available():
550
  torch.cuda.empty_cache()
551
+
552
  def process_text_input(text, history):
553
+ """Process text input with memory management."""
 
 
554
  if not text:
555
+ return history, "" # Return tuple to clear input
556
+
557
+ try:
558
+ # Process the symptoms using the configured LLM
559
+ prompt = f"""Given these symptoms: '{text}'
560
+ Please provide:
561
+ 1. Most likely ICD-10 codes
562
+ 2. Confidence levels for each diagnosis
563
+ 3. Key follow-up questions
564
+
565
+ Format as JSON with diagnoses, confidences, and follow_up fields."""
566
 
567
+ response = llm.complete(prompt)
568
+
569
+ try:
570
+ # Try to parse as JSON first
571
  result = json.loads(response.text)
572
  except json.JSONDecodeError:
573
+ # If not JSON, wrap in our format
574
+ result = {
575
+ "diagnoses": [],
576
+ "confidences": [],
577
+ "follow_up": str(response.text)[:1000] # Limit response length
578
+ }
579
+
 
 
 
 
 
 
580
  new_history = history + [
581
  {"role": "user", "content": text},
582
+ {"role": "assistant", "content": format_response_for_user(result)}
583
+ ]
584
+ return new_history, "" # Return empty string to clear input
585
  except Exception as e:
586
  print(f"Error processing text: {str(e)}")
587
+ return history, text # Keep text on error
588
+
589
  # Update the submit button handler
590
  submit_btn.click(
591
+ fn=process_text_input,
592
  inputs=[text_input, chatbot],
593
  outputs=[chatbot, text_input],
594
+ queue=True
595
  ).success( # Changed from .then to .success for better error handling
596
  fn=cleanup_memory,
597
+ inputs=None,
598
+ outputs=None,
 
 
599
  queue=False
600
  )