Spaces:
Running
Running
""" | |
Advanced Chunk Loader for large models with memory constraints | |
Optimized for CPU-only training on 16GB RAM systems | |
""" | |
import os | |
import gc | |
import mmap | |
import logging | |
import asyncio | |
from typing import Dict, Any, List, Optional, Iterator, Union | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModel, AutoConfig, AutoTokenizer | |
from safetensors import safe_open | |
import numpy as np | |
from .memory_manager import AdvancedMemoryManager | |
logger = logging.getLogger(__name__) | |
class ModelChunk: | |
"""Represents a chunk of a large model""" | |
def __init__(self, chunk_id: str, parameters: Dict[str, torch.Tensor], | |
metadata: Dict[str, Any]): | |
self.chunk_id = chunk_id | |
self.parameters = parameters | |
self.metadata = metadata | |
self.is_loaded = True | |
self.memory_size_mb = sum(p.numel() * p.element_size() for p in parameters.values()) / 1024**2 | |
def unload(self): | |
"""Unload chunk from memory""" | |
if self.is_loaded: | |
del self.parameters | |
self.parameters = {} | |
self.is_loaded = False | |
gc.collect() | |
logger.debug(f"Unloaded chunk {self.chunk_id}") | |
def __del__(self): | |
if hasattr(self, 'is_loaded') and self.is_loaded: | |
self.unload() | |
class AdvancedChunkLoader: | |
""" | |
Advanced chunk loader for handling large models with memory constraints | |
""" | |
def __init__(self, memory_manager: AdvancedMemoryManager, | |
chunk_size_mb: float = 500.0): | |
""" | |
Initialize chunk loader | |
Args: | |
memory_manager: Memory manager instance | |
chunk_size_mb: Target size for each chunk in MB | |
""" | |
self.memory_manager = memory_manager | |
self.chunk_size_mb = chunk_size_mb | |
self.chunk_size_bytes = chunk_size_mb * 1024**2 | |
self.loaded_chunks = {} | |
self.chunk_cache = {} | |
self.max_cached_chunks = 3 | |
# Register cleanup callback | |
self.memory_manager.register_cleanup_callback(self._cleanup_chunks) | |
logger.info(f"Chunk loader initialized with {chunk_size_mb}MB chunks") | |
async def load_model_in_chunks(self, model_path: str, **kwargs) -> Dict[str, Any]: | |
""" | |
Load a large model in chunks | |
Args: | |
model_path: Path to model (local or HF repo) | |
**kwargs: Additional loading parameters | |
Returns: | |
Model metadata and chunk information | |
""" | |
with self.memory_manager.memory_context("load_model_in_chunks"): | |
logger.info(f"Loading model in chunks: {model_path}") | |
# First, get model config and size estimation | |
config = await self._load_model_config(model_path, **kwargs) | |
estimated_size_mb = self._estimate_model_size(config) | |
logger.info(f"Estimated model size: {estimated_size_mb:.1f}MB") | |
if estimated_size_mb <= self.chunk_size_mb * 2: | |
# Small model, load normally | |
return await self._load_small_model(model_path, config, **kwargs) | |
else: | |
# Large model, use chunking | |
return await self._load_large_model_chunked(model_path, config, **kwargs) | |
async def _load_model_config(self, model_path: str, **kwargs) -> AutoConfig: | |
"""Load model configuration""" | |
try: | |
hf_token = kwargs.get('token') or os.getenv('HF_TOKEN') | |
trust_remote_code = kwargs.get('trust_remote_code', False) | |
config = AutoConfig.from_pretrained( | |
model_path, | |
trust_remote_code=trust_remote_code, | |
token=hf_token, | |
timeout=30 | |
) | |
return config | |
except Exception as e: | |
logger.error(f"Failed to load config for {model_path}: {e}") | |
raise | |
def _estimate_model_size(self, config: AutoConfig) -> float: | |
"""Estimate model size in MB""" | |
try: | |
# Get basic parameters | |
hidden_size = getattr(config, 'hidden_size', 768) | |
num_layers = getattr(config, 'num_hidden_layers', | |
getattr(config, 'num_layers', 12)) | |
vocab_size = getattr(config, 'vocab_size', 50000) | |
# Rough estimation for transformer models | |
embedding_params = vocab_size * hidden_size | |
layer_params = num_layers * (hidden_size * hidden_size * 4) # Simplified | |
total_params = embedding_params + layer_params | |
# Convert to MB (4 bytes per parameter for float32) | |
size_mb = (total_params * 4) / (1024 ** 2) | |
return max(size_mb, 100) # Minimum 100MB | |
except Exception: | |
return 2000 # Default 2GB if estimation fails | |
async def _load_small_model(self, model_path: str, config: AutoConfig, | |
**kwargs) -> Dict[str, Any]: | |
"""Load small model normally""" | |
logger.info(f"Loading small model normally: {model_path}") | |
hf_token = kwargs.get('token') or os.getenv('HF_TOKEN') | |
trust_remote_code = kwargs.get('trust_remote_code', False) | |
try: | |
# Load model with CPU optimization | |
model = AutoModel.from_pretrained( | |
model_path, | |
config=config, | |
torch_dtype=torch.float32, | |
trust_remote_code=trust_remote_code, | |
token=hf_token, | |
low_cpu_mem_usage=True, | |
device_map='cpu' | |
) | |
# Load tokenizer/processor | |
tokenizer = None | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
token=hf_token, | |
trust_remote_code=trust_remote_code | |
) | |
except: | |
logger.warning(f"Could not load tokenizer for {model_path}") | |
return { | |
'model': model, | |
'tokenizer': tokenizer, | |
'config': config, | |
'is_chunked': False, | |
'source': model_path, | |
'estimated_size_mb': self._estimate_model_size(config) | |
} | |
except Exception as e: | |
logger.error(f"Failed to load small model {model_path}: {e}") | |
raise | |
async def _load_large_model_chunked(self, model_path: str, config: AutoConfig, | |
**kwargs) -> Dict[str, Any]: | |
"""Load large model using chunking strategy""" | |
logger.info(f"Loading large model with chunking: {model_path}") | |
# Create chunks metadata | |
chunks_info = await self._create_chunks_metadata(model_path, config, **kwargs) | |
# Load first chunk to get model structure | |
first_chunk = await self._load_chunk(model_path, chunks_info[0], **kwargs) | |
return { | |
'model': None, # No single model object for chunked models | |
'chunks_info': chunks_info, | |
'first_chunk': first_chunk, | |
'config': config, | |
'is_chunked': True, | |
'source': model_path, | |
'total_chunks': len(chunks_info), | |
'estimated_size_mb': self._estimate_model_size(config) | |
} | |
async def _create_chunks_metadata(self, model_path: str, config: AutoConfig, | |
**kwargs) -> List[Dict[str, Any]]: | |
"""Create metadata for model chunks""" | |
# This is a simplified chunking strategy | |
# In practice, you'd analyze the model structure more carefully | |
estimated_size_mb = self._estimate_model_size(config) | |
num_chunks = max(1, int(estimated_size_mb / self.chunk_size_mb)) | |
chunks_info = [] | |
for i in range(num_chunks): | |
chunk_info = { | |
'chunk_id': f"chunk_{i}", | |
'start_layer': i * (config.num_hidden_layers // num_chunks), | |
'end_layer': min((i + 1) * (config.num_hidden_layers // num_chunks), | |
config.num_hidden_layers), | |
'estimated_size_mb': estimated_size_mb / num_chunks, | |
'parameters': [] # Will be populated during loading | |
} | |
chunks_info.append(chunk_info) | |
return chunks_info | |
async def _load_chunk(self, model_path: str, chunk_info: Dict[str, Any], | |
**kwargs) -> ModelChunk: | |
"""Load a specific chunk of the model""" | |
chunk_id = chunk_info['chunk_id'] | |
with self.memory_manager.memory_context(f"load_chunk_{chunk_id}"): | |
logger.debug(f"Loading chunk {chunk_id}") | |
# For now, this is a placeholder implementation | |
# In practice, you'd implement layer-wise loading | |
parameters = {} | |
# Create dummy parameters for demonstration | |
# Replace with actual chunk loading logic | |
hidden_size = getattr(kwargs.get('config', {}), 'hidden_size', 768) | |
chunk_params = torch.randn(hidden_size, hidden_size) * 0.02 | |
parameters[f'{chunk_id}_weight'] = chunk_params | |
metadata = { | |
'chunk_id': chunk_id, | |
'layer_range': (chunk_info['start_layer'], chunk_info['end_layer']), | |
'parameter_count': sum(p.numel() for p in parameters.values()) | |
} | |
chunk = ModelChunk(chunk_id, parameters, metadata) | |
self.loaded_chunks[chunk_id] = chunk | |
# Manage cache | |
await self._manage_chunk_cache() | |
return chunk | |
async def _manage_chunk_cache(self): | |
"""Manage chunk cache to prevent memory overflow""" | |
if len(self.loaded_chunks) > self.max_cached_chunks: | |
# Remove oldest chunks | |
chunks_to_remove = list(self.loaded_chunks.keys())[:-self.max_cached_chunks] | |
for chunk_id in chunks_to_remove: | |
chunk = self.loaded_chunks.pop(chunk_id) | |
chunk.unload() | |
logger.debug(f"Removed chunk {chunk_id} from cache") | |
def _cleanup_chunks(self): | |
"""Cleanup callback for memory manager""" | |
logger.info("Cleaning up loaded chunks") | |
for chunk in self.loaded_chunks.values(): | |
chunk.unload() | |
self.loaded_chunks.clear() | |
gc.collect() | |
async def get_chunk_iterator(self, model_info: Dict[str, Any]) -> Iterator[ModelChunk]: | |
"""Get iterator for model chunks""" | |
if not model_info.get('is_chunked', False): | |
# Not a chunked model | |
yield model_info['model'] | |
return | |
chunks_info = model_info['chunks_info'] | |
model_path = model_info['source'] | |
for chunk_info in chunks_info: | |
chunk = await self._load_chunk(model_path, chunk_info) | |
yield chunk | |
# Optionally unload chunk after yielding | |
# chunk.unload() | |
def get_memory_usage(self) -> Dict[str, float]: | |
"""Get current memory usage of loaded chunks""" | |
total_memory_mb = sum(chunk.memory_size_mb for chunk in self.loaded_chunks.values()) | |
return { | |
'total_chunks_memory_mb': total_memory_mb, | |
'loaded_chunks_count': len(self.loaded_chunks), | |
'average_chunk_size_mb': total_memory_mb / len(self.loaded_chunks) if self.loaded_chunks else 0 | |
} | |