VoxSum / src /utils.py
Luigi's picture
add gemma 270m for test use
59519b7
# utils.py
import opencc
import os
from pathlib import Path
import sys
from typing import Optional
import multiprocessing
# Detect logical cores (vCPUs available to the container)
# On HF Spaces free tier, cpu_count() reports 16 but only 2 are actually available
detected_cpus = multiprocessing.cpu_count()
if os.environ.get('SPACE_ID'):
# HF Spaces free tier limitation
num_vcpus = min(detected_cpus, 2)
else:
num_vcpus = detected_cpus
model_names = {
"tiny English":"tiny",
"tiny Arabic":"tiny-ar",
"tiny Chinese":"tiny-zh",
"tiny Japanese":"tiny-ja",
"tiny Korean":"tiny-ko",
"tiny Ukrainian":"tiny-uk",
"tiny Vietnamese":"tiny-vi",
"base English":"base",
"base Spanish":"base-es"
}
# Using only the two specified sherpa-onnx models from Hugging Face
sensevoice_models = {
"SenseVoice Small (2024)": "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17",
"SenseVoice Small (2025 int8)": "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-int8-2025-09-09",
}
available_gguf_llms = {
"Gemma-3-1B": ("bartowski/google_gemma-3-1b-it-qat-GGUF", "google_gemma-3-1b-it-qat-Q4_0.gguf"),
"Gemma-3-270M": ("bartowski/google_gemma-3-270m-it-qat-GGUF", "google_gemma-3-270m-it-qat-Q8_0.gguf"),
"Gemma-3-3N-E2B": ("unsloth/gemma-3n-E2B-it-GGUF", "gemma-3n-E2B-it-Q4_0.gguf"),
"Gemma-3-3N-E4B": ("unsloth/gemma-3n-E4B-it-GGUF", "gemma-3n-E4B-it-Q4_0.gguf"),
}
s2tw_converter = opencc.OpenCC('s2twp')
def get_writable_model_dir():
"""Get appropriate model directory for HF Spaces"""
# Check for HF Spaces environment
if os.environ.get('SPACE_ID'):
# Use HF Spaces cache directory
cache_dir = Path('/tmp/models')
else:
# Use standard cache directory
cache_dir = Path.home() / ".cache" / "speech_assistant" / "models"
# Ensure directory exists
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def download_sensevoice_model(model_name: str) -> Path:
"""Download SenseVoice model from Hugging Face using official tools"""
try:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HFValidationError
except ImportError:
raise ImportError("Please install huggingface_hub: pip install huggingface_hub")
# Use model_name directly as repo_id
repo_id = model_name
model_cache_dir = get_writable_model_dir()
local_dir = model_cache_dir / model_name.replace("/", "--")
# Check if model already exists
model_file = "model.int8.onnx" if "int8" in model_name else "model.onnx"
model_file_path = local_dir / model_file
tokens_file_path = local_dir / "tokens.txt"
if model_file_path.exists() and tokens_file_path.exists():
print(f"Model {model_name} already exists, skipping download")
return local_dir
# Remove existing incomplete model directory
if local_dir.exists():
import shutil
print(f"Removing incomplete model directory: {local_dir}")
shutil.rmtree(local_dir)
print(f"Downloading {model_name} from Hugging Face")
print("This may take several minutes depending on your connection...")
try:
# Use HF's snapshot_download for reliable download
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
resume_download=True, # Resume if interrupted
max_workers=4, # Parallel downloads
)
print(f"Model {model_name} downloaded successfully!")
return local_dir
except HFValidationError as e:
print(f"Hugging Face validation error: {e}")
raise
except Exception as e:
print(f"Download failed: {str(e)}")
# Clean up partial download
if local_dir.exists():
import shutil
shutil.rmtree(local_dir)
raise e
def load_sensevoice_model(model_name: str = "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17"):
"""Load SenseVoice ONNX model from Hugging Face"""
try:
# Try to import sherpa-onnx
import sherpa_onnx
print(f"Loading model: {model_name}")
# Download model if not exists
model_path = download_sensevoice_model(model_name)
# Determine which model file to use
model_file = "model.int8.onnx" if "int8" in model_name else "model.onnx"
model_file_path = model_path / model_file
print("Initializing recognizer...")
# Initialize recognizer with proper settings
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=str(model_file_path),
tokens=str(model_path / "tokens.txt"),
use_itn=True, # Enable inverse text normalization
language="auto" # Auto-detect language
)
print("Model loaded successfully!")
return recognizer
except Exception as e:
print(f"Failed to load SenseVoice model: {e}")
# Try to force redownload on next attempt
model_cache_dir = get_writable_model_dir()
model_dir = model_cache_dir / model_name.replace("/", "--")
if model_dir.exists():
import shutil
print(f"Removing model directory for redownload: {model_dir}")
shutil.rmtree(model_dir)
raise e