kemuriririn commited on
Commit
7fcb739
Β·
1 Parent(s): f55b556

(wip)debug

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. tts.py +25 -25
app.py CHANGED
@@ -541,6 +541,7 @@ def initialize_tts_cache():
541
  @limiter.limit("10 per minute") # Keep limit, cached responses are still requests
542
  def generate_tts():
543
  # If verification not setup, handle it first
 
544
  if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
545
  return jsonify({"error": "Turnstile verification required"}), 403
546
 
@@ -631,7 +632,7 @@ def generate_tts():
631
  # Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
632
  def process_model_on_the_fly(model):
633
  # δΌ ι€’ reference_audio_path η»™ predict_tts
634
- temp_audio_path = predict_tts(text, model.id, reference_audio_path=reference_audio_path)
635
  if not temp_audio_path or not os.path.exists(temp_audio_path):
636
  raise ValueError(f"predict_tts failed for model {model.id}")
637
 
 
541
  @limiter.limit("10 per minute") # Keep limit, cached responses are still requests
542
  def generate_tts():
543
  # If verification not setup, handle it first
544
+ user_token = request.headers['x-ip-token']
545
  if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
546
  return jsonify({"error": "Turnstile verification required"}), 403
547
 
 
632
  # Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
633
  def process_model_on_the_fly(model):
634
  # δΌ ι€’ reference_audio_path η»™ predict_tts
635
+ temp_audio_path = predict_tts(text, model.id, reference_audio_path=reference_audio_path,user_token=user_token)
636
  if not temp_audio_path or not os.path.exists(temp_audio_path):
637
  raise ValueError(f"predict_tts failed for model {model.id}")
638
 
tts.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  from dotenv import load_dotenv
3
  import random
4
-
5
- from fal_client import stream
6
- from gradio_client.exceptions import AppError
7
 
8
  load_dotenv()
9
 
@@ -44,10 +42,17 @@ headers = {
44
  }
45
  data = {"text": "string", "provider": "string", "model": "string"}
46
 
 
 
 
 
 
 
 
 
47
 
48
- def predict_index_tts(text, reference_audio_path=None):
49
- from gradio_client import Client, handle_file
50
- client = Client("kemuriririn/IndexTTS",verbose=True)
51
  if reference_audio_path:
52
  prompt = handle_file(reference_audio_path)
53
  else:
@@ -63,9 +68,8 @@ def predict_index_tts(text, reference_audio_path=None):
63
  return result
64
 
65
 
66
- def predict_spark_tts(text, reference_audio_path=None):
67
- from gradio_client import Client, handle_file
68
- client = Client("kemuriririn/SparkTTS")
69
  prompt_wav = None
70
  if reference_audio_path:
71
  prompt_wav = handle_file(reference_audio_path)
@@ -80,9 +84,8 @@ def predict_spark_tts(text, reference_audio_path=None):
80
  return result
81
 
82
 
83
- def predict_cosyvoice_tts(text, reference_audio_path=None):
84
- from gradio_client import Client, file, handle_file
85
- client = Client("kemuriririn/CosyVoice2-0.5B")
86
  if not reference_audio_path:
87
  raise ValueError("cosyvoice-2.0 ιœ€θ¦ reference_audio_path")
88
  prompt_wav = handle_file(reference_audio_path)
@@ -106,9 +109,8 @@ def predict_cosyvoice_tts(text, reference_audio_path=None):
106
  return result
107
 
108
 
109
- def predict_maskgct(text, reference_audio_path=None):
110
- from gradio_client import Client, handle_file
111
- client = Client("amphion/maskgct")
112
  if not reference_audio_path:
113
  raise ValueError("maskgct ιœ€θ¦ reference_audio_path")
114
  prompt_wav = handle_file(reference_audio_path)
@@ -123,9 +125,8 @@ def predict_maskgct(text, reference_audio_path=None):
123
  return result
124
 
125
 
126
- def predict_gpt_sovits_v2(text, reference_audio_path=None):
127
- from gradio_client import Client, file
128
- client = Client("kemuriririn/GPT-SoVITS-v2")
129
  if not reference_audio_path:
130
  raise ValueError("GPT-SoVITS-v2 ιœ€θ¦ reference_audio_path")
131
  result = client.predict(
@@ -148,20 +149,19 @@ def predict_gpt_sovits_v2(text, reference_audio_path=None):
148
  return result
149
 
150
 
151
- def predict_tts(text, model, reference_audio_path=None):
152
- global client
153
  print(f"Predicting TTS for {model}")
154
  # Exceptions: special models that shouldn't be passed to the router
155
  if model == "index-tts":
156
- result = predict_index_tts(text, reference_audio_path)
157
  elif model == "spark-tts":
158
- result = predict_spark_tts(text, reference_audio_path)
159
  elif model == "cosyvoice-2.0":
160
- result = predict_cosyvoice_tts(text, reference_audio_path)
161
  elif model == "maskgct":
162
- result = predict_maskgct(text, reference_audio_path)
163
  elif model == "gpt-sovits-v2":
164
- result = predict_gpt_sovits_v2(text, reference_audio_path)
165
  else:
166
  raise ValueError(f"Model {model} not found")
167
  return result
 
1
  import os
2
  from dotenv import load_dotenv
3
  import random
4
+ from gradio_client import Client, handle_file,file
 
 
5
 
6
  load_dotenv()
7
 
 
42
  }
43
  data = {"text": "string", "provider": "string", "model": "string"}
44
 
45
+ def set_client_for_session(space:str, user_token=None):
46
+ if user_token is None:
47
+ x_ip_token = get_zerogpu_token()
48
+ else:
49
+ x_ip_token = user_token
50
+
51
+ # The "gradio/text-to-image" space is a ZeroGPU space
52
+ return Client(space, headers={"X-IP-Token": x_ip_token})
53
 
54
+ def predict_index_tts(text, user_token=None, reference_audio_path=None):
55
+ client = set_client_for_session("kemuriririn/IndexTTS",user_token=user_token)
 
56
  if reference_audio_path:
57
  prompt = handle_file(reference_audio_path)
58
  else:
 
68
  return result
69
 
70
 
71
+ def predict_spark_tts(text, user_token=None,reference_audio_path=None):
72
+ client = set_client_for_session("kemuriririn/SparkTTS",user_token=user_token)
 
73
  prompt_wav = None
74
  if reference_audio_path:
75
  prompt_wav = handle_file(reference_audio_path)
 
84
  return result
85
 
86
 
87
+ def predict_cosyvoice_tts(text, user_token=None, reference_audio_path=None):
88
+ client = set_client_for_session("kemuriririn/CosyVoice2-0.5B",user_token=user_token)
 
89
  if not reference_audio_path:
90
  raise ValueError("cosyvoice-2.0 ιœ€θ¦ reference_audio_path")
91
  prompt_wav = handle_file(reference_audio_path)
 
109
  return result
110
 
111
 
112
+ def predict_maskgct(text, user_token=None, reference_audio_path=None):
113
+ client = set_client_for_session("amphion/maskgct",user_token=user_token)
 
114
  if not reference_audio_path:
115
  raise ValueError("maskgct ιœ€θ¦ reference_audio_path")
116
  prompt_wav = handle_file(reference_audio_path)
 
125
  return result
126
 
127
 
128
+ def predict_gpt_sovits_v2(text, user_token=None,reference_audio_path=None):
129
+ client = set_client_for_session("kemuriririn/GPT-SoVITS-v2",user_token=user_token)
 
130
  if not reference_audio_path:
131
  raise ValueError("GPT-SoVITS-v2 ιœ€θ¦ reference_audio_path")
132
  result = client.predict(
 
149
  return result
150
 
151
 
152
+ def predict_tts(text, model, user_token=None, reference_audio_path=None):
 
153
  print(f"Predicting TTS for {model}")
154
  # Exceptions: special models that shouldn't be passed to the router
155
  if model == "index-tts":
156
+ result = predict_index_tts(text, user_token,reference_audio_path)
157
  elif model == "spark-tts":
158
+ result = predict_spark_tts(text, user_token,reference_audio_path)
159
  elif model == "cosyvoice-2.0":
160
+ result = predict_cosyvoice_tts(text, user_token,reference_audio_path)
161
  elif model == "maskgct":
162
+ result = predict_maskgct(text, user_token,reference_audio_path)
163
  elif model == "gpt-sovits-v2":
164
+ result = predict_gpt_sovits_v2(text, user_token, reference_audio_path)
165
  else:
166
  raise ValueError(f"Model {model} not found")
167
  return result