import json import torch from specforge.lr_scheduler import CosineAnnealingWarmupLR from specforge.utils import print_on_rank0 class BF16Optimizer: def __init__( self, model, lr, weight_decay=0.0, max_grad_norm=0.5, total_steps=800_000, warmup_ratio=0.015, use_fp32_params=True, optimizer_type="adamw", optimizer_config=None, ): self.model = model self.model_params = [p for p in model.parameters() if p.requires_grad] self.max_grad_norm = max_grad_norm self.use_fp32_params = use_fp32_params self.optimizer_type = optimizer_type if use_fp32_params: self.fp32_params = [ p.detach().clone().to(torch.float32) for p in self.model_params ] for mp in self.fp32_params: mp.requires_grad = True self.optim_params = self.fp32_params else: self.fp32_params = None self.optim_params = self.model_params self.optimizer = self._create_optimizer(lr, weight_decay, optimizer_config) self.scheduler = CosineAnnealingWarmupLR( self.optimizer, total_steps=total_steps, warmup_steps=int(warmup_ratio * total_steps), ) def _create_optimizer(self, lr, weight_decay, optimizer_config): if self.optimizer_type == "adamw": return torch.optim.AdamW( self.optim_params, lr=lr, weight_decay=weight_decay ) elif self.optimizer_type == "adamw_8bit": import bitsandbytes as bnb return bnb.optim.AdamW8bit( self.optim_params, lr=lr, weight_decay=weight_decay ) elif self.optimizer_type == "apollo": from apollo_torch import APOLLOAdamW assert optimizer_config is not None, ( "optimizer_config path is required when optimizer_type='apollo'" ) with open(optimizer_config, "r") as f: apollo_cfg = json.load(f) param_groups = self._build_apollo_param_groups( apollo_cfg, lr, weight_decay ) return APOLLOAdamW(param_groups, lr=lr, weight_decay=weight_decay) else: raise ValueError( f"Unknown optimizer_type: {self.optimizer_type}. " f"Supported types: adamw, adamw_8bit, apollo" ) def _build_apollo_param_groups(self, apollo_cfg, lr, weight_decay): """Build param groups for APOLLO optimizer. Splits parameters into two groups: - non_lowrank_params: 1D params (bias, layernorm) - standard Adam update - lowrank_params: nD params (weight matrices) - low-rank projected update """ lowrank_params = [] non_lowrank_params = [] for p in self.optim_params: if p.ndim >= 2: lowrank_params.append(p) else: non_lowrank_params.append(p) param_groups = [ {"params": non_lowrank_params, "lr": lr, "weight_decay": weight_decay}, { "params": lowrank_params, "lr": lr, "weight_decay": weight_decay, "rank": apollo_cfg.get("rank", 1), "proj": apollo_cfg.get("proj", "random"), "scale_type": apollo_cfg.get("scale_type", "tensor"), "scale": apollo_cfg.get("scale", 128), "update_proj_gap": apollo_cfg.get("update_proj_gap", 200), "proj_type": apollo_cfg.get("proj_type", "std"), }, ] return param_groups def step(self): if self.use_fp32_params: with torch.no_grad(): for p, mp in zip(self.model_params, self.fp32_params): mp.grad = ( p.grad.detach().to(torch.float32) if p.grad is not None else None ) torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.model_params, self.max_grad_norm) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() if self.use_fp32_params: with torch.no_grad(): for p, mp in zip(self.model_params, self.fp32_params): p.data.copy_(mp.data.to(p.dtype)) p.grad = None else: with torch.no_grad(): for p in self.model_params: p.grad = None def load_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) print_on_rank0("Successfully loaded optimizer state_dict.") self.scheduler.load_state_dict(state_dict["scheduler_state_dict"]) print_on_rank0("Successfully loaded scheduler state_dict.") def state_dict(self): return { "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), } def get_learning_rate(self): return self.optimizer.param_groups[0]["lr"]