Hanrui / SpecForge-ext /specforge /optimizer.py
Lekr0's picture
Add files using upload-large-folder tool
4024ed7 verified
import torch
from specforge.lr_scheduler import CosineAnnealingWarmupLR
from specforge.utils import print_on_rank0
class BF16Optimizer:
def __init__(
self,
model,
lr,
weight_decay=0.0,
max_grad_norm=0.5,
total_steps=800_000,
warmup_ratio=0.015,
):
# TODO: For now, we only support cosine annealing warmup lr scheduler and AdamW optimizer
# TODO: We should make these parameters configurable
# These magic numbers: weight_decay=0.0, max_grad_norm=0.5, total_steps=800k, warmup_steps=12k are copied from
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/ds_config.json
self.model = model
self.model_params = [p for p in model.parameters() if p.requires_grad]
self.max_grad_norm = max_grad_norm
self.fp32_params = [
p.detach().clone().to(torch.float32) for p in self.model_params
]
for mp in self.fp32_params:
mp.requires_grad = True
self.optimizer = torch.optim.AdamW(
self.fp32_params, lr=lr, weight_decay=weight_decay
)
self.scheduler = CosineAnnealingWarmupLR(
self.optimizer,
total_steps=total_steps,
warmup_steps=int(warmup_ratio * total_steps),
)
def step(self):
with torch.no_grad():
for p, mp in zip(self.model_params, self.fp32_params):
mp.grad = (
p.grad.detach().to(torch.float32) if p.grad is not None else None
)
torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
with torch.no_grad():
for p, mp in zip(self.model_params, self.fp32_params):
p.data.copy_(mp.data.to(p.dtype))
p.grad = None
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
print_on_rank0("Successfully loaded optimizer state_dict.")
self.scheduler.load_state_dict(state_dict["scheduler_state_dict"])
print_on_rank0("Successfully loaded scheduler state_dict.")
def state_dict(self):
return {
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
}
def get_learning_rate(self):
return self.optimizer.param_groups[0]["lr"]