""" Medical Dataset Manager for handling specialized medical datasets Optimized for memory-constrained environments with streaming support """ import os import logging import asyncio from typing import Dict, Any, List, Optional, Iterator, Tuple from pathlib import Path import torch from torch.utils.data import Dataset, DataLoader from datasets import load_dataset, Dataset as HFDataset import numpy as np from PIL import Image import json from ..core.memory_manager import AdvancedMemoryManager logger = logging.getLogger(__name__) class MedicalDatasetManager: """ Manager for medical datasets with memory-efficient streaming """ # Supported medical datasets configuration SUPPORTED_DATASETS = { 'roco_v2': { 'name': 'ROCOv2 Radiology', 'repo_id': 'eltorio/ROCOv2-radiology', 'description': 'صور شعاعية مع تقارير طبية مفصلة', 'modalities': ['radiology', 'text'], 'size_gb': 8.5, 'num_samples': 81000, 'languages': ['en', 'ar'], 'medical_specialties': ['radiology', 'general'], 'data_format': 'image_text_pairs', 'streaming_supported': True }, 'ct_rate': { 'name': 'CT-RATE', 'repo_id': 'ibrahimhamamci/CT-RATE', 'description': 'صور CT مع تقييمات وتشخيصات', 'modalities': ['ct_scan', 'text'], 'size_gb': 12.3, 'num_samples': 50000, 'languages': ['en'], 'medical_specialties': ['radiology', 'emergency', 'internal_medicine'], 'data_format': 'image_text_pairs', 'streaming_supported': True }, 'umie_datasets': { 'name': 'UMIE Medical Datasets', 'repo_id': 'lion-ai/umie_datasets', 'description': 'بيانات طبية متنوعة ومتعددة الوسائط', 'modalities': ['multimodal', 'text', 'imaging'], 'size_gb': 15.7, 'num_samples': 120000, 'languages': ['en', 'ar', 'fr'], 'medical_specialties': ['general', 'cardiology', 'neurology', 'oncology'], 'data_format': 'multimodal', 'streaming_supported': True } } def __init__(self, memory_manager: AdvancedMemoryManager, cache_dir: str = "cache/medical_datasets"): """ Initialize medical dataset manager Args: memory_manager: Memory manager instance cache_dir: Directory for caching datasets """ self.memory_manager = memory_manager self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.loaded_datasets = {} self.streaming_datasets = {} logger.info("Medical Dataset Manager initialized") async def load_dataset(self, dataset_name: str, streaming: bool = True, subset: Optional[str] = None, split: str = 'train', **kwargs) -> Dict[str, Any]: """ Load medical dataset with memory optimization Args: dataset_name: Name of dataset to load streaming: Whether to use streaming mode subset: Specific subset to load split: Dataset split to load **kwargs: Additional loading parameters Returns: Dataset information and loader """ if dataset_name not in self.SUPPORTED_DATASETS: raise ValueError(f"Unsupported dataset: {dataset_name}") dataset_config = self.SUPPORTED_DATASETS[dataset_name] with self.memory_manager.memory_context(f"load_dataset_{dataset_name}"): logger.info(f"Loading medical dataset: {dataset_config['name']}") try: # Get HF token hf_token = kwargs.get('token') or os.getenv('HF_TOKEN') if streaming and dataset_config['streaming_supported']: # Load in streaming mode dataset = await self._load_streaming_dataset( dataset_config, split, hf_token, **kwargs ) else: # Load full dataset (with memory management) dataset = await self._load_full_dataset( dataset_config, split, hf_token, **kwargs ) # Create data loader data_loader = await self._create_medical_dataloader( dataset, dataset_config, **kwargs ) result = { 'dataset': dataset, 'data_loader': data_loader, 'config': dataset_config, 'streaming': streaming, 'split': split, 'estimated_size_gb': dataset_config['size_gb'] } self.loaded_datasets[dataset_name] = result return result except Exception as e: logger.error(f"Failed to load dataset {dataset_name}: {e}") raise async def _load_streaming_dataset(self, dataset_config: Dict[str, Any], split: str, hf_token: Optional[str], **kwargs) -> HFDataset: """Load dataset in streaming mode""" logger.info(f"Loading {dataset_config['name']} in streaming mode") try: dataset = load_dataset( dataset_config['repo_id'], split=split, streaming=True, token=hf_token, cache_dir=str(self.cache_dir) ) logger.info(f"Successfully loaded streaming dataset: {dataset_config['name']}") return dataset except Exception as e: logger.error(f"Failed to load streaming dataset: {e}") raise async def _load_full_dataset(self, dataset_config: Dict[str, Any], split: str, hf_token: Optional[str], **kwargs) -> HFDataset: """Load full dataset with memory management""" logger.info(f"Loading {dataset_config['name']} in full mode") # Check available memory memory_info = self.memory_manager.get_memory_info() estimated_memory_needed_gb = dataset_config['size_gb'] * 1.5 # 50% overhead if estimated_memory_needed_gb > memory_info['system_memory_available_gb']: logger.warning(f"Dataset may exceed available memory. Consider streaming mode.") try: dataset = load_dataset( dataset_config['repo_id'], split=split, streaming=False, token=hf_token, cache_dir=str(self.cache_dir) ) logger.info(f"Successfully loaded full dataset: {dataset_config['name']}") return dataset except Exception as e: logger.error(f"Failed to load full dataset: {e}") raise async def _create_medical_dataloader(self, dataset: HFDataset, dataset_config: Dict[str, Any], **kwargs) -> DataLoader: """Create optimized DataLoader for medical data""" batch_size = kwargs.get('batch_size', 4) # Small batch for memory efficiency num_workers = min(2, os.cpu_count() // 2) # Conservative worker count # Optimize batch size based on available memory memory_info = self.memory_manager.get_memory_info() if memory_info['system_memory_available_gb'] < 4: batch_size = min(batch_size, 2) # Create custom collate function for medical data collate_fn = self._create_medical_collate_fn(dataset_config) # For streaming datasets, we need a different approach if hasattr(dataset, 'iter'): # Streaming dataset return MedicalStreamingDataLoader( dataset, batch_size, collate_fn, self.memory_manager ) else: # Regular dataset return DataLoader( dataset, batch_size=batch_size, shuffle=kwargs.get('shuffle', True), num_workers=num_workers, collate_fn=collate_fn, pin_memory=False, # CPU only drop_last=True ) def _create_medical_collate_fn(self, dataset_config: Dict[str, Any]): """Create collate function for medical data""" def medical_collate_fn(batch): """Custom collate function for medical datasets""" try: if dataset_config['data_format'] == 'image_text_pairs': images = [] texts = [] for item in batch: # Handle image data if 'image' in item: image = item['image'] if isinstance(image, Image.Image): # Convert PIL image to tensor image_array = np.array(image) if len(image_array.shape) == 3: image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).float() / 255.0 else: image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() / 255.0 images.append(image_tensor) # Handle text data if 'text' in item or 'caption' in item or 'report' in item: text = item.get('text', item.get('caption', item.get('report', ''))) texts.append(str(text)) return { 'images': torch.stack(images) if images else None, 'texts': texts, 'batch_size': len(batch) } else: # Generic multimodal handling return { 'data': batch, 'batch_size': len(batch) } except Exception as e: logger.error(f"Error in collate function: {e}") # Return minimal batch on error return { 'data': batch, 'batch_size': len(batch), 'error': str(e) } return medical_collate_fn def get_dataset_info(self, dataset_name: str) -> Dict[str, Any]: """Get information about a supported dataset""" if dataset_name not in self.SUPPORTED_DATASETS: raise ValueError(f"Unsupported dataset: {dataset_name}") return self.SUPPORTED_DATASETS[dataset_name].copy() def list_supported_datasets(self) -> List[Dict[str, Any]]: """List all supported medical datasets""" return [ { 'key': key, **config } for key, config in self.SUPPORTED_DATASETS.items() ] async def preprocess_medical_batch(self, batch: Dict[str, Any], dataset_config: Dict[str, Any]) -> Dict[str, Any]: """Preprocess medical data batch""" processed_batch = {} # Handle images if 'images' in batch and batch['images'] is not None: images = batch['images'] # Resize images to standard size for memory efficiency if images.shape[-1] > 512 or images.shape[-2] > 512: images = torch.nn.functional.interpolate( images, size=(512, 512), mode='bilinear', align_corners=False ) processed_batch['images'] = images # Handle texts if 'texts' in batch: texts = batch['texts'] # Truncate long texts to save memory max_length = 512 truncated_texts = [] for text in texts: if len(text) > max_length: text = text[:max_length] + "..." truncated_texts.append(text) processed_batch['texts'] = truncated_texts processed_batch['batch_size'] = batch.get('batch_size', 0) return processed_batch def cleanup_datasets(self): """Cleanup loaded datasets to free memory""" logger.info("Cleaning up medical datasets") for dataset_name in list(self.loaded_datasets.keys()): del self.loaded_datasets[dataset_name] self.loaded_datasets.clear() self.streaming_datasets.clear() # Force garbage collection import gc gc.collect() logger.info("Medical datasets cleanup completed") class MedicalStreamingDataLoader: """Custom streaming data loader for medical datasets""" def __init__(self, dataset, batch_size: int, collate_fn, memory_manager): self.dataset = dataset self.batch_size = batch_size self.collate_fn = collate_fn self.memory_manager = memory_manager def __iter__(self): batch = [] for item in self.dataset: batch.append(item) if len(batch) >= self.batch_size: # Check memory before yielding batch status = self.memory_manager.check_memory_status() if status in ['critical', 'emergency']: self.memory_manager.force_cleanup() yield self.collate_fn(batch) batch = [] # Yield remaining items if batch: yield self.collate_fn(batch)