Diffsplat / src /models /__init__.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
from typing import *
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch import optim
from torch.optim import lr_scheduler
from diffusers.optimization import get_scheduler
from src.models.elevest import ElevEst
from src.models.gsrecon import GSRecon
from src.models.gsvae import GSAutoencoderKL
def get_optimizer(name: str, params: Parameter, **kwargs) -> Optimizer:
if name == "adamw":
return optim.AdamW(params=params, **kwargs)
else:
raise NotImplementedError(f"Not implemented optimizer: {name}")
def get_lr_scheduler(name: str, optimizer: Optimizer, **kwargs) -> LRScheduler:
if name == "one_cycle":
return lr_scheduler.OneCycleLR(
optimizer,
max_lr=kwargs["max_lr"],
total_steps=kwargs["total_steps"],
pct_start=kwargs["pct_start"],
)
elif name == "cosine_warmup":
return get_scheduler(
"cosine", optimizer,
num_warmup_steps=kwargs["num_warmup_steps"],
num_training_steps=kwargs["total_steps"],
)
elif name == "constant_warmup":
return get_scheduler(
"constant_with_warmup", optimizer,
num_warmup_steps=kwargs["num_warmup_steps"],
num_training_steps=kwargs["total_steps"],
)
elif name == "constant":
return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda _: 1)
elif name == "linear_decay":
return lr_scheduler.LambdaLR(
optimizer=optimizer,
lr_lambda=lambda epoch: max(0., 1. - epoch / kwargs["total_epochs"]),
)
else:
raise NotImplementedError(f"Not implemented lr scheduler: {name}")