|
|
|
|
|
import opencc |
|
|
import os |
|
|
from pathlib import Path |
|
|
import sys |
|
|
from typing import Optional |
|
|
import multiprocessing |
|
|
|
|
|
|
|
|
|
|
|
detected_cpus = multiprocessing.cpu_count() |
|
|
if os.environ.get('SPACE_ID'): |
|
|
|
|
|
num_vcpus = min(detected_cpus, 2) |
|
|
else: |
|
|
num_vcpus = detected_cpus |
|
|
|
|
|
model_names = { |
|
|
"tiny English":"tiny", |
|
|
"tiny Arabic":"tiny-ar", |
|
|
"tiny Chinese":"tiny-zh", |
|
|
"tiny Japanese":"tiny-ja", |
|
|
"tiny Korean":"tiny-ko", |
|
|
"tiny Ukrainian":"tiny-uk", |
|
|
"tiny Vietnamese":"tiny-vi", |
|
|
"base English":"base", |
|
|
"base Spanish":"base-es" |
|
|
} |
|
|
|
|
|
|
|
|
sensevoice_models = { |
|
|
"SenseVoice Small (2024)": "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17", |
|
|
"SenseVoice Small (2025 int8)": "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-int8-2025-09-09", |
|
|
} |
|
|
|
|
|
available_gguf_llms = { |
|
|
"Gemma-3-1B": ("bartowski/google_gemma-3-1b-it-qat-GGUF", "google_gemma-3-1b-it-qat-Q4_0.gguf"), |
|
|
"Gemma-3-270M": ("bartowski/google_gemma-3-270m-it-qat-GGUF", "google_gemma-3-270m-it-qat-Q8_0.gguf"), |
|
|
"Gemma-3-3N-E2B": ("unsloth/gemma-3n-E2B-it-GGUF", "gemma-3n-E2B-it-Q4_0.gguf"), |
|
|
"Gemma-3-3N-E4B": ("unsloth/gemma-3n-E4B-it-GGUF", "gemma-3n-E4B-it-Q4_0.gguf"), |
|
|
} |
|
|
|
|
|
s2tw_converter = opencc.OpenCC('s2twp') |
|
|
|
|
|
def get_writable_model_dir(): |
|
|
"""Get appropriate model directory for HF Spaces""" |
|
|
|
|
|
if os.environ.get('SPACE_ID'): |
|
|
|
|
|
cache_dir = Path('/tmp/models') |
|
|
else: |
|
|
|
|
|
cache_dir = Path.home() / ".cache" / "speech_assistant" / "models" |
|
|
|
|
|
|
|
|
cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
return cache_dir |
|
|
|
|
|
def download_sensevoice_model(model_name: str) -> Path: |
|
|
"""Download SenseVoice model from Hugging Face using official tools""" |
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
from huggingface_hub.utils import HFValidationError |
|
|
except ImportError: |
|
|
raise ImportError("Please install huggingface_hub: pip install huggingface_hub") |
|
|
|
|
|
|
|
|
repo_id = model_name |
|
|
model_cache_dir = get_writable_model_dir() |
|
|
local_dir = model_cache_dir / model_name.replace("/", "--") |
|
|
|
|
|
|
|
|
model_file = "model.int8.onnx" if "int8" in model_name else "model.onnx" |
|
|
model_file_path = local_dir / model_file |
|
|
tokens_file_path = local_dir / "tokens.txt" |
|
|
|
|
|
if model_file_path.exists() and tokens_file_path.exists(): |
|
|
print(f"Model {model_name} already exists, skipping download") |
|
|
return local_dir |
|
|
|
|
|
|
|
|
if local_dir.exists(): |
|
|
import shutil |
|
|
print(f"Removing incomplete model directory: {local_dir}") |
|
|
shutil.rmtree(local_dir) |
|
|
|
|
|
print(f"Downloading {model_name} from Hugging Face") |
|
|
print("This may take several minutes depending on your connection...") |
|
|
|
|
|
try: |
|
|
|
|
|
snapshot_download( |
|
|
repo_id=repo_id, |
|
|
local_dir=str(local_dir), |
|
|
resume_download=True, |
|
|
max_workers=4, |
|
|
) |
|
|
|
|
|
print(f"Model {model_name} downloaded successfully!") |
|
|
return local_dir |
|
|
|
|
|
except HFValidationError as e: |
|
|
print(f"Hugging Face validation error: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"Download failed: {str(e)}") |
|
|
|
|
|
if local_dir.exists(): |
|
|
import shutil |
|
|
shutil.rmtree(local_dir) |
|
|
raise e |
|
|
|
|
|
def load_sensevoice_model(model_name: str = "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17"): |
|
|
"""Load SenseVoice ONNX model from Hugging Face""" |
|
|
try: |
|
|
|
|
|
import sherpa_onnx |
|
|
|
|
|
print(f"Loading model: {model_name}") |
|
|
|
|
|
|
|
|
model_path = download_sensevoice_model(model_name) |
|
|
|
|
|
|
|
|
model_file = "model.int8.onnx" if "int8" in model_name else "model.onnx" |
|
|
model_file_path = model_path / model_file |
|
|
|
|
|
print("Initializing recognizer...") |
|
|
|
|
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
|
|
model=str(model_file_path), |
|
|
tokens=str(model_path / "tokens.txt"), |
|
|
use_itn=True, |
|
|
language="auto" |
|
|
) |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
return recognizer |
|
|
except Exception as e: |
|
|
print(f"Failed to load SenseVoice model: {e}") |
|
|
|
|
|
model_cache_dir = get_writable_model_dir() |
|
|
model_dir = model_cache_dir / model_name.replace("/", "--") |
|
|
if model_dir.exists(): |
|
|
import shutil |
|
|
print(f"Removing model directory for redownload: {model_dir}") |
|
|
shutil.rmtree(model_dir) |
|
|
raise e |