| import torch |
|
|
| from specforge.lr_scheduler import CosineAnnealingWarmupLR |
| from specforge.utils import print_on_rank0 |
|
|
|
|
| class BF16Optimizer: |
| def __init__( |
| self, |
| model, |
| lr, |
| weight_decay=0.0, |
| max_grad_norm=0.5, |
| total_steps=800_000, |
| warmup_ratio=0.015, |
| ): |
| |
| |
| |
| |
| self.model = model |
| self.model_params = [p for p in model.parameters() if p.requires_grad] |
| self.max_grad_norm = max_grad_norm |
| self.fp32_params = [ |
| p.detach().clone().to(torch.float32) for p in self.model_params |
| ] |
| for mp in self.fp32_params: |
| mp.requires_grad = True |
| self.optimizer = torch.optim.AdamW( |
| self.fp32_params, lr=lr, weight_decay=weight_decay |
| ) |
| self.scheduler = CosineAnnealingWarmupLR( |
| self.optimizer, |
| total_steps=total_steps, |
| warmup_steps=int(warmup_ratio * total_steps), |
| ) |
|
|
| def step(self): |
| with torch.no_grad(): |
| for p, mp in zip(self.model_params, self.fp32_params): |
| mp.grad = ( |
| p.grad.detach().to(torch.float32) if p.grad is not None else None |
| ) |
| torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm) |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| self.scheduler.step() |
| with torch.no_grad(): |
| for p, mp in zip(self.model_params, self.fp32_params): |
| p.data.copy_(mp.data.to(p.dtype)) |
| p.grad = None |
|
|
| def load_state_dict(self, state_dict): |
| self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) |
| print_on_rank0("Successfully loaded optimizer state_dict.") |
| self.scheduler.load_state_dict(state_dict["scheduler_state_dict"]) |
| print_on_rank0("Successfully loaded scheduler state_dict.") |
|
|
| def state_dict(self): |
| return { |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "scheduler_state_dict": self.scheduler.state_dict(), |
| } |
|
|
| def get_learning_rate(self): |
| return self.optimizer.param_groups[0]["lr"] |
|
|