File size: 5,294 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import json
import torch
from specforge.lr_scheduler import CosineAnnealingWarmupLR
from specforge.utils import print_on_rank0
class BF16Optimizer:
def __init__(
self,
model,
lr,
weight_decay=0.0,
max_grad_norm=0.5,
total_steps=800_000,
warmup_ratio=0.015,
use_fp32_params=True,
optimizer_type="adamw",
optimizer_config=None,
):
self.model = model
self.model_params = [p for p in model.parameters() if p.requires_grad]
self.max_grad_norm = max_grad_norm
self.use_fp32_params = use_fp32_params
self.optimizer_type = optimizer_type
if use_fp32_params:
self.fp32_params = [
p.detach().clone().to(torch.float32) for p in self.model_params
]
for mp in self.fp32_params:
mp.requires_grad = True
self.optim_params = self.fp32_params
else:
self.fp32_params = None
self.optim_params = self.model_params
self.optimizer = self._create_optimizer(lr, weight_decay, optimizer_config)
self.scheduler = CosineAnnealingWarmupLR(
self.optimizer,
total_steps=total_steps,
warmup_steps=int(warmup_ratio * total_steps),
)
def _create_optimizer(self, lr, weight_decay, optimizer_config):
if self.optimizer_type == "adamw":
return torch.optim.AdamW(
self.optim_params, lr=lr, weight_decay=weight_decay
)
elif self.optimizer_type == "adamw_8bit":
import bitsandbytes as bnb
return bnb.optim.AdamW8bit(
self.optim_params, lr=lr, weight_decay=weight_decay
)
elif self.optimizer_type == "apollo":
from apollo_torch import APOLLOAdamW
assert optimizer_config is not None, (
"optimizer_config path is required when optimizer_type='apollo'"
)
with open(optimizer_config, "r") as f:
apollo_cfg = json.load(f)
param_groups = self._build_apollo_param_groups(
apollo_cfg, lr, weight_decay
)
return APOLLOAdamW(param_groups, lr=lr, weight_decay=weight_decay)
else:
raise ValueError(
f"Unknown optimizer_type: {self.optimizer_type}. "
f"Supported types: adamw, adamw_8bit, apollo"
)
def _build_apollo_param_groups(self, apollo_cfg, lr, weight_decay):
"""Build param groups for APOLLO optimizer.
Splits parameters into two groups:
- non_lowrank_params: 1D params (bias, layernorm) - standard Adam update
- lowrank_params: nD params (weight matrices) - low-rank projected update
"""
lowrank_params = []
non_lowrank_params = []
for p in self.optim_params:
if p.ndim >= 2:
lowrank_params.append(p)
else:
non_lowrank_params.append(p)
param_groups = [
{"params": non_lowrank_params, "lr": lr, "weight_decay": weight_decay},
{
"params": lowrank_params,
"lr": lr,
"weight_decay": weight_decay,
"rank": apollo_cfg.get("rank", 1),
"proj": apollo_cfg.get("proj", "random"),
"scale_type": apollo_cfg.get("scale_type", "tensor"),
"scale": apollo_cfg.get("scale", 128),
"update_proj_gap": apollo_cfg.get("update_proj_gap", 200),
"proj_type": apollo_cfg.get("proj_type", "std"),
},
]
return param_groups
def step(self):
if self.use_fp32_params:
with torch.no_grad():
for p, mp in zip(self.model_params, self.fp32_params):
mp.grad = (
p.grad.detach().to(torch.float32)
if p.grad is not None
else None
)
torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model_params, self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
if self.use_fp32_params:
with torch.no_grad():
for p, mp in zip(self.model_params, self.fp32_params):
p.data.copy_(mp.data.to(p.dtype))
p.grad = None
else:
with torch.no_grad():
for p in self.model_params:
p.grad = None
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
print_on_rank0("Successfully loaded optimizer state_dict.")
self.scheduler.load_state_dict(state_dict["scheduler_state_dict"])
print_on_rank0("Successfully loaded scheduler state_dict.")
def state_dict(self):
return {
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
}
def get_learning_rate(self):
return self.optimizer.param_groups[0]["lr"]
|