Hanrui / progress /SpecForge /specforge /optimizer.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import json
import torch
from specforge.lr_scheduler import CosineAnnealingWarmupLR
from specforge.utils import print_on_rank0
class BF16Optimizer:
def __init__(
self,
model,
lr,
weight_decay=0.0,
max_grad_norm=0.5,
total_steps=800_000,
warmup_ratio=0.015,
use_fp32_params=True,
optimizer_type="adamw",
optimizer_config=None,
):
self.model = model
self.model_params = [p for p in model.parameters() if p.requires_grad]
self.max_grad_norm = max_grad_norm
self.use_fp32_params = use_fp32_params
self.optimizer_type = optimizer_type
if use_fp32_params:
self.fp32_params = [
p.detach().clone().to(torch.float32) for p in self.model_params
]
for mp in self.fp32_params:
mp.requires_grad = True
self.optim_params = self.fp32_params
else:
self.fp32_params = None
self.optim_params = self.model_params
self.optimizer = self._create_optimizer(lr, weight_decay, optimizer_config)
self.scheduler = CosineAnnealingWarmupLR(
self.optimizer,
total_steps=total_steps,
warmup_steps=int(warmup_ratio * total_steps),
)
def _create_optimizer(self, lr, weight_decay, optimizer_config):
if self.optimizer_type == "adamw":
return torch.optim.AdamW(
self.optim_params, lr=lr, weight_decay=weight_decay
)
elif self.optimizer_type == "adamw_8bit":
import bitsandbytes as bnb
return bnb.optim.AdamW8bit(
self.optim_params, lr=lr, weight_decay=weight_decay
)
elif self.optimizer_type == "apollo":
from apollo_torch import APOLLOAdamW
assert optimizer_config is not None, (
"optimizer_config path is required when optimizer_type='apollo'"
)
with open(optimizer_config, "r") as f:
apollo_cfg = json.load(f)
param_groups = self._build_apollo_param_groups(
apollo_cfg, lr, weight_decay
)
return APOLLOAdamW(param_groups, lr=lr, weight_decay=weight_decay)
else:
raise ValueError(
f"Unknown optimizer_type: {self.optimizer_type}. "
f"Supported types: adamw, adamw_8bit, apollo"
)
def _build_apollo_param_groups(self, apollo_cfg, lr, weight_decay):
"""Build param groups for APOLLO optimizer.
Splits parameters into two groups:
- non_lowrank_params: 1D params (bias, layernorm) - standard Adam update
- lowrank_params: nD params (weight matrices) - low-rank projected update
"""
lowrank_params = []
non_lowrank_params = []
for p in self.optim_params:
if p.ndim >= 2:
lowrank_params.append(p)
else:
non_lowrank_params.append(p)
param_groups = [
{"params": non_lowrank_params, "lr": lr, "weight_decay": weight_decay},
{
"params": lowrank_params,
"lr": lr,
"weight_decay": weight_decay,
"rank": apollo_cfg.get("rank", 1),
"proj": apollo_cfg.get("proj", "random"),
"scale_type": apollo_cfg.get("scale_type", "tensor"),
"scale": apollo_cfg.get("scale", 128),
"update_proj_gap": apollo_cfg.get("update_proj_gap", 200),
"proj_type": apollo_cfg.get("proj_type", "std"),
},
]
return param_groups
def step(self):
if self.use_fp32_params:
with torch.no_grad():
for p, mp in zip(self.model_params, self.fp32_params):
mp.grad = (
p.grad.detach().to(torch.float32)
if p.grad is not None
else None
)
torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model_params, self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
if self.use_fp32_params:
with torch.no_grad():
for p, mp in zip(self.model_params, self.fp32_params):
p.data.copy_(mp.data.to(p.dtype))
p.grad = None
else:
with torch.no_grad():
for p in self.model_params:
p.grad = None
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
print_on_rank0("Successfully loaded optimizer state_dict.")
self.scheduler.load_state_dict(state_dict["scheduler_state_dict"])
print_on_rank0("Successfully loaded scheduler state_dict.")
def state_dict(self):
return {
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
}
def get_learning_rate(self):
return self.optimizer.param_groups[0]["lr"]