import os from dotenv import load_dotenv import random from gradio_client import Client, handle_file,file from huggingface_hub.constants import HF_TOKEN_PATH load_dotenv() ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") def get_zerogpu_token(): if not ZEROGPU_TOKENS or ZEROGPU_TOKENS == [""]: return os.getenv("HF_TOKEN") return random.choice(ZEROGPU_TOKENS) model_mapping = { "spark-tts": { "provider": "spark", "model": "spark-tts", }, "cosyvoice-2.0": { "provider": "cosyvoice", "model": "cosyvoice_2_0", }, "index-tts": { "provider": "bilibili", "model": "index-tts", }, "maskgct": { "provider": "amphion", "model": "maskgct", }, "gpt-sovits-v2": { "provider": "gpt-sovits", "model": "gpt-sovits-v2", }, } url = "https://tts-agi-tts-router-v2.hf.space/tts" headers = { "accept": "application/json", "Content-Type": "application/json", "Authorization": f'Bearer {os.getenv("HF_TOKEN")}', } data = {"text": "string", "provider": "string", "model": "string"} def set_client_for_session(space:str, user_token=None): if user_token is None: return Client(space, hf_token=get_zerogpu_token()) else: x_ip_token = user_token return Client(space, headers={"X-IP-Token": x_ip_token}) def predict_index_tts(text, user_token=None, reference_audio_path=None): client = set_client_for_session("kemuriririn/IndexTTS",user_token=user_token) if reference_audio_path: prompt = handle_file(reference_audio_path) else: raise ValueError("index-tts 需要 reference_audio_path") result = client.predict( prompt=prompt, text=text, api_name="/gen_single" ) if type(result) != str: result = result.get("value") print("index-tts result:", result) return result def predict_spark_tts(text, user_token=None,reference_audio_path=None): client = set_client_for_session("thunnai/SparkTTS",user_token=user_token) prompt_wav = None if reference_audio_path: prompt_wav = handle_file(reference_audio_path) result = client.predict( text=text, prompt_text=text, prompt_wav_upload=prompt_wav, prompt_wav_record=prompt_wav, api_name="/voice_clone" ) print("spark-tts result:", result) return result def predict_cosyvoice_tts(text, user_token=None, reference_audio_path=None): client = set_client_for_session("kemuriririn/CosyVoice2-0.5B",user_token=user_token) if not reference_audio_path: raise ValueError("cosyvoice-2.0 需要 reference_audio_path") prompt_wav = handle_file(reference_audio_path) # 先识别参考音频文本 recog_result = client.predict( prompt_wav=file(reference_audio_path), api_name="/prompt_wav_recognition" ) print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result) prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result) result = client.predict( tts_text=text, prompt_text=prompt_text, prompt_wav_upload=prompt_wav, prompt_wav_record=prompt_wav, seed=0, stream=False, api_name="/generate_audio" ) print("cosyvoice-2.0 result:", result) return result def predict_maskgct(text, user_token=None, reference_audio_path=None): client = set_client_for_session("amphion/maskgct",user_token=user_token) if not reference_audio_path: raise ValueError("maskgct 需要 reference_audio_path") prompt_wav = handle_file(reference_audio_path) result = client.predict( prompt_wav=prompt_wav, target_text=text, target_len=-1, n_timesteps=25, api_name="/predict" ) print("maskgct result:", result) return result def predict_gpt_sovits_v2(text, user_token=None,reference_audio_path=None): client = set_client_for_session("kemuriririn/GPT-SoVITS-v2",user_token=user_token) if not reference_audio_path: raise ValueError("GPT-SoVITS-v2 需要 reference_audio_path") result = client.predict( ref_wav_path=file(reference_audio_path), prompt_text="", prompt_language="English", text=text, text_language="English", how_to_cut="Slice once every 4 sentences", top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False, inp_refs=[], api_name="/get_tts_wav" ) print("gpt-sovits-v2 result:", result) return result def predict_tts(text, model, user_token=None, reference_audio_path=None): print(f"Predicting TTS for {model}, user_token: {user_token}, reference_audio_path: {reference_audio_path}") # Exceptions: special models that shouldn't be passed to the router if model == "index-tts": result = predict_index_tts(text, user_token,reference_audio_path) elif model == "spark-tts": result = predict_spark_tts(text, user_token,reference_audio_path) elif model == "cosyvoice-2.0": result = predict_cosyvoice_tts(text, user_token,reference_audio_path) elif model == "maskgct": result = predict_maskgct(text, user_token,reference_audio_path) elif model == "gpt-sovits-v2": result = predict_gpt_sovits_v2(text, user_token, reference_audio_path) else: raise ValueError(f"Model {model} not found") return result if __name__ == "__main__": pass