Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from torch.nn import Parameter | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import LRScheduler | |
from torch import optim | |
from torch.optim import lr_scheduler | |
from diffusers.optimization import get_scheduler | |
from src.models.elevest import ElevEst | |
from src.models.gsrecon import GSRecon | |
from src.models.gsvae import GSAutoencoderKL | |
def get_optimizer(name: str, params: Parameter, **kwargs) -> Optimizer: | |
if name == "adamw": | |
return optim.AdamW(params=params, **kwargs) | |
else: | |
raise NotImplementedError(f"Not implemented optimizer: {name}") | |
def get_lr_scheduler(name: str, optimizer: Optimizer, **kwargs) -> LRScheduler: | |
if name == "one_cycle": | |
return lr_scheduler.OneCycleLR( | |
optimizer, | |
max_lr=kwargs["max_lr"], | |
total_steps=kwargs["total_steps"], | |
pct_start=kwargs["pct_start"], | |
) | |
elif name == "cosine_warmup": | |
return get_scheduler( | |
"cosine", optimizer, | |
num_warmup_steps=kwargs["num_warmup_steps"], | |
num_training_steps=kwargs["total_steps"], | |
) | |
elif name == "constant_warmup": | |
return get_scheduler( | |
"constant_with_warmup", optimizer, | |
num_warmup_steps=kwargs["num_warmup_steps"], | |
num_training_steps=kwargs["total_steps"], | |
) | |
elif name == "constant": | |
return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda _: 1) | |
elif name == "linear_decay": | |
return lr_scheduler.LambdaLR( | |
optimizer=optimizer, | |
lr_lambda=lambda epoch: max(0., 1. - epoch / kwargs["total_epochs"]), | |
) | |
else: | |
raise NotImplementedError(f"Not implemented lr scheduler: {name}") | |