| import json |
|
|
| import torch |
|
|
| from specforge.lr_scheduler import CosineAnnealingWarmupLR |
| from specforge.utils import print_on_rank0 |
|
|
|
|
| class BF16Optimizer: |
| def __init__( |
| self, |
| model, |
| lr, |
| weight_decay=0.0, |
| max_grad_norm=0.5, |
| total_steps=800_000, |
| warmup_ratio=0.015, |
| use_fp32_params=True, |
| optimizer_type="adamw", |
| optimizer_config=None, |
| ): |
| self.model = model |
| self.model_params = [p for p in model.parameters() if p.requires_grad] |
| self.max_grad_norm = max_grad_norm |
| self.use_fp32_params = use_fp32_params |
| self.optimizer_type = optimizer_type |
|
|
| if use_fp32_params: |
| self.fp32_params = [ |
| p.detach().clone().to(torch.float32) for p in self.model_params |
| ] |
| for mp in self.fp32_params: |
| mp.requires_grad = True |
| self.optim_params = self.fp32_params |
| else: |
| self.fp32_params = None |
| self.optim_params = self.model_params |
|
|
| self.optimizer = self._create_optimizer(lr, weight_decay, optimizer_config) |
| self.scheduler = CosineAnnealingWarmupLR( |
| self.optimizer, |
| total_steps=total_steps, |
| warmup_steps=int(warmup_ratio * total_steps), |
| ) |
|
|
| def _create_optimizer(self, lr, weight_decay, optimizer_config): |
| if self.optimizer_type == "adamw": |
| return torch.optim.AdamW( |
| self.optim_params, lr=lr, weight_decay=weight_decay |
| ) |
| elif self.optimizer_type == "adamw_8bit": |
| import bitsandbytes as bnb |
|
|
| return bnb.optim.AdamW8bit( |
| self.optim_params, lr=lr, weight_decay=weight_decay |
| ) |
| elif self.optimizer_type == "apollo": |
| from apollo_torch import APOLLOAdamW |
|
|
| assert optimizer_config is not None, ( |
| "optimizer_config path is required when optimizer_type='apollo'" |
| ) |
| with open(optimizer_config, "r") as f: |
| apollo_cfg = json.load(f) |
| param_groups = self._build_apollo_param_groups( |
| apollo_cfg, lr, weight_decay |
| ) |
| return APOLLOAdamW(param_groups, lr=lr, weight_decay=weight_decay) |
| else: |
| raise ValueError( |
| f"Unknown optimizer_type: {self.optimizer_type}. " |
| f"Supported types: adamw, adamw_8bit, apollo" |
| ) |
|
|
| def _build_apollo_param_groups(self, apollo_cfg, lr, weight_decay): |
| """Build param groups for APOLLO optimizer. |
| |
| Splits parameters into two groups: |
| - non_lowrank_params: 1D params (bias, layernorm) - standard Adam update |
| - lowrank_params: nD params (weight matrices) - low-rank projected update |
| """ |
| lowrank_params = [] |
| non_lowrank_params = [] |
| for p in self.optim_params: |
| if p.ndim >= 2: |
| lowrank_params.append(p) |
| else: |
| non_lowrank_params.append(p) |
|
|
| param_groups = [ |
| {"params": non_lowrank_params, "lr": lr, "weight_decay": weight_decay}, |
| { |
| "params": lowrank_params, |
| "lr": lr, |
| "weight_decay": weight_decay, |
| "rank": apollo_cfg.get("rank", 1), |
| "proj": apollo_cfg.get("proj", "random"), |
| "scale_type": apollo_cfg.get("scale_type", "tensor"), |
| "scale": apollo_cfg.get("scale", 128), |
| "update_proj_gap": apollo_cfg.get("update_proj_gap", 200), |
| "proj_type": apollo_cfg.get("proj_type", "std"), |
| }, |
| ] |
| return param_groups |
|
|
| def step(self): |
| if self.use_fp32_params: |
| with torch.no_grad(): |
| for p, mp in zip(self.model_params, self.fp32_params): |
| mp.grad = ( |
| p.grad.detach().to(torch.float32) |
| if p.grad is not None |
| else None |
| ) |
| torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm) |
| else: |
| torch.nn.utils.clip_grad_norm_(self.model_params, self.max_grad_norm) |
|
|
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| self.scheduler.step() |
|
|
| if self.use_fp32_params: |
| with torch.no_grad(): |
| for p, mp in zip(self.model_params, self.fp32_params): |
| p.data.copy_(mp.data.to(p.dtype)) |
| p.grad = None |
| else: |
| with torch.no_grad(): |
| for p in self.model_params: |
| p.grad = None |
|
|
| def load_state_dict(self, state_dict): |
| self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) |
| print_on_rank0("Successfully loaded optimizer state_dict.") |
| self.scheduler.load_state_dict(state_dict["scheduler_state_dict"]) |
| print_on_rank0("Successfully loaded scheduler state_dict.") |
|
|
| def state_dict(self): |
| return { |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "scheduler_state_dict": self.scheduler.state_dict(), |
| } |
|
|
| def get_learning_rate(self): |
| return self.optimizer.param_groups[0]["lr"] |
|
|