Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Test Suite for Comprehensive Medical Document Chunking | |
Validates clinical context preservation and chunk quality | |
""" | |
import json | |
import pytest | |
from pathlib import Path | |
from typing import Dict, List, Any | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ChunkingQualityValidator: | |
"""Validates the quality of medical document chunking""" | |
def __init__(self, chunks_dir: Path = Path("comprehensive_chunks")): | |
self.chunks_dir = chunks_dir | |
self.test_results = {} | |
def load_chunking_report(self) -> Dict[str, Any]: | |
"""Load the comprehensive chunking report""" | |
report_file = self.chunks_dir / "comprehensive_chunking_report.json" | |
if not report_file.exists(): | |
raise FileNotFoundError(f"Chunking report not found: {report_file}") | |
with open(report_file) as f: | |
return json.load(f) | |
def load_sample_chunks(self, doc_name: str, limit: int = 5) -> List[Dict]: | |
"""Load sample chunks from a document""" | |
doc_chunks_file = self.chunks_dir / doc_name / "comprehensive_chunks.json" | |
if not doc_chunks_file.exists(): | |
return [] | |
with open(doc_chunks_file) as f: | |
chunks = json.load(f) | |
return chunks[:limit] | |
def test_basic_statistics(self, report: Dict[str, Any]) -> bool: | |
"""Test basic chunking statistics""" | |
logger.info("Testing basic chunking statistics...") | |
try: | |
# Test that we have reasonable number of chunks | |
total_chunks = report['total_chunks'] | |
total_docs = report['total_documents'] | |
assert total_chunks > 0, "No chunks were created" | |
assert total_docs > 0, "No documents were processed" | |
assert total_chunks >= total_docs, "Too few chunks per document" | |
# Test chunk distribution | |
chunk_types = report['chunk_type_distribution'] | |
assert len(chunk_types) > 0, "No chunk types identified" | |
# Test importance distribution | |
importance_dist = report['clinical_importance_distribution'] | |
high_importance = importance_dist.get('critical', 0) + importance_dist.get('high', 0) | |
assert high_importance > 0, "No high importance chunks found" | |
logger.info(f"β Basic statistics: {total_chunks} chunks from {total_docs} documents") | |
self.test_results['basic_statistics'] = True | |
return True | |
except AssertionError as e: | |
logger.error(f"β Basic statistics test failed: {e}") | |
self.test_results['basic_statistics'] = False | |
return False | |
def test_clinical_content_recognition(self, report: Dict[str, Any]) -> bool: | |
"""Test that clinical content is properly recognized""" | |
logger.info("Testing clinical content recognition...") | |
try: | |
processing_summary = report['processing_summary'] | |
# Test for maternal health content | |
maternal_chunks = processing_summary.get('maternal_chunks', 0) | |
assert maternal_chunks > 0, "No maternal health content identified" | |
# Test for dosage information | |
dosage_chunks = processing_summary.get('dosage_chunks', 0) | |
assert dosage_chunks > 0, "No dosage information identified" | |
# Test for emergency content | |
emergency_chunks = processing_summary.get('emergency_chunks', 0) | |
# Emergency content is optional but good to have | |
# Test for table preservation | |
table_chunks = processing_summary.get('chunks_with_tables', 0) | |
assert table_chunks > 0, "No table content preserved" | |
logger.info(f"β Clinical content: {maternal_chunks} maternal, {dosage_chunks} dosage, {table_chunks} with tables") | |
self.test_results['clinical_content'] = True | |
return True | |
except AssertionError as e: | |
logger.error(f"β Clinical content test failed: {e}") | |
self.test_results['clinical_content'] = False | |
return False | |
def test_chunk_quality(self, report: Dict[str, Any]) -> bool: | |
"""Test individual chunk quality""" | |
logger.info("Testing chunk quality...") | |
try: | |
# Load sample chunks from different documents | |
doc_names = list(report['document_statistics'].keys()) | |
sample_count = 0 | |
valid_chunks = 0 | |
for doc_name in doc_names[:3]: # Test first 3 documents | |
chunks = self.load_sample_chunks(doc_name, limit=3) | |
for chunk in chunks: | |
sample_count += 1 | |
# Test chunk structure | |
required_fields = ['content', 'chunk_type', 'clinical_importance', 'medical_context'] | |
if all(field in chunk for field in required_fields): | |
valid_chunks += 1 | |
# Test content quality | |
content = chunk['content'] | |
if len(content.strip()) > 50: # Reasonable content length | |
# Test clinical importance scoring | |
importance = chunk['clinical_importance'] | |
if 0 <= importance <= 1: | |
# Test medical context | |
context = chunk['medical_context'] | |
if isinstance(context, dict) and len(context) > 0: | |
continue | |
chunk_quality_ratio = valid_chunks / sample_count if sample_count > 0 else 0 | |
assert chunk_quality_ratio >= 0.8, f"Chunk quality too low: {chunk_quality_ratio:.2f}" | |
logger.info(f"β Chunk quality: {valid_chunks}/{sample_count} chunks passed quality checks") | |
self.test_results['chunk_quality'] = True | |
return True | |
except AssertionError as e: | |
logger.error(f"β Chunk quality test failed: {e}") | |
self.test_results['chunk_quality'] = False | |
return False | |
except Exception as e: | |
logger.error(f"β Chunk quality test error: {e}") | |
self.test_results['chunk_quality'] = False | |
return False | |
def test_medical_context_preservation(self) -> bool: | |
"""Test that medical context is properly preserved""" | |
logger.info("Testing medical context preservation...") | |
try: | |
# Load LangChain documents | |
langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" | |
if not langchain_file.exists(): | |
raise FileNotFoundError("LangChain documents not found") | |
with open(langchain_file) as f: | |
langchain_docs = json.load(f) | |
# Test sample of documents | |
medical_context_count = 0 | |
total_tested = 0 | |
for doc in langchain_docs[:20]: # Test first 20 documents | |
total_tested += 1 | |
metadata = doc.get('metadata', {}) | |
# Check for medical context fields | |
medical_fields = [ | |
'chunk_type', 'clinical_importance', 'keywords', | |
'has_clinical_protocols', 'has_dosage_info', 'is_maternal_specific' | |
] | |
if any(field in metadata for field in medical_fields): | |
medical_context_count += 1 | |
context_ratio = medical_context_count / total_tested if total_tested > 0 else 0 | |
assert context_ratio >= 0.8, f"Medical context preservation too low: {context_ratio:.2f}" | |
logger.info(f"β Medical context: {medical_context_count}/{total_tested} documents have medical context") | |
self.test_results['medical_context'] = True | |
return True | |
except AssertionError as e: | |
logger.error(f"β Medical context test failed: {e}") | |
self.test_results['medical_context'] = False | |
return False | |
except Exception as e: | |
logger.error(f"β Medical context test error: {e}") | |
self.test_results['medical_context'] = False | |
return False | |
def test_document_coverage(self, report: Dict[str, Any]) -> bool: | |
"""Test that all documents were processed""" | |
logger.info("Testing document coverage...") | |
try: | |
doc_stats = report['document_statistics'] | |
processed_docs = len(doc_stats) | |
# We should have processed all 15 maternal health documents | |
expected_min_docs = 10 # Minimum expected | |
assert processed_docs >= expected_min_docs, f"Too few documents processed: {processed_docs}" | |
# Check that each document has reasonable chunks | |
docs_with_good_coverage = 0 | |
for doc_name, stats in doc_stats.items(): | |
if stats['total_chunks'] > 0: | |
docs_with_good_coverage += 1 | |
coverage_ratio = docs_with_good_coverage / processed_docs | |
assert coverage_ratio >= 0.9, f"Document coverage too low: {coverage_ratio:.2f}" | |
logger.info(f"β Document coverage: {docs_with_good_coverage}/{processed_docs} documents well covered") | |
self.test_results['document_coverage'] = True | |
return True | |
except AssertionError as e: | |
logger.error(f"β Document coverage test failed: {e}") | |
self.test_results['document_coverage'] = False | |
return False | |
def test_clinical_importance_distribution(self, report: Dict[str, Any]) -> bool: | |
"""Test that clinical importance is properly distributed""" | |
logger.info("Testing clinical importance distribution...") | |
try: | |
importance_dist = report['clinical_importance_distribution'] | |
total = sum(importance_dist.values()) | |
critical_ratio = importance_dist.get('critical', 0) / total | |
high_ratio = importance_dist.get('high', 0) / total | |
# We expect a good amount of high-importance content for medical guidelines | |
high_importance_ratio = critical_ratio + high_ratio | |
assert high_importance_ratio >= 0.3, f"Too little high-importance content: {high_importance_ratio:.2f}" | |
logger.info(f"β Clinical importance: {high_importance_ratio:.1%} high-importance chunks") | |
self.test_results['clinical_importance'] = True | |
return True | |
except AssertionError as e: | |
logger.error(f"β Clinical importance test failed: {e}") | |
self.test_results['clinical_importance'] = False | |
return False | |
def run_all_tests(self) -> Dict[str, bool]: | |
"""Run all quality validation tests""" | |
logger.info("=" * 80) | |
logger.info("STARTING COMPREHENSIVE CHUNKING QUALITY VALIDATION") | |
logger.info("=" * 80) | |
try: | |
# Load the chunking report | |
report = self.load_chunking_report() | |
# Run all tests | |
tests = [ | |
('Basic Statistics', lambda: self.test_basic_statistics(report)), | |
('Clinical Content Recognition', lambda: self.test_clinical_content_recognition(report)), | |
('Chunk Quality', lambda: self.test_chunk_quality(report)), | |
('Medical Context Preservation', lambda: self.test_medical_context_preservation()), | |
('Document Coverage', lambda: self.test_document_coverage(report)), | |
('Clinical Importance Distribution', lambda: self.test_clinical_importance_distribution(report)) | |
] | |
results = {} | |
passed_tests = 0 | |
for test_name, test_func in tests: | |
logger.info(f"\nπ§ͺ Running: {test_name}") | |
try: | |
result = test_func() | |
results[test_name] = result | |
if result: | |
passed_tests += 1 | |
except Exception as e: | |
logger.error(f"β {test_name} failed with error: {e}") | |
results[test_name] = False | |
# Summary | |
logger.info("\n" + "=" * 80) | |
logger.info("CHUNKING QUALITY VALIDATION SUMMARY") | |
logger.info("=" * 80) | |
logger.info(f"β Tests Passed: {passed_tests}/{len(tests)}") | |
for test_name, result in results.items(): | |
status = "β PASS" if result else "β FAIL" | |
logger.info(f"{status}: {test_name}") | |
overall_success = passed_tests >= (len(tests) * 0.8) # 80% pass rate | |
if overall_success: | |
logger.info("\nπ OVERALL RESULT: CHUNKING QUALITY VALIDATION PASSED!") | |
else: | |
logger.info("\nβ οΈ OVERALL RESULT: CHUNKING QUALITY VALIDATION NEEDS IMPROVEMENT") | |
logger.info("=" * 80) | |
return results | |
except Exception as e: | |
logger.error(f"β Validation failed with error: {e}") | |
return {} | |
def main(): | |
"""Main test function""" | |
validator = ChunkingQualityValidator() | |
results = validator.run_all_tests() | |
# Save test results | |
test_results_file = Path("comprehensive_chunks") / "quality_validation_results.json" | |
with open(test_results_file, "w") as f: | |
json.dump({ | |
'test_results': results, | |
'summary': { | |
'total_tests': len(results), | |
'passed_tests': sum(results.values()), | |
'pass_rate': sum(results.values()) / len(results) if results else 0 | |
} | |
}, f, indent=2) | |
logger.info(f"π Test results saved to: {test_results_file}") | |
if __name__ == "__main__": | |
main() |