Spaces:
Sleeping
Sleeping
""" | |
Simple Medical Chatbot Interface v2.0 | |
Beautiful Gradio interface for the simplified medical RAG system | |
""" | |
import gradio as gr | |
import time | |
import json | |
from datetime import datetime | |
from typing import List, Tuple, Dict, Any | |
# Import our simplified medical RAG system | |
from simple_medical_rag import SimpleMedicalRAG, MedicalResponse | |
class SimpleMedicalChatbot: | |
"""Professional medical chatbot interface using simplified RAG system""" | |
def __init__(self): | |
"""Initialize the medical chatbot""" | |
self.rag_system = None | |
self.chat_history = [] | |
self.session_stats = { | |
"queries_processed": 0, | |
"total_response_time": 0, | |
"avg_confidence": 0, | |
"session_start": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
# Initialize RAG system | |
self._initialize_rag_system() | |
def _initialize_rag_system(self): | |
"""Initialize the RAG system""" | |
try: | |
print("π Initializing Medical RAG System...") | |
self.rag_system = SimpleMedicalRAG() | |
print("β Medical RAG System initialized successfully!") | |
except Exception as e: | |
print(f"β Error initializing RAG system: {e}") | |
self.rag_system = None | |
def process_query(self, query: str, history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: | |
"""Process medical query and return response""" | |
if not self.rag_system: | |
error_msg = "β **System Error**: Medical RAG system not initialized. Please refresh and try again." | |
history.append((query, error_msg)) | |
return history, "" | |
if not query.strip(): | |
return history, "" | |
start_time = time.time() | |
try: | |
# Process query with RAG system | |
response = self.rag_system.query(query, k=5) | |
# Format response for display | |
formatted_response = self._format_response_for_display(response) | |
# Update session statistics | |
query_time = time.time() - start_time | |
self._update_session_stats(query_time, response.confidence) | |
# Add to chat history | |
history.append((query, formatted_response)) | |
return history, "" | |
except Exception as e: | |
error_msg = f"β **Error processing query**: {str(e)}\n\nβ οΈ Please try rephrasing your question or contact support." | |
history.append((query, error_msg)) | |
return history, "" | |
def _format_response_for_display(self, response: MedicalResponse) -> str: | |
"""Format medical response for beautiful display in Gradio""" | |
# Confidence level indicator | |
confidence_emoji = "π’" if response.confidence > 0.7 else "π‘" if response.confidence > 0.5 else "π΄" | |
confidence_text = f"{confidence_emoji} **Confidence: {response.confidence:.1%}**" | |
# Response type indicator | |
type_emoji = "π" if "dosage" in response.response_type else "π¨" if "emergency" in response.response_type else "π₯" | |
# Main response | |
formatted_response = f""" | |
{type_emoji} **Medical Information** | |
{response.answer} | |
--- | |
π **Response Details** | |
{confidence_text} | |
π **Sources**: {len(response.sources)} documents referenced | |
""" | |
# Add top sources | |
if response.sources: | |
formatted_response += "π **Primary Sources**:\n" | |
for i, source in enumerate(response.sources[:3], 1): | |
doc_name = source['document'].replace('.pdf', '').replace('-', ' ').title() | |
formatted_response += f"{i}. {doc_name} (Relevance: {source['relevance_score']:.1%})\n" | |
formatted_response += "\n" | |
# Add medical disclaimer | |
formatted_response += f""" | |
--- | |
{response.medical_disclaimer} | |
π **Note**: This response is based on Sri Lankan maternal health guidelines and should be used in conjunction with current clinical protocols. | |
""" | |
return formatted_response | |
def _update_session_stats(self, query_time: float, confidence: float): | |
"""Update session statistics""" | |
self.session_stats["queries_processed"] += 1 | |
self.session_stats["total_response_time"] += query_time | |
# Update average confidence | |
current_avg = self.session_stats["avg_confidence"] | |
queries = self.session_stats["queries_processed"] | |
self.session_stats["avg_confidence"] = ((current_avg * (queries - 1)) + confidence) / queries | |
def get_system_info(self) -> str: | |
"""Get system information for display""" | |
if not self.rag_system: | |
return "β **System Status**: Not initialized" | |
try: | |
stats = self.rag_system.get_system_stats() | |
system_info = f""" | |
π₯ **Sri Lankan Maternal Health Assistant v2.0** | |
π **System Status**: {stats['status'].upper()} β | |
**Knowledge Base**: | |
β’ π Total Documents: {stats['vector_store']['total_chunks']:,} medical chunks | |
β’ π§ Embedding Model: {stats['vector_store']['embedding_model']} | |
β’ πΎ Vector Store Size: {stats['vector_store']['vector_store_size_mb']} MB | |
β’ β‘ Approach: Simplified document-based retrieval | |
**Content Distribution**: | |
""" | |
# Add content distribution | |
for content_type, count in stats['vector_store']['content_type_distribution'].items(): | |
percentage = (count / stats['vector_store']['total_chunks']) * 100 | |
content_info = content_type.replace('_', ' ').title() | |
system_info += f"β’ {content_info}: {count:,} chunks ({percentage:.1f}%)\n" | |
return system_info | |
except Exception as e: | |
return f"β **Error retrieving system info**: {str(e)}" | |
def get_session_stats(self) -> str: | |
"""Get session statistics for display""" | |
if self.session_stats["queries_processed"] == 0: | |
return "π **Session Statistics**: No queries processed yet" | |
avg_response_time = self.session_stats["total_response_time"] / self.session_stats["queries_processed"] | |
return f""" | |
π **Session Statistics** | |
π **Session Started**: {self.session_stats["session_start"]} | |
π **Queries Processed**: {self.session_stats["queries_processed"]} | |
β‘ **Avg Response Time**: {avg_response_time:.2f} seconds | |
π― **Avg Confidence**: {self.session_stats["avg_confidence"]:.1%} | |
""" | |
def clear_chat(self) -> Tuple[List, str]: | |
"""Clear chat history""" | |
self.chat_history = [] | |
return [], "" | |
def get_example_queries(self) -> List[str]: | |
"""Get example medical queries""" | |
return [ | |
"What is the dosage of magnesium sulfate for preeclampsia?", | |
"How to manage postpartum hemorrhage emergency?", | |
"Normal fetal heart rate during labor monitoring?", | |
"Management protocol for breech delivery?", | |
"Antenatal care schedule for high-risk pregnancies?", | |
"Signs and symptoms of preeclampsia?", | |
"When to perform cesarean delivery?", | |
"Postpartum care guidelines for new mothers?" | |
] | |
def create_medical_chatbot_interface(): | |
"""Create the main Gradio interface""" | |
# Initialize chatbot | |
chatbot = SimpleMedicalChatbot() | |
# Custom CSS for medical theme | |
css = """ | |
.gradio-container { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
.medical-header { | |
background: white; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.chat-container { | |
background: white; | |
border-radius: 15px; | |
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1); | |
} | |
.medical-disclaimer { | |
background: #fff3cd; | |
border: 1px solid #ffeaa7; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
color: #856404; | |
} | |
.example-queries { | |
background: #e8f5e8; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
""" | |
with gr.Blocks(css=css, title="Sri Lankan Maternal Health Assistant", theme=gr.themes.Soft()) as interface: | |
# Header | |
gr.Markdown(""" | |
# π₯ Sri Lankan Maternal Health Assistant v2.0 | |
### Simplified Document-Based Medical RAG System | |
**Professional medical guidance based on Sri Lankan maternal health guidelines** | |
""", elem_classes=["medical-header"]) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Main chat interface | |
with gr.Group(elem_classes=["chat-container"]): | |
gr.Markdown("## π¬ Medical Query Interface") | |
chatbot_display = gr.Chatbot( | |
label="Medical Assistant", | |
height=500, | |
show_label=False, | |
container=True, | |
bubble_full_width=False | |
) | |
with gr.Row(): | |
query_input = gr.Textbox( | |
placeholder="Ask a medical question about maternal health...", | |
label="Your Medical Query", | |
lines=2, | |
scale=4 | |
) | |
submit_btn = gr.Button("π Ask", variant="primary", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
refresh_btn = gr.Button("π Refresh System", variant="secondary") | |
with gr.Column(scale=1): | |
# System information and examples | |
with gr.Group(): | |
gr.Markdown("## π System Information") | |
system_info_display = gr.Markdown( | |
chatbot.get_system_info(), | |
label="System Status" | |
) | |
with gr.Group(): | |
gr.Markdown("## π Session Statistics") | |
session_stats_display = gr.Markdown( | |
chatbot.get_session_stats(), | |
label="Current Session" | |
) | |
# Example queries | |
with gr.Group(elem_classes=["example-queries"]): | |
gr.Markdown("## π‘ Example Queries") | |
example_queries = chatbot.get_example_queries() | |
for i, example in enumerate(example_queries[:4]): | |
example_btn = gr.Button( | |
f"π {example}", | |
variant="secondary", | |
size="sm" | |
) | |
example_btn.click( | |
fn=lambda x=example: x, | |
outputs=query_input | |
) | |
# Medical disclaimer | |
gr.Markdown(""" | |
## β οΈ Important Medical Disclaimer | |
This system provides information from Sri Lankan maternal health guidelines for **educational purposes only**. | |
**Always consult qualified healthcare providers for**: | |
- Medical decisions and patient care | |
- Emergency medical situations | |
- Clinical diagnosis and treatment | |
- Medication administration | |
This tool is designed to **supplement**, not replace, professional medical judgment. | |
""", elem_classes=["medical-disclaimer"]) | |
# Event handlers | |
def submit_query(query, history): | |
"""Handle query submission""" | |
new_history, _ = chatbot.process_query(query, history) | |
return new_history, "", chatbot.get_session_stats() | |
def refresh_system(): | |
"""Refresh system information""" | |
return chatbot.get_system_info(), chatbot.get_session_stats() | |
def clear_chat_handler(): | |
"""Handle chat clearing""" | |
new_history, _ = chatbot.clear_chat() | |
return new_history, "", chatbot.get_session_stats() | |
# Connect event handlers | |
submit_btn.click( | |
fn=submit_query, | |
inputs=[query_input, chatbot_display], | |
outputs=[chatbot_display, query_input, session_stats_display] | |
) | |
query_input.submit( | |
fn=submit_query, | |
inputs=[query_input, chatbot_display], | |
outputs=[chatbot_display, query_input, session_stats_display] | |
) | |
clear_btn.click( | |
fn=clear_chat_handler, | |
inputs=[], | |
outputs=[chatbot_display, query_input, session_stats_display] | |
) | |
refresh_btn.click( | |
fn=refresh_system, | |
inputs=[], | |
outputs=[system_info_display, session_stats_display] | |
) | |
return interface | |
def main(): | |
"""Main function to launch the medical chatbot""" | |
print("π Launching Sri Lankan Maternal Health Assistant v2.0") | |
print("=" * 60) | |
# Create and launch interface | |
interface = create_medical_chatbot_interface() | |
# Launch with custom settings | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, # Enable public sharing | |
show_error=True, | |
inbrowser=True, | |
debug=True | |
) | |
if __name__ == "__main__": | |
main() |