kemuriririn's picture
(wip)add gpu tags
57c2abe
# TODO: V2 of TTS Router
# Currently just use current TTS router.
import os
import json
import sys
from dotenv import load_dotenv
import fal_client
import requests
import time
import io
from gradio_client import handle_file
from pyht import Client as PyhtClient
from pyht.client import TTSOptions
import base64
import tempfile
import random
load_dotenv()
ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",")
def get_zerogpu_token():
return random.choice(ZEROGPU_TOKENS)
model_mapping = {
# "eleven-multilingual-v2": {
# "provider": "elevenlabs",
# "model": "eleven_multilingual_v2",
# },
# "eleven-turbo-v2.5": {
# "provider": "elevenlabs",
# "model": "eleven_turbo_v2_5",
# },
# "eleven-flash-v2.5": {
# "provider": "elevenlabs",
# "model": "eleven_flash_v2_5",
# },
"spark-tts": {
"provider": "spark",
"model": "spark-tts",
},
# "playht-2.0": {
# "provider": "playht",
# "model": "PlayHT2.0",
# },
# "styletts2": {
# "provider": "styletts",
# "model": "styletts2",
# },
"cosyvoice-2.0": {
"provider": "cosyvoice",
"model": "cosyvoice_2_0",
},
# "papla-p1": {
# "provider": "papla",
# "model": "papla_p1",
# },
# "hume-octave": {
# "provider": "hume",
# "model": "octave",
# },
# "minimax-02-hd": {
# "provider": "minimax",
# "model": "speech-02-hd",
# },
# "minimax-02-turbo": {
# "provider": "minimax",
# "model": "speech-02-turbo",
# },
# "lanternfish-1": {
# "provider": "lanternfish",
# "model": "lanternfish-1",
# },
"index-tts": {
"provider": "bilibili",
"model": "index-tts",
},
"maskgct": {
"provider": "amphion",
"model": "maskgct",
},
"gpt-sovits-v2": {
"provider": "gpt-sovits",
"model": "gpt-sovits-v2",
},
}
url = "https://tts-agi-tts-router-v2.hf.space/tts"
headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f'Bearer {os.getenv("HF_TOKEN")}',
}
data = {"text": "string", "provider": "string", "model": "string"}
def predict_csm(script):
result = fal_client.subscribe(
"fal-ai/csm-1b",
arguments={
# "scene": [{
# "text": "Hey how are you doing.",
# "speaker_id": 0
# }, {
# "text": "Pretty good, pretty good.",
# "speaker_id": 1
# }, {
# "text": "I'm great, so happy to be speaking to you.",
# "speaker_id": 0
# }]
"scene": script
},
with_logs=True,
)
return requests.get(result["audio"]["url"]).content
def predict_playdialog(script):
# Initialize the PyHT client
pyht_client = PyhtClient(
user_id=os.getenv("PLAY_USERID"),
api_key=os.getenv("PLAY_SECRETKEY"),
)
# Define the voices
voice_1 = "s3://voice-cloning-zero-shot/baf1ef41-36b6-428c-9bdf-50ba54682bd8/original/manifest.json"
voice_2 = "s3://voice-cloning-zero-shot/e040bd1b-f190-4bdb-83f0-75ef85b18f84/original/manifest.json"
# Convert script format from CSM to PlayDialog format
if isinstance(script, list):
# Process script in CSM format (list of dictionaries)
text = ""
for turn in script:
speaker_id = turn.get("speaker_id", 0)
prefix = "Host 1:" if speaker_id == 0 else "Host 2:"
text += f"{prefix} {turn['text']}\n"
else:
# If it's already a string, use as is
text = script
# Set up TTSOptions
options = TTSOptions(
voice=voice_1, voice_2=voice_2, turn_prefix="Host 1:", turn_prefix_2="Host 2:"
)
# Generate audio using PlayDialog
audio_chunks = []
for chunk in pyht_client.tts(text, options, voice_engine="PlayDialog"):
audio_chunks.append(chunk)
# Combine all chunks into a single audio file
return b"".join(audio_chunks)
def predict_dia(script):
# Convert script to the required format for Dia
if isinstance(script, list):
# Convert from list of dictionaries to formatted string
formatted_text = ""
for turn in script:
speaker_id = turn.get("speaker_id", 0)
speaker_tag = "[S1]" if speaker_id == 0 else "[S2]"
text = turn.get("text", "").strip().replace("[S1]", "").replace("[S2]", "")
formatted_text += f"{speaker_tag} {text} "
text = formatted_text.strip()
else:
# If it's already a string, use as is
text = script
print(text)
# Make a POST request to initiate the dialogue generation
headers = {
# "Content-Type": "application/json",
"Authorization": f"Bearer {get_zerogpu_token()}"
}
response = requests.post(
"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue",
headers=headers,
json={"data": [text]},
)
# Extract the event ID from the response
event_id = response.json()["event_id"]
# Make a streaming request to get the generated dialogue
stream_url = f"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue/{event_id}"
# Use a streaming request to get the audio data
with requests.get(stream_url, headers=headers, stream=True) as stream_response:
# Process the streaming response
for line in stream_response.iter_lines():
if line:
if line.startswith(b"data: ") and not line.startswith(b"data: null"):
audio_data = line[6:]
return requests.get(json.loads(audio_data)[0]["url"]).content
def predict_index_tts(text, reference_audio_path=None):
from gradio_client import Client, handle_file
client = Client("IndexTeam/IndexTTS")
if reference_audio_path:
prompt = handle_file(reference_audio_path)
else:
raise ValueError("index-tts 需要 reference_audio_path")
result = client.predict(
prompt=prompt,
text=text,
api_name="/gen_single"
)
if type(result) != str:
result = result.get("value")
print("index-tts result:", result)
return result
def predict_spark_tts(text, reference_audio_path=None):
from gradio_client import Client, handle_file
client = Client("amortalize/Spark-TTS-Zero")
prompt_wav = None
if reference_audio_path:
prompt_wav = handle_file(reference_audio_path)
result = client.predict(
text=text,
prompt_text=text,
prompt_wav_upload=prompt_wav,
prompt_wav_record=prompt_wav,
api_name="/voice_clone"
)
print("spark-tts result:", result)
return result
def predict_cosyvoice_tts(text, reference_audio_path=None):
import tempfile
import soundfile as sf
from huggingface_hub import snapshot_download
model_dir = os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B", "pretrained_models", "CosyVoice2-0.5B")
if not os.path.exists(model_dir) or not os.listdir(model_dir):
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir=model_dir)
sys.path.append(os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B"))
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
# 全局模型初始化
global _cosyvoice_model
if '_cosyvoice_model' not in globals() or _cosyvoice_model is None:
_cosyvoice_model = CosyVoice2(model_dir)
model = _cosyvoice_model
if not reference_audio_path:
raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
# 读取参考音频
prompt_speech_16k = load_wav(reference_audio_path, 16000)
# 参考文本可选,这里不做ASR,直接传空字符串
prompt_text = ""
# 推理
result = None
for i in model.inference_zero_shot(text, prompt_text, prompt_speech_16k):
result = i['tts_speech'].numpy().flatten()
# 保存为临时wav
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(temp_file.name, result, 24000)
return temp_file.name
def predict_maskgct(text, reference_audio_path=None):
from gradio_client import Client, handle_file
client = Client("cocktailpeanut/maskgct")
if not reference_audio_path:
raise ValueError("maskgct 需要 reference_audio_path")
prompt_wav = handle_file(reference_audio_path)
result = client.predict(
prompt_wav=prompt_wav,
target_text=text,
target_len=-1,
n_timesteps=25,
api_name="/predict"
)
print("maskgct result:", result)
return result
def predict_gpt_sovits_v2(text, reference_audio_path=None):
from gradio_client import Client, file
client = Client("lj1995/GPT-SoVITS-v2")
if not reference_audio_path:
raise ValueError("GPT-SoVITS-v2 需要 reference_audio_path")
result = client.predict(
ref_wav_path=file(reference_audio_path),
prompt_text="",
prompt_language="English",
text=text,
text_language="English",
how_to_cut="Slice once every 4 sentences",
top_k=15,
top_p=1,
temperature=1,
ref_free=False,
speed=1,
if_freeze=False,
inp_refs=[],
api_name="/get_tts_wav"
)
print("gpt-sovits-v2 result:", result)
return result
def predict_tts(text, model, reference_audio_path=None):
global client
print(f"Predicting TTS for {model}")
# Exceptions: special models that shouldn't be passed to the router
if model == "csm-1b":
return predict_csm(text)
elif model == "playdialog-1.0":
return predict_playdialog(text)
elif model == "dia-1.6b":
return predict_dia(text)
elif model == "index-tts":
return predict_index_tts(text, reference_audio_path)
elif model == "spark-tts":
return predict_spark_tts(text, reference_audio_path)
elif model == "cosyvoice-2.0":
return predict_cosyvoice_tts(text, reference_audio_path)
elif model == "maskgct":
return predict_maskgct(text, reference_audio_path)
elif model == "gpt-sovits-v2":
return predict_gpt_sovits_v2(text, reference_audio_path)
if not model in model_mapping:
raise ValueError(f"Model {model} not found")
# 构建请求体
payload = {
"text": text,
"provider": model_mapping[model]["provider"],
"model": model_mapping[model]["model"],
}
# 仅支持音色克隆的模型传递参考音色
supports_reference = model in [
"styletts2", "eleven-multilingual-v2", "eleven-turbo-v2.5", "eleven-flash-v2.5"
]
if reference_audio_path and supports_reference:
with open(reference_audio_path, "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
# 不同模型参考音色字段不同
if model == "styletts2":
payload["reference_speaker"] = audio_b64
else: # elevenlabs 系列
payload["reference_audio"] = audio_b64
result = requests.post(
url,
headers=headers,
data=json.dumps(payload),
)
response_json = result.json()
audio_data = response_json["audio_data"] # base64 encoded audio data
extension = response_json["extension"]
# Decode the base64 audio data
audio_bytes = base64.b64decode(audio_data)
# Create a temporary file to store the audio data
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
return temp_path
if __name__ == "__main__":
print(
predict_dia(
[
{"text": "Hello, how are you?", "speaker_id": 0},
{"text": "I'm great, thank you!", "speaker_id": 1},
]
)
)
# print("Predicting PlayDialog")
# print(
# predict_playdialog(
# [
# {"text": "Hey how are you doing.", "speaker_id": 0},
# {"text": "Pretty good, pretty good.", "speaker_id": 1},
# {"text": "I'm great, so happy to be speaking to you.", "speaker_id": 0},
# ]
# )
# )