Spaces:
Runtime error
Runtime error
| """ | |
| Custom Training Loop Module | |
| Provides custom training loop implementation with fine-grained control over training. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Dict, Any, Callable | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| class TrainingConfig: | |
| """Configuration for custom training loop.""" | |
| num_epochs: int = 3 | |
| learning_rate: float = 2e-4 | |
| batch_size: int = 4 | |
| gradient_accumulation_steps: int = 4 | |
| max_grad_norm: float = 1.0 | |
| warmup_steps: int = 100 | |
| logging_steps: int = 10 | |
| eval_steps: int = 500 | |
| save_steps: int = 500 | |
| output_dir: str = "./models/output" | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| class TrainingLoop: | |
| """ | |
| Custom training loop for fine-grained control over the training process. | |
| Provides manual control over: | |
| - Forward/backward passes | |
| - Gradient accumulation | |
| - Learning rate scheduling | |
| - Logging and evaluation | |
| - Checkpointing | |
| """ | |
| def __init__( | |
| self, | |
| model: torch.nn.Module, | |
| train_dataloader: DataLoader, | |
| eval_dataloader: Optional[DataLoader] = None, | |
| config: Optional[TrainingConfig] = None | |
| ): | |
| """ | |
| Initialize custom training loop. | |
| Args: | |
| model: PyTorch model to train | |
| train_dataloader: Training data loader | |
| eval_dataloader: Optional evaluation data loader | |
| config: Training configuration | |
| """ | |
| self.model = model | |
| self.train_dataloader = train_dataloader | |
| self.eval_dataloader = eval_dataloader | |
| self.config = config or TrainingConfig() | |
| self.optimizer = None | |
| self.scheduler = None | |
| self.global_step = 0 | |
| self.current_epoch = 0 | |
| def setup_optimizer(self, optimizer_class=torch.optim.AdamW, **optimizer_kwargs): | |
| """ | |
| Setup optimizer and learning rate scheduler. | |
| Args: | |
| optimizer_class: Optimizer class to use | |
| **optimizer_kwargs: Additional optimizer arguments | |
| """ | |
| self.optimizer = optimizer_class( | |
| self.model.parameters(), | |
| lr=self.config.learning_rate, | |
| **optimizer_kwargs | |
| ) | |
| # Linear warmup scheduler | |
| def lr_lambda(current_step: int): | |
| if current_step < self.config.warmup_steps: | |
| return float(current_step) / float(max(1, self.config.warmup_steps)) | |
| return 1.0 | |
| self.scheduler = torch.optim.lr_scheduler.LambdaLR( | |
| self.optimizer, | |
| lr_lambda | |
| ) | |
| def train_step(self, batch: Dict[str, torch.Tensor]) -> float: | |
| """ | |
| Perform a single training step. | |
| Args: | |
| batch: Batch of training data | |
| Returns: | |
| Loss value | |
| """ | |
| # Move batch to device | |
| batch = {k: v.to(self.config.device) for k, v in batch.items()} | |
| # Forward pass | |
| outputs = self.model(**batch) | |
| loss = outputs.loss | |
| # Scale loss for gradient accumulation | |
| loss = loss / self.config.gradient_accumulation_steps | |
| # Backward pass | |
| loss.backward() | |
| return loss.item() | |
| def train_epoch(self) -> Dict[str, float]: | |
| """ | |
| Train for one epoch. | |
| Returns: | |
| Training metrics | |
| """ | |
| self.model.train() | |
| total_loss = 0 | |
| num_batches = 0 | |
| progress_bar = tqdm( | |
| self.train_dataloader, | |
| desc=f"Epoch {self.current_epoch + 1}/{self.config.num_epochs}" | |
| ) | |
| for step, batch in enumerate(progress_bar): | |
| # Training step | |
| loss = self.train_step(batch) | |
| total_loss += loss | |
| # Gradient accumulation | |
| if (step + 1) % self.config.gradient_accumulation_steps == 0: | |
| # Clip gradients | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), | |
| self.config.max_grad_norm | |
| ) | |
| # Optimizer step | |
| self.optimizer.step() | |
| self.scheduler.step() | |
| self.optimizer.zero_grad() | |
| self.global_step += 1 | |
| num_batches += 1 | |
| # Update progress bar | |
| progress_bar.set_postfix({ | |
| "loss": total_loss / num_batches, | |
| "lr": self.scheduler.get_last_lr()[0] | |
| }) | |
| # Logging | |
| if self.global_step % self.config.logging_steps == 0: | |
| avg_loss = total_loss / num_batches | |
| print(f"Step {self.global_step}: loss={avg_loss:.4f}") | |
| # Evaluation | |
| if self.eval_dataloader and self.global_step % self.config.eval_steps == 0: | |
| eval_metrics = self.evaluate() | |
| print(f"Evaluation: {eval_metrics}") | |
| self.model.train() | |
| return { | |
| "loss": total_loss / max(num_batches, 1), | |
| "epoch": self.current_epoch | |
| } | |
| def evaluate(self) -> Dict[str, float]: | |
| """ | |
| Evaluate model on validation set. | |
| Returns: | |
| Evaluation metrics | |
| """ | |
| if self.eval_dataloader is None: | |
| return {} | |
| self.model.eval() | |
| total_loss = 0 | |
| num_batches = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(self.eval_dataloader, desc="Evaluating"): | |
| batch = {k: v.to(self.config.device) for k, v in batch.items()} | |
| outputs = self.model(**batch) | |
| total_loss += outputs.loss.item() | |
| num_batches += 1 | |
| return { | |
| "eval_loss": total_loss / max(num_batches, 1) | |
| } | |
| def train(self, callback: Optional[Callable] = None) -> Dict[str, Any]: | |
| """ | |
| Run full training loop. | |
| Args: | |
| callback: Optional callback function called after each epoch | |
| Returns: | |
| Training history | |
| """ | |
| if self.optimizer is None: | |
| self.setup_optimizer() | |
| print(f"Starting training for {self.config.num_epochs} epochs") | |
| print(f"Device: {self.config.device}") | |
| print(f"Batch size: {self.config.batch_size}") | |
| print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}") | |
| history = { | |
| "train_loss": [], | |
| "eval_loss": [] | |
| } | |
| for epoch in range(self.config.num_epochs): | |
| self.current_epoch = epoch | |
| # Train epoch | |
| train_metrics = self.train_epoch() | |
| history["train_loss"].append(train_metrics["loss"]) | |
| # Evaluate | |
| if self.eval_dataloader: | |
| eval_metrics = self.evaluate() | |
| history["eval_loss"].append(eval_metrics.get("eval_loss", 0)) | |
| # Callback | |
| if callback: | |
| callback(epoch, train_metrics) | |
| print("✅ Training complete!") | |
| return history | |
| def save_checkpoint(self, path: str) -> None: | |
| """ | |
| Save training checkpoint. | |
| Args: | |
| path: Path to save checkpoint | |
| """ | |
| checkpoint = { | |
| "model_state_dict": self.model.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| "scheduler_state_dict": self.scheduler.state_dict(), | |
| "global_step": self.global_step, | |
| "epoch": self.current_epoch | |
| } | |
| torch.save(checkpoint, path) | |
| print(f"Checkpoint saved to: {path}") | |
| def load_checkpoint(self, path: str) -> None: | |
| """ | |
| Load training checkpoint. | |
| Args: | |
| path: Path to checkpoint | |
| """ | |
| checkpoint = torch.load(path) | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| self.global_step = checkpoint["global_step"] | |
| self.current_epoch = checkpoint["epoch"] | |
| print(f"Checkpoint loaded from: {path}") | |