Spaces:
Running
Running
File size: 5,557 Bytes
f1a0148 7fcb739 7f2c965 3c41269 f1a0148 7f2c965 f1a0148 19ae156 1edfb59 87f7c84 4172058 9e032ec f1a0148 7fcb739 7f2c965 7fcb739 7f2c965 f1a0148 7fcb739 1edfb59 d25b36f dab4429 1edfb59 7fcb739 a8daec4 1edfb59 dab4429 1edfb59 7fcb739 19ae156 dab4429 4172058 19ae156 4172058 19ae156 3c41269 19ae156 dab4429 4172058 dab4429 4172058 7fcb739 4172058 19ae156 7fcb739 9e032ec 7fcb739 136768b f1a0148 58071a6 7fcb739 58071a6 7fcb739 58071a6 7fcb739 58071a6 7fcb739 58071a6 7fcb739 58071a6 f1a0148 f2186e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
from dotenv import load_dotenv
import random
from gradio_client import Client, handle_file,file
from huggingface_hub.constants import HF_TOKEN_PATH
load_dotenv()
ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",")
def get_zerogpu_token():
if not ZEROGPU_TOKENS or ZEROGPU_TOKENS == [""]:
return os.getenv("HF_TOKEN")
return random.choice(ZEROGPU_TOKENS)
model_mapping = {
"spark-tts": {
"provider": "spark",
"model": "spark-tts",
},
"cosyvoice-2.0": {
"provider": "cosyvoice",
"model": "cosyvoice_2_0",
},
"index-tts": {
"provider": "bilibili",
"model": "index-tts",
},
"maskgct": {
"provider": "amphion",
"model": "maskgct",
},
"gpt-sovits-v2": {
"provider": "gpt-sovits",
"model": "gpt-sovits-v2",
},
}
url = "https://tts-agi-tts-router-v2.hf.space/tts"
headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f'Bearer {os.getenv("HF_TOKEN")}',
}
data = {"text": "string", "provider": "string", "model": "string"}
def set_client_for_session(space:str, user_token=None):
if user_token is None:
return Client(space, hf_token=get_zerogpu_token())
else:
x_ip_token = user_token
return Client(space, headers={"X-IP-Token": x_ip_token})
def predict_index_tts(text, user_token=None, reference_audio_path=None):
client = set_client_for_session("kemuriririn/IndexTTS",user_token=user_token)
if reference_audio_path:
prompt = handle_file(reference_audio_path)
else:
raise ValueError("index-tts ιθ¦ reference_audio_path")
result = client.predict(
prompt=prompt,
text=text,
api_name="/gen_single"
)
if type(result) != str:
result = result.get("value")
print("index-tts result:", result)
return result
def predict_spark_tts(text, user_token=None,reference_audio_path=None):
client = set_client_for_session("thunnai/SparkTTS",user_token=user_token)
prompt_wav = None
if reference_audio_path:
prompt_wav = handle_file(reference_audio_path)
result = client.predict(
text=text,
prompt_text=text,
prompt_wav_upload=prompt_wav,
prompt_wav_record=prompt_wav,
api_name="/voice_clone"
)
print("spark-tts result:", result)
return result
def predict_cosyvoice_tts(text, user_token=None, reference_audio_path=None):
client = set_client_for_session("kemuriririn/CosyVoice2-0.5B",user_token=user_token)
if not reference_audio_path:
raise ValueError("cosyvoice-2.0 ιθ¦ reference_audio_path")
prompt_wav = handle_file(reference_audio_path)
# ε
θ―ε«εθι³ι’ζζ¬
recog_result = client.predict(
prompt_wav=file(reference_audio_path),
api_name="/prompt_wav_recognition"
)
print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result)
prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result)
result = client.predict(
tts_text=text,
prompt_text=prompt_text,
prompt_wav_upload=prompt_wav,
prompt_wav_record=prompt_wav,
seed=0,
stream=False,
api_name="/generate_audio"
)
print("cosyvoice-2.0 result:", result)
return result
def predict_maskgct(text, user_token=None, reference_audio_path=None):
client = set_client_for_session("amphion/maskgct",user_token=user_token)
if not reference_audio_path:
raise ValueError("maskgct ιθ¦ reference_audio_path")
prompt_wav = handle_file(reference_audio_path)
result = client.predict(
prompt_wav=prompt_wav,
target_text=text,
target_len=-1,
n_timesteps=25,
api_name="/predict"
)
print("maskgct result:", result)
return result
def predict_gpt_sovits_v2(text, user_token=None,reference_audio_path=None):
client = set_client_for_session("kemuriririn/GPT-SoVITS-v2",user_token=user_token)
if not reference_audio_path:
raise ValueError("GPT-SoVITS-v2 ιθ¦ reference_audio_path")
result = client.predict(
ref_wav_path=file(reference_audio_path),
prompt_text="",
prompt_language="English",
text=text,
text_language="English",
how_to_cut="Slice once every 4 sentences",
top_k=15,
top_p=1,
temperature=1,
ref_free=False,
speed=1,
if_freeze=False,
inp_refs=[],
api_name="/get_tts_wav"
)
print("gpt-sovits-v2 result:", result)
return result
def predict_tts(text, model, user_token=None, reference_audio_path=None):
print(f"Predicting TTS for {model}, user_token: {user_token}, reference_audio_path: {reference_audio_path}")
# Exceptions: special models that shouldn't be passed to the router
if model == "index-tts":
result = predict_index_tts(text, user_token,reference_audio_path)
elif model == "spark-tts":
result = predict_spark_tts(text, user_token,reference_audio_path)
elif model == "cosyvoice-2.0":
result = predict_cosyvoice_tts(text, user_token,reference_audio_path)
elif model == "maskgct":
result = predict_maskgct(text, user_token,reference_audio_path)
elif model == "gpt-sovits-v2":
result = predict_gpt_sovits_v2(text, user_token, reference_audio_path)
else:
raise ValueError(f"Model {model} not found")
return result
if __name__ == "__main__":
pass
|