Spaces:
Running
Running
""" | |
Medical Data Preprocessing for AI training | |
Optimized for medical images and text with memory constraints | |
""" | |
import logging | |
import numpy as np | |
from typing import Dict, Any, List, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageEnhance, ImageFilter | |
import cv2 | |
import re | |
logger = logging.getLogger(__name__) | |
class MedicalPreprocessor: | |
""" | |
Medical data preprocessor with memory optimization | |
""" | |
def __init__(self, target_size: Tuple[int, int] = (512, 512), | |
normalize_images: bool = True): | |
""" | |
Initialize medical preprocessor | |
Args: | |
target_size: Target size for image resizing | |
normalize_images: Whether to normalize images | |
""" | |
self.target_size = target_size | |
self.normalize_images = normalize_images | |
# Medical text preprocessing patterns | |
self.medical_patterns = { | |
'measurements': r'\d+\.?\d*\s*(mm|cm|m|ml|l|kg|g|mg)', | |
'dates': r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', | |
'times': r'\d{1,2}:\d{2}(?::\d{2})?', | |
'medical_codes': r'[A-Z]\d{2}\.?\d*', | |
'dosages': r'\d+\.?\d*\s*(mg|g|ml|units?)', | |
} | |
# Common medical abbreviations | |
self.medical_abbreviations = { | |
'pt': 'patient', | |
'pts': 'patients', | |
'dx': 'diagnosis', | |
'tx': 'treatment', | |
'hx': 'history', | |
'sx': 'symptoms', | |
'rx': 'prescription', | |
'w/': 'with', | |
'w/o': 'without', | |
'c/o': 'complains of', | |
'r/o': 'rule out', | |
's/p': 'status post', | |
'nkda': 'no known drug allergies', | |
'sob': 'shortness of breath', | |
'cp': 'chest pain', | |
'abd': 'abdomen', | |
'ext': 'extremities' | |
} | |
logger.info(f"Medical Preprocessor initialized with target size {target_size}") | |
def preprocess_medical_image(self, image: torch.Tensor, | |
modality: str = 'unknown', | |
enhance_contrast: bool = True) -> torch.Tensor: | |
""" | |
Preprocess medical image with modality-specific optimizations | |
Args: | |
image: Input image tensor | |
modality: Medical imaging modality (CT, MRI, X-ray, etc.) | |
enhance_contrast: Whether to enhance contrast | |
Returns: | |
Preprocessed image tensor | |
""" | |
try: | |
# Ensure image is float tensor | |
if image.dtype != torch.float32: | |
image = image.float() | |
# Handle different input shapes | |
if len(image.shape) == 2: | |
image = image.unsqueeze(0) # Add channel dimension | |
elif len(image.shape) == 4: | |
image = image.squeeze(0) # Remove batch dimension if present | |
# Resize to target size | |
if image.shape[-2:] != self.target_size: | |
image = F.interpolate( | |
image.unsqueeze(0), | |
size=self.target_size, | |
mode='bilinear', | |
align_corners=False | |
).squeeze(0) | |
# Apply modality-specific preprocessing | |
image = self._apply_modality_specific_processing(image, modality) | |
# Enhance contrast if requested | |
if enhance_contrast: | |
image = self._enhance_medical_image_contrast(image) | |
# Normalize if requested | |
if self.normalize_images: | |
image = self._normalize_medical_image(image) | |
# Ensure proper range [0, 1] | |
image = torch.clamp(image, 0.0, 1.0) | |
return image | |
except Exception as e: | |
logger.error(f"Error preprocessing medical image: {e}") | |
# Return dummy image on error | |
return torch.zeros(1, *self.target_size) | |
def _apply_modality_specific_processing(self, image: torch.Tensor, | |
modality: str) -> torch.Tensor: | |
"""Apply modality-specific image processing""" | |
modality_lower = modality.lower() | |
try: | |
if 'ct' in modality_lower: | |
# CT scan specific processing | |
image = self._process_ct_image(image) | |
elif 'mri' in modality_lower: | |
# MRI specific processing | |
image = self._process_mri_image(image) | |
elif 'xray' in modality_lower or 'x-ray' in modality_lower: | |
# X-ray specific processing | |
image = self._process_xray_image(image) | |
elif 'ultrasound' in modality_lower: | |
# Ultrasound specific processing | |
image = self._process_ultrasound_image(image) | |
return image | |
except Exception as e: | |
logger.warning(f"Error in modality-specific processing for {modality}: {e}") | |
return image | |
def _process_ct_image(self, image: torch.Tensor) -> torch.Tensor: | |
"""Process CT scan images""" | |
# CT images often need windowing adjustments | |
# Apply soft tissue window as default | |
image = torch.clamp(image, 0.0, 1.0) | |
# Enhance contrast for better tissue differentiation | |
image = self._apply_gamma_correction(image, gamma=0.8) | |
return image | |
def _process_mri_image(self, image: torch.Tensor) -> torch.Tensor: | |
"""Process MRI images""" | |
# MRI images often have good contrast already | |
# Apply mild enhancement | |
image = self._apply_gamma_correction(image, gamma=0.9) | |
return image | |
def _process_xray_image(self, image: torch.Tensor) -> torch.Tensor: | |
"""Process X-ray images""" | |
# X-rays often need contrast enhancement | |
image = self._enhance_medical_image_contrast(image, factor=1.2) | |
# Apply histogram equalization equivalent | |
image = self._apply_histogram_equalization(image) | |
return image | |
def _process_ultrasound_image(self, image: torch.Tensor) -> torch.Tensor: | |
"""Process ultrasound images""" | |
# Ultrasound images often need noise reduction | |
image = self._apply_noise_reduction(image) | |
return image | |
def _enhance_medical_image_contrast(self, image: torch.Tensor, | |
factor: float = 1.1) -> torch.Tensor: | |
"""Enhance contrast of medical images""" | |
try: | |
# Apply contrast enhancement | |
mean_val = torch.mean(image) | |
enhanced = (image - mean_val) * factor + mean_val | |
return torch.clamp(enhanced, 0.0, 1.0) | |
except Exception as e: | |
logger.warning(f"Error enhancing contrast: {e}") | |
return image | |
def _apply_gamma_correction(self, image: torch.Tensor, | |
gamma: float = 1.0) -> torch.Tensor: | |
"""Apply gamma correction to image""" | |
try: | |
return torch.pow(image, gamma) | |
except Exception as e: | |
logger.warning(f"Error applying gamma correction: {e}") | |
return image | |
def _apply_histogram_equalization(self, image: torch.Tensor) -> torch.Tensor: | |
"""Apply histogram equalization equivalent""" | |
try: | |
# Convert to numpy for processing | |
image_np = image.squeeze().numpy() | |
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
# Convert to uint8 for CLAHE | |
image_uint8 = (image_np * 255).astype(np.uint8) | |
equalized = clahe.apply(image_uint8) | |
# Convert back to tensor | |
result = torch.from_numpy(equalized.astype(np.float32) / 255.0) | |
# Restore original shape | |
if len(image.shape) == 3: | |
result = result.unsqueeze(0) | |
return result | |
except Exception as e: | |
logger.warning(f"Error applying histogram equalization: {e}") | |
return image | |
def _apply_noise_reduction(self, image: torch.Tensor) -> torch.Tensor: | |
"""Apply noise reduction to image""" | |
try: | |
# Simple Gaussian blur for noise reduction | |
kernel_size = 3 | |
sigma = 0.5 | |
# Create Gaussian kernel | |
kernel = self._create_gaussian_kernel(kernel_size, sigma) | |
kernel = kernel.unsqueeze(0).unsqueeze(0) # Add batch and channel dims | |
# Apply convolution | |
if len(image.shape) == 3: | |
image_input = image.unsqueeze(0) # Add batch dimension | |
else: | |
image_input = image | |
filtered = F.conv2d(image_input, kernel, padding=kernel_size//2) | |
# Remove batch dimension if added | |
if len(image.shape) == 3: | |
filtered = filtered.squeeze(0) | |
return filtered | |
except Exception as e: | |
logger.warning(f"Error applying noise reduction: {e}") | |
return image | |
def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor: | |
"""Create Gaussian kernel for filtering""" | |
coords = torch.arange(kernel_size, dtype=torch.float32) | |
coords -= kernel_size // 2 | |
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) | |
g /= g.sum() | |
# Create 2D kernel | |
kernel = g[:, None] * g[None, :] | |
return kernel | |
def _normalize_medical_image(self, image: torch.Tensor) -> torch.Tensor: | |
"""Normalize medical image""" | |
try: | |
# Z-score normalization per image | |
mean_val = torch.mean(image) | |
std_val = torch.std(image) | |
if std_val > 0: | |
normalized = (image - mean_val) / std_val | |
# Scale to [0, 1] range | |
normalized = (normalized - normalized.min()) / (normalized.max() - normalized.min()) | |
else: | |
normalized = image | |
return normalized | |
except Exception as e: | |
logger.warning(f"Error normalizing image: {e}") | |
return image | |
def preprocess_medical_text(self, text: str, | |
expand_abbreviations: bool = True, | |
remove_phi: bool = True) -> str: | |
""" | |
Preprocess medical text | |
Args: | |
text: Input medical text | |
expand_abbreviations: Whether to expand medical abbreviations | |
remove_phi: Whether to remove potential PHI (Protected Health Information) | |
Returns: | |
Preprocessed text | |
""" | |
try: | |
if not isinstance(text, str): | |
text = str(text) | |
# Convert to lowercase for processing | |
processed_text = text.lower() | |
# Remove potential PHI if requested | |
if remove_phi: | |
processed_text = self._remove_phi(processed_text) | |
# Expand medical abbreviations | |
if expand_abbreviations: | |
processed_text = self._expand_medical_abbreviations(processed_text) | |
# Clean up text | |
processed_text = self._clean_medical_text(processed_text) | |
# Limit length to prevent memory issues | |
max_length = 2048 | |
if len(processed_text) > max_length: | |
processed_text = processed_text[:max_length] + "..." | |
return processed_text | |
except Exception as e: | |
logger.error(f"Error preprocessing medical text: {e}") | |
return text # Return original text on error | |
def _remove_phi(self, text: str) -> str: | |
"""Remove potential Protected Health Information""" | |
# Remove dates | |
text = re.sub(self.medical_patterns['dates'], '[DATE]', text) | |
# Remove times | |
text = re.sub(self.medical_patterns['times'], '[TIME]', text) | |
# Remove phone numbers | |
text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text) | |
# Remove email addresses | |
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text) | |
# Remove potential names (very basic - would need more sophisticated NER in practice) | |
text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text) | |
return text | |
def _expand_medical_abbreviations(self, text: str) -> str: | |
"""Expand common medical abbreviations""" | |
for abbrev, expansion in self.medical_abbreviations.items(): | |
# Use word boundaries to avoid partial matches | |
pattern = r'\b' + re.escape(abbrev) + r'\b' | |
text = re.sub(pattern, expansion, text, flags=re.IGNORECASE) | |
return text | |
def _clean_medical_text(self, text: str) -> str: | |
"""Clean and normalize medical text""" | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# Remove special characters but keep medical-relevant ones | |
text = re.sub(r'[^\w\s\-\.\,\:\;\(\)\/\%]', '', text) | |
# Strip leading/trailing whitespace | |
text = text.strip() | |
return text | |
def batch_preprocess_medical_data(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |
"""Preprocess a batch of medical data""" | |
processed_batch = {} | |
try: | |
# Process images if present | |
if 'images' in batch and batch['images'] is not None: | |
images = batch['images'] | |
processed_images = [] | |
for i, image in enumerate(images): | |
# Get modality if available | |
modality = 'unknown' | |
if 'modalities' in batch and i < len(batch['modalities']): | |
modality = batch['modalities'][i] | |
processed_image = self.preprocess_medical_image(image, modality) | |
processed_images.append(processed_image) | |
processed_batch['images'] = torch.stack(processed_images) | |
# Process texts if present | |
if 'texts' in batch: | |
texts = batch['texts'] | |
processed_texts = [] | |
for text in texts: | |
processed_text = self.preprocess_medical_text(text) | |
processed_texts.append(processed_text) | |
processed_batch['texts'] = processed_texts | |
# Copy other fields | |
for key, value in batch.items(): | |
if key not in ['images', 'texts']: | |
processed_batch[key] = value | |
return processed_batch | |
except Exception as e: | |
logger.error(f"Error in batch preprocessing: {e}") | |
return batch # Return original batch on error | |