File size: 588 Bytes
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
780954b
91394e0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from .base import AbstractSVSModel

SVS_MODEL_REGISTRY = {}


def register_svs_model(prefix: str):
    def wrapper(cls):
        assert issubclass(cls, AbstractSVSModel), f"{cls} must inherit AbstractSVSModel"
        SVS_MODEL_REGISTRY[prefix] = cls
        return cls

    return wrapper


def get_svs_model(model_id: str, device="auto", **kwargs) -> AbstractSVSModel:
    for prefix, cls in SVS_MODEL_REGISTRY.items():
        if model_id.startswith(prefix):
            return cls(model_id, device=device, **kwargs)
    raise ValueError(f"No SVS wrapper found for model: {model_id}")