SingingSDS / modules /asr /registry.py
jhansss's picture
add more llms; add system prompt support; add device=auto
780954b
raw
history blame contribute delete
588 Bytes
from .base import AbstractASRModel
ASR_MODEL_REGISTRY = {}
def register_asr_model(prefix: str):
def wrapper(cls):
assert issubclass(cls, AbstractASRModel), f"{cls} must inherit AbstractASRModel"
ASR_MODEL_REGISTRY[prefix] = cls
return cls
return wrapper
def get_asr_model(model_id: str, device="auto", **kwargs) -> AbstractASRModel:
for prefix, cls in ASR_MODEL_REGISTRY.items():
if model_id.startswith(prefix):
return cls(model_id, device=device, **kwargs)
raise ValueError(f"No ASR wrapper found for model: {model_id}")