| | """ |
| | Test script to verify 250K context length support |
| | Tests RoPE scaling and long context handling |
| | """ |
| |
|
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| | import logging |
| | from typing import Optional |
| | import time |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class LongContextTester: |
| | """Test long context capabilities of Helion-OSC""" |
| | |
| | def __init__(self, model_path: str = "./inference"): |
| | """ |
| | Initialize tester |
| | |
| | Args: |
| | model_path: Path to model inference directory |
| | """ |
| | self.model_path = model_path |
| | logger.info("Loading model configuration...") |
| | |
| | |
| | self.config = AutoConfig.from_pretrained(model_path) |
| | |
| | |
| | max_pos = self.config.max_position_embeddings |
| | logger.info(f"Model max position embeddings: {max_pos:,}") |
| | |
| | if max_pos < 250000: |
| | logger.warning(f"Context length ({max_pos:,}) is less than 250K!") |
| | else: |
| | logger.info(f"✓ Context length supports 250K+ tokens ({max_pos:,})") |
| | |
| | |
| | rope_scaling = getattr(self.config, 'rope_scaling', None) |
| | rope_theta = getattr(self.config, 'rope_theta', None) |
| | |
| | if rope_scaling: |
| | logger.info(f"RoPE Scaling: {rope_scaling}") |
| | if rope_theta: |
| | logger.info(f"RoPE Theta: {rope_theta:,}") |
| | |
| | def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"): |
| | """Test that tokenizer supports long sequences""" |
| | logger.info("\n" + "="*80) |
| | logger.info("TEST 1: Tokenizer Capacity") |
| | logger.info("="*80) |
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| | |
| | max_length = tokenizer.model_max_length |
| | logger.info(f"Tokenizer max length: {max_length:,}") |
| | |
| | if max_length >= 250000: |
| | logger.info("✓ Tokenizer supports 250K+ tokens") |
| | else: |
| | logger.warning(f"✗ Tokenizer max length only {max_length:,}") |
| | |
| | |
| | test_tokens = 10000 |
| | test_text = "Hello world! " * (test_tokens // 2) |
| | |
| | logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...") |
| | encoded = tokenizer(test_text, return_tensors="pt", truncation=False) |
| | actual_tokens = encoded['input_ids'].shape[1] |
| | |
| | logger.info(f"Successfully tokenized {actual_tokens:,} tokens") |
| | logger.info("✓ Tokenization test passed") |
| | |
| | return True |
| | |
| | except Exception as e: |
| | logger.error(f"✗ Tokenization test failed: {e}") |
| | return False |
| | |
| | def test_position_embeddings(self): |
| | """Test position embedding capacity""" |
| | logger.info("\n" + "="*80) |
| | logger.info("TEST 2: Position Embeddings") |
| | logger.info("="*80) |
| | |
| | max_pos = self.config.max_position_embeddings |
| | hidden_size = self.config.hidden_size |
| | |
| | logger.info(f"Max positions: {max_pos:,}") |
| | logger.info(f"Hidden size: {hidden_size:,}") |
| | |
| | |
| | if hasattr(self.config, 'rope_theta'): |
| | logger.info("Using RoPE (Rotary Position Embeddings)") |
| | logger.info("✓ RoPE scales efficiently to long contexts") |
| | |
| | |
| | logger.info(f"RoPE Theta: {self.config.rope_theta:,}") |
| | |
| | if hasattr(self.config, 'rope_scaling'): |
| | scaling = self.config.rope_scaling |
| | logger.info(f"RoPE Scaling Configuration:") |
| | logger.info(f" Type: {scaling.get('type', 'N/A')}") |
| | logger.info(f" Factor: {scaling.get('factor', 'N/A')}") |
| | |
| | if scaling.get('factor', 0) >= 32: |
| | logger.info("✓ RoPE scaling factor supports 250K+ context (32x from 8K base)") |
| | else: |
| | logger.warning("✗ RoPE scaling factor may be insufficient") |
| | |
| | return True |
| | else: |
| | |
| | pos_emb_size = max_pos * hidden_size * 2 |
| | pos_emb_gb = pos_emb_size / (1024**3) |
| | logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB") |
| | |
| | if max_pos >= 250000: |
| | logger.info("✓ Sufficient position embeddings for 250K context") |
| | return True |
| | else: |
| | logger.warning("✗ Insufficient position embeddings") |
| | return False |
| | |
| | def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]): |
| | """Test attention computation at various lengths""" |
| | logger.info("\n" + "="*80) |
| | logger.info("TEST 3: Attention Computation Scaling") |
| | logger.info("="*80) |
| | |
| | hidden_size = self.config.hidden_size |
| | num_heads = self.config.num_attention_heads |
| | head_dim = hidden_size // num_heads |
| | |
| | logger.info(f"Attention heads: {num_heads}") |
| | logger.info(f"Head dimension: {head_dim}") |
| | |
| | for seq_len in sequence_lengths: |
| | |
| | |
| | attn_size = 1 * num_heads * seq_len * seq_len * 2 |
| | attn_gb = attn_size / (1024**3) |
| | |
| | logger.info(f"\nSequence length: {seq_len:,} tokens") |
| | logger.info(f" Attention matrix: {attn_gb:.2f} GB") |
| | |
| | if seq_len <= 32768: |
| | logger.info(f" ✓ Manageable size") |
| | elif seq_len <= 131072: |
| | logger.info(f" ⚠ Large - may need Flash Attention") |
| | else: |
| | logger.info(f" ⚠ Very large - requires optimizations") |
| | |
| | |
| | use_flash = getattr(self.config, 'use_flash_attention_2', False) |
| | if use_flash: |
| | logger.info("\n✓ Flash Attention 2 enabled - efficient for long contexts") |
| | else: |
| | logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts") |
| | |
| | return True |
| | |
| | def test_memory_requirements(self): |
| | """Calculate memory requirements for 250K context""" |
| | logger.info("\n" + "="*80) |
| | logger.info("TEST 4: Memory Requirements") |
| | logger.info("="*80) |
| | |
| | context_length = 250000 |
| | batch_size = 1 |
| | hidden_size = self.config.hidden_size |
| | num_layers = self.config.num_hidden_layers |
| | |
| | logger.info(f"Configuration:") |
| | logger.info(f" Context: {context_length:,} tokens") |
| | logger.info(f" Batch size: {batch_size}") |
| | logger.info(f" Hidden size: {hidden_size:,}") |
| | logger.info(f" Layers: {num_layers}") |
| | |
| | |
| | |
| | hidden_states_size = batch_size * context_length * hidden_size * 2 |
| | hidden_states_gb = hidden_states_size / (1024**3) |
| | |
| | |
| | layer_memory_gb = hidden_states_gb * 2 |
| | total_activation_gb = layer_memory_gb * num_layers |
| | |
| | logger.info(f"\nMemory estimates:") |
| | logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB") |
| | logger.info(f" Total activation memory: {total_activation_gb:.2f} GB") |
| | logger.info(f" Model weights: ~349 GB") |
| | logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB") |
| | |
| | logger.info(f"\nRecommendations:") |
| | if total_activation_gb < 50: |
| | logger.info(" ✓ Should fit on 8x A100 (80GB) GPUs") |
| | elif total_activation_gb < 100: |
| | logger.info(" ⚠ May need gradient checkpointing") |
| | else: |
| | logger.info(" ⚠ Will need aggressive optimizations (checkpointing, CPU offload)") |
| | |
| | return True |
| | |
| | def test_rope_frequencies(self): |
| | """Test RoPE frequency calculations for long context""" |
| | logger.info("\n" + "="*80) |
| | logger.info("TEST 5: RoPE Frequency Analysis") |
| | logger.info("="*80) |
| | |
| | rope_theta = getattr(self.config, 'rope_theta', 10000) |
| | hidden_size = self.config.hidden_size |
| | num_heads = self.config.num_attention_heads |
| | head_dim = hidden_size // num_heads |
| | |
| | logger.info(f"RoPE theta: {rope_theta:,}") |
| | logger.info(f"Head dimension: {head_dim}") |
| | |
| | |
| | |
| | min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim) |
| | max_freq = rope_theta ** 0 |
| | |
| | logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]") |
| | |
| | |
| | wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim)) |
| | for i in range(0, head_dim // 2, head_dim // 8)] |
| | |
| | logger.info(f"\nWavelengths (in tokens):") |
| | for i, wl in enumerate(wavelengths): |
| | logger.info(f" Frequency {i}: {wl:,.0f} tokens") |
| | |
| | max_wavelength = max(wavelengths) |
| | if max_wavelength >= 250000: |
| | logger.info(f"\n✓ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context") |
| | else: |
| | logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient") |
| | |
| | return True |
| | |
| | def run_all_tests(self): |
| | """Run all context length tests""" |
| | logger.info("\n" + "="*80) |
| | logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE") |
| | logger.info("="*80) |
| | |
| | results = { |
| | "tokenization": self.test_tokenization_capacity(), |
| | "position_embeddings": self.test_position_embeddings(), |
| | "attention_scaling": self.test_attention_computation(), |
| | "memory_requirements": self.test_memory_requirements(), |
| | "rope_frequencies": self.test_rope_frequencies() |
| | } |
| | |
| | |
| | logger.info("\n" + "="*80) |
| | logger.info("TEST SUMMARY") |
| | logger.info("="*80) |
| | |
| | for test_name, passed in results.items(): |
| | status = "✓ PASS" if passed else "✗ FAIL" |
| | logger.info(f"{test_name}: {status}") |
| | |
| | all_passed = all(results.values()) |
| | |
| | if all_passed: |
| | logger.info("\n✓ All tests passed - Model supports 250K context length") |
| | else: |
| | logger.warning("\n⚠ Some tests failed - Check configuration") |
| | |
| | return all_passed |
| |
|
| |
|
| | def main(): |
| | """Main test script""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support") |
| | parser.add_argument( |
| | "--model-path", |
| | type=str, |
| | default="./inference", |
| | help="Path to model inference directory" |
| | ) |
| | parser.add_argument( |
| | "--test", |
| | choices=["all", "tokenization", "position", "attention", "memory", "rope"], |
| | default="all", |
| | help="Which test to run" |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | tester = LongContextTester(args.model_path) |
| | |
| | if args.test == "all": |
| | tester.run_all_tests() |
| | elif args.test == "tokenization": |
| | tester.test_tokenization_capacity() |
| | elif args.test == "position": |
| | tester.test_position_embeddings() |
| | elif args.test == "attention": |
| | tester.test_attention_computation() |
| | elif args.test == "memory": |
| | tester.test_memory_requirements() |
| | elif args.test == "rope": |
| | tester.test_rope_frequencies() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |