Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| from dataclasses import dataclass | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, WhisperFeatureExtractor, AutoTokenizer | |
| from typing import Dict, Any | |
| class ModelConfig: | |
| """Configuration for Whisper models""" | |
| model_id: str | |
| display_name: str | |
| class ModelConfigs: | |
| """Available model configurations""" | |
| SMALL = ModelConfig( | |
| model_id="nineninesix/kyrgyz-whisper-small", | |
| display_name="Small" | |
| ) | |
| MEDIUM = ModelConfig( | |
| model_id="nineninesix/kyrgyz-whisper-medium", | |
| display_name="Medium" | |
| ) | |
| def get_all_configs(cls) -> Dict[str, ModelConfig]: | |
| """Get all available model configurations""" | |
| return { | |
| "Small": cls.SMALL, | |
| "Medium": cls.MEDIUM | |
| } | |
| class InitModels: | |
| """Initialize and manage Whisper models for Kyrgyz speech recognition""" | |
| def __init__(self): | |
| self.token = os.getenv('HF_TOKEN') | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| self.models: Dict[str, Any] = {} | |
| self.pipelines: Dict[str, Any] = {} | |
| def initialize_model(self, model_config: ModelConfig) -> None: | |
| """Initialize a specific model and its pipeline""" | |
| model_id = model_config.model_id | |
| # Load model (keep on CPU for ZeroGPU compatibility) | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=self.torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| token=self.token | |
| ) | |
| # Load feature extractor | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained( | |
| model_id, | |
| token=self.token | |
| ) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| language="kyrgyz", | |
| task="transcribe", | |
| token=self.token | |
| ) | |
| # Create pipeline (device will be set during inference) | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| torch_dtype=self.torch_dtype, | |
| device=-1 # CPU, will move to GPU in decorated function | |
| ) | |
| # Store model components | |
| self.models[model_config.display_name] = { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "feature_extractor": feature_extractor | |
| } | |
| self.pipelines[model_config.display_name] = pipe | |
| def initialize_all_models(self) -> None: | |
| """Initialize all available models""" | |
| configs = ModelConfigs.get_all_configs() | |
| for name, config in configs.items(): | |
| print(f"Initializing {name} model: {config.model_id}") | |
| self.initialize_model(config) | |
| def get_pipeline(self, model_name: str) -> Any: | |
| """Get pipeline for a specific model""" | |
| return self.pipelines.get(model_name) | |
| def get_tokenizer(self, model_name: str) -> Any: | |
| """Get tokenizer for a specific model""" | |
| return self.models.get(model_name, {}).get("tokenizer") | |
| def get_model(self, model_name: str) -> Any: | |
| """Get model for a specific model name""" | |
| return self.models.get(model_name, {}).get("model") | |