train-modle / src /medical /medical_datasets.py
fokan's picture
Initial clean commit: Multi-Modal Knowledge Distillation Platform
ab4e093
"""
Medical Dataset Manager for handling specialized medical datasets
Optimized for memory-constrained environments with streaming support
"""
import os
import logging
import asyncio
from typing import Dict, Any, List, Optional, Iterator, Tuple
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset as HFDataset
import numpy as np
from PIL import Image
import json
from ..core.memory_manager import AdvancedMemoryManager
logger = logging.getLogger(__name__)
class MedicalDatasetManager:
"""
Manager for medical datasets with memory-efficient streaming
"""
# Supported medical datasets configuration
SUPPORTED_DATASETS = {
'roco_v2': {
'name': 'ROCOv2 Radiology',
'repo_id': 'eltorio/ROCOv2-radiology',
'description': 'صور شعاعية مع تقارير طبية مفصلة',
'modalities': ['radiology', 'text'],
'size_gb': 8.5,
'num_samples': 81000,
'languages': ['en', 'ar'],
'medical_specialties': ['radiology', 'general'],
'data_format': 'image_text_pairs',
'streaming_supported': True
},
'ct_rate': {
'name': 'CT-RATE',
'repo_id': 'ibrahimhamamci/CT-RATE',
'description': 'صور CT مع تقييمات وتشخيصات',
'modalities': ['ct_scan', 'text'],
'size_gb': 12.3,
'num_samples': 50000,
'languages': ['en'],
'medical_specialties': ['radiology', 'emergency', 'internal_medicine'],
'data_format': 'image_text_pairs',
'streaming_supported': True
},
'umie_datasets': {
'name': 'UMIE Medical Datasets',
'repo_id': 'lion-ai/umie_datasets',
'description': 'بيانات طبية متنوعة ومتعددة الوسائط',
'modalities': ['multimodal', 'text', 'imaging'],
'size_gb': 15.7,
'num_samples': 120000,
'languages': ['en', 'ar', 'fr'],
'medical_specialties': ['general', 'cardiology', 'neurology', 'oncology'],
'data_format': 'multimodal',
'streaming_supported': True
}
}
def __init__(self, memory_manager: AdvancedMemoryManager,
cache_dir: str = "cache/medical_datasets"):
"""
Initialize medical dataset manager
Args:
memory_manager: Memory manager instance
cache_dir: Directory for caching datasets
"""
self.memory_manager = memory_manager
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.loaded_datasets = {}
self.streaming_datasets = {}
logger.info("Medical Dataset Manager initialized")
async def load_dataset(self, dataset_name: str,
streaming: bool = True,
subset: Optional[str] = None,
split: str = 'train',
**kwargs) -> Dict[str, Any]:
"""
Load medical dataset with memory optimization
Args:
dataset_name: Name of dataset to load
streaming: Whether to use streaming mode
subset: Specific subset to load
split: Dataset split to load
**kwargs: Additional loading parameters
Returns:
Dataset information and loader
"""
if dataset_name not in self.SUPPORTED_DATASETS:
raise ValueError(f"Unsupported dataset: {dataset_name}")
dataset_config = self.SUPPORTED_DATASETS[dataset_name]
with self.memory_manager.memory_context(f"load_dataset_{dataset_name}"):
logger.info(f"Loading medical dataset: {dataset_config['name']}")
try:
# Get HF token
hf_token = kwargs.get('token') or os.getenv('HF_TOKEN')
if streaming and dataset_config['streaming_supported']:
# Load in streaming mode
dataset = await self._load_streaming_dataset(
dataset_config, split, hf_token, **kwargs
)
else:
# Load full dataset (with memory management)
dataset = await self._load_full_dataset(
dataset_config, split, hf_token, **kwargs
)
# Create data loader
data_loader = await self._create_medical_dataloader(
dataset, dataset_config, **kwargs
)
result = {
'dataset': dataset,
'data_loader': data_loader,
'config': dataset_config,
'streaming': streaming,
'split': split,
'estimated_size_gb': dataset_config['size_gb']
}
self.loaded_datasets[dataset_name] = result
return result
except Exception as e:
logger.error(f"Failed to load dataset {dataset_name}: {e}")
raise
async def _load_streaming_dataset(self, dataset_config: Dict[str, Any],
split: str, hf_token: Optional[str],
**kwargs) -> HFDataset:
"""Load dataset in streaming mode"""
logger.info(f"Loading {dataset_config['name']} in streaming mode")
try:
dataset = load_dataset(
dataset_config['repo_id'],
split=split,
streaming=True,
token=hf_token,
cache_dir=str(self.cache_dir)
)
logger.info(f"Successfully loaded streaming dataset: {dataset_config['name']}")
return dataset
except Exception as e:
logger.error(f"Failed to load streaming dataset: {e}")
raise
async def _load_full_dataset(self, dataset_config: Dict[str, Any],
split: str, hf_token: Optional[str],
**kwargs) -> HFDataset:
"""Load full dataset with memory management"""
logger.info(f"Loading {dataset_config['name']} in full mode")
# Check available memory
memory_info = self.memory_manager.get_memory_info()
estimated_memory_needed_gb = dataset_config['size_gb'] * 1.5 # 50% overhead
if estimated_memory_needed_gb > memory_info['system_memory_available_gb']:
logger.warning(f"Dataset may exceed available memory. Consider streaming mode.")
try:
dataset = load_dataset(
dataset_config['repo_id'],
split=split,
streaming=False,
token=hf_token,
cache_dir=str(self.cache_dir)
)
logger.info(f"Successfully loaded full dataset: {dataset_config['name']}")
return dataset
except Exception as e:
logger.error(f"Failed to load full dataset: {e}")
raise
async def _create_medical_dataloader(self, dataset: HFDataset,
dataset_config: Dict[str, Any],
**kwargs) -> DataLoader:
"""Create optimized DataLoader for medical data"""
batch_size = kwargs.get('batch_size', 4) # Small batch for memory efficiency
num_workers = min(2, os.cpu_count() // 2) # Conservative worker count
# Optimize batch size based on available memory
memory_info = self.memory_manager.get_memory_info()
if memory_info['system_memory_available_gb'] < 4:
batch_size = min(batch_size, 2)
# Create custom collate function for medical data
collate_fn = self._create_medical_collate_fn(dataset_config)
# For streaming datasets, we need a different approach
if hasattr(dataset, 'iter'):
# Streaming dataset
return MedicalStreamingDataLoader(
dataset, batch_size, collate_fn, self.memory_manager
)
else:
# Regular dataset
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=kwargs.get('shuffle', True),
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=False, # CPU only
drop_last=True
)
def _create_medical_collate_fn(self, dataset_config: Dict[str, Any]):
"""Create collate function for medical data"""
def medical_collate_fn(batch):
"""Custom collate function for medical datasets"""
try:
if dataset_config['data_format'] == 'image_text_pairs':
images = []
texts = []
for item in batch:
# Handle image data
if 'image' in item:
image = item['image']
if isinstance(image, Image.Image):
# Convert PIL image to tensor
image_array = np.array(image)
if len(image_array.shape) == 3:
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).float() / 255.0
else:
image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() / 255.0
images.append(image_tensor)
# Handle text data
if 'text' in item or 'caption' in item or 'report' in item:
text = item.get('text', item.get('caption', item.get('report', '')))
texts.append(str(text))
return {
'images': torch.stack(images) if images else None,
'texts': texts,
'batch_size': len(batch)
}
else:
# Generic multimodal handling
return {
'data': batch,
'batch_size': len(batch)
}
except Exception as e:
logger.error(f"Error in collate function: {e}")
# Return minimal batch on error
return {
'data': batch,
'batch_size': len(batch),
'error': str(e)
}
return medical_collate_fn
def get_dataset_info(self, dataset_name: str) -> Dict[str, Any]:
"""Get information about a supported dataset"""
if dataset_name not in self.SUPPORTED_DATASETS:
raise ValueError(f"Unsupported dataset: {dataset_name}")
return self.SUPPORTED_DATASETS[dataset_name].copy()
def list_supported_datasets(self) -> List[Dict[str, Any]]:
"""List all supported medical datasets"""
return [
{
'key': key,
**config
}
for key, config in self.SUPPORTED_DATASETS.items()
]
async def preprocess_medical_batch(self, batch: Dict[str, Any],
dataset_config: Dict[str, Any]) -> Dict[str, Any]:
"""Preprocess medical data batch"""
processed_batch = {}
# Handle images
if 'images' in batch and batch['images'] is not None:
images = batch['images']
# Resize images to standard size for memory efficiency
if images.shape[-1] > 512 or images.shape[-2] > 512:
images = torch.nn.functional.interpolate(
images, size=(512, 512), mode='bilinear', align_corners=False
)
processed_batch['images'] = images
# Handle texts
if 'texts' in batch:
texts = batch['texts']
# Truncate long texts to save memory
max_length = 512
truncated_texts = []
for text in texts:
if len(text) > max_length:
text = text[:max_length] + "..."
truncated_texts.append(text)
processed_batch['texts'] = truncated_texts
processed_batch['batch_size'] = batch.get('batch_size', 0)
return processed_batch
def cleanup_datasets(self):
"""Cleanup loaded datasets to free memory"""
logger.info("Cleaning up medical datasets")
for dataset_name in list(self.loaded_datasets.keys()):
del self.loaded_datasets[dataset_name]
self.loaded_datasets.clear()
self.streaming_datasets.clear()
# Force garbage collection
import gc
gc.collect()
logger.info("Medical datasets cleanup completed")
class MedicalStreamingDataLoader:
"""Custom streaming data loader for medical datasets"""
def __init__(self, dataset, batch_size: int, collate_fn, memory_manager):
self.dataset = dataset
self.batch_size = batch_size
self.collate_fn = collate_fn
self.memory_manager = memory_manager
def __iter__(self):
batch = []
for item in self.dataset:
batch.append(item)
if len(batch) >= self.batch_size:
# Check memory before yielding batch
status = self.memory_manager.check_memory_status()
if status in ['critical', 'emergency']:
self.memory_manager.force_cleanup()
yield self.collate_fn(batch)
batch = []
# Yield remaining items
if batch:
yield self.collate_fn(batch)