Girinath11's picture
Create train.py
c2b8ae0 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import os
import json
import argparse
import time
import math
import glob
from typing import Dict, List
from tqdm import tqdm
import numpy as np
import gc
from collections import defaultdict
import multiprocessing
# Import custom modules
try:
from model_slm import MixtureOfRecursions, count_parameters, TextGenerator
from custom_tokenizer import TechnicalTokenizer
except ImportError as e:
print(f"Import error: {e}")
exit(1)
class FastTechnicalTextDataset(Dataset):
"""Ultra-fast dataset with aggressive optimizations for 4-5hr training"""
def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 128, max_examples: int = 50000):
self.tokenizer = tokenizer
self.max_length = max_length
self.pad_token_id = tokenizer.vocab.get('<pad>', 0)
self.max_examples = max_examples
print(f"FAST DATASET LOADING")
print(f"Data file: {data_file}")
print(f"Max sequence length: {max_length}")
print(f"Max examples: {max_examples}")
start_time = time.time()
self.examples = []
self._fast_load_data(data_file)
load_time = time.time() - start_time
print(f" Loaded {len(self.examples)} examples in {load_time:.1f}s")
self._tensorize_data()
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def _fast_load_data(self, data_file: str):
print("🔍 Fast reading file...")
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
print(f"File has {len(lines)} lines")
good_examples = []
seen_hashes = set()
for line in lines[:self.max_examples * 3]:
line = line.strip()
if (50 <= len(line) <= 400 and
line.count(' ') >= 8 and
not line.lower().startswith(('http', 'www', 'ftp')) and
line.count('.') <= len(line) * 0.1):
line_hash = hash(line[:100])
if line_hash not in seen_hashes:
seen_hashes.add(line_hash)
good_examples.append(line)
if len(good_examples) >= self.max_examples:
break
print(f"After fast filtering: {len(good_examples)} quality examples")
batch_size = 1000
for i in range(0, len(good_examples), batch_size):
batch = good_examples[i:i+batch_size]
for line in batch:
try:
if not line.endswith('<|endoftext|>'):
line += ' <|endoftext|>'
tokens = self.tokenizer.encode_ids(line, add_special_tokens=True)
if 30 <= len(tokens) <= self.max_length:
if len(tokens) < self.max_length:
tokens = tokens + [self.pad_token_id] * (self.max_length - len(tokens))
self.examples.append(tokens)
except:
continue
if i % 5000 == 0:
print(f"Processed {len(self.examples)} examples...")
print(f"Final dataset: {len(self.examples)} examples")
def _tensorize_data(self):
print("Pre-tensorizing data for maximum speed...")
seq_len = self.max_length - 1
tensorized_examples = []
for tokens in self.examples:
if len(tokens) < self.max_length:
continue
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
targets = torch.tensor(tokens[1:], dtype=torch.long)
original_len = next((i for i, x in enumerate(tokens) if x == self.pad_token_id), self.max_length)
mask_len = min(original_len, seq_len)
attention_mask = torch.zeros(seq_len, dtype=torch.long)
attention_mask[:mask_len] = 1
tensorized_examples.append({
'input_ids': input_ids,
'targets': targets,
'attention_mask': attention_mask
})
self.examples = tensorized_examples
print("All data pre-tensorized")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
class FastCosineScheduler:
def __init__(self, optimizer, total_steps: int, warmup_ratio: float = 0.05):
self.optimizer = optimizer
self.total_steps = total_steps
self.warmup_steps = int(total_steps * warmup_ratio)
self.base_lr = optimizer.param_groups[0]['lr']
self.step_count = 0
def step(self):
self.step_count += 1
if self.step_count <= self.warmup_steps:
lr = self.base_lr * self.step_count / self.warmup_steps
else:
progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * progress))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
class UltraFastTrainer:
def __init__(self, model, tokenizer, train_dataset, val_dataset=None, config=None):
self.model = model
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.config = config or {}
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self._fast_init_weights()
self._setup_fast_optimizer()
epochs = self.config.get('epochs', 15)
batch_size = self.config.get('batch_size', 16)
total_steps = len(train_dataset) // batch_size * epochs
self.scheduler = FastCosineScheduler(self.optimizer, total_steps)
self.scaler = GradScaler()
self.global_step = 0
self.best_loss = float('inf')
self.grad_accum_steps = self.config.get('gradient_accumulation_steps', 1)
self.eval_every = self.config.get('eval_every', 500)
def _fast_init_weights(self):
def fast_init(module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
self.model.apply(fast_init)
def _setup_fast_optimizer(self):
lr = self.config.get('learning_rate', 5e-4)
params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = optim.AdamW(params, lr=lr, betas=(0.9, 0.99), weight_decay=0.01, eps=1e-6)
def compute_fast_loss(self, logits, targets, mask):
logits_flat = logits.view(-1, logits.size(-1))
targets_flat = targets.view(-1)
mask_flat = mask.view(-1).bool()
if not mask_flat.any():
return torch.tensor(0.0, device=logits.device, requires_grad=True)
loss = F.cross_entropy(logits_flat[mask_flat], targets_flat[mask_flat])
return loss
def train_epoch_fast(self, epoch: int, dataloader: DataLoader) -> Dict[str, float]:
self.model.train()
total_loss = 0
num_batches = 0
start_time = time.time()
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False, miniters=50)
for batch_idx, batch in enumerate(progress_bar):
input_ids = batch['input_ids'].to(self.device, non_blocking=True)
targets = batch['targets'].to(self.device, non_blocking=True)
mask = batch['attention_mask'].to(self.device, non_blocking=True)
with autocast():
logits, comp_loss = self.model(input_ids, mask)
lm_loss = self.compute_fast_loss(logits, targets, mask)
total_loss_step = lm_loss + 0.0001 * comp_loss
if self.grad_accum_steps > 1:
total_loss_step = total_loss_step / self.grad_accum_steps
self.scaler.scale(total_loss_step).backward()
if (batch_idx + 1) % self.grad_accum_steps == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
self.scheduler.step()
self.global_step += 1
total_loss += lm_loss.item()
num_batches += 1
if batch_idx % 100 == 0:
current_loss = total_loss / num_batches
progress_bar.set_postfix({'loss': f"{current_loss:.3f}", 'ppl': f"{math.exp(min(current_loss, 10)):.1f}"})
if batch_idx % 200 == 0 and batch_idx > 0:
torch.cuda.empty_cache()
epoch_time = time.time() - start_time
avg_loss = total_loss / max(num_batches, 1)
return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10)), 'epoch_time_min': epoch_time / 60}
def validate_fast(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.eval()
total_loss = 0
num_batches = 0
max_val_batches = min(100, len(dataloader))
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= max_val_batches:
break
input_ids = batch['input_ids'].to(self.device, non_blocking=True)
targets = batch['targets'].to(self.device, non_blocking=True)
mask = batch['attention_mask'].to(self.device, non_blocking=True)
with autocast():
logits, _ = self.model(input_ids, mask)
loss = self.compute_fast_loss(logits, targets, mask)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10))}
def save_checkpoint_fast(self, epoch: int, metrics: Dict, save_dir: str = "checkpoints"):
os.makedirs(save_dir, exist_ok=True)
val_loss = metrics.get('val_loss', metrics.get('loss', float('inf')))
if val_loss < self.best_loss:
self.best_loss = val_loss
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'metrics': metrics,
'scaler_state_dict': self.scaler.state_dict()
}
best_path = os.path.join(save_dir, "best_model.pt")
torch.save(checkpoint, best_path)
print(f"New best! Loss: {val_loss:.4f}")
return best_path
return None
def train_ultra_fast(self, num_epochs: int = 15, batch_size: int = 16):
print(f"\n ULTRA-FAST TRAINING")
print(f" Target: Loss < 2.0, PPL < 12")
print(f" Time target: 4-5 hours")
print(f" Epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print("-" * 60)
train_loader = DataLoader(
self.train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
drop_last=True
)
val_loader = None
if self.val_dataset:
val_loader = DataLoader(
self.val_dataset,
batch_size=batch_size * 2,
shuffle=False,
num_workers=2,
pin_memory=True
)
total_start_time = time.time()
history = []
for epoch in range(1, num_epochs + 1):
epoch_start = time.time()
print(f"\n EPOCH {epoch}/{num_epochs}")
train_metrics = self.train_epoch_fast(epoch, train_loader)
val_metrics = {}
if val_loader and (epoch % 2 == 0 or epoch == num_epochs):
val_metrics = self.validate_fast(val_loader)
epoch_time = time.time() - epoch_start
epoch_info = {
'epoch': epoch,
'train_loss': train_metrics['loss'],
'train_ppl': train_metrics['perplexity'],
'epoch_time_min': epoch_time / 60
}
if val_metrics:
epoch_info.update({'val_loss': val_metrics['loss'], 'val_ppl': val_metrics['perplexity']})
history.append(epoch_info)
elapsed_hours = (time.time() - total_start_time) / 3600
remaining_hours = elapsed_hours * (num_epochs - epoch) / epoch
print(f"\n EPOCH {epoch} RESULTS:")
print(f" Epoch time: {epoch_time/60:.1f} min")
print(f" Total elapsed: {elapsed_hours:.1f}h")
print(f" Est. remaining: {remaining_hours:.1f}h")
print(f" Train Loss: {train_metrics['loss']:.4f}")
print(f" Train PPL: {train_metrics['perplexity']:.1f}")
if val_metrics:
print(f" Val Loss: {val_metrics['loss']:.4f}")
print(f" Val PPL: {val_metrics['perplexity']:.1f}")
current_loss = val_metrics.get('loss', train_metrics['loss'])
current_ppl = val_metrics.get('perplexity', train_metrics['perplexity'])
if current_loss < 2.0 and current_ppl < 12:
print(f" TARGETS ACHIEVED!")
print(f" Loss: {current_loss:.4f} < 2.0")
print(f" PPL: {current_ppl:.1f} < 12")
combined_metrics = {**train_metrics}
if val_metrics:
combined_metrics.update({f"val_{k}": v for k, v in val_metrics.items()})
self.save_checkpoint_fast(epoch, combined_metrics)
torch.cuda.empty_cache()
gc.collect()
if current_loss < 1.8 and current_ppl < 10:
print(f"EARLY STOPPING - Excellent performance achieved!")
break
total_time = time.time() - total_start_time
print(f"\n TRAINING COMPLETED!")
print(f"Total time: {total_time/3600:.1f} hours")
print(f" Best loss: {self.best_loss:.4f}")
return history
def run_ultra_fast_training():
parser = argparse.ArgumentParser(description="Ultra-Fast Training for 4-5 Hours")
parser.add_argument("--train_file", default=None)
parser.add_argument("--val_file", default=None)
parser.add_argument("--tokenizer_dir", default="tokenizer")
parser.add_argument("--max_examples", type=int, default=50000)
parser.add_argument("--d_model", type=int, default=384)
parser.add_argument("--n_layers", type=int, default=6)
parser.add_argument("--n_heads", type=int, default=6)
parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--eval_every", type=int, default=500)
args = parser.parse_args()
torch.manual_seed(42)
np.random.seed(42)
print("Training My Model")
print("-" * 50)
if args.train_file is None:
patterns = ["*train*.txt", "*_train.txt"]
files = []
for pattern in patterns:
files.extend(glob.glob(pattern))
files.extend(glob.glob(f"split_data/{pattern}"))
files.extend(glob.glob(f"data/{pattern}"))
if files:
args.train_file = files[0]
print(f"Found: {args.train_file}")
else:
print(" No training files found!")
return 1
tokenizer = TechnicalTokenizer()
try:
tokenizer.load(args.tokenizer_dir)
print(f"Tokenizer loaded. Vocab size: {tokenizer.get_vocab_size()}")
except Exception as e:
print(f" Tokenizer error: {e}")
return 1
print(" Creating ultra-fast dataset...")
train_dataset = FastTechnicalTextDataset(
args.train_file, tokenizer, args.max_seq_len, args.max_examples
)
val_dataset = None
if args.val_file and os.path.exists(args.val_file):
val_dataset = FastTechnicalTextDataset(
args.val_file, tokenizer, args.max_seq_len, max_examples=5000
)
model = MixtureOfRecursions(
vocab_size=tokenizer.get_vocab_size(),
d_model=args.d_model,
n_layers=args.n_layers,
n_heads=args.n_heads,
max_seq_len=args.max_seq_len - 1, # Pass the actual sequence length to the model
padding_idx=tokenizer.vocab.get('<pad>', 0)
)
config = {
'learning_rate': args.learning_rate,
'gradient_accumulation_steps': args.gradient_accumulation_steps,
'eval_every': args.eval_every,
'batch_size': args.batch_size,
'epochs': args.epochs
}
trainer = UltraFastTrainer(model, tokenizer, train_dataset, val_dataset, config)
print(f"\n START TRAINING")
results = trainer.train_ultra_fast(args.epochs, args.batch_size)
with open('ultra_fast_results.json', 'w') as f:
json.dump(results, f, indent=2)
print("\n Training Completed!")
print(" Results saved to: ultra_fast_results.json")
return 0
if __name__ == "__main__":
exit(run_ultra_fast_training())