File size: 11,852 Bytes
19aaa42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#!/usr/bin/env python3
"""
Test Suite for Maternal Health RAG System
Validates complete RAG pipeline including vector retrieval and response generation
"""

import unittest
import time
from typing import List, Dict, Any
from pathlib import Path

from maternal_health_rag import MaternalHealthRAG, QueryResponse

class TestMaternalHealthRAG(unittest.TestCase):
    """Test suite for RAG system functionality"""
    
    @classmethod
    def setUpClass(cls):
        """Set up test environment"""
        print("πŸš€ Initializing RAG system for testing...")
        cls.rag_system = MaternalHealthRAG(use_mock_llm=True)
    
    def test_rag_system_initialization(self):
        """Test RAG system initializes correctly"""
        self.assertIsNotNone(self.rag_system.vector_store)
        self.assertIsNotNone(self.rag_system.llm)
        self.assertIsNotNone(self.rag_system.rag_chain)
        
        # Check system status
        stats = self.rag_system.get_system_stats()
        self.assertEqual(stats['status'], 'initialized')
        self.assertGreater(stats['vector_store']['total_chunks'], 0)
    
    def test_basic_query_processing(self):
        """Test basic query processing functionality"""
        query = "What is magnesium sulfate used for?"
        response = self.rag_system.query(query)
        
        # Basic response validation
        self.assertIsInstance(response, QueryResponse)
        self.assertEqual(response.query, query)
        self.assertIsInstance(response.answer, str)
        self.assertGreater(len(response.answer), 0)
        self.assertGreaterEqual(response.confidence, 0.0)
        self.assertLessEqual(response.confidence, 1.0)
        self.assertGreater(response.response_time, 0.0)
    
    def test_medical_context_queries(self):
        """Test queries with specific medical context"""
        
        test_cases = [
            {
                'query': 'What is the dosage for magnesium sulfate in preeclampsia?',
                'content_types': ['dosage', 'emergency'],
                'expected_keywords': ['magnesium', 'dosage', 'preeclampsia'],
                'min_confidence': 0.3
            },
            {
                'query': 'How to manage postpartum hemorrhage emergency?',
                'content_types': ['emergency', 'maternal'],
                'expected_keywords': ['hemorrhage', 'postpartum', 'emergency'],
                'min_confidence': 0.3
            },
            {
                'query': 'Normal fetal heart rate monitoring procedures',
                'content_types': ['procedure', 'maternal'],
                'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'],
                'min_confidence': 0.3
            }
        ]
        
        for case in test_cases:
            with self.subTest(query=case['query']):
                response = self.rag_system.query(
                    case['query'],
                    content_types=case['content_types']
                )
                
                # Response quality checks
                self.assertGreater(len(response.sources), 0)
                self.assertGreaterEqual(response.confidence, case['min_confidence'])
                
                # Check if expected keywords appear in answer or sources
                combined_text = response.answer.lower()
                if response.sources:
                    combined_text += ' ' + ' '.join([s.content.lower() for s in response.sources])
                
                keyword_found = any(keyword in combined_text for keyword in case['expected_keywords'])
                self.assertTrue(keyword_found, 
                               f"No expected keywords found for query: {case['query']}")
    
    def test_response_metadata(self):
        """Test response metadata is populated correctly"""
        query = "What are the signs of preeclampsia?"
        response = self.rag_system.query(query)
        
        # Check metadata fields
        required_fields = ['num_sources', 'avg_relevance', 'content_types', 'high_importance_sources']
        for field in required_fields:
            self.assertIn(field, response.metadata)
        
        # Validate metadata values
        self.assertIsInstance(response.metadata['num_sources'], int)
        self.assertGreaterEqual(response.metadata['num_sources'], 0)
        self.assertIsInstance(response.metadata['avg_relevance'], float)
        self.assertIsInstance(response.metadata['content_types'], list)
        self.assertIsInstance(response.metadata['high_importance_sources'], int)
    
    def test_confidence_scoring(self):
        """Test confidence scoring mechanism"""
        
        # High-confidence query (should match well)
        high_conf_query = "magnesium sulfate for preeclampsia"
        high_response = self.rag_system.query(high_conf_query)
        
        # Low-confidence query (less specific)
        low_conf_query = "medical procedures in general"
        low_response = self.rag_system.query(low_conf_query)
        
        # Confidence should be higher for more specific medical queries
        self.assertGreaterEqual(high_response.confidence, 0.3)
        
        # Both should have valid confidence scores
        self.assertGreaterEqual(high_response.confidence, 0.0)
        self.assertLessEqual(high_response.confidence, 1.0)
        self.assertGreaterEqual(low_response.confidence, 0.0)
        self.assertLessEqual(low_response.confidence, 1.0)
    
    def test_performance_metrics(self):
        """Test RAG system performance"""
        query = "What is normal labor management?"
        
        # Measure response time
        start_time = time.time()
        response = self.rag_system.query(query)
        actual_time = time.time() - start_time
        
        # Response should be fast (under 2 seconds for mock LLM)
        self.assertLess(response.response_time, 2.0)
        self.assertLess(actual_time, 3.0)
        
        # Should return relevant sources
        self.assertGreater(len(response.sources), 0)
        self.assertLessEqual(len(response.sources), 10)  # Should be reasonable number
    
    def test_batch_query_processing(self):
        """Test batch query processing functionality"""
        queries = [
            "What is magnesium sulfate used for?",
            "How to manage labor complications?",
            "Normal fetal heart rate ranges"
        ]
        
        responses = self.rag_system.batch_query(queries)
        
        # Should return same number of responses
        self.assertEqual(len(responses), len(queries))
        
        # Each response should be valid
        for i, response in enumerate(responses):
            self.assertIsInstance(response, QueryResponse)
            self.assertEqual(response.query, queries[i])
            self.assertGreater(len(response.answer), 0)
    
    def test_context_preparation(self):
        """Test context preparation from search results"""
        query = "preeclampsia management guidelines"
        response = self.rag_system.query(query)
        
        # Should have sources for context
        if response.sources:
            # Context should be prepared from these sources
            self.assertGreater(len(response.answer), 20)  # Reasonable answer length
            
            # Mock LLM should include safety disclaimer
            self.assertIn("healthcare", response.answer.lower())
    
    def test_error_handling(self):
        """Test error handling for edge cases"""
        
        # Empty query
        empty_response = self.rag_system.query("")
        self.assertIsInstance(empty_response, QueryResponse)
        self.assertIsInstance(empty_response.answer, str)
        
        # Very long query
        long_query = "What is the management protocol for " + "complicated " * 100 + "pregnancy cases?"
        long_response = self.rag_system.query(long_query)
        self.assertIsInstance(long_response, QueryResponse)
        
        # Special characters query
        special_query = "What is the dosage for mg++ and other electrolytes?"
        special_response = self.rag_system.query(special_query)
        self.assertIsInstance(special_response, QueryResponse)
    
    def test_medical_safety_responses(self):
        """Test that responses include appropriate medical safety disclaimers"""
        queries = [
            "What medication should I take for preeclampsia?",
            "How much magnesium sulfate should I give?",
            "What should I do for emergency bleeding?"
        ]
        
        for query in queries:
            with self.subTest(query=query):
                response = self.rag_system.query(query)
                
                # Should include medical safety language
                safety_terms = ['consult', 'healthcare', 'professional', 'medical']
                answer_lower = response.answer.lower()
                
                safety_found = any(term in answer_lower for term in safety_terms)
                self.assertTrue(safety_found, 
                               f"No safety disclaimer found in response to: {query}")
    
    def test_system_statistics(self):
        """Test system statistics functionality"""
        stats = self.rag_system.get_system_stats()
        
        # Check required fields
        required_fields = ['vector_store', 'rag_config', 'status']
        for field in required_fields:
            self.assertIn(field, stats)
        
        # Check vector store stats
        self.assertIn('total_chunks', stats['vector_store'])
        self.assertGreater(stats['vector_store']['total_chunks'], 0)
        
        # Check RAG config
        self.assertIn('default_k', stats['rag_config'])
        self.assertIn('llm_type', stats['rag_config'])
        self.assertEqual(stats['rag_config']['llm_type'], 'mock')

def run_comprehensive_rag_tests():
    """Run all RAG tests with detailed reporting"""
    print("πŸ§ͺ Running Comprehensive RAG System Tests...")
    print("=" * 60)
    
    # Create test suite
    loader = unittest.TestLoader()
    suite = loader.loadTestsFromTestCase(TestMaternalHealthRAG)
    
    # Run tests with detailed output
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    # Print summary
    print("\n" + "=" * 60)
    print("πŸ“Š RAG TEST SUMMARY:")
    print(f"  Tests run: {result.testsRun}")
    print(f"  Failures: {len(result.failures)}")
    print(f"  Errors: {len(result.errors)}")
    
    if result.wasSuccessful():
        print("βœ… ALL RAG TESTS PASSED! RAG system is production-ready.")
        print("\nπŸŽ‰ Key Validations Completed:")
        print("  βœ… RAG system initialization")
        print("  βœ… Query processing with medical context")
        print("  βœ… Confidence scoring and metadata")
        print("  βœ… Performance under 2 seconds")
        print("  βœ… Batch query processing")
        print("  βœ… Medical safety disclaimers")
        print("  βœ… Error handling and edge cases")
    else:
        print("❌ Some RAG tests failed. Check output above for details.")
        
        if result.failures:
            print("\nFailures:")
            for test, traceback in result.failures:
                lines = traceback.strip().split('\n')
                error_line = lines[-1] if lines else "Unknown failure"
                print(f"  - {test}: {error_line}")
        
        if result.errors:
            print("\nErrors:")
            for test, traceback in result.errors:
                lines = traceback.strip().split('\n')
                error_line = lines[-1] if lines else "Unknown error"
                print(f"  - {test}: {error_line}")
    
    return result.wasSuccessful()

if __name__ == "__main__":
    success = run_comprehensive_rag_tests()
    exit(0 if success else 1)