vedaMD / src /maternal_health_rag.py
sniro23's picture
Initial commit without binary files
19aaa42
#!/usr/bin/env python3
"""
Maternal Health RAG Query Engine
Integrates vector store with LangChain for intelligent medical query processing
"""
import json
import logging
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
from dataclasses import dataclass
from datetime import datetime
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from vector_store_manager import MaternalHealthVectorStore, SearchResult
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class QueryResponse:
"""Container for RAG query responses"""
query: str
answer: str
sources: List[SearchResult]
confidence: float
response_time: float
metadata: Dict[str, Any]
class MockLLM(LLM):
"""Mock LLM for testing RAG pipeline without external API calls"""
@property
def _llm_type(self) -> str:
return "mock"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate a mock medical response based on the prompt"""
# Extract context and query from prompt
if "Context:" in prompt and "Question:" in prompt:
context_section = prompt.split("Context:")[1].split("Question:")[0].strip()
question_section = prompt.split("Question:")[1].split("Answer:")[0].strip()
# Generate mock response based on medical keywords
medical_keywords = {
'magnesium': 'Magnesium sulfate is administered for seizure prevention in preeclampsia.',
'hemorrhage': 'Postpartum hemorrhage requires immediate assessment and management with uterotonics.',
'sepsis': 'Puerperal sepsis is diagnosed based on fever, tachycardia, and other systemic signs.',
'fetal': 'Fetal heart rate monitoring is essential during labor to assess fetal well-being.',
'labor': 'Normal labor management involves monitoring progress and maternal-fetal well-being.',
'preeclampsia': 'Preeclampsia management includes blood pressure control and seizure prevention.',
'oxytocin': 'Oxytocin is used for labor induction and augmentation with careful monitoring.',
'cesarean': 'Cesarean section indications include fetal distress and failure to progress.',
'diabetes': 'Gestational diabetes requires blood glucose monitoring and dietary management.',
'hypertension': 'Pregnancy-induced hypertension requires close monitoring and treatment.'
}
# Find relevant keywords and build response
response_parts = []
question_lower = question_section.lower()
for keyword, response in medical_keywords.items():
if keyword in question_lower:
response_parts.append(response)
if response_parts:
base_response = " ".join(response_parts)
return f"Based on the maternal health guidelines: {base_response} Please consult with a healthcare professional for specific medical advice."
else:
return "Based on the available maternal health guidelines, this appears to be a clinical question that requires professional medical evaluation. Please consult with a qualified healthcare provider."
return "I can provide information based on maternal health guidelines, but specific medical decisions should always be made in consultation with healthcare professionals."
class MaternalHealthRAG:
"""RAG system for maternal health queries"""
def __init__(self,
vector_store_dir: str = "vector_store",
chunks_dir: str = "comprehensive_chunks",
use_mock_llm: bool = True):
self.vector_store_dir = Path(vector_store_dir)
self.chunks_dir = Path(chunks_dir)
self.use_mock_llm = use_mock_llm
# Initialize components
self.vector_store = None
self.llm = None
self.rag_chain = None
# Query parameters
self.default_k = 5
self.min_relevance_score = 0.3
self.max_context_length = 3000
# Initialize RAG system
self.initialize_rag_system()
def initialize_rag_system(self):
"""Initialize the complete RAG system"""
logger.info("πŸš€ Initializing Maternal Health RAG System...")
try:
# Initialize vector store
self.vector_store = MaternalHealthVectorStore(
vector_store_dir=self.vector_store_dir,
chunks_dir=self.chunks_dir
)
# Load existing vector store
if self.vector_store.index_file.exists():
success = self.vector_store.load_existing_index()
if not success:
logger.error("Failed to load vector store")
raise RuntimeError("Vector store initialization failed")
else:
logger.error("Vector store not found. Please create it first.")
raise FileNotFoundError("Vector store not found")
# Initialize LLM
if self.use_mock_llm:
self.llm = MockLLM()
logger.info("βœ… Using Mock LLM for testing")
else:
# Future: Initialize actual LLM (OpenAI, Hugging Face, etc.)
logger.warning("External LLM not implemented yet, using Mock LLM")
self.llm = MockLLM()
# Create RAG chain
self.rag_chain = self.create_rag_chain()
logger.info("βœ… RAG system initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize RAG system: {e}")
raise
def create_rag_chain(self) -> LLMChain:
"""Create the RAG chain with medical prompt template"""
# Medical-focused prompt template
template = """You are a medical information assistant specializing in maternal health guidelines.
Use the provided context from Sri Lankan maternal health guidelines to answer questions accurately and safely.
Context:
{context}
Question: {question}
Instructions:
1. Answer based ONLY on the provided context from maternal health guidelines
2. If the context doesn't contain sufficient information, clearly state this
3. Always include relevant clinical details when available (dosages, procedures, contraindications)
4. Mention when professional medical consultation is recommended
5. Be precise and avoid generalizations
Answer:"""
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
chain = LLMChain(
llm=self.llm,
prompt=prompt,
verbose=False
)
return chain
def query(self,
question: str,
k: int = None,
min_score: float = None,
content_types: List[str] = None,
min_importance: float = 0.5) -> QueryResponse:
"""Process a medical query and return comprehensive response"""
if k is None:
k = self.default_k
if min_score is None:
min_score = self.min_relevance_score
start_time = datetime.now()
logger.info(f"πŸ” Processing query: {question}")
try:
# Retrieve relevant context
if content_types:
search_results = self.vector_store.search_by_medical_context(
question,
content_types=content_types,
min_importance=min_importance,
k=k
)
else:
search_results = self.vector_store.search(
question,
k=k,
min_score=min_score
)
# Prepare context
context = self.prepare_context(search_results)
# Generate response
response = self.rag_chain.run(
context=context,
question=question
)
# Calculate response time
end_time = datetime.now()
response_time = (end_time - start_time).total_seconds()
# Calculate confidence based on relevance scores
confidence = self.calculate_confidence(search_results)
# Create response object
query_response = QueryResponse(
query=question,
answer=response,
sources=search_results,
confidence=confidence,
response_time=response_time,
metadata={
'num_sources': len(search_results),
'avg_relevance': sum(r.score for r in search_results) / len(search_results) if search_results else 0,
'content_types': list(set(r.chunk_type for r in search_results)),
'high_importance_sources': sum(1 for r in search_results if r.clinical_importance >= 0.8)
}
)
logger.info(f"βœ… Query processed in {response_time:.2f}s with {len(search_results)} sources")
return query_response
except Exception as e:
logger.error(f"❌ Query processing failed: {e}")
# Return error response
end_time = datetime.now()
response_time = (end_time - start_time).total_seconds()
return QueryResponse(
query=question,
answer=f"I apologize, but I encountered an error while processing your query: {str(e)}",
sources=[],
confidence=0.0,
response_time=response_time,
metadata={'error': str(e)}
)
def prepare_context(self, search_results: List[SearchResult]) -> str:
"""Prepare context from search results for LLM"""
if not search_results:
return "No relevant information found in the maternal health guidelines."
context_parts = []
current_length = 0
for i, result in enumerate(search_results):
# Add source information
source_info = f"Source {i+1} (Relevance: {result.score:.3f}, Type: {result.chunk_type}):"
content = f"{source_info}\n{result.content}\n"
# Check if adding this would exceed max length
if current_length + len(content) > self.max_context_length and context_parts:
break
context_parts.append(content)
current_length += len(content)
return "\n".join(context_parts)
def calculate_confidence(self, search_results: List[SearchResult]) -> float:
"""Calculate confidence score based on search results"""
if not search_results:
return 0.0
# Factors for confidence calculation
avg_relevance = sum(r.score for r in search_results) / len(search_results)
high_relevance_count = sum(1 for r in search_results if r.score >= 0.7)
high_importance_count = sum(1 for r in search_results if r.clinical_importance >= 0.8)
# Weighted confidence score
relevance_weight = 0.5
coverage_weight = 0.3
importance_weight = 0.2
relevance_score = min(avg_relevance / 0.8, 1.0) # Normalize to 0.8 as max
coverage_score = min(high_relevance_count / 3, 1.0) # 3+ high relevance results = full score
importance_score = min(high_importance_count / 2, 1.0) # 2+ high importance = full score
confidence = (
relevance_weight * relevance_score +
coverage_weight * coverage_score +
importance_weight * importance_score
)
return min(confidence, 1.0)
def batch_query(self, questions: List[str]) -> List[QueryResponse]:
"""Process multiple queries efficiently"""
logger.info(f"πŸ“‹ Processing {len(questions)} queries in batch...")
responses = []
for i, question in enumerate(questions, 1):
logger.info(f"Processing query {i}/{len(questions)}")
response = self.query(question)
responses.append(response)
logger.info(f"βœ… Batch processing complete: {len(responses)} responses generated")
return responses
def get_system_stats(self) -> Dict[str, Any]:
"""Get RAG system statistics"""
vector_stats = self.vector_store.get_statistics()
return {
'vector_store': vector_stats,
'rag_config': {
'default_k': self.default_k,
'min_relevance_score': self.min_relevance_score,
'max_context_length': self.max_context_length,
'llm_type': self.llm._llm_type if self.llm else 'None'
},
'status': 'initialized' if self.vector_store and self.llm else 'not_initialized'
}
def main():
"""Main function to test RAG system"""
logger.info("πŸš€ Testing Maternal Health RAG System...")
# Initialize RAG system
rag_system = MaternalHealthRAG()
# Test queries
test_queries = [
"What is the recommended dosage of magnesium sulfate for preeclampsia?",
"How should postpartum hemorrhage be managed in emergency situations?",
"What are the signs and symptoms of puerperal sepsis?",
"What is the normal fetal heart rate range during labor?",
"When is cesarean section indicated during delivery?"
]
logger.info("\nπŸ” Testing RAG Query Processing...")
for i, query in enumerate(test_queries, 1):
logger.info(f"\nπŸ“ Query {i}: {query}")
response = rag_system.query(query)
logger.info(f"⏱️ Response time: {response.response_time:.2f}s")
logger.info(f"🎯 Confidence: {response.confidence:.3f}")
logger.info(f"πŸ“š Sources: {response.metadata['num_sources']}")
logger.info(f"πŸ“Š Avg relevance: {response.metadata['avg_relevance']:.3f}")
logger.info(f"πŸ’‘ Answer: {response.answer[:200]}...")
# Get system statistics
stats = rag_system.get_system_stats()
logger.info(f"\nπŸ“Š RAG System Statistics:")
logger.info(f" Vector store chunks: {stats['vector_store']['total_chunks']}")
logger.info(f" LLM type: {stats['rag_config']['llm_type']}")
logger.info(f" System status: {stats['status']}")
logger.info("\nβœ… RAG system testing complete!")
if __name__ == "__main__":
main()