File size: 11,790 Bytes
ab4e093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Advanced Chunk Loader for large models with memory constraints
Optimized for CPU-only training on 16GB RAM systems
"""

import os
import gc
import mmap
import logging
import asyncio
from typing import Dict, Any, List, Optional, Iterator, Union
from pathlib import Path
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, AutoTokenizer
from safetensors import safe_open
import numpy as np
from .memory_manager import AdvancedMemoryManager

logger = logging.getLogger(__name__)

class ModelChunk:
    """Represents a chunk of a large model"""
    
    def __init__(self, chunk_id: str, parameters: Dict[str, torch.Tensor], 
                 metadata: Dict[str, Any]):
        self.chunk_id = chunk_id
        self.parameters = parameters
        self.metadata = metadata
        self.is_loaded = True
        self.memory_size_mb = sum(p.numel() * p.element_size() for p in parameters.values()) / 1024**2
    
    def unload(self):
        """Unload chunk from memory"""
        if self.is_loaded:
            del self.parameters
            self.parameters = {}
            self.is_loaded = False
            gc.collect()
            logger.debug(f"Unloaded chunk {self.chunk_id}")
    
    def __del__(self):
        if hasattr(self, 'is_loaded') and self.is_loaded:
            self.unload()

class AdvancedChunkLoader:
    """
    Advanced chunk loader for handling large models with memory constraints
    """
    
    def __init__(self, memory_manager: AdvancedMemoryManager, 
                 chunk_size_mb: float = 500.0):
        """
        Initialize chunk loader
        
        Args:
            memory_manager: Memory manager instance
            chunk_size_mb: Target size for each chunk in MB
        """
        self.memory_manager = memory_manager
        self.chunk_size_mb = chunk_size_mb
        self.chunk_size_bytes = chunk_size_mb * 1024**2
        self.loaded_chunks = {}
        self.chunk_cache = {}
        self.max_cached_chunks = 3
        
        # Register cleanup callback
        self.memory_manager.register_cleanup_callback(self._cleanup_chunks)
        
        logger.info(f"Chunk loader initialized with {chunk_size_mb}MB chunks")
    
    async def load_model_in_chunks(self, model_path: str, **kwargs) -> Dict[str, Any]:
        """
        Load a large model in chunks
        
        Args:
            model_path: Path to model (local or HF repo)
            **kwargs: Additional loading parameters
            
        Returns:
            Model metadata and chunk information
        """
        with self.memory_manager.memory_context("load_model_in_chunks"):
            logger.info(f"Loading model in chunks: {model_path}")
            
            # First, get model config and size estimation
            config = await self._load_model_config(model_path, **kwargs)
            estimated_size_mb = self._estimate_model_size(config)
            
            logger.info(f"Estimated model size: {estimated_size_mb:.1f}MB")
            
            if estimated_size_mb <= self.chunk_size_mb * 2:
                # Small model, load normally
                return await self._load_small_model(model_path, config, **kwargs)
            else:
                # Large model, use chunking
                return await self._load_large_model_chunked(model_path, config, **kwargs)
    
    async def _load_model_config(self, model_path: str, **kwargs) -> AutoConfig:
        """Load model configuration"""
        try:
            hf_token = kwargs.get('token') or os.getenv('HF_TOKEN')
            trust_remote_code = kwargs.get('trust_remote_code', False)
            
            config = AutoConfig.from_pretrained(
                model_path,
                trust_remote_code=trust_remote_code,
                token=hf_token,
                timeout=30
            )
            return config
        except Exception as e:
            logger.error(f"Failed to load config for {model_path}: {e}")
            raise
    
    def _estimate_model_size(self, config: AutoConfig) -> float:
        """Estimate model size in MB"""
        try:
            # Get basic parameters
            hidden_size = getattr(config, 'hidden_size', 768)
            num_layers = getattr(config, 'num_hidden_layers', 
                               getattr(config, 'num_layers', 12))
            vocab_size = getattr(config, 'vocab_size', 50000)
            
            # Rough estimation for transformer models
            embedding_params = vocab_size * hidden_size
            layer_params = num_layers * (hidden_size * hidden_size * 4)  # Simplified
            total_params = embedding_params + layer_params
            
            # Convert to MB (4 bytes per parameter for float32)
            size_mb = (total_params * 4) / (1024 ** 2)
            
            return max(size_mb, 100)  # Minimum 100MB
        except Exception:
            return 2000  # Default 2GB if estimation fails
    
    async def _load_small_model(self, model_path: str, config: AutoConfig, 
                               **kwargs) -> Dict[str, Any]:
        """Load small model normally"""
        logger.info(f"Loading small model normally: {model_path}")
        
        hf_token = kwargs.get('token') or os.getenv('HF_TOKEN')
        trust_remote_code = kwargs.get('trust_remote_code', False)
        
        try:
            # Load model with CPU optimization
            model = AutoModel.from_pretrained(
                model_path,
                config=config,
                torch_dtype=torch.float32,
                trust_remote_code=trust_remote_code,
                token=hf_token,
                low_cpu_mem_usage=True,
                device_map='cpu'
            )
            
            # Load tokenizer/processor
            tokenizer = None
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    model_path, 
                    token=hf_token,
                    trust_remote_code=trust_remote_code
                )
            except:
                logger.warning(f"Could not load tokenizer for {model_path}")
            
            return {
                'model': model,
                'tokenizer': tokenizer,
                'config': config,
                'is_chunked': False,
                'source': model_path,
                'estimated_size_mb': self._estimate_model_size(config)
            }
            
        except Exception as e:
            logger.error(f"Failed to load small model {model_path}: {e}")
            raise
    
    async def _load_large_model_chunked(self, model_path: str, config: AutoConfig,
                                       **kwargs) -> Dict[str, Any]:
        """Load large model using chunking strategy"""
        logger.info(f"Loading large model with chunking: {model_path}")
        
        # Create chunks metadata
        chunks_info = await self._create_chunks_metadata(model_path, config, **kwargs)
        
        # Load first chunk to get model structure
        first_chunk = await self._load_chunk(model_path, chunks_info[0], **kwargs)
        
        return {
            'model': None,  # No single model object for chunked models
            'chunks_info': chunks_info,
            'first_chunk': first_chunk,
            'config': config,
            'is_chunked': True,
            'source': model_path,
            'total_chunks': len(chunks_info),
            'estimated_size_mb': self._estimate_model_size(config)
        }
    
    async def _create_chunks_metadata(self, model_path: str, config: AutoConfig,
                                     **kwargs) -> List[Dict[str, Any]]:
        """Create metadata for model chunks"""
        # This is a simplified chunking strategy
        # In practice, you'd analyze the model structure more carefully
        
        estimated_size_mb = self._estimate_model_size(config)
        num_chunks = max(1, int(estimated_size_mb / self.chunk_size_mb))
        
        chunks_info = []
        for i in range(num_chunks):
            chunk_info = {
                'chunk_id': f"chunk_{i}",
                'start_layer': i * (config.num_hidden_layers // num_chunks),
                'end_layer': min((i + 1) * (config.num_hidden_layers // num_chunks), 
                               config.num_hidden_layers),
                'estimated_size_mb': estimated_size_mb / num_chunks,
                'parameters': []  # Will be populated during loading
            }
            chunks_info.append(chunk_info)
        
        return chunks_info
    
    async def _load_chunk(self, model_path: str, chunk_info: Dict[str, Any],
                         **kwargs) -> ModelChunk:
        """Load a specific chunk of the model"""
        chunk_id = chunk_info['chunk_id']
        
        with self.memory_manager.memory_context(f"load_chunk_{chunk_id}"):
            logger.debug(f"Loading chunk {chunk_id}")
            
            # For now, this is a placeholder implementation
            # In practice, you'd implement layer-wise loading
            parameters = {}
            
            # Create dummy parameters for demonstration
            # Replace with actual chunk loading logic
            hidden_size = getattr(kwargs.get('config', {}), 'hidden_size', 768)
            chunk_params = torch.randn(hidden_size, hidden_size) * 0.02
            parameters[f'{chunk_id}_weight'] = chunk_params
            
            metadata = {
                'chunk_id': chunk_id,
                'layer_range': (chunk_info['start_layer'], chunk_info['end_layer']),
                'parameter_count': sum(p.numel() for p in parameters.values())
            }
            
            chunk = ModelChunk(chunk_id, parameters, metadata)
            self.loaded_chunks[chunk_id] = chunk
            
            # Manage cache
            await self._manage_chunk_cache()
            
            return chunk
    
    async def _manage_chunk_cache(self):
        """Manage chunk cache to prevent memory overflow"""
        if len(self.loaded_chunks) > self.max_cached_chunks:
            # Remove oldest chunks
            chunks_to_remove = list(self.loaded_chunks.keys())[:-self.max_cached_chunks]
            for chunk_id in chunks_to_remove:
                chunk = self.loaded_chunks.pop(chunk_id)
                chunk.unload()
                logger.debug(f"Removed chunk {chunk_id} from cache")
    
    def _cleanup_chunks(self):
        """Cleanup callback for memory manager"""
        logger.info("Cleaning up loaded chunks")
        for chunk in self.loaded_chunks.values():
            chunk.unload()
        self.loaded_chunks.clear()
        gc.collect()
    
    async def get_chunk_iterator(self, model_info: Dict[str, Any]) -> Iterator[ModelChunk]:
        """Get iterator for model chunks"""
        if not model_info.get('is_chunked', False):
            # Not a chunked model
            yield model_info['model']
            return
        
        chunks_info = model_info['chunks_info']
        model_path = model_info['source']
        
        for chunk_info in chunks_info:
            chunk = await self._load_chunk(model_path, chunk_info)
            yield chunk
            
            # Optionally unload chunk after yielding
            # chunk.unload()
    
    def get_memory_usage(self) -> Dict[str, float]:
        """Get current memory usage of loaded chunks"""
        total_memory_mb = sum(chunk.memory_size_mb for chunk in self.loaded_chunks.values())
        
        return {
            'total_chunks_memory_mb': total_memory_mb,
            'loaded_chunks_count': len(self.loaded_chunks),
            'average_chunk_size_mb': total_memory_mb / len(self.loaded_chunks) if self.loaded_chunks else 0
        }