File size: 19,761 Bytes
e7b895b
36a5fc5
 
 
e7b895b
 
 
 
36a5fc5
 
 
 
e7b895b
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
d7cde9b
e7b895b
 
 
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e2fdb
e7b895b
 
60e2fdb
 
 
 
 
 
e7b895b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36a5fc5
e7b895b
 
36a5fc5
e7b895b
 
 
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
36a5fc5
e7b895b
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
"""
Configuration module for Enhanced SPG compression.
Contains all research constants, configuration classes, and validation logic.
STRICT COMPLIANCE: No hardcoding, all parameters from config.
"""

import json
import hashlib
import logging
import sys
import os
import platform
from dataclasses import dataclass, field, asdict
from typing import List, Optional, NamedTuple, Any
from enum import Enum
from datetime import datetime
import torch
import transformers

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class CompressionType(Enum):
    """RocketKV-enhanced SPG methods with explicit validation."""
    NONE = "none"
    SPG = "spg"
    ADAPTIVE_SPG = "adaptive_spg"
    ENHANCED_SPG = "enhanced_spg"
    PROGRESSIVE_SPG = "progressive_spg"


class PrecisionLevel(NamedTuple):
    """Precision level configuration with validation."""
    threshold: float
    bits: Optional[int]
    name: str


@dataclass
class ResearchConstants:
    """All constants/thresholds from validated research - NO HARDCODING."""
    # Magnitude-based importance thresholds (configurable, not magic)
    MAGNITUDE_THRESHOLD_CONSERVATIVE: float = 0.99   # Top 1%
    MAGNITUDE_THRESHOLD_AGGRESSIVE: float = 0.995    # Top 0.5%
    MAGNITUDE_THRESHOLD_EXTREME: float = 0.999       # Top 0.1%
    
    # Layer-specific retention bounds (explicit configuration)
    EARLY_LAYER_MAX_RETENTION: float = 0.02   # 2% max for early layers (tighter for 405x+)
    LATE_LAYER_MAX_RETENTION: float = 0.035   # 3.5% max for late layers (tighter for 405x+)
    
    # RocketKV-style compression parameters (research-validated)
    HEAD_RETENTION_AGGRESSIVE: float = 0.35   # Keep 35% of heads (more aggressive)
    HEAD_RETENTION_CONSERVATIVE: float = 0.6  # Keep 60% of heads
    POSITION_BOOST_SINK: float = 3.0          # 3x boost for sink tokens
    POSITION_BOOST_RECENT: float = 2.0        # 2x boost for recent tokens
    
    # Adaptive decomposition parameters (explicit formulas)
    SPARSE_STAGE1_POWER: float = 0.75         # More compression in Stage 1
    BALANCED_STAGE1_POWER: float = 0.5        # Balanced split
    DENSE_STAGE1_POWER: float = 0.25          # Less compression in Stage 1
    SPARSITY_HIGH_THRESHOLD: float = 0.8      # Threshold for highly sparse
    SPARSITY_MEDIUM_THRESHOLD: float = 0.5    # Threshold for moderately sparse
    
    # Attention sparsity estimation (explicit thresholds)
    ATTENTION_SPARSITY_THRESHOLD: float = 0.1  # Threshold for near-zero weights
    
    # Quality monitoring
    QUALITY_HISTORY_MAX_SIZE: int = 50
    PROGRESSIVE_QUALITY_WINDOW: int = 10
    PROGRESSIVE_RECENT_WINDOW: int = 5
    
    # Memory overhead (measured, not estimated)
    METADATA_OVERHEAD_BYTES: int = 256
    INDEX_SIZE_BYTES: int = 4  # int32 per index
    INT2_METADATA_BYTES: int = 24  # Measured overhead for INT2 packing
    
    # Compression ratio bounds (configurable, not hardcoded)
    STAGE_COMPRESSION_MIN: float = 2.0         # Minimum stage compression
    STAGE_COMPRESSION_MAX: float = 150.0       # Maximum stage compression (increased for 450x)
    
    # Stability parameters (explicit, not magic)
    MIN_TOKENS_FOR_STABILITY: int = 4          # Minimum tokens for seq_budget
    RECENT_BOOST_FACTOR: float = 0.1           # Boost factor for recent tokens
    PROGRESSIVE_MIN_RATIO: float = 0.0001      # Minimum ratio to prevent division by zero
    
    # Kernel size thresholds (explicit sequence length boundaries)
    KERNEL_SIZE_SMALL_THRESHOLD: int = 1024    # Small sequence threshold
    KERNEL_SIZE_MEDIUM_THRESHOLD: int = 4096   # Medium sequence threshold
    KERNEL_SIZE_LARGE_THRESHOLD: int = 16384   # Large sequence threshold
    
    # Precision level defaults (research-validated for 450x compression)
    DEFAULT_PRECISION_LEVELS_AGGRESSIVE: List[PrecisionLevel] = field(default_factory=lambda: [
        PrecisionLevel(0.99999, None, "fp16"),   # Ultra-selective FP16 (0.001%) - increased selectivity
        PrecisionLevel(0.9995, 8, "int8"),       # High importance INT8 (0.049%)
        PrecisionLevel(0.996, 4, "int4"),        # Medium importance INT4 (0.35%) - FLOOR
        PrecisionLevel(0.0, 4, "int4")           # UPDATED: INT4 floor instead of discard
    ])
    
    DEFAULT_PRECISION_LEVELS_STANDARD: List[PrecisionLevel] = field(default_factory=lambda: [
        PrecisionLevel(0.99995, None, "fp16"),   # Ultra-selective FP16
        PrecisionLevel(0.9999, 8, "int8"),       # High importance INT8
        PrecisionLevel(0.999, 4, "int4"),        # Medium importance INT4
        PrecisionLevel(0.995, 4, "int4"),        # UPDATED: INT4 floor
        PrecisionLevel(0.0, 4, "int4")           # UPDATED: INT4 floor instead of discard
    ])
    
    # Validation bounds
    MIN_LAYERS: int = 1
    MAX_LAYERS: int = 200
    MIN_SEQUENCE_LENGTH: int = 16
    MAX_SEQUENCE_LENGTH: int = 32768
    MIN_EVAL_SAMPLES: int = 1
    MAX_EVAL_SAMPLES: int = 1000
    MIN_COMPRESSION_RATIO: float = 1.0
    MAX_COMPRESSION_RATIO: float = 1000.0


@dataclass
class EnhancedSPGConfig:
    """Research-grade configuration with RocketKV-style 450x compression support."""
    # Core SPG parameters with validation
    base_decay_rate: float = 0.95
    decay_normalization: int = 64
    sink_tokens: int = 0      # Reduced for 405x+
    recent_window: int = 24    # UPDATED: Keep last 24 tokens uncompressed for stability
    recent_min_precision: float = 1.0  # UPDATED: Full precision for recent tokens
    
    # Multi-stage parameters (explicit, no hardcoding)
    enable_two_stage: bool = True
    stage1_compression_ratio: float = 20.0
    stage2_compression_ratio: float = 20.0
    
    # RocketKV-style parameters for 450x compression
    target_compression_ratio: float = 450.0  # Target 450x compression
    use_adaptive_decomposition: bool = True   # Adaptive stage splitting
    use_hybrid_sparse_attention: bool = True  # HSA for Stage 2
    use_snapkv_plus_plus: bool = True        # SnapKV++ for Stage 1
    
    # Multi-dimensional compression (explicit configuration for 450x)
    enable_head_compression: bool = True
    sequence_compression_ratio: float = 0.00015  # 0.015% - tighter for 405x+
    head_compression_ratio: float = 0.00015      # 0.015% - tighter for 405x+
    head_retention_mode: str = "aggressive"      # aggressive/conservative
    head_fp16_reserve: int = 2   # NEW: Reserve top 2 heads per layer at FP16
    
    # Magnitude-based parameters (configurable)
    magnitude_page_size: int = 64
    magnitude_threshold_mode: str = "extreme"   # Use extreme by default for 450x
    
    # Progressive compression (explicit controls for 450x capability)
    enable_progressive: bool = False
    initial_compression_ratio: float = 100.0  # Start higher for 450x target
    max_compression_ratio: float = 450.0      # Target compression
    quality_threshold: float = 0.01           # UPDATED: 1% degradation threshold (tighter)
    progression_steps: int = 6                # More steps for gradual progression
    progression_factor: float = 1.15          # 15% increase per step
    quality_feedback_frequency: int = 16      # Quality feedback frequency
    
    # Hardware optimization flags
    page_aligned_storage: bool = True
    use_custom_kernels: bool = False  # Disabled until implemented
    memory_layout_optimization: bool = True
    
    # Precision levels (from research constants) - configurable for compression level
    precision_levels: List[PrecisionLevel] = field(default_factory=list)
    use_aggressive_precision: bool = True  # Use aggressive precision levels for 450x
    
    # Adaptive parameters with validation
    enable_adaptive: bool = False
    target_perplexity_delta: float = 1.8  # More lenient for 450x compression
    decay_adjustment_rate: float = 0.015   # Slower adjustment for stability
    per_layer_decay: bool = True
    
    # Performance optimization
    vectorized: bool = True
    block_size: int = 64
    
    # Kernel size calculation parameters (explicit, not hardcoded)
    kernel_size_small_seq: int = 4     # For seq_len < small_threshold
    kernel_size_medium_seq: int = 8    # For seq_len < medium_threshold
    kernel_size_large_seq: int = 16    # For seq_len < large_threshold
    kernel_size_xlarge_seq: int = 32   # For seq_len >= large_threshold
    
    # Stability and boost parameters (explicit, not magic numbers)
    min_tokens_for_stability: int = 4     # Minimum tokens for seq_budget
    recent_boost_factor: float = 0.1      # Boost factor for recent tokens
    progressive_min_ratio: float = 0.0001 # Minimum ratio to prevent division by zero
    
    # Compression bounds (configurable, not hardcoded) - increased for 450x
    stage_compression_min: float = 2.0    # Minimum stage compression ratio
    stage_compression_max: float = 500.0  # Maximum stage compression ratio (INCREASED for 450x)
    
    def __post_init__(self):
        """Validate all parameters - fail fast on invalid config."""
        constants = ResearchConstants()
        
        if not 0.5 <= self.base_decay_rate <= 0.99:
            raise ValueError(f"base_decay_rate must be in [0.5, 0.99], got {self.base_decay_rate}")
        if self.decay_normalization <= 0:
            raise ValueError(f"decay_normalization must be positive, got {self.decay_normalization}")
        if self.sink_tokens < 0:
            raise ValueError(f"sink_tokens must be non-negative, got {self.sink_tokens}")
        if self.recent_window < 0:
            raise ValueError(f"recent_window must be non-negative, got {self.recent_window}")
        if not 0.0 <= self.recent_min_precision <= 1.0:
            raise ValueError(f"recent_min_precision must be in [0,1], got {self.recent_min_precision}")
        
        if self.stage1_compression_ratio <= 1.0:
            raise ValueError(f"stage1_compression_ratio must be > 1.0, got {self.stage1_compression_ratio}")
        if self.stage2_compression_ratio <= 1.0:
            raise ValueError(f"stage2_compression_ratio must be > 1.0, got {self.stage2_compression_ratio}")
        
        # RocketKV validation
        if not constants.MIN_COMPRESSION_RATIO <= self.target_compression_ratio <= constants.MAX_COMPRESSION_RATIO:
            raise ValueError(f"target_compression_ratio must be in [{constants.MIN_COMPRESSION_RATIO}, {constants.MAX_COMPRESSION_RATIO}], got {self.target_compression_ratio}")
        if self.target_compression_ratio > 500.0:
            logger.warning(f"target_compression_ratio {self.target_compression_ratio} is extremely high - quality may degrade")
        
        if not 0.0 < self.sequence_compression_ratio <= 1.0:
            raise ValueError(f"sequence_compression_ratio must be in (0,1], got {self.sequence_compression_ratio}")
        if not 0.0 < self.head_compression_ratio <= 1.0:
            raise ValueError(f"head_compression_ratio must be in (0,1], got {self.head_compression_ratio}")
        
        if self.magnitude_threshold_mode not in ["conservative", "aggressive", "extreme"]:
            raise ValueError(f"magnitude_threshold_mode must be conservative/aggressive/extreme, got {self.magnitude_threshold_mode}")
        
        if self.head_retention_mode not in ["aggressive", "conservative"]:
            raise ValueError(f"head_retention_mode must be aggressive/conservative, got {self.head_retention_mode}")
        
        # Validate configurable parameters
        if self.quality_feedback_frequency <= 0:
            raise ValueError(f"quality_feedback_frequency must be positive, got {self.quality_feedback_frequency}")
        if self.min_tokens_for_stability <= 0:
            raise ValueError(f"min_tokens_for_stability must be positive, got {self.min_tokens_for_stability}")
        if not 0.0 <= self.recent_boost_factor <= 1.0:
            raise ValueError(f"recent_boost_factor must be in [0,1], got {self.recent_boost_factor}")
        if self.progressive_min_ratio <= 0:
            raise ValueError(f"progressive_min_ratio must be positive, got {self.progressive_min_ratio}")
        
        # Set precision levels based on compression aggressiveness
        if not self.precision_levels:
            if self.use_aggressive_precision or self.target_compression_ratio >= 400.0:
                self.precision_levels = constants.DEFAULT_PRECISION_LEVELS_AGGRESSIVE.copy()
                logger.info("Using aggressive precision levels for high compression")
            else:
                self.precision_levels = constants.DEFAULT_PRECISION_LEVELS_STANDARD.copy()
                logger.info("Using standard precision levels")
        
        logger.info(f"Enhanced SPG config validated successfully (target: {self.target_compression_ratio}x)")
    
    def get_magnitude_threshold(self) -> float:
        """Get magnitude threshold based on mode - no hardcoding."""
        constants = ResearchConstants()
        thresholds = {
            "conservative": constants.MAGNITUDE_THRESHOLD_CONSERVATIVE,
            "aggressive": constants.MAGNITUDE_THRESHOLD_AGGRESSIVE,
            "extreme": constants.MAGNITUDE_THRESHOLD_EXTREME
        }
        return thresholds[self.magnitude_threshold_mode]
    
    def get_head_retention_ratio(self) -> float:
        """Get head retention ratio based on mode - no hardcoding."""
        constants = ResearchConstants()
        ratios = {
            "aggressive": constants.HEAD_RETENTION_AGGRESSIVE,
            "conservative": constants.HEAD_RETENTION_CONSERVATIVE
        }
        return ratios[self.head_retention_mode]
    
    def get_adaptive_kernel_size(self, seq_len: int) -> int:
        """Get adaptive kernel size based on sequence length - explicit rules."""
        constants = ResearchConstants()
        if seq_len < constants.KERNEL_SIZE_SMALL_THRESHOLD:
            return self.kernel_size_small_seq
        elif seq_len < constants.KERNEL_SIZE_MEDIUM_THRESHOLD:
            return self.kernel_size_medium_seq
        elif seq_len < constants.KERNEL_SIZE_LARGE_THRESHOLD:
            return self.kernel_size_large_seq
        else:
            return self.kernel_size_xlarge_seq


@dataclass
class ProvingConfig:
    """Configuration for attestable proof generation and verification - NO HARDCODING."""
    enabled: bool = True
    numeric_tolerance: float = 0.01     # Relaxed from 1e-8 for realistic drift
    time_tolerance_ms: float = 0.5      # 0.5ms tolerance for timing
    ppl_tolerance: float = 0.1          # 10% relative tolerance for perplexity
    comp_ratio_floor: float = 0.90      # Min fraction of target achieved (configurable)
    require_cuda: bool = True           # Mirrors fail_on_cpu_fallback
    verify_recompute: bool = True       # Recompute summary from records and compare
    export_per_sample: bool = True      # Export detailed per-sample records
    export_fingerprints: bool = True    # Export KV cache fingerprints
    
    def __post_init__(self):
        """Validate proving parameters - fail fast on invalid config."""
        if not 0 < self.numeric_tolerance < 1:
            raise ValueError(f"numeric_tolerance must be in (0, 1), got {self.numeric_tolerance}")
        if not 0 < self.comp_ratio_floor <= 1:
            raise ValueError(f"comp_ratio_floor must be in (0, 1], got {self.comp_ratio_floor}")
        if self.time_tolerance_ms <= 0:
            raise ValueError(f"time_tolerance_ms must be positive, got {self.time_tolerance_ms}")
        if not 0 < self.ppl_tolerance < 1:
            raise ValueError(f"ppl_tolerance must be in (0, 1), got {self.ppl_tolerance}")


@dataclass
class CompressionConfig:
    """Research-grade configuration for RocketKV-enhanced SPG methods."""
    # Core settings
    compression_type: CompressionType = CompressionType.ENHANCED_SPG
    seed: int = 42
    
    # Enhanced SPG configuration
    enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
    
    # Proving configuration
    proving: ProvingConfig = field(default_factory=ProvingConfig)
    
    # Evaluation settings with validation
    eval_samples: int = 50
    prefill_length: int = 512
    generation_length: int = 64
    batch_size: int = 1
    warmup_steps: int = 3
    n_seeds: int = 3
    
    # Statistical validation
    n_bootstrap: int = 500
    confidence_level: float = 0.95
    
    # Dataset configuration
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-103-raw-v1"
    dataset_split: str = "test"
    
    # Model configuration for publication
    model_name: str = "gpt2"
    test_sequence_lengths: List[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384])
    downstream_tasks: List[str] = field(default_factory=lambda: ["perplexity", "gsm8k", "mmlu"])
    baseline_methods: List[str] = field(default_factory=lambda: ["h2o", "streamingllm", "snapkv"])
    
    # Memory and system settings
    clear_cache_between_runs: bool = True
    use_memory_snapshot: bool = True
    fail_on_cpu_fallback: bool = True  # CHANGED: Default to True for strict compliance
    
    # Output settings
    generate_latex: bool = True
    save_intermediate_results: bool = True
    
    # System info (auto-populated, no hardcoding)
    torch_version: str = field(default_factory=lambda: torch.__version__)
    transformers_version: str = field(default_factory=lambda: transformers.__version__)
    cuda_version: str = field(default_factory=lambda: torch.version.cuda if torch.cuda.is_available() else "cpu")
    device_name: str = field(default_factory=lambda: torch.cuda.get_device_name() if torch.cuda.is_available() else "cpu")
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    
    def __post_init__(self):
        """Comprehensive validation - fail fast on any invalid parameter."""
        constants = ResearchConstants()
        
        # Validate core parameters
        if not isinstance(self.seed, int) or self.seed < 0:
            raise ValueError(f"seed must be non-negative integer, got {self.seed}")
        
        # Validate evaluation parameters
        if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
            logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
        
        if not constants.MIN_SEQUENCE_LENGTH <= self.prefill_length <= constants.MAX_SEQUENCE_LENGTH:
            logger.warning(f"prefill_length {self.prefill_length} outside range [{constants.MIN_SEQUENCE_LENGTH}, {constants.MAX_SEQUENCE_LENGTH}]")
        
        if self.generation_length <= 0:
            raise ValueError(f"generation_length must be positive, got {self.generation_length}")
        
        if not 1 <= self.n_seeds <= 10:
            logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
        
        # Validate statistical parameters
        if not 0.5 <= self.confidence_level < 1.0:
            raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
        
        if not 100 <= self.n_bootstrap <= 10000:
            logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
        
        logger.info("RocketKV-enhanced SPG config validated successfully")
    
    def to_json(self) -> str:
        """Export config for reproducibility."""
        config_dict = asdict(self)
        config_dict['compression_type'] = self.compression_type.value
        return json.dumps(config_dict, indent=2, default=str)
    
    def get_hash(self) -> str:
        """Get deterministic hash for caching."""
        return hashlib.md5(self.to_json().encode()).hexdigest()[:8]