Simonlob's picture
Update utils.py
bf7f64a verified
import os
import torch
from dataclasses import dataclass
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, WhisperFeatureExtractor, AutoTokenizer
from typing import Dict, Any
@dataclass
class ModelConfig:
"""Configuration for Whisper models"""
model_id: str
display_name: str
class ModelConfigs:
"""Available model configurations"""
SMALL = ModelConfig(
model_id="nineninesix/kyrgyz-whisper-small",
display_name="Small"
)
MEDIUM = ModelConfig(
model_id="nineninesix/kyrgyz-whisper-medium",
display_name="Medium"
)
@classmethod
def get_all_configs(cls) -> Dict[str, ModelConfig]:
"""Get all available model configurations"""
return {
"Small": cls.SMALL,
"Medium": cls.MEDIUM
}
class InitModels:
"""Initialize and manage Whisper models for Kyrgyz speech recognition"""
def __init__(self):
self.token = os.getenv('HF_TOKEN')
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.models: Dict[str, Any] = {}
self.pipelines: Dict[str, Any] = {}
def initialize_model(self, model_config: ModelConfig) -> None:
"""Initialize a specific model and its pipeline"""
model_id = model_config.model_id
# Load model (keep on CPU for ZeroGPU compatibility)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
token=self.token
)
# Load feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained(
model_id,
token=self.token
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
language="kyrgyz",
task="transcribe",
token=self.token
)
# Create pipeline (device will be set during inference)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
torch_dtype=self.torch_dtype,
device=-1 # CPU, will move to GPU in decorated function
)
# Store model components
self.models[model_config.display_name] = {
"model": model,
"tokenizer": tokenizer,
"feature_extractor": feature_extractor
}
self.pipelines[model_config.display_name] = pipe
def initialize_all_models(self) -> None:
"""Initialize all available models"""
configs = ModelConfigs.get_all_configs()
for name, config in configs.items():
print(f"Initializing {name} model: {config.model_id}")
self.initialize_model(config)
def get_pipeline(self, model_name: str) -> Any:
"""Get pipeline for a specific model"""
return self.pipelines.get(model_name)
def get_tokenizer(self, model_name: str) -> Any:
"""Get tokenizer for a specific model"""
return self.models.get(model_name, {}).get("tokenizer")
def get_model(self, model_name: str) -> Any:
"""Get model for a specific model name"""
return self.models.get(model_name, {}).get("model")