Spaces:
Sleeping
Sleeping
File size: 2,014 Bytes
6983b01 91394e0 6983b01 91394e0 05779d3 91394e0 6983b01 91394e0 7a23964 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
from abc import ABC, abstractmethod
import librosa
import numpy as np
from transformers import pipeline
ASR_MODEL_REGISTRY = {}
hf_token = os.getenv("HF_TOKEN")
class AbstractASRModel(ABC):
def __init__(
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
):
print(f"Loading ASR model {model_id}...")
self.model_id = model_id
self.device = device
self.cache_dir = cache_dir
@abstractmethod
def transcribe(self, audio: np.ndarray, audio_sample_rate: int, **kwargs) -> str:
pass
def register_asr_model(prefix):
def wrapper(cls):
assert issubclass(cls, AbstractASRModel), f"{cls} must inherit AbstractASRModel"
ASR_MODEL_REGISTRY[prefix] = cls
return cls
return wrapper
def get_asr_model(model_id: str, device="cpu", **kwargs) -> AbstractASRModel:
for prefix, cls in ASR_MODEL_REGISTRY.items():
if model_id.startswith(prefix):
return cls(model_id, device=device, **kwargs)
raise ValueError(f"No ASR wrapper found for model: {model_id}")
@register_asr_model("openai/whisper")
class WhisperASR(AbstractASRModel):
def __init__(
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
):
super().__init__(model_id, device, cache_dir, **kwargs)
model_kwargs = kwargs.setdefault("model_kwargs", {})
model_kwargs["cache_dir"] = cache_dir
self.pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
device=0 if device == "cuda" else -1,
token=hf_token,
**kwargs,
)
def transcribe(self, audio: np.ndarray, audio_sample_rate: int, language: str, **kwargs) -> str:
if audio_sample_rate != 16000:
audio = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000)
return self.pipe(audio, generate_kwargs={"language": language}, return_timestamps=False).get("text", "")
|