|
|
|
|
|
""" |
|
|
BREAKTHROUGH BitTransformerLM Training Script |
|
|
=========================================== |
|
|
|
|
|
Using the ACTUAL BitTransformerLM model and training infrastructure, |
|
|
configured for the Fixed RL Adafactor breakthrough results. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import ( |
|
|
BitTransformerLM, |
|
|
text_to_bits, |
|
|
train_loop, |
|
|
save_model, |
|
|
load_model, |
|
|
set_dropout |
|
|
) |
|
|
from BTLM_Extensions import configure_adafactor_optimizer |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler('breakthrough_training.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_and_prepare_dataset(): |
|
|
"""Load HF dataset and convert to bit tensors.""" |
|
|
logger.info("Loading WCNegentropy/BitTransformerLM dataset...") |
|
|
|
|
|
|
|
|
hf_token = os.getenv('HF_TOKEN') |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
else: |
|
|
print("Warning: HF_TOKEN environment variable not set") |
|
|
|
|
|
|
|
|
dataset = load_dataset("WCNegentropy/BitTransformerLM") |
|
|
train_data = dataset['train'] |
|
|
|
|
|
logger.info(f"Dataset loaded: {len(train_data)} samples") |
|
|
|
|
|
|
|
|
bit_sequences = [] |
|
|
for sample in train_data: |
|
|
if 'bit_sequence' in sample and sample['bit_sequence'] is not None: |
|
|
|
|
|
bits = sample['bit_sequence'] |
|
|
if isinstance(bits, str): |
|
|
try: |
|
|
bits = eval(bits) |
|
|
except: |
|
|
bits = None |
|
|
if isinstance(bits, list) and len(bits) > 0: |
|
|
bit_sequences.append(bits) |
|
|
else: |
|
|
|
|
|
text = sample.get('original_text', '') |
|
|
if text: |
|
|
bits = text_to_bits(text) |
|
|
bit_sequences.append(bits) |
|
|
else: |
|
|
|
|
|
text = sample.get('text', '') or sample.get('original_text', '') |
|
|
if text: |
|
|
bits = text_to_bits(text) |
|
|
bit_sequences.append(bits) |
|
|
|
|
|
logger.info(f"Processed {len(bit_sequences)} bit sequences") |
|
|
|
|
|
|
|
|
max_len = 512 |
|
|
training_sequences = [] |
|
|
|
|
|
for bits in bit_sequences: |
|
|
|
|
|
for i in range(0, len(bits) - max_len + 1, max_len // 2): |
|
|
seq = bits[i:i + max_len] |
|
|
if len(seq) == max_len: |
|
|
training_sequences.append(seq) |
|
|
|
|
|
|
|
|
data_tensor = torch.tensor(training_sequences, dtype=torch.long) |
|
|
logger.info(f"Created training tensor: {data_tensor.shape}") |
|
|
|
|
|
return data_tensor |
|
|
|
|
|
def create_breakthrough_model(): |
|
|
"""Create the EXACT breakthrough BitTransformerLM configuration.""" |
|
|
logger.info("Creating breakthrough BitTransformerLM model...") |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=512, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=512, |
|
|
reversible=True, |
|
|
use_checkpoint=True, |
|
|
use_autocast=True, |
|
|
use_act=True, |
|
|
act_threshold=0.9, |
|
|
lambda_K=0.05, |
|
|
lambda_C=0.05, |
|
|
lambda_S=0.05 |
|
|
) |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
logger.info(f"Model created: {total_params:,} parameters") |
|
|
logger.info(f"Target: ~16M parameters - {'β' if 15_000_000 <= total_params <= 17_000_000 else 'β'}") |
|
|
|
|
|
return model |
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
logger.info("π STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!") |
|
|
logger.info("Using ACTUAL BitTransformerLM model and train_loop") |
|
|
|
|
|
|
|
|
data = load_and_prepare_dataset() |
|
|
|
|
|
|
|
|
model = create_breakthrough_model() |
|
|
|
|
|
|
|
|
logger.info("Configuring Fixed RL Adafactor optimizer...") |
|
|
optimizer, scheduler = configure_adafactor_optimizer( |
|
|
model, |
|
|
lr=1e-3, |
|
|
weight_decay=0.01, |
|
|
total_steps=5000 |
|
|
) |
|
|
logger.info("Fixed RL Adafactor configured with LR=0.001") |
|
|
|
|
|
|
|
|
training_config = { |
|
|
'epochs': 20, |
|
|
'batch_size': 4, |
|
|
'accum_steps': 4, |
|
|
'amp': True, |
|
|
'log': True, |
|
|
'compress_prob': 0.0, |
|
|
'optimizer': optimizer, |
|
|
'scheduler': scheduler |
|
|
} |
|
|
|
|
|
logger.info(f"Training configuration: {training_config}") |
|
|
logger.info("Starting training loop...") |
|
|
|
|
|
|
|
|
metrics = train_loop( |
|
|
model=model, |
|
|
data=data, |
|
|
**training_config |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint_dir = Path('/data/BitTransformerLM/checkpoints') |
|
|
checkpoint_dir.mkdir(exist_ok=True) |
|
|
|
|
|
model_path = checkpoint_dir / 'breakthrough_model.pt' |
|
|
save_model(model, model_path) |
|
|
logger.info(f"Model saved to: {model_path}") |
|
|
|
|
|
|
|
|
if metrics: |
|
|
final_metrics = metrics[-1] |
|
|
logger.info("π TRAINING COMPLETED!") |
|
|
logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}") |
|
|
logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}") |
|
|
|
|
|
|
|
|
if final_metrics['raw_loss'] < 3.0: |
|
|
logger.info("π BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!") |
|
|
|
|
|
logger.info("Breakthrough training completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |