Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,618 Bytes
bef5729 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from src.utils.typing_utils import *
import os
from omegaconf import OmegaConf
from torch import optim
from torch.optim import lr_scheduler
from diffusers.training_utils import *
from diffusers.optimization import get_scheduler
# https://github.com/huggingface/diffusers/pull/9812: fix `self.use_ema_warmup`
class MyEMAModel(EMAModel):
"""
Exponential Moving Average of models weights
"""
def __init__(
self,
parameters: Iterable[torch.nn.Parameter],
decay: float = 0.9999,
min_decay: float = 0.0,
update_after_step: int = 0,
use_ema_warmup: bool = False,
inv_gamma: Union[float, int] = 1.0,
power: Union[float, int] = 2 / 3,
foreach: bool = False,
model_cls: Optional[Any] = None,
model_config: Dict[str, Any] = None,
**kwargs,
):
"""
Args:
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
decay (float): The decay factor for the exponential moving average.
min_decay (float): The minimum decay factor for the exponential moving average.
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
use_ema_warmup (bool): Whether to use EMA warmup.
inv_gamma (float):
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
weights will be stored on CPU.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
"""
if isinstance(parameters, torch.nn.Module):
deprecation_message = (
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
"Please pass the parameters of the module instead."
)
deprecate(
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
"1.0.0",
deprecation_message,
standard_warn=False,
)
parameters = parameters.parameters()
# # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
# use_ema_warmup = True
if kwargs.get("max_value", None) is not None:
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
decay = kwargs["max_value"]
if kwargs.get("min_value", None) is not None:
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
min_decay = kwargs["min_value"]
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
if kwargs.get("device", None) is not None:
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
self.to(device=kwargs["device"])
self.temp_stored_params = None
self.decay = decay
self.min_decay = min_decay
self.update_after_step = update_after_step
self.use_ema_warmup = use_ema_warmup
self.inv_gamma = inv_gamma
self.power = power
self.optimization_step = 0
self.cur_decay_value = None # set in `step()`
self.foreach = foreach
self.model_cls = model_cls
self.model_config = model_config
def get_decay(self, optimization_step: int) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
if step <= 0:
return 0.0
if self.use_ema_warmup:
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
else:
# cur_decay_value = (1 + step) / (10 + step)
cur_decay_value = self.decay
cur_decay_value = min(cur_decay_value, self.decay)
# make sure decay is not smaller than min_decay
cur_decay_value = max(cur_decay_value, self.min_decay)
return cur_decay_value
def get_configs(yaml_path: str, cli_configs: List[str]=[], **kwargs) -> DictConfig:
yaml_configs = OmegaConf.load(yaml_path)
cli_configs = OmegaConf.from_cli(cli_configs)
configs = OmegaConf.merge(yaml_configs, cli_configs, kwargs)
OmegaConf.resolve(configs) # resolve ${...} placeholders
return configs
def get_optimizer(name: str, params: Parameter, **kwargs) -> Optimizer:
if name == "adamw":
return optim.AdamW(params=params, **kwargs)
else:
raise NotImplementedError(f"Not implemented optimizer: {name}")
def get_lr_scheduler(name: str, optimizer: Optimizer, **kwargs) -> LRScheduler:
if name == "one_cycle":
return lr_scheduler.OneCycleLR(
optimizer,
max_lr=kwargs["max_lr"],
total_steps=kwargs["total_steps"],
pct_start=kwargs["pct_start"],
)
elif name == "cosine_warmup":
return get_scheduler(
"cosine", optimizer,
num_warmup_steps=kwargs["num_warmup_steps"],
num_training_steps=kwargs["total_steps"],
)
elif name == "constant_warmup":
return get_scheduler(
"constant_with_warmup", optimizer,
num_warmup_steps=kwargs["num_warmup_steps"],
num_training_steps=kwargs["total_steps"],
)
elif name == "constant":
return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda _: 1)
elif name == "linear_decay":
return lr_scheduler.LambdaLR(
optimizer=optimizer,
lr_lambda=lambda epoch: max(0., 1. - epoch / kwargs["total_epochs"]),
)
else:
raise NotImplementedError(f"Not implemented lr scheduler: {name}")
def save_experiment_params(
args: Namespace,
configs: DictConfig,
save_dir: str
) -> Dict[str, Any]:
params = OmegaConf.merge(configs, {"args": {k: str(v) for k, v in vars(args).items()}})
OmegaConf.save(params, os.path.join(save_dir, "params.yaml"))
return dict(params)
def save_model_architecture(model: Module, save_dir: str) -> None:
num_buffers = sum(b.numel() for b in model.buffers())
num_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
message = f"Number of buffers: {num_buffers}\n" +\
f"Number of trainable / all parameters: {num_trainable_params} / {num_params}\n\n" +\
f"Model architecture:\n{model}"
with open(os.path.join(save_dir, "model.txt"), "w") as f:
f.write(message) |