File size: 5,294 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json

import torch

from specforge.lr_scheduler import CosineAnnealingWarmupLR
from specforge.utils import print_on_rank0


class BF16Optimizer:
    def __init__(
        self,
        model,
        lr,
        weight_decay=0.0,
        max_grad_norm=0.5,
        total_steps=800_000,
        warmup_ratio=0.015,
        use_fp32_params=True,
        optimizer_type="adamw",
        optimizer_config=None,
    ):
        self.model = model
        self.model_params = [p for p in model.parameters() if p.requires_grad]
        self.max_grad_norm = max_grad_norm
        self.use_fp32_params = use_fp32_params
        self.optimizer_type = optimizer_type

        if use_fp32_params:
            self.fp32_params = [
                p.detach().clone().to(torch.float32) for p in self.model_params
            ]
            for mp in self.fp32_params:
                mp.requires_grad = True
            self.optim_params = self.fp32_params
        else:
            self.fp32_params = None
            self.optim_params = self.model_params

        self.optimizer = self._create_optimizer(lr, weight_decay, optimizer_config)
        self.scheduler = CosineAnnealingWarmupLR(
            self.optimizer,
            total_steps=total_steps,
            warmup_steps=int(warmup_ratio * total_steps),
        )

    def _create_optimizer(self, lr, weight_decay, optimizer_config):
        if self.optimizer_type == "adamw":
            return torch.optim.AdamW(
                self.optim_params, lr=lr, weight_decay=weight_decay
            )
        elif self.optimizer_type == "adamw_8bit":
            import bitsandbytes as bnb

            return bnb.optim.AdamW8bit(
                self.optim_params, lr=lr, weight_decay=weight_decay
            )
        elif self.optimizer_type == "apollo":
            from apollo_torch import APOLLOAdamW

            assert optimizer_config is not None, (
                "optimizer_config path is required when optimizer_type='apollo'"
            )
            with open(optimizer_config, "r") as f:
                apollo_cfg = json.load(f)
            param_groups = self._build_apollo_param_groups(
                apollo_cfg, lr, weight_decay
            )
            return APOLLOAdamW(param_groups, lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError(
                f"Unknown optimizer_type: {self.optimizer_type}. "
                f"Supported types: adamw, adamw_8bit, apollo"
            )

    def _build_apollo_param_groups(self, apollo_cfg, lr, weight_decay):
        """Build param groups for APOLLO optimizer.

        Splits parameters into two groups:
        - non_lowrank_params: 1D params (bias, layernorm) - standard Adam update
        - lowrank_params: nD params (weight matrices) - low-rank projected update
        """
        lowrank_params = []
        non_lowrank_params = []
        for p in self.optim_params:
            if p.ndim >= 2:
                lowrank_params.append(p)
            else:
                non_lowrank_params.append(p)

        param_groups = [
            {"params": non_lowrank_params, "lr": lr, "weight_decay": weight_decay},
            {
                "params": lowrank_params,
                "lr": lr,
                "weight_decay": weight_decay,
                "rank": apollo_cfg.get("rank", 1),
                "proj": apollo_cfg.get("proj", "random"),
                "scale_type": apollo_cfg.get("scale_type", "tensor"),
                "scale": apollo_cfg.get("scale", 128),
                "update_proj_gap": apollo_cfg.get("update_proj_gap", 200),
                "proj_type": apollo_cfg.get("proj_type", "std"),
            },
        ]
        return param_groups

    def step(self):
        if self.use_fp32_params:
            with torch.no_grad():
                for p, mp in zip(self.model_params, self.fp32_params):
                    mp.grad = (
                        p.grad.detach().to(torch.float32)
                        if p.grad is not None
                        else None
                    )
            torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm)
        else:
            torch.nn.utils.clip_grad_norm_(self.model_params, self.max_grad_norm)

        self.optimizer.step()
        self.optimizer.zero_grad()
        self.scheduler.step()

        if self.use_fp32_params:
            with torch.no_grad():
                for p, mp in zip(self.model_params, self.fp32_params):
                    p.data.copy_(mp.data.to(p.dtype))
                    p.grad = None
        else:
            with torch.no_grad():
                for p in self.model_params:
                    p.grad = None

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
        print_on_rank0("Successfully loaded optimizer state_dict.")
        self.scheduler.load_state_dict(state_dict["scheduler_state_dict"])
        print_on_rank0("Successfully loaded scheduler state_dict.")

    def state_dict(self):
        return {
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
        }

    def get_learning_rate(self):
        return self.optimizer.param_groups[0]["lr"]