Spaces:
Running
Running
""" | |
Advanced Memory Manager for CPU-only training with 16GB RAM constraint | |
Optimized for Hugging Face Spaces free tier | |
""" | |
import os | |
import gc | |
import psutil | |
import logging | |
import threading | |
import time | |
from typing import Dict, Any, Optional, List, Callable | |
from pathlib import Path | |
import torch | |
import numpy as np | |
from contextlib import contextmanager | |
logger = logging.getLogger(__name__) | |
class AdvancedMemoryManager: | |
""" | |
Advanced memory management for CPU-only training with strict memory constraints | |
""" | |
def __init__(self, max_memory_gb: float = 14.0): | |
""" | |
Initialize memory manager | |
Args: | |
max_memory_gb: Maximum memory usage in GB (default 14GB for 16GB systems) | |
""" | |
self.max_memory_bytes = max_memory_gb * 1024**3 | |
self.current_memory_usage = 0 | |
self.memory_threshold_warning = 0.8 # 80% warning | |
self.memory_threshold_critical = 0.9 # 90% critical | |
self.memory_threshold_emergency = 0.95 # 95% emergency cleanup | |
# Memory tracking | |
self.allocated_objects = {} | |
self.memory_history = [] | |
self.cleanup_callbacks = [] | |
# Threading for monitoring | |
self.monitoring_active = False | |
self.monitor_thread = None | |
# CPU optimization | |
self.cpu_count = os.cpu_count() | |
torch.set_num_threads(min(self.cpu_count, 8)) # Limit threads for stability | |
logger.info(f"Memory Manager initialized with {max_memory_gb}GB limit") | |
logger.info(f"CPU threads set to: {torch.get_num_threads()}") | |
def get_memory_info(self) -> Dict[str, Any]: | |
"""Get current memory information""" | |
process = psutil.Process() | |
memory_info = process.memory_info() | |
system_memory = psutil.virtual_memory() | |
return { | |
'process_memory_mb': memory_info.rss / 1024**2, | |
'process_memory_percent': (memory_info.rss / system_memory.total) * 100, | |
'system_memory_total_gb': system_memory.total / 1024**3, | |
'system_memory_available_gb': system_memory.available / 1024**3, | |
'system_memory_percent': system_memory.percent, | |
'max_allowed_gb': self.max_memory_bytes / 1024**3, | |
'torch_allocated_mb': torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0, | |
'torch_cached_mb': torch.cuda.memory_reserved() / 1024**2 if torch.cuda.is_available() else 0 | |
} | |
def check_memory_status(self) -> str: | |
"""Check current memory status""" | |
memory_info = self.get_memory_info() | |
usage_ratio = memory_info['process_memory_mb'] * 1024**2 / self.max_memory_bytes | |
if usage_ratio >= self.memory_threshold_emergency: | |
return 'emergency' | |
elif usage_ratio >= self.memory_threshold_critical: | |
return 'critical' | |
elif usage_ratio >= self.memory_threshold_warning: | |
return 'warning' | |
else: | |
return 'normal' | |
def force_cleanup(self): | |
"""Force aggressive memory cleanup""" | |
logger.warning("Performing emergency memory cleanup") | |
# Clear Python garbage | |
collected = gc.collect() | |
logger.info(f"Garbage collection freed {collected} objects") | |
# Clear PyTorch cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Run cleanup callbacks | |
for callback in self.cleanup_callbacks: | |
try: | |
callback() | |
except Exception as e: | |
logger.error(f"Cleanup callback failed: {e}") | |
# Force another garbage collection | |
gc.collect() | |
memory_info = self.get_memory_info() | |
logger.info(f"Memory after cleanup: {memory_info['process_memory_mb']:.1f}MB") | |
def memory_context(self, operation_name: str, expected_memory_mb: float = 0): | |
"""Context manager for memory-aware operations""" | |
start_memory = self.get_memory_info() | |
logger.debug(f"Starting {operation_name}, memory: {start_memory['process_memory_mb']:.1f}MB") | |
# Check if we have enough memory | |
if expected_memory_mb > 0: | |
available_mb = (self.max_memory_bytes / 1024**2) - start_memory['process_memory_mb'] | |
if expected_memory_mb > available_mb * 0.8: # 80% safety margin | |
logger.warning(f"Operation {operation_name} may exceed memory limit") | |
self.force_cleanup() | |
try: | |
yield self | |
finally: | |
end_memory = self.get_memory_info() | |
memory_diff = end_memory['process_memory_mb'] - start_memory['process_memory_mb'] | |
logger.debug(f"Completed {operation_name}, memory change: {memory_diff:+.1f}MB") | |
# Check if cleanup is needed | |
status = self.check_memory_status() | |
if status in ['critical', 'emergency']: | |
self.force_cleanup() | |
def register_cleanup_callback(self, callback: Callable): | |
"""Register a cleanup callback function""" | |
self.cleanup_callbacks.append(callback) | |
def start_monitoring(self, interval_seconds: float = 30.0): | |
"""Start memory monitoring thread""" | |
if self.monitoring_active: | |
return | |
self.monitoring_active = True | |
self.monitor_thread = threading.Thread( | |
target=self._monitor_memory, | |
args=(interval_seconds,), | |
daemon=True | |
) | |
self.monitor_thread.start() | |
logger.info("Memory monitoring started") | |
def stop_monitoring(self): | |
"""Stop memory monitoring""" | |
self.monitoring_active = False | |
if self.monitor_thread: | |
self.monitor_thread.join(timeout=5.0) | |
logger.info("Memory monitoring stopped") | |
def _monitor_memory(self, interval_seconds: float): | |
"""Internal memory monitoring loop""" | |
while self.monitoring_active: | |
try: | |
memory_info = self.get_memory_info() | |
status = self.check_memory_status() | |
# Log memory status | |
if status != 'normal': | |
logger.warning(f"Memory status: {status}, usage: {memory_info['process_memory_mb']:.1f}MB") | |
# Auto cleanup if needed | |
if status == 'emergency': | |
self.force_cleanup() | |
elif status == 'critical': | |
gc.collect() | |
# Store history | |
self.memory_history.append({ | |
'timestamp': time.time(), | |
'memory_mb': memory_info['process_memory_mb'], | |
'status': status | |
}) | |
# Keep only last 100 entries | |
if len(self.memory_history) > 100: | |
self.memory_history = self.memory_history[-100:] | |
time.sleep(interval_seconds) | |
except Exception as e: | |
logger.error(f"Memory monitoring error: {e}") | |
time.sleep(interval_seconds) | |
def get_memory_recommendations(self) -> List[str]: | |
"""Get memory optimization recommendations""" | |
memory_info = self.get_memory_info() | |
recommendations = [] | |
if memory_info['process_memory_mb'] > 8000: # > 8GB | |
recommendations.append("Consider using smaller batch sizes") | |
recommendations.append("Enable gradient checkpointing") | |
recommendations.append("Use model sharding for large models") | |
if memory_info['system_memory_percent'] > 80: | |
recommendations.append("Close unnecessary applications") | |
recommendations.append("Consider using swap memory") | |
if len(self.memory_history) > 10: | |
recent_growth = self.memory_history[-1]['memory_mb'] - self.memory_history[-10]['memory_mb'] | |
if recent_growth > 1000: # > 1GB growth | |
recommendations.append("Memory usage is growing rapidly - check for memory leaks") | |
return recommendations | |
def optimize_torch_settings(self): | |
"""Optimize PyTorch settings for CPU and memory efficiency""" | |
# Set optimal thread count | |
torch.set_num_threads(min(self.cpu_count, 8)) | |
# Enable memory efficient attention if available | |
try: | |
torch.backends.cuda.enable_flash_sdp(False) # Disable for CPU | |
torch.backends.cuda.enable_math_sdp(True) | |
torch.backends.cuda.enable_mem_efficient_sdp(True) | |
except: | |
pass | |
# Set memory allocation strategy | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' | |
logger.info("PyTorch settings optimized for CPU and memory efficiency") | |
def __enter__(self): | |
self.start_monitoring() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.stop_monitoring() | |
self.force_cleanup() | |