Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Maternal Health RAG Query Engine | |
Integrates vector store with LangChain for intelligent medical query processing | |
""" | |
import json | |
import logging | |
from typing import List, Dict, Any, Optional, Tuple | |
from pathlib import Path | |
from dataclasses import dataclass | |
from datetime import datetime | |
from langchain.schema import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from vector_store_manager import MaternalHealthVectorStore, SearchResult | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class QueryResponse: | |
"""Container for RAG query responses""" | |
query: str | |
answer: str | |
sources: List[SearchResult] | |
confidence: float | |
response_time: float | |
metadata: Dict[str, Any] | |
class MockLLM(LLM): | |
"""Mock LLM for testing RAG pipeline without external API calls""" | |
def _llm_type(self) -> str: | |
return "mock" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Generate a mock medical response based on the prompt""" | |
# Extract context and query from prompt | |
if "Context:" in prompt and "Question:" in prompt: | |
context_section = prompt.split("Context:")[1].split("Question:")[0].strip() | |
question_section = prompt.split("Question:")[1].split("Answer:")[0].strip() | |
# Generate mock response based on medical keywords | |
medical_keywords = { | |
'magnesium': 'Magnesium sulfate is administered for seizure prevention in preeclampsia.', | |
'hemorrhage': 'Postpartum hemorrhage requires immediate assessment and management with uterotonics.', | |
'sepsis': 'Puerperal sepsis is diagnosed based on fever, tachycardia, and other systemic signs.', | |
'fetal': 'Fetal heart rate monitoring is essential during labor to assess fetal well-being.', | |
'labor': 'Normal labor management involves monitoring progress and maternal-fetal well-being.', | |
'preeclampsia': 'Preeclampsia management includes blood pressure control and seizure prevention.', | |
'oxytocin': 'Oxytocin is used for labor induction and augmentation with careful monitoring.', | |
'cesarean': 'Cesarean section indications include fetal distress and failure to progress.', | |
'diabetes': 'Gestational diabetes requires blood glucose monitoring and dietary management.', | |
'hypertension': 'Pregnancy-induced hypertension requires close monitoring and treatment.' | |
} | |
# Find relevant keywords and build response | |
response_parts = [] | |
question_lower = question_section.lower() | |
for keyword, response in medical_keywords.items(): | |
if keyword in question_lower: | |
response_parts.append(response) | |
if response_parts: | |
base_response = " ".join(response_parts) | |
return f"Based on the maternal health guidelines: {base_response} Please consult with a healthcare professional for specific medical advice." | |
else: | |
return "Based on the available maternal health guidelines, this appears to be a clinical question that requires professional medical evaluation. Please consult with a qualified healthcare provider." | |
return "I can provide information based on maternal health guidelines, but specific medical decisions should always be made in consultation with healthcare professionals." | |
class MaternalHealthRAG: | |
"""RAG system for maternal health queries""" | |
def __init__(self, | |
vector_store_dir: str = "vector_store", | |
chunks_dir: str = "comprehensive_chunks", | |
use_mock_llm: bool = True): | |
self.vector_store_dir = Path(vector_store_dir) | |
self.chunks_dir = Path(chunks_dir) | |
self.use_mock_llm = use_mock_llm | |
# Initialize components | |
self.vector_store = None | |
self.llm = None | |
self.rag_chain = None | |
# Query parameters | |
self.default_k = 5 | |
self.min_relevance_score = 0.3 | |
self.max_context_length = 3000 | |
# Initialize RAG system | |
self.initialize_rag_system() | |
def initialize_rag_system(self): | |
"""Initialize the complete RAG system""" | |
logger.info("π Initializing Maternal Health RAG System...") | |
try: | |
# Initialize vector store | |
self.vector_store = MaternalHealthVectorStore( | |
vector_store_dir=self.vector_store_dir, | |
chunks_dir=self.chunks_dir | |
) | |
# Load existing vector store | |
if self.vector_store.index_file.exists(): | |
success = self.vector_store.load_existing_index() | |
if not success: | |
logger.error("Failed to load vector store") | |
raise RuntimeError("Vector store initialization failed") | |
else: | |
logger.error("Vector store not found. Please create it first.") | |
raise FileNotFoundError("Vector store not found") | |
# Initialize LLM | |
if self.use_mock_llm: | |
self.llm = MockLLM() | |
logger.info("β Using Mock LLM for testing") | |
else: | |
# Future: Initialize actual LLM (OpenAI, Hugging Face, etc.) | |
logger.warning("External LLM not implemented yet, using Mock LLM") | |
self.llm = MockLLM() | |
# Create RAG chain | |
self.rag_chain = self.create_rag_chain() | |
logger.info("β RAG system initialized successfully") | |
except Exception as e: | |
logger.error(f"β Failed to initialize RAG system: {e}") | |
raise | |
def create_rag_chain(self) -> LLMChain: | |
"""Create the RAG chain with medical prompt template""" | |
# Medical-focused prompt template | |
template = """You are a medical information assistant specializing in maternal health guidelines. | |
Use the provided context from Sri Lankan maternal health guidelines to answer questions accurately and safely. | |
Context: | |
{context} | |
Question: {question} | |
Instructions: | |
1. Answer based ONLY on the provided context from maternal health guidelines | |
2. If the context doesn't contain sufficient information, clearly state this | |
3. Always include relevant clinical details when available (dosages, procedures, contraindications) | |
4. Mention when professional medical consultation is recommended | |
5. Be precise and avoid generalizations | |
Answer:""" | |
prompt = PromptTemplate( | |
template=template, | |
input_variables=["context", "question"] | |
) | |
chain = LLMChain( | |
llm=self.llm, | |
prompt=prompt, | |
verbose=False | |
) | |
return chain | |
def query(self, | |
question: str, | |
k: int = None, | |
min_score: float = None, | |
content_types: List[str] = None, | |
min_importance: float = 0.5) -> QueryResponse: | |
"""Process a medical query and return comprehensive response""" | |
if k is None: | |
k = self.default_k | |
if min_score is None: | |
min_score = self.min_relevance_score | |
start_time = datetime.now() | |
logger.info(f"π Processing query: {question}") | |
try: | |
# Retrieve relevant context | |
if content_types: | |
search_results = self.vector_store.search_by_medical_context( | |
question, | |
content_types=content_types, | |
min_importance=min_importance, | |
k=k | |
) | |
else: | |
search_results = self.vector_store.search( | |
question, | |
k=k, | |
min_score=min_score | |
) | |
# Prepare context | |
context = self.prepare_context(search_results) | |
# Generate response | |
response = self.rag_chain.run( | |
context=context, | |
question=question | |
) | |
# Calculate response time | |
end_time = datetime.now() | |
response_time = (end_time - start_time).total_seconds() | |
# Calculate confidence based on relevance scores | |
confidence = self.calculate_confidence(search_results) | |
# Create response object | |
query_response = QueryResponse( | |
query=question, | |
answer=response, | |
sources=search_results, | |
confidence=confidence, | |
response_time=response_time, | |
metadata={ | |
'num_sources': len(search_results), | |
'avg_relevance': sum(r.score for r in search_results) / len(search_results) if search_results else 0, | |
'content_types': list(set(r.chunk_type for r in search_results)), | |
'high_importance_sources': sum(1 for r in search_results if r.clinical_importance >= 0.8) | |
} | |
) | |
logger.info(f"β Query processed in {response_time:.2f}s with {len(search_results)} sources") | |
return query_response | |
except Exception as e: | |
logger.error(f"β Query processing failed: {e}") | |
# Return error response | |
end_time = datetime.now() | |
response_time = (end_time - start_time).total_seconds() | |
return QueryResponse( | |
query=question, | |
answer=f"I apologize, but I encountered an error while processing your query: {str(e)}", | |
sources=[], | |
confidence=0.0, | |
response_time=response_time, | |
metadata={'error': str(e)} | |
) | |
def prepare_context(self, search_results: List[SearchResult]) -> str: | |
"""Prepare context from search results for LLM""" | |
if not search_results: | |
return "No relevant information found in the maternal health guidelines." | |
context_parts = [] | |
current_length = 0 | |
for i, result in enumerate(search_results): | |
# Add source information | |
source_info = f"Source {i+1} (Relevance: {result.score:.3f}, Type: {result.chunk_type}):" | |
content = f"{source_info}\n{result.content}\n" | |
# Check if adding this would exceed max length | |
if current_length + len(content) > self.max_context_length and context_parts: | |
break | |
context_parts.append(content) | |
current_length += len(content) | |
return "\n".join(context_parts) | |
def calculate_confidence(self, search_results: List[SearchResult]) -> float: | |
"""Calculate confidence score based on search results""" | |
if not search_results: | |
return 0.0 | |
# Factors for confidence calculation | |
avg_relevance = sum(r.score for r in search_results) / len(search_results) | |
high_relevance_count = sum(1 for r in search_results if r.score >= 0.7) | |
high_importance_count = sum(1 for r in search_results if r.clinical_importance >= 0.8) | |
# Weighted confidence score | |
relevance_weight = 0.5 | |
coverage_weight = 0.3 | |
importance_weight = 0.2 | |
relevance_score = min(avg_relevance / 0.8, 1.0) # Normalize to 0.8 as max | |
coverage_score = min(high_relevance_count / 3, 1.0) # 3+ high relevance results = full score | |
importance_score = min(high_importance_count / 2, 1.0) # 2+ high importance = full score | |
confidence = ( | |
relevance_weight * relevance_score + | |
coverage_weight * coverage_score + | |
importance_weight * importance_score | |
) | |
return min(confidence, 1.0) | |
def batch_query(self, questions: List[str]) -> List[QueryResponse]: | |
"""Process multiple queries efficiently""" | |
logger.info(f"π Processing {len(questions)} queries in batch...") | |
responses = [] | |
for i, question in enumerate(questions, 1): | |
logger.info(f"Processing query {i}/{len(questions)}") | |
response = self.query(question) | |
responses.append(response) | |
logger.info(f"β Batch processing complete: {len(responses)} responses generated") | |
return responses | |
def get_system_stats(self) -> Dict[str, Any]: | |
"""Get RAG system statistics""" | |
vector_stats = self.vector_store.get_statistics() | |
return { | |
'vector_store': vector_stats, | |
'rag_config': { | |
'default_k': self.default_k, | |
'min_relevance_score': self.min_relevance_score, | |
'max_context_length': self.max_context_length, | |
'llm_type': self.llm._llm_type if self.llm else 'None' | |
}, | |
'status': 'initialized' if self.vector_store and self.llm else 'not_initialized' | |
} | |
def main(): | |
"""Main function to test RAG system""" | |
logger.info("π Testing Maternal Health RAG System...") | |
# Initialize RAG system | |
rag_system = MaternalHealthRAG() | |
# Test queries | |
test_queries = [ | |
"What is the recommended dosage of magnesium sulfate for preeclampsia?", | |
"How should postpartum hemorrhage be managed in emergency situations?", | |
"What are the signs and symptoms of puerperal sepsis?", | |
"What is the normal fetal heart rate range during labor?", | |
"When is cesarean section indicated during delivery?" | |
] | |
logger.info("\nπ Testing RAG Query Processing...") | |
for i, query in enumerate(test_queries, 1): | |
logger.info(f"\nπ Query {i}: {query}") | |
response = rag_system.query(query) | |
logger.info(f"β±οΈ Response time: {response.response_time:.2f}s") | |
logger.info(f"π― Confidence: {response.confidence:.3f}") | |
logger.info(f"π Sources: {response.metadata['num_sources']}") | |
logger.info(f"π Avg relevance: {response.metadata['avg_relevance']:.3f}") | |
logger.info(f"π‘ Answer: {response.answer[:200]}...") | |
# Get system statistics | |
stats = rag_system.get_system_stats() | |
logger.info(f"\nπ RAG System Statistics:") | |
logger.info(f" Vector store chunks: {stats['vector_store']['total_chunks']}") | |
logger.info(f" LLM type: {stats['rag_config']['llm_type']}") | |
logger.info(f" System status: {stats['status']}") | |
logger.info("\nβ RAG system testing complete!") | |
if __name__ == "__main__": | |
main() |