|
|
|
|
|
import os |
|
import json |
|
import sys |
|
|
|
from dotenv import load_dotenv |
|
import fal_client |
|
import requests |
|
import time |
|
import io |
|
|
|
from gradio_client import handle_file |
|
from pyht import Client as PyhtClient |
|
from pyht.client import TTSOptions |
|
import base64 |
|
import tempfile |
|
import random |
|
|
|
load_dotenv() |
|
|
|
ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") |
|
|
|
|
|
def get_zerogpu_token(): |
|
return random.choice(ZEROGPU_TOKENS) |
|
|
|
|
|
model_mapping = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"spark-tts": { |
|
"provider": "spark", |
|
"model": "spark-tts", |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"cosyvoice-2.0": { |
|
"provider": "cosyvoice", |
|
"model": "cosyvoice_2_0", |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"index-tts": { |
|
"provider": "bilibili", |
|
"model": "index-tts", |
|
}, |
|
"maskgct": { |
|
"provider": "amphion", |
|
"model": "maskgct", |
|
}, |
|
"gpt-sovits-v2": { |
|
"provider": "gpt-sovits", |
|
"model": "gpt-sovits-v2", |
|
}, |
|
} |
|
url = "https://tts-agi-tts-router-v2.hf.space/tts" |
|
headers = { |
|
"accept": "application/json", |
|
"Content-Type": "application/json", |
|
"Authorization": f'Bearer {os.getenv("HF_TOKEN")}', |
|
} |
|
data = {"text": "string", "provider": "string", "model": "string"} |
|
|
|
|
|
def predict_csm(script): |
|
result = fal_client.subscribe( |
|
"fal-ai/csm-1b", |
|
arguments={ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"scene": script |
|
}, |
|
with_logs=True, |
|
) |
|
return requests.get(result["audio"]["url"]).content |
|
|
|
|
|
def predict_playdialog(script): |
|
|
|
pyht_client = PyhtClient( |
|
user_id=os.getenv("PLAY_USERID"), |
|
api_key=os.getenv("PLAY_SECRETKEY"), |
|
) |
|
|
|
|
|
voice_1 = "s3://voice-cloning-zero-shot/baf1ef41-36b6-428c-9bdf-50ba54682bd8/original/manifest.json" |
|
voice_2 = "s3://voice-cloning-zero-shot/e040bd1b-f190-4bdb-83f0-75ef85b18f84/original/manifest.json" |
|
|
|
|
|
if isinstance(script, list): |
|
|
|
text = "" |
|
for turn in script: |
|
speaker_id = turn.get("speaker_id", 0) |
|
prefix = "Host 1:" if speaker_id == 0 else "Host 2:" |
|
text += f"{prefix} {turn['text']}\n" |
|
else: |
|
|
|
text = script |
|
|
|
|
|
options = TTSOptions( |
|
voice=voice_1, voice_2=voice_2, turn_prefix="Host 1:", turn_prefix_2="Host 2:" |
|
) |
|
|
|
|
|
audio_chunks = [] |
|
for chunk in pyht_client.tts(text, options, voice_engine="PlayDialog"): |
|
audio_chunks.append(chunk) |
|
|
|
|
|
return b"".join(audio_chunks) |
|
|
|
|
|
def predict_dia(script): |
|
|
|
if isinstance(script, list): |
|
|
|
formatted_text = "" |
|
for turn in script: |
|
speaker_id = turn.get("speaker_id", 0) |
|
speaker_tag = "[S1]" if speaker_id == 0 else "[S2]" |
|
text = turn.get("text", "").strip().replace("[S1]", "").replace("[S2]", "") |
|
formatted_text += f"{speaker_tag} {text} " |
|
text = formatted_text.strip() |
|
else: |
|
|
|
text = script |
|
print(text) |
|
|
|
headers = { |
|
|
|
"Authorization": f"Bearer {get_zerogpu_token()}" |
|
} |
|
|
|
response = requests.post( |
|
"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue", |
|
headers=headers, |
|
json={"data": [text]}, |
|
) |
|
|
|
|
|
event_id = response.json()["event_id"] |
|
|
|
|
|
stream_url = f"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue/{event_id}" |
|
|
|
|
|
with requests.get(stream_url, headers=headers, stream=True) as stream_response: |
|
|
|
for line in stream_response.iter_lines(): |
|
if line: |
|
if line.startswith(b"data: ") and not line.startswith(b"data: null"): |
|
audio_data = line[6:] |
|
return requests.get(json.loads(audio_data)[0]["url"]).content |
|
|
|
|
|
def predict_index_tts(text, reference_audio_path=None): |
|
from gradio_client import Client, handle_file |
|
client = Client("IndexTeam/IndexTTS") |
|
if reference_audio_path: |
|
prompt = handle_file(reference_audio_path) |
|
else: |
|
raise ValueError("index-tts 需要 reference_audio_path") |
|
result = client.predict( |
|
prompt=prompt, |
|
text=text, |
|
api_name="/gen_single" |
|
) |
|
if type(result) != str: |
|
result = result.get("value") |
|
print("index-tts result:", result) |
|
return result |
|
|
|
|
|
def predict_spark_tts(text, reference_audio_path=None): |
|
from gradio_client import Client, handle_file |
|
client = Client("amortalize/Spark-TTS-Zero") |
|
prompt_wav = None |
|
if reference_audio_path: |
|
prompt_wav = handle_file(reference_audio_path) |
|
result = client.predict( |
|
text=text, |
|
prompt_text=text, |
|
prompt_wav_upload=prompt_wav, |
|
prompt_wav_record=prompt_wav, |
|
api_name="/voice_clone" |
|
) |
|
print("spark-tts result:", result) |
|
return result |
|
|
|
|
|
def predict_cosyvoice_tts(text, reference_audio_path=None): |
|
import tempfile |
|
import soundfile as sf |
|
from huggingface_hub import snapshot_download |
|
model_dir = os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B", "pretrained_models", "CosyVoice2-0.5B") |
|
if not os.path.exists(model_dir) or not os.listdir(model_dir): |
|
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir=model_dir) |
|
sys.path.append(os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B")) |
|
from cosyvoice.cli.cosyvoice import CosyVoice2 |
|
from cosyvoice.utils.file_utils import load_wav |
|
|
|
|
|
global _cosyvoice_model |
|
if '_cosyvoice_model' not in globals() or _cosyvoice_model is None: |
|
_cosyvoice_model = CosyVoice2(model_dir) |
|
model = _cosyvoice_model |
|
|
|
if not reference_audio_path: |
|
raise ValueError("cosyvoice-2.0 需要 reference_audio_path") |
|
|
|
prompt_speech_16k = load_wav(reference_audio_path, 16000) |
|
|
|
prompt_text = "" |
|
|
|
result = None |
|
for i in model.inference_zero_shot(text, prompt_text, prompt_speech_16k): |
|
result = i['tts_speech'].numpy().flatten() |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
sf.write(temp_file.name, result, 24000) |
|
return temp_file.name |
|
|
|
|
|
def predict_maskgct(text, reference_audio_path=None): |
|
from gradio_client import Client, handle_file |
|
client = Client("cocktailpeanut/maskgct") |
|
if not reference_audio_path: |
|
raise ValueError("maskgct 需要 reference_audio_path") |
|
prompt_wav = handle_file(reference_audio_path) |
|
result = client.predict( |
|
prompt_wav=prompt_wav, |
|
target_text=text, |
|
target_len=-1, |
|
n_timesteps=25, |
|
api_name="/predict" |
|
) |
|
print("maskgct result:", result) |
|
return result |
|
|
|
|
|
def predict_gpt_sovits_v2(text, reference_audio_path=None): |
|
from gradio_client import Client, file |
|
client = Client("lj1995/GPT-SoVITS-v2") |
|
if not reference_audio_path: |
|
raise ValueError("GPT-SoVITS-v2 需要 reference_audio_path") |
|
result = client.predict( |
|
ref_wav_path=file(reference_audio_path), |
|
prompt_text="", |
|
prompt_language="English", |
|
text=text, |
|
text_language="English", |
|
how_to_cut="Slice once every 4 sentences", |
|
top_k=15, |
|
top_p=1, |
|
temperature=1, |
|
ref_free=False, |
|
speed=1, |
|
if_freeze=False, |
|
inp_refs=[], |
|
api_name="/get_tts_wav" |
|
) |
|
print("gpt-sovits-v2 result:", result) |
|
return result |
|
|
|
|
|
def predict_tts(text, model, reference_audio_path=None): |
|
global client |
|
print(f"Predicting TTS for {model}") |
|
|
|
if model == "csm-1b": |
|
return predict_csm(text) |
|
elif model == "playdialog-1.0": |
|
return predict_playdialog(text) |
|
elif model == "dia-1.6b": |
|
return predict_dia(text) |
|
elif model == "index-tts": |
|
return predict_index_tts(text, reference_audio_path) |
|
elif model == "spark-tts": |
|
return predict_spark_tts(text, reference_audio_path) |
|
elif model == "cosyvoice-2.0": |
|
return predict_cosyvoice_tts(text, reference_audio_path) |
|
elif model == "maskgct": |
|
return predict_maskgct(text, reference_audio_path) |
|
elif model == "gpt-sovits-v2": |
|
return predict_gpt_sovits_v2(text, reference_audio_path) |
|
|
|
if not model in model_mapping: |
|
raise ValueError(f"Model {model} not found") |
|
|
|
|
|
payload = { |
|
"text": text, |
|
"provider": model_mapping[model]["provider"], |
|
"model": model_mapping[model]["model"], |
|
} |
|
|
|
supports_reference = model in [ |
|
"styletts2", "eleven-multilingual-v2", "eleven-turbo-v2.5", "eleven-flash-v2.5" |
|
] |
|
if reference_audio_path and supports_reference: |
|
with open(reference_audio_path, "rb") as f: |
|
audio_bytes = f.read() |
|
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") |
|
|
|
if model == "styletts2": |
|
payload["reference_speaker"] = audio_b64 |
|
else: |
|
payload["reference_audio"] = audio_b64 |
|
|
|
result = requests.post( |
|
url, |
|
headers=headers, |
|
data=json.dumps(payload), |
|
) |
|
|
|
response_json = result.json() |
|
|
|
audio_data = response_json["audio_data"] |
|
extension = response_json["extension"] |
|
|
|
audio_bytes = base64.b64decode(audio_data) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as temp_file: |
|
temp_file.write(audio_bytes) |
|
temp_path = temp_file.name |
|
|
|
return temp_path |
|
|
|
|
|
if __name__ == "__main__": |
|
print( |
|
predict_dia( |
|
[ |
|
{"text": "Hello, how are you?", "speaker_id": 0}, |
|
{"text": "I'm great, thank you!", "speaker_id": 1}, |
|
] |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|