train-modle / src /medical /dicom_handler.py
fokan's picture
Initial clean commit: Multi-Modal Knowledge Distillation Platform
ab4e093
"""
DICOM Handler for medical image processing
Optimized for memory-constrained environments
"""
import os
import logging
import numpy as np
from typing import Dict, Any, Optional, Tuple, List
from pathlib import Path
import torch
from PIL import Image
import cv2
logger = logging.getLogger(__name__)
# Try to import medical libraries with fallbacks
try:
import pydicom
PYDICOM_AVAILABLE = True
except ImportError:
PYDICOM_AVAILABLE = False
logger.warning("pydicom not available - DICOM support limited")
try:
import SimpleITK as sitk
SIMPLEITK_AVAILABLE = True
except ImportError:
SIMPLEITK_AVAILABLE = False
logger.warning("SimpleITK not available - advanced medical image processing limited")
class DicomHandler:
"""
DICOM file handler with memory optimization
"""
def __init__(self, memory_limit_mb: float = 1000.0):
"""
Initialize DICOM handler
Args:
memory_limit_mb: Memory limit for DICOM processing in MB
"""
self.memory_limit_mb = memory_limit_mb
self.memory_limit_bytes = memory_limit_mb * 1024**2
# Default DICOM processing settings
self.default_window_center = 40
self.default_window_width = 400
self.default_output_size = (512, 512)
logger.info(f"DICOM Handler initialized with {memory_limit_mb}MB limit")
logger.info(f"pydicom available: {PYDICOM_AVAILABLE}")
logger.info(f"SimpleITK available: {SIMPLEITK_AVAILABLE}")
def read_dicom_file(self, file_path: str) -> Optional[Dict[str, Any]]:
"""
Read DICOM file and extract image data and metadata
Args:
file_path: Path to DICOM file
Returns:
Dictionary containing image data and metadata
"""
if not PYDICOM_AVAILABLE:
logger.error("pydicom not available - cannot read DICOM files")
return None
try:
file_path = Path(file_path)
if not file_path.exists():
logger.error(f"DICOM file not found: {file_path}")
return None
# Check file size
file_size_mb = file_path.stat().st_size / (1024**2)
if file_size_mb > self.memory_limit_mb:
logger.warning(f"DICOM file too large: {file_size_mb:.1f}MB > {self.memory_limit_mb}MB")
return self._read_large_dicom_file(file_path)
# Read DICOM file
dicom_data = pydicom.dcmread(str(file_path))
# Extract image data
image_array = dicom_data.pixel_array
# Extract metadata
metadata = self._extract_dicom_metadata(dicom_data)
# Process image
processed_image = self._process_dicom_image(image_array, metadata)
return {
'image': processed_image,
'metadata': metadata,
'original_shape': image_array.shape,
'file_path': str(file_path),
'file_size_mb': file_size_mb
}
except Exception as e:
logger.error(f"Error reading DICOM file {file_path}: {e}")
return None
def _read_large_dicom_file(self, file_path: Path) -> Optional[Dict[str, Any]]:
"""Read large DICOM file with memory optimization"""
try:
# Read only metadata first
dicom_data = pydicom.dcmread(str(file_path), stop_before_pixels=True)
metadata = self._extract_dicom_metadata(dicom_data)
# Read image data in chunks if possible
if SIMPLEITK_AVAILABLE:
return self._read_dicom_with_sitk(file_path, metadata)
else:
# Fallback: read with reduced resolution
dicom_data = pydicom.dcmread(str(file_path))
image_array = dicom_data.pixel_array
# Downsample if too large
if image_array.nbytes > self.memory_limit_bytes:
scale_factor = np.sqrt(self.memory_limit_bytes / image_array.nbytes)
new_shape = (int(image_array.shape[0] * scale_factor),
int(image_array.shape[1] * scale_factor))
image_array = cv2.resize(image_array, new_shape)
logger.info(f"Downsampled DICOM image to {new_shape}")
processed_image = self._process_dicom_image(image_array, metadata)
return {
'image': processed_image,
'metadata': metadata,
'original_shape': dicom_data.pixel_array.shape,
'file_path': str(file_path),
'downsampled': True
}
except Exception as e:
logger.error(f"Error reading large DICOM file: {e}")
return None
def _read_dicom_with_sitk(self, file_path: Path, metadata: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Read DICOM using SimpleITK for better memory management"""
try:
# Read with SimpleITK
image = sitk.ReadImage(str(file_path))
image_array = sitk.GetArrayFromImage(image)
# Process image
processed_image = self._process_dicom_image(image_array, metadata)
return {
'image': processed_image,
'metadata': metadata,
'original_shape': image_array.shape,
'file_path': str(file_path),
'reader': 'SimpleITK'
}
except Exception as e:
logger.error(f"Error reading DICOM with SimpleITK: {e}")
return None
def _extract_dicom_metadata(self, dicom_data) -> Dict[str, Any]:
"""Extract relevant metadata from DICOM data"""
metadata = {}
try:
# Patient information
metadata['patient_id'] = getattr(dicom_data, 'PatientID', 'Unknown')
metadata['patient_age'] = getattr(dicom_data, 'PatientAge', 'Unknown')
metadata['patient_sex'] = getattr(dicom_data, 'PatientSex', 'Unknown')
# Study information
metadata['study_date'] = getattr(dicom_data, 'StudyDate', 'Unknown')
metadata['study_description'] = getattr(dicom_data, 'StudyDescription', 'Unknown')
metadata['modality'] = getattr(dicom_data, 'Modality', 'Unknown')
# Image information
metadata['rows'] = getattr(dicom_data, 'Rows', 0)
metadata['columns'] = getattr(dicom_data, 'Columns', 0)
metadata['pixel_spacing'] = getattr(dicom_data, 'PixelSpacing', [1.0, 1.0])
metadata['slice_thickness'] = getattr(dicom_data, 'SliceThickness', 1.0)
# Window/Level information for display
metadata['window_center'] = getattr(dicom_data, 'WindowCenter', self.default_window_center)
metadata['window_width'] = getattr(dicom_data, 'WindowWidth', self.default_window_width)
# Ensure window values are scalars
if isinstance(metadata['window_center'], (list, tuple)):
metadata['window_center'] = metadata['window_center'][0]
if isinstance(metadata['window_width'], (list, tuple)):
metadata['window_width'] = metadata['window_width'][0]
except Exception as e:
logger.warning(f"Error extracting DICOM metadata: {e}")
return metadata
def _process_dicom_image(self, image_array: np.ndarray,
metadata: Dict[str, Any]) -> torch.Tensor:
"""Process DICOM image array to tensor"""
try:
# Handle different image dimensions
if len(image_array.shape) == 3:
# 3D volume - take middle slice for 2D processing
middle_slice = image_array.shape[0] // 2
image_array = image_array[middle_slice]
# Apply windowing for better contrast
window_center = metadata.get('window_center', self.default_window_center)
window_width = metadata.get('window_width', self.default_window_width)
image_array = self._apply_windowing(image_array, window_center, window_width)
# Normalize to 0-1 range
image_array = self._normalize_image(image_array)
# Resize to standard size
if image_array.shape != self.default_output_size:
image_array = cv2.resize(image_array, self.default_output_size)
# Convert to tensor
image_tensor = torch.from_numpy(image_array).float()
# Add channel dimension if needed
if len(image_tensor.shape) == 2:
image_tensor = image_tensor.unsqueeze(0) # Add channel dimension
return image_tensor
except Exception as e:
logger.error(f"Error processing DICOM image: {e}")
# Return dummy tensor on error
return torch.zeros(1, *self.default_output_size)
def _apply_windowing(self, image_array: np.ndarray,
window_center: float, window_width: float) -> np.ndarray:
"""Apply windowing to DICOM image for better contrast"""
try:
window_min = window_center - window_width / 2
window_max = window_center + window_width / 2
# Apply windowing
windowed_image = np.clip(image_array, window_min, window_max)
return windowed_image
except Exception as e:
logger.warning(f"Error applying windowing: {e}")
return image_array
def _normalize_image(self, image_array: np.ndarray) -> np.ndarray:
"""Normalize image to 0-1 range"""
try:
# Handle different data types
if image_array.dtype == np.uint8:
return image_array.astype(np.float32) / 255.0
elif image_array.dtype == np.uint16:
return image_array.astype(np.float32) / 65535.0
else:
# For other types, normalize to min-max
img_min = image_array.min()
img_max = image_array.max()
if img_max > img_min:
return (image_array - img_min) / (img_max - img_min)
else:
return np.zeros_like(image_array, dtype=np.float32)
except Exception as e:
logger.warning(f"Error normalizing image: {e}")
return image_array.astype(np.float32)
def batch_process_dicom_files(self, file_paths: List[str]) -> List[Dict[str, Any]]:
"""Process multiple DICOM files with memory management"""
results = []
for i, file_path in enumerate(file_paths):
logger.info(f"Processing DICOM file {i+1}/{len(file_paths)}: {file_path}")
result = self.read_dicom_file(file_path)
if result:
results.append(result)
# Memory cleanup every 10 files
if (i + 1) % 10 == 0:
import gc
gc.collect()
logger.debug(f"Memory cleanup after {i+1} files")
return results
def convert_dicom_to_standard_format(self, dicom_result: Dict[str, Any],
output_format: str = 'png') -> Optional[str]:
"""Convert processed DICOM to standard image format"""
try:
image_tensor = dicom_result['image']
# Convert tensor to numpy
if isinstance(image_tensor, torch.Tensor):
image_array = image_tensor.squeeze().numpy()
else:
image_array = image_tensor
# Convert to 8-bit
image_8bit = (image_array * 255).astype(np.uint8)
# Create PIL image
pil_image = Image.fromarray(image_8bit, mode='L') # Grayscale
# Generate output filename
input_path = Path(dicom_result['file_path'])
output_path = input_path.with_suffix(f'.{output_format}')
# Save image
pil_image.save(output_path)
logger.info(f"Converted DICOM to {output_format}: {output_path}")
return str(output_path)
except Exception as e:
logger.error(f"Error converting DICOM to {output_format}: {e}")
return None
def get_dicom_statistics(self, dicom_results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Get statistics from processed DICOM files"""
if not dicom_results:
return {}
try:
modalities = [r['metadata'].get('modality', 'Unknown') for r in dicom_results]
file_sizes = [r.get('file_size_mb', 0) for r in dicom_results]
stats = {
'total_files': len(dicom_results),
'modalities': list(set(modalities)),
'modality_counts': {mod: modalities.count(mod) for mod in set(modalities)},
'total_size_mb': sum(file_sizes),
'average_size_mb': np.mean(file_sizes) if file_sizes else 0,
'size_range_mb': (min(file_sizes), max(file_sizes)) if file_sizes else (0, 0)
}
return stats
except Exception as e:
logger.error(f"Error calculating DICOM statistics: {e}")
return {}