Spaces:
Running
Running
""" | |
Medical Dataset Manager for handling specialized medical datasets | |
Optimized for memory-constrained environments with streaming support | |
""" | |
import os | |
import logging | |
import asyncio | |
from typing import Dict, Any, List, Optional, Iterator, Tuple | |
from pathlib import Path | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from datasets import load_dataset, Dataset as HFDataset | |
import numpy as np | |
from PIL import Image | |
import json | |
from ..core.memory_manager import AdvancedMemoryManager | |
logger = logging.getLogger(__name__) | |
class MedicalDatasetManager: | |
""" | |
Manager for medical datasets with memory-efficient streaming | |
""" | |
# Supported medical datasets configuration | |
SUPPORTED_DATASETS = { | |
'roco_v2': { | |
'name': 'ROCOv2 Radiology', | |
'repo_id': 'eltorio/ROCOv2-radiology', | |
'description': 'صور شعاعية مع تقارير طبية مفصلة', | |
'modalities': ['radiology', 'text'], | |
'size_gb': 8.5, | |
'num_samples': 81000, | |
'languages': ['en', 'ar'], | |
'medical_specialties': ['radiology', 'general'], | |
'data_format': 'image_text_pairs', | |
'streaming_supported': True | |
}, | |
'ct_rate': { | |
'name': 'CT-RATE', | |
'repo_id': 'ibrahimhamamci/CT-RATE', | |
'description': 'صور CT مع تقييمات وتشخيصات', | |
'modalities': ['ct_scan', 'text'], | |
'size_gb': 12.3, | |
'num_samples': 50000, | |
'languages': ['en'], | |
'medical_specialties': ['radiology', 'emergency', 'internal_medicine'], | |
'data_format': 'image_text_pairs', | |
'streaming_supported': True | |
}, | |
'umie_datasets': { | |
'name': 'UMIE Medical Datasets', | |
'repo_id': 'lion-ai/umie_datasets', | |
'description': 'بيانات طبية متنوعة ومتعددة الوسائط', | |
'modalities': ['multimodal', 'text', 'imaging'], | |
'size_gb': 15.7, | |
'num_samples': 120000, | |
'languages': ['en', 'ar', 'fr'], | |
'medical_specialties': ['general', 'cardiology', 'neurology', 'oncology'], | |
'data_format': 'multimodal', | |
'streaming_supported': True | |
} | |
} | |
def __init__(self, memory_manager: AdvancedMemoryManager, | |
cache_dir: str = "cache/medical_datasets"): | |
""" | |
Initialize medical dataset manager | |
Args: | |
memory_manager: Memory manager instance | |
cache_dir: Directory for caching datasets | |
""" | |
self.memory_manager = memory_manager | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
self.loaded_datasets = {} | |
self.streaming_datasets = {} | |
logger.info("Medical Dataset Manager initialized") | |
async def load_dataset(self, dataset_name: str, | |
streaming: bool = True, | |
subset: Optional[str] = None, | |
split: str = 'train', | |
**kwargs) -> Dict[str, Any]: | |
""" | |
Load medical dataset with memory optimization | |
Args: | |
dataset_name: Name of dataset to load | |
streaming: Whether to use streaming mode | |
subset: Specific subset to load | |
split: Dataset split to load | |
**kwargs: Additional loading parameters | |
Returns: | |
Dataset information and loader | |
""" | |
if dataset_name not in self.SUPPORTED_DATASETS: | |
raise ValueError(f"Unsupported dataset: {dataset_name}") | |
dataset_config = self.SUPPORTED_DATASETS[dataset_name] | |
with self.memory_manager.memory_context(f"load_dataset_{dataset_name}"): | |
logger.info(f"Loading medical dataset: {dataset_config['name']}") | |
try: | |
# Get HF token | |
hf_token = kwargs.get('token') or os.getenv('HF_TOKEN') | |
if streaming and dataset_config['streaming_supported']: | |
# Load in streaming mode | |
dataset = await self._load_streaming_dataset( | |
dataset_config, split, hf_token, **kwargs | |
) | |
else: | |
# Load full dataset (with memory management) | |
dataset = await self._load_full_dataset( | |
dataset_config, split, hf_token, **kwargs | |
) | |
# Create data loader | |
data_loader = await self._create_medical_dataloader( | |
dataset, dataset_config, **kwargs | |
) | |
result = { | |
'dataset': dataset, | |
'data_loader': data_loader, | |
'config': dataset_config, | |
'streaming': streaming, | |
'split': split, | |
'estimated_size_gb': dataset_config['size_gb'] | |
} | |
self.loaded_datasets[dataset_name] = result | |
return result | |
except Exception as e: | |
logger.error(f"Failed to load dataset {dataset_name}: {e}") | |
raise | |
async def _load_streaming_dataset(self, dataset_config: Dict[str, Any], | |
split: str, hf_token: Optional[str], | |
**kwargs) -> HFDataset: | |
"""Load dataset in streaming mode""" | |
logger.info(f"Loading {dataset_config['name']} in streaming mode") | |
try: | |
dataset = load_dataset( | |
dataset_config['repo_id'], | |
split=split, | |
streaming=True, | |
token=hf_token, | |
cache_dir=str(self.cache_dir) | |
) | |
logger.info(f"Successfully loaded streaming dataset: {dataset_config['name']}") | |
return dataset | |
except Exception as e: | |
logger.error(f"Failed to load streaming dataset: {e}") | |
raise | |
async def _load_full_dataset(self, dataset_config: Dict[str, Any], | |
split: str, hf_token: Optional[str], | |
**kwargs) -> HFDataset: | |
"""Load full dataset with memory management""" | |
logger.info(f"Loading {dataset_config['name']} in full mode") | |
# Check available memory | |
memory_info = self.memory_manager.get_memory_info() | |
estimated_memory_needed_gb = dataset_config['size_gb'] * 1.5 # 50% overhead | |
if estimated_memory_needed_gb > memory_info['system_memory_available_gb']: | |
logger.warning(f"Dataset may exceed available memory. Consider streaming mode.") | |
try: | |
dataset = load_dataset( | |
dataset_config['repo_id'], | |
split=split, | |
streaming=False, | |
token=hf_token, | |
cache_dir=str(self.cache_dir) | |
) | |
logger.info(f"Successfully loaded full dataset: {dataset_config['name']}") | |
return dataset | |
except Exception as e: | |
logger.error(f"Failed to load full dataset: {e}") | |
raise | |
async def _create_medical_dataloader(self, dataset: HFDataset, | |
dataset_config: Dict[str, Any], | |
**kwargs) -> DataLoader: | |
"""Create optimized DataLoader for medical data""" | |
batch_size = kwargs.get('batch_size', 4) # Small batch for memory efficiency | |
num_workers = min(2, os.cpu_count() // 2) # Conservative worker count | |
# Optimize batch size based on available memory | |
memory_info = self.memory_manager.get_memory_info() | |
if memory_info['system_memory_available_gb'] < 4: | |
batch_size = min(batch_size, 2) | |
# Create custom collate function for medical data | |
collate_fn = self._create_medical_collate_fn(dataset_config) | |
# For streaming datasets, we need a different approach | |
if hasattr(dataset, 'iter'): | |
# Streaming dataset | |
return MedicalStreamingDataLoader( | |
dataset, batch_size, collate_fn, self.memory_manager | |
) | |
else: | |
# Regular dataset | |
return DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=kwargs.get('shuffle', True), | |
num_workers=num_workers, | |
collate_fn=collate_fn, | |
pin_memory=False, # CPU only | |
drop_last=True | |
) | |
def _create_medical_collate_fn(self, dataset_config: Dict[str, Any]): | |
"""Create collate function for medical data""" | |
def medical_collate_fn(batch): | |
"""Custom collate function for medical datasets""" | |
try: | |
if dataset_config['data_format'] == 'image_text_pairs': | |
images = [] | |
texts = [] | |
for item in batch: | |
# Handle image data | |
if 'image' in item: | |
image = item['image'] | |
if isinstance(image, Image.Image): | |
# Convert PIL image to tensor | |
image_array = np.array(image) | |
if len(image_array.shape) == 3: | |
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).float() / 255.0 | |
else: | |
image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() / 255.0 | |
images.append(image_tensor) | |
# Handle text data | |
if 'text' in item or 'caption' in item or 'report' in item: | |
text = item.get('text', item.get('caption', item.get('report', ''))) | |
texts.append(str(text)) | |
return { | |
'images': torch.stack(images) if images else None, | |
'texts': texts, | |
'batch_size': len(batch) | |
} | |
else: | |
# Generic multimodal handling | |
return { | |
'data': batch, | |
'batch_size': len(batch) | |
} | |
except Exception as e: | |
logger.error(f"Error in collate function: {e}") | |
# Return minimal batch on error | |
return { | |
'data': batch, | |
'batch_size': len(batch), | |
'error': str(e) | |
} | |
return medical_collate_fn | |
def get_dataset_info(self, dataset_name: str) -> Dict[str, Any]: | |
"""Get information about a supported dataset""" | |
if dataset_name not in self.SUPPORTED_DATASETS: | |
raise ValueError(f"Unsupported dataset: {dataset_name}") | |
return self.SUPPORTED_DATASETS[dataset_name].copy() | |
def list_supported_datasets(self) -> List[Dict[str, Any]]: | |
"""List all supported medical datasets""" | |
return [ | |
{ | |
'key': key, | |
**config | |
} | |
for key, config in self.SUPPORTED_DATASETS.items() | |
] | |
async def preprocess_medical_batch(self, batch: Dict[str, Any], | |
dataset_config: Dict[str, Any]) -> Dict[str, Any]: | |
"""Preprocess medical data batch""" | |
processed_batch = {} | |
# Handle images | |
if 'images' in batch and batch['images'] is not None: | |
images = batch['images'] | |
# Resize images to standard size for memory efficiency | |
if images.shape[-1] > 512 or images.shape[-2] > 512: | |
images = torch.nn.functional.interpolate( | |
images, size=(512, 512), mode='bilinear', align_corners=False | |
) | |
processed_batch['images'] = images | |
# Handle texts | |
if 'texts' in batch: | |
texts = batch['texts'] | |
# Truncate long texts to save memory | |
max_length = 512 | |
truncated_texts = [] | |
for text in texts: | |
if len(text) > max_length: | |
text = text[:max_length] + "..." | |
truncated_texts.append(text) | |
processed_batch['texts'] = truncated_texts | |
processed_batch['batch_size'] = batch.get('batch_size', 0) | |
return processed_batch | |
def cleanup_datasets(self): | |
"""Cleanup loaded datasets to free memory""" | |
logger.info("Cleaning up medical datasets") | |
for dataset_name in list(self.loaded_datasets.keys()): | |
del self.loaded_datasets[dataset_name] | |
self.loaded_datasets.clear() | |
self.streaming_datasets.clear() | |
# Force garbage collection | |
import gc | |
gc.collect() | |
logger.info("Medical datasets cleanup completed") | |
class MedicalStreamingDataLoader: | |
"""Custom streaming data loader for medical datasets""" | |
def __init__(self, dataset, batch_size: int, collate_fn, memory_manager): | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.collate_fn = collate_fn | |
self.memory_manager = memory_manager | |
def __iter__(self): | |
batch = [] | |
for item in self.dataset: | |
batch.append(item) | |
if len(batch) >= self.batch_size: | |
# Check memory before yielding batch | |
status = self.memory_manager.check_memory_status() | |
if status in ['critical', 'emergency']: | |
self.memory_manager.force_cleanup() | |
yield self.collate_fn(batch) | |
batch = [] | |
# Yield remaining items | |
if batch: | |
yield self.collate_fn(batch) | |