#!/usr/bin/env python3 """ BREAKTHROUGH BitTransformerLM Training Script =========================================== Using the ACTUAL BitTransformerLM model and training infrastructure, configured for the Fixed RL Adafactor breakthrough results. """ import sys import os import logging from pathlib import Path import torch from datasets import load_dataset from huggingface_hub import login # Add paths for imports sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import ( BitTransformerLM, text_to_bits, train_loop, save_model, load_model, set_dropout ) from BTLM_Extensions import configure_adafactor_optimizer # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('breakthrough_training.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def load_and_prepare_dataset(): """Load HF dataset and convert to bit tensors.""" logger.info("Loading WCNegentropy/BitTransformerLM dataset...") # Login to HuggingFace hf_token = os.getenv('HF_TOKEN') if hf_token: login(token=hf_token) else: print("Warning: HF_TOKEN environment variable not set") # Load dataset dataset = load_dataset("WCNegentropy/BitTransformerLM") train_data = dataset['train'] logger.info(f"Dataset loaded: {len(train_data)} samples") # Process dataset - the HF dataset already has bit_sequence field! bit_sequences = [] for sample in train_data: if 'bit_sequence' in sample and sample['bit_sequence'] is not None: # The bit_sequence might already be a list bits = sample['bit_sequence'] if isinstance(bits, str): try: bits = eval(bits) # Convert string representation to list except: bits = None if isinstance(bits, list) and len(bits) > 0: bit_sequences.append(bits) else: # Fallback: convert original_text to bits text = sample.get('original_text', '') if text: bits = text_to_bits(text) bit_sequences.append(bits) else: # Fallback: convert text to bits text = sample.get('text', '') or sample.get('original_text', '') if text: bits = text_to_bits(text) bit_sequences.append(bits) logger.info(f"Processed {len(bit_sequences)} bit sequences") # Create training tensors with proper sequence length max_len = 512 # BitTransformerLM default max_seq_len training_sequences = [] for bits in bit_sequences: # Split long sequences into chunks for i in range(0, len(bits) - max_len + 1, max_len // 2): seq = bits[i:i + max_len] if len(seq) == max_len: # Only use full-length sequences training_sequences.append(seq) # Convert to tensor data_tensor = torch.tensor(training_sequences, dtype=torch.long) logger.info(f"Created training tensor: {data_tensor.shape}") return data_tensor def create_breakthrough_model(): """Create the EXACT breakthrough BitTransformerLM configuration.""" logger.info("Creating breakthrough BitTransformerLM model...") # EXACT breakthrough configuration using ACTUAL BitTransformerLM parameters model = BitTransformerLM( d_model=512, # Breakthrough config nhead=16, # 16 attention heads num_layers=8, # 8 layers for ~16M params dim_feedforward=1024, # 2x d_model max_seq_len=512, # Match data preparation reversible=True, # Memory efficiency use_checkpoint=True, # Gradient checkpointing use_autocast=True, # Mixed precision use_act=True, # Adaptive Computation Time act_threshold=0.9, lambda_K=0.05, # Safety telemetry weights lambda_C=0.05, lambda_S=0.05 ) # Calculate parameter count total_params = sum(p.numel() for p in model.parameters()) logger.info(f"Model created: {total_params:,} parameters") logger.info(f"Target: ~16M parameters - {'✓' if 15_000_000 <= total_params <= 17_000_000 else '✗'}") return model def main(): """Main training function.""" logger.info("🚀 STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!") logger.info("Using ACTUAL BitTransformerLM model and train_loop") # Load dataset data = load_and_prepare_dataset() # Create model model = create_breakthrough_model() # CRITICAL: Use Fixed RL Adafactor (the breakthrough secret!) logger.info("Configuring Fixed RL Adafactor optimizer...") optimizer, scheduler = configure_adafactor_optimizer( model, lr=1e-3, # FIXED learning rate - key to breakthrough! weight_decay=0.01, total_steps=5000 # Estimated total steps ) logger.info("Fixed RL Adafactor configured with LR=0.001") # Training configuration training_config = { 'epochs': 20, # Reasonable number of epochs 'batch_size': 4, # Adjust based on memory 'accum_steps': 4, # Gradient accumulation 'amp': True, # Mixed precision 'log': True, # Enable logging 'compress_prob': 0.0, # Start with no compression 'optimizer': optimizer, 'scheduler': scheduler } logger.info(f"Training configuration: {training_config}") logger.info("Starting training loop...") # Use the ACTUAL BitTransformerLM train_loop function! metrics = train_loop( model=model, data=data, **training_config ) # Save the trained model checkpoint_dir = Path('/data/BitTransformerLM/checkpoints') checkpoint_dir.mkdir(exist_ok=True) model_path = checkpoint_dir / 'breakthrough_model.pt' save_model(model, model_path) logger.info(f"Model saved to: {model_path}") # Log final metrics if metrics: final_metrics = metrics[-1] logger.info("🎉 TRAINING COMPLETED!") logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}") logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}") # Check for breakthrough performance if final_metrics['raw_loss'] < 3.0: logger.info("🚀 BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!") logger.info("Breakthrough training completed successfully!") if __name__ == "__main__": main()