PyPilot / training_manager.py
prelington's picture
Create training_manager.py
ff17dd4 verified
"""
PyPilot Training Manager - Advanced distributed training with monitoring
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
import wandb
import numpy as np
import time
from datetime import datetime
import os
class CodeDataset(Dataset):
def __init__(self, tokenized_data):
self.data = tokenized_data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class PyPilotTrainingManager:
def __init__(self, model, model_name="PyPilot"):
self.model = model
self.model_name = model_name
self.training_history = []
self.best_loss = float('inf')
def setup_distributed_training(self, use_fp16=True, use_gradient_checkpointing=True):
"""Configure distributed training options"""
training_args = TrainingArguments(
output_dir=f"./pypilot-checkpoints",
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=5e-5,
weight_decay=0.01,
warmup_steps=1000,
logging_dir="./logs",
logging_steps=500,
eval_steps=1000,
save_steps=2000,
save_total_limit=5,
prediction_loss_only=True,
remove_unused_columns=False,
fp16=use_fp16,
dataloader_pin_memory=False,
gradient_checkpointing=use_gradient_checkpointing,
report_to=["wandb"],
run_name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
)
return training_args
def setup_wandb_monitoring(self, project_name="pypilot"):
"""Initialize Weights & Biases for experiment tracking"""
wandb.init(
project=project_name,
name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
config={
"architecture": "Transformer",
"dataset": "GitHub Code",
"epochs": 10,
"batch_size": 32,
}
)
def create_advanced_callbacks(self):
"""Create callbacks for training optimization"""
callbacks = [
EarlyStoppingCallback(early_stopping_patience=3),
]
return callbacks
def compute_metrics(self, eval_pred):
"""Compute advanced metrics for code generation"""
predictions, labels = eval_pred
predictions = torch.tensor(predictions)
labels = torch.tensor(labels)
# Calculate perplexity
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(predictions.view(-1, predictions.size(-1)), labels.view(-1))
perplexity = torch.exp(loss)
# Calculate accuracy
preds = torch.argmax(predictions, dim=-1)
accuracy = (preds == labels).float().mean()
return {
"perplexity": perplexity.item(),
"accuracy": accuracy.item(),
"loss": loss.item()
}
def train_with_advanced_features(self, train_dataset, eval_dataset=None):
"""Start advanced training with all features"""
print("πŸš€ Starting Advanced PyPilot Training...")
# Setup monitoring
self.setup_wandb_monitoring()
# Configure training
training_args = self.setup_distributed_training()
callbacks = self.create_advanced_callbacks()
# Create trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=self.compute_metrics,
callbacks=callbacks,
)
# Start training
print("🎯 Training started with advanced features:")
print(f" - FP16 Precision: Enabled")
print(f" - Gradient Checkpointing: Enabled")
print(f" - Early Stopping: Enabled")
print(f" - W&B Monitoring: Enabled")
trainer.train()
# Save final model
trainer.save_model("./pypilot-final-model")
print("βœ… Training completed and model saved!")
return trainer
def hyperparameter_search(self, train_dataset, param_combinations):
"""Perform hyperparameter search"""
best_params = None
for i, params in enumerate(param_combinations):
print(f"πŸ” Testing hyperparameter combination {i+1}/{len(param_combinations)}")
# Update model with new params
self.update_model_hyperparams(params)
# Quick training run to evaluate
quick_trainer = Trainer(
model=self.model,
args=TrainingArguments(
output_dir=f"./hparam-search-{i}",
num_train_epochs=1,
per_device_train_batch_size=params['batch_size'],
learning_rate=params['learning_rate'],
),
train_dataset=train_dataset,
)
results = quick_trainer.train()
if results.training_loss < self.best_loss:
self.best_loss = results.training_loss
best_params = params
print(f"🎯 Best hyperparameters: {best_params}")
return best_params
if __name__ == "__main__":
# Example usage
from modeling_pypilot import PyPilotModel, PyPilotConfig
config = PyPilotConfig()
model = PyPilotModel(config)
manager = PyPilotTrainingManager(model)
print("βœ… Training Manager ready!")