fffiloni's picture
Migrated from GitHub
406f22d verified
from .tiger import TIGER
from .tiger_dnr import TIGERDNR
from .base_model import BaseModel
__all__ = [
"TIGER",
]
def register_model(custom_model):
"""Register a custom model, gettable with `models.get`.
Args:
custom_model: Custom model to register.
"""
if (
custom_model.__name__ in globals().keys()
or custom_model.__name__.lower() in globals().keys()
):
raise ValueError(
f"Model {custom_model.__name__} already exists. Choose another name."
)
globals().update({custom_model.__name__: custom_model})
def get(identifier):
"""Returns an model class from a string (case-insensitive).
Args:
identifier (str): the model name.
Returns:
:class:`torch.nn.Module`
"""
if isinstance(identifier, str):
to_get = {k.lower(): v for k, v in globals().items()}
cls = to_get.get(identifier.lower())
if cls is None:
raise ValueError(f"Could not interpret model name : {str(identifier)}")
return cls
raise ValueError(f"Could not interpret model name : {str(identifier)}")