File size: 5,919 Bytes
ff17dd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""
PyPilot Training Manager - Advanced distributed training with monitoring
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
import wandb
import numpy as np
import time
from datetime import datetime
import os
class CodeDataset(Dataset):
def __init__(self, tokenized_data):
self.data = tokenized_data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class PyPilotTrainingManager:
def __init__(self, model, model_name="PyPilot"):
self.model = model
self.model_name = model_name
self.training_history = []
self.best_loss = float('inf')
def setup_distributed_training(self, use_fp16=True, use_gradient_checkpointing=True):
"""Configure distributed training options"""
training_args = TrainingArguments(
output_dir=f"./pypilot-checkpoints",
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=5e-5,
weight_decay=0.01,
warmup_steps=1000,
logging_dir="./logs",
logging_steps=500,
eval_steps=1000,
save_steps=2000,
save_total_limit=5,
prediction_loss_only=True,
remove_unused_columns=False,
fp16=use_fp16,
dataloader_pin_memory=False,
gradient_checkpointing=use_gradient_checkpointing,
report_to=["wandb"],
run_name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
)
return training_args
def setup_wandb_monitoring(self, project_name="pypilot"):
"""Initialize Weights & Biases for experiment tracking"""
wandb.init(
project=project_name,
name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
config={
"architecture": "Transformer",
"dataset": "GitHub Code",
"epochs": 10,
"batch_size": 32,
}
)
def create_advanced_callbacks(self):
"""Create callbacks for training optimization"""
callbacks = [
EarlyStoppingCallback(early_stopping_patience=3),
]
return callbacks
def compute_metrics(self, eval_pred):
"""Compute advanced metrics for code generation"""
predictions, labels = eval_pred
predictions = torch.tensor(predictions)
labels = torch.tensor(labels)
# Calculate perplexity
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(predictions.view(-1, predictions.size(-1)), labels.view(-1))
perplexity = torch.exp(loss)
# Calculate accuracy
preds = torch.argmax(predictions, dim=-1)
accuracy = (preds == labels).float().mean()
return {
"perplexity": perplexity.item(),
"accuracy": accuracy.item(),
"loss": loss.item()
}
def train_with_advanced_features(self, train_dataset, eval_dataset=None):
"""Start advanced training with all features"""
print("π Starting Advanced PyPilot Training...")
# Setup monitoring
self.setup_wandb_monitoring()
# Configure training
training_args = self.setup_distributed_training()
callbacks = self.create_advanced_callbacks()
# Create trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=self.compute_metrics,
callbacks=callbacks,
)
# Start training
print("π― Training started with advanced features:")
print(f" - FP16 Precision: Enabled")
print(f" - Gradient Checkpointing: Enabled")
print(f" - Early Stopping: Enabled")
print(f" - W&B Monitoring: Enabled")
trainer.train()
# Save final model
trainer.save_model("./pypilot-final-model")
print("β
Training completed and model saved!")
return trainer
def hyperparameter_search(self, train_dataset, param_combinations):
"""Perform hyperparameter search"""
best_params = None
for i, params in enumerate(param_combinations):
print(f"π Testing hyperparameter combination {i+1}/{len(param_combinations)}")
# Update model with new params
self.update_model_hyperparams(params)
# Quick training run to evaluate
quick_trainer = Trainer(
model=self.model,
args=TrainingArguments(
output_dir=f"./hparam-search-{i}",
num_train_epochs=1,
per_device_train_batch_size=params['batch_size'],
learning_rate=params['learning_rate'],
),
train_dataset=train_dataset,
)
results = quick_trainer.train()
if results.training_loss < self.best_loss:
self.best_loss = results.training_loss
best_params = params
print(f"π― Best hyperparameters: {best_params}")
return best_params
if __name__ == "__main__":
# Example usage
from modeling_pypilot import PyPilotModel, PyPilotConfig
config = PyPilotConfig()
model = PyPilotModel(config)
manager = PyPilotTrainingManager(model)
print("β
Training Manager ready!") |