LaunchLLM / fine_tuning /training_loop.py
Bmccloud22's picture
Deploy LaunchLLM - Production AI Training Platform
ec8f374 verified
"""
Custom Training Loop Module
Provides custom training loop implementation with fine-grained control over training.
"""
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
@dataclass
class TrainingConfig:
"""Configuration for custom training loop."""
num_epochs: int = 3
learning_rate: float = 2e-4
batch_size: int = 4
gradient_accumulation_steps: int = 4
max_grad_norm: float = 1.0
warmup_steps: int = 100
logging_steps: int = 10
eval_steps: int = 500
save_steps: int = 500
output_dir: str = "./models/output"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
class TrainingLoop:
"""
Custom training loop for fine-grained control over the training process.
Provides manual control over:
- Forward/backward passes
- Gradient accumulation
- Learning rate scheduling
- Logging and evaluation
- Checkpointing
"""
def __init__(
self,
model: torch.nn.Module,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
config: Optional[TrainingConfig] = None
):
"""
Initialize custom training loop.
Args:
model: PyTorch model to train
train_dataloader: Training data loader
eval_dataloader: Optional evaluation data loader
config: Training configuration
"""
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.config = config or TrainingConfig()
self.optimizer = None
self.scheduler = None
self.global_step = 0
self.current_epoch = 0
def setup_optimizer(self, optimizer_class=torch.optim.AdamW, **optimizer_kwargs):
"""
Setup optimizer and learning rate scheduler.
Args:
optimizer_class: Optimizer class to use
**optimizer_kwargs: Additional optimizer arguments
"""
self.optimizer = optimizer_class(
self.model.parameters(),
lr=self.config.learning_rate,
**optimizer_kwargs
)
# Linear warmup scheduler
def lr_lambda(current_step: int):
if current_step < self.config.warmup_steps:
return float(current_step) / float(max(1, self.config.warmup_steps))
return 1.0
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lr_lambda
)
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
"""
Perform a single training step.
Args:
batch: Batch of training data
Returns:
Loss value
"""
# Move batch to device
batch = {k: v.to(self.config.device) for k, v in batch.items()}
# Forward pass
outputs = self.model(**batch)
loss = outputs.loss
# Scale loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass
loss.backward()
return loss.item()
def train_epoch(self) -> Dict[str, float]:
"""
Train for one epoch.
Returns:
Training metrics
"""
self.model.train()
total_loss = 0
num_batches = 0
progress_bar = tqdm(
self.train_dataloader,
desc=f"Epoch {self.current_epoch + 1}/{self.config.num_epochs}"
)
for step, batch in enumerate(progress_bar):
# Training step
loss = self.train_step(batch)
total_loss += loss
# Gradient accumulation
if (step + 1) % self.config.gradient_accumulation_steps == 0:
# Clip gradients
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
num_batches += 1
# Update progress bar
progress_bar.set_postfix({
"loss": total_loss / num_batches,
"lr": self.scheduler.get_last_lr()[0]
})
# Logging
if self.global_step % self.config.logging_steps == 0:
avg_loss = total_loss / num_batches
print(f"Step {self.global_step}: loss={avg_loss:.4f}")
# Evaluation
if self.eval_dataloader and self.global_step % self.config.eval_steps == 0:
eval_metrics = self.evaluate()
print(f"Evaluation: {eval_metrics}")
self.model.train()
return {
"loss": total_loss / max(num_batches, 1),
"epoch": self.current_epoch
}
def evaluate(self) -> Dict[str, float]:
"""
Evaluate model on validation set.
Returns:
Evaluation metrics
"""
if self.eval_dataloader is None:
return {}
self.model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for batch in tqdm(self.eval_dataloader, desc="Evaluating"):
batch = {k: v.to(self.config.device) for k, v in batch.items()}
outputs = self.model(**batch)
total_loss += outputs.loss.item()
num_batches += 1
return {
"eval_loss": total_loss / max(num_batches, 1)
}
def train(self, callback: Optional[Callable] = None) -> Dict[str, Any]:
"""
Run full training loop.
Args:
callback: Optional callback function called after each epoch
Returns:
Training history
"""
if self.optimizer is None:
self.setup_optimizer()
print(f"Starting training for {self.config.num_epochs} epochs")
print(f"Device: {self.config.device}")
print(f"Batch size: {self.config.batch_size}")
print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}")
history = {
"train_loss": [],
"eval_loss": []
}
for epoch in range(self.config.num_epochs):
self.current_epoch = epoch
# Train epoch
train_metrics = self.train_epoch()
history["train_loss"].append(train_metrics["loss"])
# Evaluate
if self.eval_dataloader:
eval_metrics = self.evaluate()
history["eval_loss"].append(eval_metrics.get("eval_loss", 0))
# Callback
if callback:
callback(epoch, train_metrics)
print("✅ Training complete!")
return history
def save_checkpoint(self, path: str) -> None:
"""
Save training checkpoint.
Args:
path: Path to save checkpoint
"""
checkpoint = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"global_step": self.global_step,
"epoch": self.current_epoch
}
torch.save(checkpoint, path)
print(f"Checkpoint saved to: {path}")
def load_checkpoint(self, path: str) -> None:
"""
Load training checkpoint.
Args:
path: Path to checkpoint
"""
checkpoint = torch.load(path)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.global_step = checkpoint["global_step"]
self.current_epoch = checkpoint["epoch"]
print(f"Checkpoint loaded from: {path}")