Spaces:
Runtime error
Runtime error
File size: 5,659 Bytes
8970226 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""
Voice cloning utility for Coqui TTS XTTS v2 with a cached, reusable model service.
- Provides a CLI for one-off synthesis
- Exposes a clone_voice() API that reuses a loaded model across calls
- Exposes warm_model() and is_model_loaded() for backend progress integration
"""
import argparse
import os
import sys
import threading
from typing import Optional
try:
import torch
_HAS_CUDA = torch.cuda.is_available()
except Exception:
torch = None
_HAS_CUDA = False
try:
from torch.serialization import add_safe_globals
except Exception:
add_safe_globals = None
try:
from TTS.config.shared_configs import BaseDatasetConfig
except Exception:
BaseDatasetConfig = None
try:
from TTS.tts.configs.xtts_config import XttsConfig
except Exception:
XttsConfig = None
try:
from TTS.tts.models.xtts import XttsAudioConfig
except Exception:
XttsAudioConfig = None
from TTS.api import TTS
MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
def _collect_safe_globals():
safe_classes = []
for cls in (BaseDatasetConfig, XttsConfig, XttsAudioConfig):
if cls:
safe_classes.append(cls)
try:
from TTS.tts.models.xtts import XttsArgs # type: ignore
safe_classes.append(XttsArgs)
except Exception:
pass
return safe_classes
class ModelService:
"""Thread-safe, reusable XTTS model service."""
def __init__(self, device: Optional[str] = None) -> None:
self.device = device or ("cuda" if _HAS_CUDA else "cpu")
self._tts = None
self._load_lock = threading.Lock()
def _register_safe_globals(self) -> None:
if not add_safe_globals:
return
safe_classes = _collect_safe_globals()
if not safe_classes:
return
try:
add_safe_globals(safe_classes)
print(f"[INFO] Registered safe globals: {[c.__name__ for c in safe_classes]}")
except Exception as e:
print(f"[WARN] Could not register safe globals: {e}")
def load(self) -> None:
if self._tts is not None:
return
with self._load_lock:
if self._tts is not None:
return
print(f"[INFO] Loading model '{MODEL_NAME}' on device: {self.device} ...", flush=True)
self._register_safe_globals()
self._tts = TTS(MODEL_NAME).to(self.device)
@property
def tts(self):
if self._tts is None:
self.load()
return self._tts
def tts_to_file(self, *, text: str, speaker_wav: str, language: str, file_path: str) -> None:
if not os.path.isfile(speaker_wav):
raise FileNotFoundError(f"Reference voice file not found: {speaker_wav}")
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
print(f"[INFO] Generating audio => {file_path}", flush=True)
self.tts.tts_to_file(
text=text,
speaker_wav=speaker_wav,
language=language,
file_path=file_path,
)
# Global cache of services per device
_SERVICES: dict[str, ModelService] = {}
_SERVICES_LOCK = threading.Lock()
def get_service(device: Optional[str] = None) -> ModelService:
key = (device or ("cuda" if _HAS_CUDA else "cpu")).lower()
with _SERVICES_LOCK:
svc = _SERVICES.get(key)
if svc is None:
svc = ModelService(key)
svc.load()
_SERVICES[key] = svc
return svc
def is_model_loaded(device: Optional[str] = None) -> bool:
"""Return True if the model service for the given device is present and loaded."""
key = (device or ("cuda" if _HAS_CUDA else "cpu")).lower()
with _SERVICES_LOCK:
svc = _SERVICES.get(key)
return bool(svc and getattr(svc, "_tts", None) is not None)
def warm_model(device: Optional[str] = None) -> None:
"""Ensure the model for the given device is loaded into memory."""
svc = get_service(device)
svc.load()
def clone_voice(text: str, speaker_wav: str, language: str, output: str, device: Optional[str] = None) -> None:
"""Clone a voice using a cached XTTS v2 model and synthesize text to a WAV file.
This function is thread-safe and reuses a single model instance per device
across repeated calls in the same process (e.g., a Flask app).
"""
svc = get_service(device)
svc.tts_to_file(text=text, speaker_wav=speaker_wav, language=language, file_path=output)
print("[SUCCESS] Done.")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Clone a voice with Coqui TTS XTTS v2 and synthesize text to a WAV file.",
)
parser.add_argument("--text", "-t", required=True, help="Text to synthesize.")
parser.add_argument("--speaker_wav", "-s", required=True, help="Path to the reference voice WAV file.")
parser.add_argument("--language", "-l", default="en", help="Target language code (default: en).")
parser.add_argument("--output", "-o", default="output.wav", help="Output WAV file path (default: output.wav).")
parser.add_argument(
"--device",
"-d",
choices=["cpu", "cuda"],
help="Execution device. Defaults to CUDA if available, otherwise CPU.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
try:
clone_voice(
text=args.text,
speaker_wav=args.speaker_wav,
language=args.language,
output=args.output,
device=args.device,
)
except Exception as e:
print(f"[ERROR] {e}", file=sys.stderr)
sys.exit(1)
|