# TODO: V2 of TTS Router # Currently just use current TTS router. import os import json import sys from dotenv import load_dotenv import fal_client import requests import time import io from gradio_client import handle_file from pyht import Client as PyhtClient from pyht.client import TTSOptions import base64 import tempfile import random load_dotenv() ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") def get_zerogpu_token(): return random.choice(ZEROGPU_TOKENS) model_mapping = { # "eleven-multilingual-v2": { # "provider": "elevenlabs", # "model": "eleven_multilingual_v2", # }, # "eleven-turbo-v2.5": { # "provider": "elevenlabs", # "model": "eleven_turbo_v2_5", # }, # "eleven-flash-v2.5": { # "provider": "elevenlabs", # "model": "eleven_flash_v2_5", # }, "spark-tts": { "provider": "spark", "model": "spark-tts", }, # "playht-2.0": { # "provider": "playht", # "model": "PlayHT2.0", # }, # "styletts2": { # "provider": "styletts", # "model": "styletts2", # }, "cosyvoice-2.0": { "provider": "cosyvoice", "model": "cosyvoice_2_0", }, # "papla-p1": { # "provider": "papla", # "model": "papla_p1", # }, # "hume-octave": { # "provider": "hume", # "model": "octave", # }, # "minimax-02-hd": { # "provider": "minimax", # "model": "speech-02-hd", # }, # "minimax-02-turbo": { # "provider": "minimax", # "model": "speech-02-turbo", # }, # "lanternfish-1": { # "provider": "lanternfish", # "model": "lanternfish-1", # }, "index-tts": { "provider": "bilibili", "model": "index-tts", }, "maskgct": { "provider": "amphion", "model": "maskgct", }, "gpt-sovits-v2": { "provider": "gpt-sovits", "model": "gpt-sovits-v2", }, } url = "https://tts-agi-tts-router-v2.hf.space/tts" headers = { "accept": "application/json", "Content-Type": "application/json", "Authorization": f'Bearer {os.getenv("HF_TOKEN")}', } data = {"text": "string", "provider": "string", "model": "string"} def predict_csm(script): result = fal_client.subscribe( "fal-ai/csm-1b", arguments={ # "scene": [{ # "text": "Hey how are you doing.", # "speaker_id": 0 # }, { # "text": "Pretty good, pretty good.", # "speaker_id": 1 # }, { # "text": "I'm great, so happy to be speaking to you.", # "speaker_id": 0 # }] "scene": script }, with_logs=True, ) return requests.get(result["audio"]["url"]).content def predict_playdialog(script): # Initialize the PyHT client pyht_client = PyhtClient( user_id=os.getenv("PLAY_USERID"), api_key=os.getenv("PLAY_SECRETKEY"), ) # Define the voices voice_1 = "s3://voice-cloning-zero-shot/baf1ef41-36b6-428c-9bdf-50ba54682bd8/original/manifest.json" voice_2 = "s3://voice-cloning-zero-shot/e040bd1b-f190-4bdb-83f0-75ef85b18f84/original/manifest.json" # Convert script format from CSM to PlayDialog format if isinstance(script, list): # Process script in CSM format (list of dictionaries) text = "" for turn in script: speaker_id = turn.get("speaker_id", 0) prefix = "Host 1:" if speaker_id == 0 else "Host 2:" text += f"{prefix} {turn['text']}\n" else: # If it's already a string, use as is text = script # Set up TTSOptions options = TTSOptions( voice=voice_1, voice_2=voice_2, turn_prefix="Host 1:", turn_prefix_2="Host 2:" ) # Generate audio using PlayDialog audio_chunks = [] for chunk in pyht_client.tts(text, options, voice_engine="PlayDialog"): audio_chunks.append(chunk) # Combine all chunks into a single audio file return b"".join(audio_chunks) def predict_dia(script): # Convert script to the required format for Dia if isinstance(script, list): # Convert from list of dictionaries to formatted string formatted_text = "" for turn in script: speaker_id = turn.get("speaker_id", 0) speaker_tag = "[S1]" if speaker_id == 0 else "[S2]" text = turn.get("text", "").strip().replace("[S1]", "").replace("[S2]", "") formatted_text += f"{speaker_tag} {text} " text = formatted_text.strip() else: # If it's already a string, use as is text = script print(text) # Make a POST request to initiate the dialogue generation headers = { # "Content-Type": "application/json", "Authorization": f"Bearer {get_zerogpu_token()}" } response = requests.post( "https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue", headers=headers, json={"data": [text]}, ) # Extract the event ID from the response event_id = response.json()["event_id"] # Make a streaming request to get the generated dialogue stream_url = f"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue/{event_id}" # Use a streaming request to get the audio data with requests.get(stream_url, headers=headers, stream=True) as stream_response: # Process the streaming response for line in stream_response.iter_lines(): if line: if line.startswith(b"data: ") and not line.startswith(b"data: null"): audio_data = line[6:] return requests.get(json.loads(audio_data)[0]["url"]).content def predict_index_tts(text, reference_audio_path=None): from gradio_client import Client, handle_file client = Client("IndexTeam/IndexTTS") if reference_audio_path: prompt = handle_file(reference_audio_path) else: raise ValueError("index-tts 需要 reference_audio_path") result = client.predict( prompt=prompt, text=text, api_name="/gen_single" ) if type(result) != str: result = result.get("value") print("index-tts result:", result) return result def predict_spark_tts(text, reference_audio_path=None): from gradio_client import Client, handle_file client = Client("amortalize/Spark-TTS-Zero") prompt_wav = None if reference_audio_path: prompt_wav = handle_file(reference_audio_path) result = client.predict( text=text, prompt_text=text, prompt_wav_upload=prompt_wav, prompt_wav_record=prompt_wav, api_name="/voice_clone" ) print("spark-tts result:", result) return result def predict_cosyvoice_tts(text, reference_audio_path=None): import tempfile import soundfile as sf from huggingface_hub import snapshot_download model_dir = os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B", "pretrained_models", "CosyVoice2-0.5B") if not os.path.exists(model_dir) or not os.listdir(model_dir): snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir=model_dir) sys.path.append(os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B")) from cosyvoice.cli.cosyvoice import CosyVoice2 from cosyvoice.utils.file_utils import load_wav # 全局模型初始化 global _cosyvoice_model if '_cosyvoice_model' not in globals() or _cosyvoice_model is None: _cosyvoice_model = CosyVoice2(model_dir) model = _cosyvoice_model if not reference_audio_path: raise ValueError("cosyvoice-2.0 需要 reference_audio_path") # 读取参考音频 prompt_speech_16k = load_wav(reference_audio_path, 16000) # 参考文本可选,这里不做ASR,直接传空字符串 prompt_text = "" # 推理 result = None for i in model.inference_zero_shot(text, prompt_text, prompt_speech_16k): result = i['tts_speech'].numpy().flatten() # 保存为临时wav temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") sf.write(temp_file.name, result, 24000) return temp_file.name def predict_maskgct(text, reference_audio_path=None): from gradio_client import Client, handle_file client = Client("cocktailpeanut/maskgct") if not reference_audio_path: raise ValueError("maskgct 需要 reference_audio_path") prompt_wav = handle_file(reference_audio_path) result = client.predict( prompt_wav=prompt_wav, target_text=text, target_len=-1, n_timesteps=25, api_name="/predict" ) print("maskgct result:", result) return result def predict_gpt_sovits_v2(text, reference_audio_path=None): from gradio_client import Client, file client = Client("lj1995/GPT-SoVITS-v2") if not reference_audio_path: raise ValueError("GPT-SoVITS-v2 需要 reference_audio_path") result = client.predict( ref_wav_path=file(reference_audio_path), prompt_text="", prompt_language="English", text=text, text_language="English", how_to_cut="Slice once every 4 sentences", top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False, inp_refs=[], api_name="/get_tts_wav" ) print("gpt-sovits-v2 result:", result) return result def predict_tts(text, model, reference_audio_path=None): global client print(f"Predicting TTS for {model}") # Exceptions: special models that shouldn't be passed to the router if model == "csm-1b": return predict_csm(text) elif model == "playdialog-1.0": return predict_playdialog(text) elif model == "dia-1.6b": return predict_dia(text) elif model == "index-tts": return predict_index_tts(text, reference_audio_path) elif model == "spark-tts": return predict_spark_tts(text, reference_audio_path) elif model == "cosyvoice-2.0": return predict_cosyvoice_tts(text, reference_audio_path) elif model == "maskgct": return predict_maskgct(text, reference_audio_path) elif model == "gpt-sovits-v2": return predict_gpt_sovits_v2(text, reference_audio_path) if not model in model_mapping: raise ValueError(f"Model {model} not found") # 构建请求体 payload = { "text": text, "provider": model_mapping[model]["provider"], "model": model_mapping[model]["model"], } # 仅支持音色克隆的模型传递参考音色 supports_reference = model in [ "styletts2", "eleven-multilingual-v2", "eleven-turbo-v2.5", "eleven-flash-v2.5" ] if reference_audio_path and supports_reference: with open(reference_audio_path, "rb") as f: audio_bytes = f.read() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") # 不同模型参考音色字段不同 if model == "styletts2": payload["reference_speaker"] = audio_b64 else: # elevenlabs 系列 payload["reference_audio"] = audio_b64 result = requests.post( url, headers=headers, data=json.dumps(payload), ) response_json = result.json() audio_data = response_json["audio_data"] # base64 encoded audio data extension = response_json["extension"] # Decode the base64 audio data audio_bytes = base64.b64decode(audio_data) # Create a temporary file to store the audio data with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as temp_file: temp_file.write(audio_bytes) temp_path = temp_file.name return temp_path if __name__ == "__main__": print( predict_dia( [ {"text": "Hello, how are you?", "speaker_id": 0}, {"text": "I'm great, thank you!", "speaker_id": 1}, ] ) ) # print("Predicting PlayDialog") # print( # predict_playdialog( # [ # {"text": "Hey how are you doing.", "speaker_id": 0}, # {"text": "Pretty good, pretty good.", "speaker_id": 1}, # {"text": "I'm great, so happy to be speaking to you.", "speaker_id": 0}, # ] # ) # )