#!/usr/bin/env python3 """ Test Suite for Maternal Health RAG System Validates complete RAG pipeline including vector retrieval and response generation """ import unittest import time from typing import List, Dict, Any from pathlib import Path from maternal_health_rag import MaternalHealthRAG, QueryResponse class TestMaternalHealthRAG(unittest.TestCase): """Test suite for RAG system functionality""" @classmethod def setUpClass(cls): """Set up test environment""" print("๐Ÿš€ Initializing RAG system for testing...") cls.rag_system = MaternalHealthRAG(use_mock_llm=True) def test_rag_system_initialization(self): """Test RAG system initializes correctly""" self.assertIsNotNone(self.rag_system.vector_store) self.assertIsNotNone(self.rag_system.llm) self.assertIsNotNone(self.rag_system.rag_chain) # Check system status stats = self.rag_system.get_system_stats() self.assertEqual(stats['status'], 'initialized') self.assertGreater(stats['vector_store']['total_chunks'], 0) def test_basic_query_processing(self): """Test basic query processing functionality""" query = "What is magnesium sulfate used for?" response = self.rag_system.query(query) # Basic response validation self.assertIsInstance(response, QueryResponse) self.assertEqual(response.query, query) self.assertIsInstance(response.answer, str) self.assertGreater(len(response.answer), 0) self.assertGreaterEqual(response.confidence, 0.0) self.assertLessEqual(response.confidence, 1.0) self.assertGreater(response.response_time, 0.0) def test_medical_context_queries(self): """Test queries with specific medical context""" test_cases = [ { 'query': 'What is the dosage for magnesium sulfate in preeclampsia?', 'content_types': ['dosage', 'emergency'], 'expected_keywords': ['magnesium', 'dosage', 'preeclampsia'], 'min_confidence': 0.3 }, { 'query': 'How to manage postpartum hemorrhage emergency?', 'content_types': ['emergency', 'maternal'], 'expected_keywords': ['hemorrhage', 'postpartum', 'emergency'], 'min_confidence': 0.3 }, { 'query': 'Normal fetal heart rate monitoring procedures', 'content_types': ['procedure', 'maternal'], 'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'], 'min_confidence': 0.3 } ] for case in test_cases: with self.subTest(query=case['query']): response = self.rag_system.query( case['query'], content_types=case['content_types'] ) # Response quality checks self.assertGreater(len(response.sources), 0) self.assertGreaterEqual(response.confidence, case['min_confidence']) # Check if expected keywords appear in answer or sources combined_text = response.answer.lower() if response.sources: combined_text += ' ' + ' '.join([s.content.lower() for s in response.sources]) keyword_found = any(keyword in combined_text for keyword in case['expected_keywords']) self.assertTrue(keyword_found, f"No expected keywords found for query: {case['query']}") def test_response_metadata(self): """Test response metadata is populated correctly""" query = "What are the signs of preeclampsia?" response = self.rag_system.query(query) # Check metadata fields required_fields = ['num_sources', 'avg_relevance', 'content_types', 'high_importance_sources'] for field in required_fields: self.assertIn(field, response.metadata) # Validate metadata values self.assertIsInstance(response.metadata['num_sources'], int) self.assertGreaterEqual(response.metadata['num_sources'], 0) self.assertIsInstance(response.metadata['avg_relevance'], float) self.assertIsInstance(response.metadata['content_types'], list) self.assertIsInstance(response.metadata['high_importance_sources'], int) def test_confidence_scoring(self): """Test confidence scoring mechanism""" # High-confidence query (should match well) high_conf_query = "magnesium sulfate for preeclampsia" high_response = self.rag_system.query(high_conf_query) # Low-confidence query (less specific) low_conf_query = "medical procedures in general" low_response = self.rag_system.query(low_conf_query) # Confidence should be higher for more specific medical queries self.assertGreaterEqual(high_response.confidence, 0.3) # Both should have valid confidence scores self.assertGreaterEqual(high_response.confidence, 0.0) self.assertLessEqual(high_response.confidence, 1.0) self.assertGreaterEqual(low_response.confidence, 0.0) self.assertLessEqual(low_response.confidence, 1.0) def test_performance_metrics(self): """Test RAG system performance""" query = "What is normal labor management?" # Measure response time start_time = time.time() response = self.rag_system.query(query) actual_time = time.time() - start_time # Response should be fast (under 2 seconds for mock LLM) self.assertLess(response.response_time, 2.0) self.assertLess(actual_time, 3.0) # Should return relevant sources self.assertGreater(len(response.sources), 0) self.assertLessEqual(len(response.sources), 10) # Should be reasonable number def test_batch_query_processing(self): """Test batch query processing functionality""" queries = [ "What is magnesium sulfate used for?", "How to manage labor complications?", "Normal fetal heart rate ranges" ] responses = self.rag_system.batch_query(queries) # Should return same number of responses self.assertEqual(len(responses), len(queries)) # Each response should be valid for i, response in enumerate(responses): self.assertIsInstance(response, QueryResponse) self.assertEqual(response.query, queries[i]) self.assertGreater(len(response.answer), 0) def test_context_preparation(self): """Test context preparation from search results""" query = "preeclampsia management guidelines" response = self.rag_system.query(query) # Should have sources for context if response.sources: # Context should be prepared from these sources self.assertGreater(len(response.answer), 20) # Reasonable answer length # Mock LLM should include safety disclaimer self.assertIn("healthcare", response.answer.lower()) def test_error_handling(self): """Test error handling for edge cases""" # Empty query empty_response = self.rag_system.query("") self.assertIsInstance(empty_response, QueryResponse) self.assertIsInstance(empty_response.answer, str) # Very long query long_query = "What is the management protocol for " + "complicated " * 100 + "pregnancy cases?" long_response = self.rag_system.query(long_query) self.assertIsInstance(long_response, QueryResponse) # Special characters query special_query = "What is the dosage for mg++ and other electrolytes?" special_response = self.rag_system.query(special_query) self.assertIsInstance(special_response, QueryResponse) def test_medical_safety_responses(self): """Test that responses include appropriate medical safety disclaimers""" queries = [ "What medication should I take for preeclampsia?", "How much magnesium sulfate should I give?", "What should I do for emergency bleeding?" ] for query in queries: with self.subTest(query=query): response = self.rag_system.query(query) # Should include medical safety language safety_terms = ['consult', 'healthcare', 'professional', 'medical'] answer_lower = response.answer.lower() safety_found = any(term in answer_lower for term in safety_terms) self.assertTrue(safety_found, f"No safety disclaimer found in response to: {query}") def test_system_statistics(self): """Test system statistics functionality""" stats = self.rag_system.get_system_stats() # Check required fields required_fields = ['vector_store', 'rag_config', 'status'] for field in required_fields: self.assertIn(field, stats) # Check vector store stats self.assertIn('total_chunks', stats['vector_store']) self.assertGreater(stats['vector_store']['total_chunks'], 0) # Check RAG config self.assertIn('default_k', stats['rag_config']) self.assertIn('llm_type', stats['rag_config']) self.assertEqual(stats['rag_config']['llm_type'], 'mock') def run_comprehensive_rag_tests(): """Run all RAG tests with detailed reporting""" print("๐Ÿงช Running Comprehensive RAG System Tests...") print("=" * 60) # Create test suite loader = unittest.TestLoader() suite = loader.loadTestsFromTestCase(TestMaternalHealthRAG) # Run tests with detailed output runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) # Print summary print("\n" + "=" * 60) print("๐Ÿ“Š RAG TEST SUMMARY:") print(f" Tests run: {result.testsRun}") print(f" Failures: {len(result.failures)}") print(f" Errors: {len(result.errors)}") if result.wasSuccessful(): print("โœ… ALL RAG TESTS PASSED! RAG system is production-ready.") print("\n๐ŸŽ‰ Key Validations Completed:") print(" โœ… RAG system initialization") print(" โœ… Query processing with medical context") print(" โœ… Confidence scoring and metadata") print(" โœ… Performance under 2 seconds") print(" โœ… Batch query processing") print(" โœ… Medical safety disclaimers") print(" โœ… Error handling and edge cases") else: print("โŒ Some RAG tests failed. Check output above for details.") if result.failures: print("\nFailures:") for test, traceback in result.failures: lines = traceback.strip().split('\n') error_line = lines[-1] if lines else "Unknown failure" print(f" - {test}: {error_line}") if result.errors: print("\nErrors:") for test, traceback in result.errors: lines = traceback.strip().split('\n') error_line = lines[-1] if lines else "Unknown error" print(f" - {test}: {error_line}") return result.wasSuccessful() if __name__ == "__main__": success = run_comprehensive_rag_tests() exit(0 if success else 1)