from torch.optim.optimizer import Optimizer from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD from torch_optimizer import ( AccSGD, AdaBound, AdaMod, DiffGrad, Lamb, NovoGrad, PID, QHAdam, QHM, RAdam, SGDW, Yogi, Ranger, RangerQH, RangerVA, ) __all__ = [ "AccSGD", "AdaBound", "AdaMod", "DiffGrad", "Lamb", "NovoGrad", "PID", "QHAdam", "QHM", "RAdam", "SGDW", "Yogi", "Ranger", "RangerQH", "RangerVA", "Adam", "RMSprop", "SGD", "Adadelta", "Adagrad", "Adamax", "AdamW", "ASGD", "make_optimizer", "get", ] def make_optimizer(params, optim_name="adam", **kwargs): """ Args: params (iterable): Output of `nn.Module.parameters()`. optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood by :func:`~.get`. **kwargs (dict): keyword arguments for the optimizer. Returns: torch.optim.Optimizer Examples >>> from torch import nn >>> model = nn.Sequential(nn.Linear(10, 10)) >>> optimizer = make_optimizer(model.parameters(), optimizer='sgd', >>> lr=1e-3) """ return get(optim_name)(params, **kwargs) def register_optimizer(custom_opt): """Register a custom opt, gettable with `optimzers.get`. Args: custom_opt: Custom optimizer to register. """ if ( custom_opt.__name__ in globals().keys() or custom_opt.__name__.lower() in globals().keys() ): raise ValueError( f"Activation {custom_opt.__name__} already exists. Choose another name." ) globals().update({custom_opt.__name__: custom_opt}) def get(identifier): """Returns an optimizer function from a string. Returns its input if it is callable (already a :class:`torch.optim.Optimizer` for example). Args: identifier (str or Callable): the optimizer identifier. Returns: :class:`torch.optim.Optimizer` or None """ if isinstance(identifier, Optimizer): return identifier elif isinstance(identifier, str): to_get = {k.lower(): v for k, v in globals().items()} cls = to_get.get(identifier.lower()) if cls is None: raise ValueError(f"Could not interpret optimizer : {str(identifier)}") return cls raise ValueError(f"Could not interpret optimizer : {str(identifier)}")