kemuriririn commited on
Commit
1edfb59
·
1 Parent(s): 87f7c84

(wip)modify for voice clone

Browse files
Files changed (6) hide show
  1. README.md +4 -4
  2. app.py +22 -6
  3. templates/arena.html +43 -16
  4. templates/base.html +6 -6
  5. tts.old.py +0 -117
  6. tts.py +112 -63
README.md CHANGED
@@ -1,16 +1,16 @@
1
  ---
2
- title: TTS Arena V2
3
  emoji: 🏆
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
  app_file: app.py
8
- short_description: Vote on the latest TTS models!
9
  pinned: true
10
 
11
  hf_oauth: true
12
  ---
13
 
14
- Please see the [GitHub repo](https://github.com/TTS-AGI/TTS-Arena-V2) for information.
15
 
16
- Join the [Discord server](https://discord.gg/HB8fMR6GTr) for updates and support.
 
1
  ---
2
+ title: Voice Clone Arena
3
  emoji: 🏆
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
  app_file: app.py
8
+ short_description: Vote on the latest Voice Clone TTS models!
9
  pinned: true
10
 
11
  hf_oauth: true
12
  ---
13
 
14
+ [//]: # (Please see the [GitHub repo](https://github.com/TTS-AGI/TTS-Arena-V2) for information.)
15
 
16
+ [//]: # (Join the [Discord server](https://discord.gg/HB8fMR6GTr) for updates and support.)
app.py CHANGED
@@ -509,8 +509,19 @@ def generate_tts():
509
  if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
510
  return jsonify({"error": "Turnstile verification required"}), 403
511
 
512
- data = request.json
513
- text = data.get("text", "").strip() # Ensure text is stripped
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  if not text or len(text) > 1000:
516
  return jsonify({"error": "Invalid or too long text"}), 400
@@ -584,9 +595,8 @@ def generate_tts():
584
 
585
  # Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
586
  def process_model_on_the_fly(model):
587
- # Generate and save directly to the main temp dir
588
- # Assume predict_tts handles saving temporary files
589
- temp_audio_path = predict_tts(text, model.id)
590
  if not temp_audio_path or not os.path.exists(temp_audio_path):
591
  raise ValueError(f"predict_tts failed for model {model.id}")
592
 
@@ -597,7 +607,6 @@ def generate_tts():
597
 
598
  return {"model_id": model.id, "audio_path": dest_path}
599
 
600
-
601
  # Use ThreadPoolExecutor to process models concurrently
602
  with ThreadPoolExecutor(max_workers=2) as executor:
603
  results = list(executor.map(process_model_on_the_fly, selected_models))
@@ -620,6 +629,10 @@ def generate_tts():
620
  "voted": False,
621
  }
622
 
 
 
 
 
623
  # Return audio file paths and session
624
  return jsonify(
625
  {
@@ -641,6 +654,9 @@ def generate_tts():
641
  os.remove(res['audio_path'])
642
  except OSError:
643
  pass
 
 
 
644
  return jsonify({"error": "Failed to generate TTS"}), 500
645
  # --- End Cache Miss ---
646
 
 
509
  if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
510
  return jsonify({"error": "Turnstile verification required"}), 403
511
 
512
+ # 新增:支持 multipart/form-data 以接收音频文件
513
+ if request.content_type and request.content_type.startswith('multipart/form-data'):
514
+ text = request.form.get("text", "").strip()
515
+ voice_file = request.files.get("voice_file")
516
+ reference_audio_path = None
517
+ if voice_file:
518
+ temp_voice_path = os.path.join(TEMP_AUDIO_DIR, f"ref_{uuid.uuid4()}.wav")
519
+ voice_file.save(temp_voice_path)
520
+ reference_audio_path = temp_voice_path
521
+ else:
522
+ data = request.json
523
+ text = data.get("text", "").strip() # Ensure text is stripped
524
+ reference_audio_path = None
525
 
526
  if not text or len(text) > 1000:
527
  return jsonify({"error": "Invalid or too long text"}), 400
 
595
 
596
  # Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
597
  def process_model_on_the_fly(model):
598
+ # 传递 reference_audio_path predict_tts
599
+ temp_audio_path = predict_tts(text, model.id, reference_audio_path=reference_audio_path)
 
600
  if not temp_audio_path or not os.path.exists(temp_audio_path):
601
  raise ValueError(f"predict_tts failed for model {model.id}")
602
 
 
607
 
608
  return {"model_id": model.id, "audio_path": dest_path}
609
 
 
610
  # Use ThreadPoolExecutor to process models concurrently
611
  with ThreadPoolExecutor(max_workers=2) as executor:
612
  results = list(executor.map(process_model_on_the_fly, selected_models))
 
629
  "voted": False,
630
  }
631
 
632
+ # 清理临时参考音频文件
633
+ if reference_audio_path and os.path.exists(reference_audio_path):
634
+ os.remove(reference_audio_path)
635
+
636
  # Return audio file paths and session
637
  return jsonify(
638
  {
 
654
  os.remove(res['audio_path'])
655
  except OSError:
656
  pass
657
+ # 清理临时参考音频文件
658
+ if reference_audio_path and os.path.exists(reference_audio_path):
659
+ os.remove(reference_audio_path)
660
  return jsonify({"error": "Failed to generate TTS"}), 500
661
  # --- End Cache Miss ---
662
 
templates/arena.html CHANGED
@@ -12,6 +12,11 @@
12
 
13
  <div id="tts-tab" class="tab-content active">
14
  <form class="input-container">
 
 
 
 
 
15
  <div class="input-group">
16
  <button type="button" class="segmented-btn random-btn" title="Roll random text">
17
  <svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-shuffle-icon lucide-shuffle">
@@ -62,7 +67,7 @@
62
  </span>
63
  </button>
64
  </div>
65
-
66
  <div class="player">
67
  <div class="player-label">Model B <span class="model-name-display"></span></div>
68
  <div class="wave-player-container" data-model="b"></div>
@@ -76,7 +81,6 @@
76
  </div>
77
  </div>
78
  </div>
79
-
80
  <div class="vote-results" style="display: none;">
81
  <h3 class="results-heading">Vote Recorded!</h3>
82
  <div class="results-content">
@@ -88,11 +92,9 @@
88
  </div>
89
  </div>
90
  </div>
91
-
92
  <div class="next-round-container" style="display: none;">
93
  <button class="next-round-btn">Next Round</button>
94
  </div>
95
-
96
  <div id="playback-keyboard-hint" class="keyboard-hint" style="display: none;">
97
  Press <kbd>Space</kbd> to play/pause, <kbd>A</kbd>/<kbd>B</kbd> to vote, <kbd>R</kbd> for random text, <kbd>N</kbd> for next random round
98
  </div>
@@ -1017,7 +1019,8 @@
1017
  let modelNames = { a: '', b: '' };
1018
  let wavePlayers = { a: null, b: null };
1019
  let cachedSentences = []; // To store sentences available in cache
1020
-
 
1021
  // Initialize WavePlayers with mobile settings
1022
  wavePlayerContainers.forEach(container => {
1023
  const model = container.dataset.model;
@@ -1137,15 +1140,31 @@
1137
 
1138
  // Reset the flag for both samples played
1139
  bothSamplesPlayed = false;
1140
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  // Call the API to generate TTS
1142
- fetch('/api/tts/generate', {
1143
- method: 'POST',
1144
- headers: {
1145
- 'Content-Type': 'application/json',
1146
- },
1147
- body: JSON.stringify({ text: text }),
1148
- })
1149
  .then(response => {
1150
  if (!response.ok) {
1151
  return response.json().then(err => {
@@ -1199,6 +1218,11 @@
1199
  }
1200
 
1201
  function handleVote(model) {
 
 
 
 
 
1202
  // Disable both vote buttons
1203
  voteButtons.forEach(btn => {
1204
  btn.disabled = true;
@@ -1220,8 +1244,9 @@
1220
  })
1221
  .then(response => {
1222
  if (!response.ok) {
 
1223
  return response.json().then(err => {
1224
- throw new Error(err.error || 'Failed to submit vote');
1225
  });
1226
  }
1227
  return response.json();
@@ -1257,9 +1282,10 @@
1257
  nextRoundContainer.style.display = 'block';
1258
 
1259
  // Show success toast
1260
- openToast("Vote recorded successfully!", "success");
1261
  })
1262
  .catch(error => {
 
1263
  // Re-enable vote buttons
1264
  voteButtons.forEach(btn => {
1265
  btn.disabled = false;
@@ -1311,6 +1337,7 @@
1311
  // Show initial hint, hide playback hint
1312
  initialKeyboardHint.style.display = 'block';
1313
  playbackKeyboardHint.style.display = 'none';
 
1314
  }
1315
 
1316
  function handleRandom() {
@@ -1990,4 +2017,4 @@
1990
  initializePodcastLines();
1991
  });
1992
  </script>
1993
- {% endblock %}
 
12
 
13
  <div id="tts-tab" class="tab-content active">
14
  <form class="input-container">
15
+ <div class="input-group">
16
+ <label for="voice-file">上传参考音色:</label>
17
+ <input type="file" id="voice-file" accept="audio/*">
18
+ <audio id="voice-preview" controls style="display:none;"></audio>
19
+ </div>
20
  <div class="input-group">
21
  <button type="button" class="segmented-btn random-btn" title="Roll random text">
22
  <svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-shuffle-icon lucide-shuffle">
 
67
  </span>
68
  </button>
69
  </div>
70
+
71
  <div class="player">
72
  <div class="player-label">Model B <span class="model-name-display"></span></div>
73
  <div class="wave-player-container" data-model="b"></div>
 
81
  </div>
82
  </div>
83
  </div>
 
84
  <div class="vote-results" style="display: none;">
85
  <h3 class="results-heading">Vote Recorded!</h3>
86
  <div class="results-content">
 
92
  </div>
93
  </div>
94
  </div>
 
95
  <div class="next-round-container" style="display: none;">
96
  <button class="next-round-btn">Next Round</button>
97
  </div>
 
98
  <div id="playback-keyboard-hint" class="keyboard-hint" style="display: none;">
99
  Press <kbd>Space</kbd> to play/pause, <kbd>A</kbd>/<kbd>B</kbd> to vote, <kbd>R</kbd> for random text, <kbd>N</kbd> for next random round
100
  </div>
 
1019
  let modelNames = { a: '', b: '' };
1020
  let wavePlayers = { a: null, b: null };
1021
  let cachedSentences = []; // To store sentences available in cache
1022
+ let hasVoted = false; // 防止重复投票
1023
+
1024
  // Initialize WavePlayers with mobile settings
1025
  wavePlayerContainers.forEach(container => {
1026
  const model = container.dataset.model;
 
1140
 
1141
  // Reset the flag for both samples played
1142
  bothSamplesPlayed = false;
1143
+
1144
+ // 新增:处理参考音色文件上传
1145
+ const voiceFileInput = document.getElementById('voice-file');
1146
+ const file = voiceFileInput.files[0];
1147
+ let fetchOptions;
1148
+ if (file) {
1149
+ const formData = new FormData();
1150
+ formData.append('text', text);
1151
+ formData.append('voice_file', file);
1152
+ fetchOptions = {
1153
+ method: 'POST',
1154
+ body: formData
1155
+ };
1156
+ } else {
1157
+ fetchOptions = {
1158
+ method: 'POST',
1159
+ headers: {
1160
+ 'Content-Type': 'application/json',
1161
+ },
1162
+ body: JSON.stringify({ text: text }),
1163
+ };
1164
+ }
1165
+
1166
  // Call the API to generate TTS
1167
+ fetch('/api/tts/generate', fetchOptions)
 
 
 
 
 
 
1168
  .then(response => {
1169
  if (!response.ok) {
1170
  return response.json().then(err => {
 
1218
  }
1219
 
1220
  function handleVote(model) {
1221
+ if (hasVoted) {
1222
+ openToast("You have already voted. Duplicate voting is not allowed.", "warning");
1223
+ return;
1224
+ }
1225
+ hasVoted = true;
1226
  // Disable both vote buttons
1227
  voteButtons.forEach(btn => {
1228
  btn.disabled = true;
 
1244
  })
1245
  .then(response => {
1246
  if (!response.ok) {
1247
+ hasVoted = false; // allow retry
1248
  return response.json().then(err => {
1249
+ throw new Error(err.error || 'Vote failed, please try again later.');
1250
  });
1251
  }
1252
  return response.json();
 
1282
  nextRoundContainer.style.display = 'block';
1283
 
1284
  // Show success toast
1285
+ openToast("Vote successful!", "success");
1286
  })
1287
  .catch(error => {
1288
+ hasVoted = false;
1289
  // Re-enable vote buttons
1290
  voteButtons.forEach(btn => {
1291
  btn.disabled = false;
 
1337
  // Show initial hint, hide playback hint
1338
  initialKeyboardHint.style.display = 'block';
1339
  playbackKeyboardHint.style.display = 'none';
1340
+ hasVoted = false;
1341
  }
1342
 
1343
  function handleRandom() {
 
2017
  initializePodcastLines();
2018
  });
2019
  </script>
2020
+ {% endblock %}
templates/base.html CHANGED
@@ -1086,12 +1086,12 @@
1086
  </nav>
1087
 
1088
  <div class="sidebar-footer">
1089
- <a href="https://discord.gg/HB8fMR6GTr" target="_blank" rel="noopener noreferrer" class="discord-link">
1090
- <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 127.14 96.36" fill="currentColor">
1091
- <path d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/>
1092
- </svg>
1093
- Join our Discord
1094
- </a>
1095
 
1096
  {% if current_user.is_authenticated %}
1097
  <div class="user-auth" onclick="toggleUserDropdown(event)">
 
1086
  </nav>
1087
 
1088
  <div class="sidebar-footer">
1089
+ {# <a href="https://discord.gg/HB8fMR6GTr" target="_blank" rel="noopener noreferrer" class="discord-link">#}
1090
+ {# <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 127.14 96.36" fill="currentColor">#}
1091
+ {# <path d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/>#}
1092
+ {# </svg>#}
1093
+ {# Join our Discord#}
1094
+ {# </a>#}
1095
 
1096
  {% if current_user.is_authenticated %}
1097
  <div class="user-auth" onclick="toggleUserDropdown(event)">
tts.old.py DELETED
@@ -1,117 +0,0 @@
1
- # TODO: V2 of TTS Router
2
- # Currently just use current TTS router.
3
- from gradio_client import Client
4
- import os
5
- from dotenv import load_dotenv
6
- import fal_client
7
- import requests
8
- import time
9
- import io
10
- from pyht import Client as PyhtClient
11
- from pyht.client import TTSOptions
12
-
13
- load_dotenv()
14
-
15
- try:
16
- client = Client("TTS-AGI/tts-router", hf_token=os.getenv("HF_TOKEN"))
17
- except Exception as e:
18
- print(f"Error initializing client: {e}")
19
- client = None
20
-
21
- model_mapping = {
22
- "eleven-multilingual-v2": "eleven",
23
- "playht-2.0": "playht",
24
- "styletts2": "styletts2",
25
- "kokoro-v1": "kokorov1",
26
- "cosyvoice-2.0": "cosyvoice",
27
- "playht-3.0-mini": "playht3",
28
- "papla-p1": "papla",
29
- "hume-octave": "hume",
30
- }
31
-
32
-
33
- def predict_csm(script):
34
- result = fal_client.subscribe(
35
- "fal-ai/csm-1b",
36
- arguments={
37
- # "scene": [{
38
- # "text": "Hey how are you doing.",
39
- # "speaker_id": 0
40
- # }, {
41
- # "text": "Pretty good, pretty good.",
42
- # "speaker_id": 1
43
- # }, {
44
- # "text": "I'm great, so happy to be speaking to you.",
45
- # "speaker_id": 0
46
- # }]
47
- "scene": script
48
- },
49
- with_logs=True,
50
- )
51
- return requests.get(result["audio"]["url"]).content
52
-
53
-
54
- def predict_playdialog(script):
55
- # Initialize the PyHT client
56
- pyht_client = PyhtClient(
57
- user_id=os.getenv("PLAY_USERID"),
58
- api_key=os.getenv("PLAY_SECRETKEY"),
59
- )
60
-
61
- # Define the voices
62
- voice_1 = "s3://voice-cloning-zero-shot/baf1ef41-36b6-428c-9bdf-50ba54682bd8/original/manifest.json"
63
- voice_2 = "s3://voice-cloning-zero-shot/e040bd1b-f190-4bdb-83f0-75ef85b18f84/original/manifest.json"
64
-
65
- # Convert script format from CSM to PlayDialog format
66
- if isinstance(script, list):
67
- # Process script in CSM format (list of dictionaries)
68
- text = ""
69
- for turn in script:
70
- speaker_id = turn.get("speaker_id", 0)
71
- prefix = "Host 1:" if speaker_id == 0 else "Host 2:"
72
- text += f"{prefix} {turn['text']}\n"
73
- else:
74
- # If it's already a string, use as is
75
- text = script
76
-
77
- # Set up TTSOptions
78
- options = TTSOptions(
79
- voice=voice_1, voice_2=voice_2, turn_prefix="Host 1:", turn_prefix_2="Host 2:"
80
- )
81
-
82
- # Generate audio using PlayDialog
83
- audio_chunks = []
84
- for chunk in pyht_client.tts(text, options, voice_engine="PlayDialog"):
85
- audio_chunks.append(chunk)
86
-
87
- # Combine all chunks into a single audio file
88
- return b"".join(audio_chunks)
89
-
90
-
91
- def predict_tts(text, model):
92
- global client
93
- # Exceptions: special models that shouldn't be passed to the router
94
- if model == "csm-1b":
95
- return predict_csm(text)
96
- elif model == "playdialog-1.0":
97
- return predict_playdialog(text)
98
-
99
- if not model in model_mapping:
100
- raise ValueError(f"Model {model} not found")
101
- result = client.predict(
102
- text=text, model=model_mapping[model], api_name="/synthesize"
103
- ) # returns path to audio file
104
- return result
105
-
106
-
107
- if __name__ == "__main__":
108
- print("Predicting PlayDialog")
109
- print(
110
- predict_playdialog(
111
- [
112
- {"text": "Hey how are you doing.", "speaker_id": 0},
113
- {"text": "Pretty good, pretty good.", "speaker_id": 1},
114
- {"text": "I'm great, so happy to be speaking to you.", "speaker_id": 0},
115
- ]
116
- )
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts.py CHANGED
@@ -23,65 +23,65 @@ def get_zerogpu_token():
23
 
24
 
25
  model_mapping = {
26
- "eleven-multilingual-v2": {
27
- "provider": "elevenlabs",
28
- "model": "eleven_multilingual_v2",
29
- },
30
- "eleven-turbo-v2.5": {
31
- "provider": "elevenlabs",
32
- "model": "eleven_turbo_v2_5",
33
- },
34
- "eleven-flash-v2.5": {
35
- "provider": "elevenlabs",
36
- "model": "eleven_flash_v2_5",
37
- },
38
- "cartesia-sonic-2": {
39
- "provider": "cartesia",
40
- "model": "sonic-2",
41
- },
42
  "spark-tts": {
43
  "provider": "spark",
44
  "model": "spark-tts",
45
  },
46
- "playht-2.0": {
47
- "provider": "playht",
48
- "model": "PlayHT2.0",
49
- },
50
- "styletts2": {
51
- "provider": "styletts",
52
- "model": "styletts2",
53
- },
54
- "kokoro-v1": {
55
- "provider": "kokoro",
56
- "model": "kokoro_v1",
57
- },
58
- "cosyvoice-2.0": {
59
- "provider": "cosyvoice",
60
- "model": "cosyvoice_2_0",
61
- },
62
- "papla-p1": {
63
- "provider": "papla",
64
- "model": "papla_p1",
65
- },
66
- "hume-octave": {
67
- "provider": "hume",
68
- "model": "octave",
69
- },
70
- "megatts3": {
71
- "provider": "megatts3",
72
- "model": "megatts3",
73
- },
74
- "minimax-02-hd": {
75
- "provider": "minimax",
76
- "model": "speech-02-hd",
77
- },
78
- "minimax-02-turbo": {
79
- "provider": "minimax",
80
- "model": "speech-02-turbo",
81
- },
82
- "lanternfish-1": {
83
- "provider": "lanternfish",
84
- "model": "lanternfish-1",
85
  },
86
  }
87
  url = "https://tts-agi-tts-router-v2.hf.space/tts"
@@ -194,7 +194,38 @@ def predict_dia(script):
194
  return requests.get(json.loads(audio_data)[0]["url"]).content
195
 
196
 
197
- def predict_tts(text, model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  global client
199
  print(f"Predicting TTS for {model}")
200
  # Exceptions: special models that shouldn't be passed to the router
@@ -204,20 +235,38 @@ def predict_tts(text, model):
204
  return predict_playdialog(text)
205
  elif model == "dia-1.6b":
206
  return predict_dia(text)
 
 
 
 
207
 
208
  if not model in model_mapping:
209
  raise ValueError(f"Model {model} not found")
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  result = requests.post(
212
  url,
213
  headers=headers,
214
- data=json.dumps(
215
- {
216
- "text": text,
217
- "provider": model_mapping[model]["provider"],
218
- "model": model_mapping[model]["model"],
219
- }
220
- ),
221
  )
222
 
223
  response_json = result.json()
 
23
 
24
 
25
  model_mapping = {
26
+ # "eleven-multilingual-v2": {
27
+ # "provider": "elevenlabs",
28
+ # "model": "eleven_multilingual_v2",
29
+ # },
30
+ # "eleven-turbo-v2.5": {
31
+ # "provider": "elevenlabs",
32
+ # "model": "eleven_turbo_v2_5",
33
+ # },
34
+ # "eleven-flash-v2.5": {
35
+ # "provider": "elevenlabs",
36
+ # "model": "eleven_flash_v2_5",
37
+ # },
38
+ # "cartesia-sonic-2": {
39
+ # "provider": "cartesia",
40
+ # "model": "sonic-2",
41
+ # },
42
  "spark-tts": {
43
  "provider": "spark",
44
  "model": "spark-tts",
45
  },
46
+ # "playht-2.0": {
47
+ # "provider": "playht",
48
+ # "model": "PlayHT2.0",
49
+ # },
50
+ # "styletts2": {
51
+ # "provider": "styletts",
52
+ # "model": "styletts2",
53
+ # },
54
+ # "cosyvoice-2.0": {
55
+ # "provider": "cosyvoice",
56
+ # "model": "cosyvoice_2_0",
57
+ # },
58
+ # "papla-p1": {
59
+ # "provider": "papla",
60
+ # "model": "papla_p1",
61
+ # },
62
+ # "hume-octave": {
63
+ # "provider": "hume",
64
+ # "model": "octave",
65
+ # },
66
+ # "megatts3": {
67
+ # "provider": "megatts3",
68
+ # "model": "megatts3",
69
+ # },
70
+ # "minimax-02-hd": {
71
+ # "provider": "minimax",
72
+ # "model": "speech-02-hd",
73
+ # },
74
+ # "minimax-02-turbo": {
75
+ # "provider": "minimax",
76
+ # "model": "speech-02-turbo",
77
+ # },
78
+ # "lanternfish-1": {
79
+ # "provider": "lanternfish",
80
+ # "model": "lanternfish-1",
81
+ # },
82
+ "index-tts": {
83
+ "provider": "bilibili",
84
+ "model": "index-tts",
85
  },
86
  }
87
  url = "https://tts-agi-tts-router-v2.hf.space/tts"
 
194
  return requests.get(json.loads(audio_data)[0]["url"]).content
195
 
196
 
197
+ def predict_index_tts(text, reference_audio_path=None):
198
+ from gradio_client import Client, handle_file
199
+ client = Client("IndexTeam/IndexTTS")
200
+ if reference_audio_path:
201
+ prompt = handle_file(reference_audio_path)
202
+ else:
203
+ raise ValueError("index-tts 需要 reference_audio_path")
204
+ result = client.predict(
205
+ prompt=prompt,
206
+ text=text,
207
+ api_name="/gen_single"
208
+ )
209
+ return result
210
+
211
+
212
+ def predict_spark_tts(text, reference_audio_path=None):
213
+ from gradio_client import Client, handle_file
214
+ client = Client("amortalize/Spark-TTS-Zero")
215
+ prompt_wav = None
216
+ if reference_audio_path:
217
+ prompt_wav = handle_file(reference_audio_path)
218
+ result = client.predict(
219
+ text=text,
220
+ prompt_text=text,
221
+ prompt_wav_upload=prompt_wav,
222
+ prompt_wav_record=prompt_wav,
223
+ api_name="/voice_clone"
224
+ )
225
+ return result
226
+
227
+
228
+ def predict_tts(text, model, reference_audio_path=None):
229
  global client
230
  print(f"Predicting TTS for {model}")
231
  # Exceptions: special models that shouldn't be passed to the router
 
235
  return predict_playdialog(text)
236
  elif model == "dia-1.6b":
237
  return predict_dia(text)
238
+ elif model == "index-tts":
239
+ return predict_index_tts(text, reference_audio_path)
240
+ elif model == "spark-tts":
241
+ return predict_spark_tts(text, reference_audio_path)
242
 
243
  if not model in model_mapping:
244
  raise ValueError(f"Model {model} not found")
245
 
246
+ # 构建请求体
247
+ payload = {
248
+ "text": text,
249
+ "provider": model_mapping[model]["provider"],
250
+ "model": model_mapping[model]["model"],
251
+ }
252
+ # 仅对支持音色克隆的模型传递参考音色
253
+ supports_reference = model in [
254
+ "styletts2", "eleven-multilingual-v2", "eleven-turbo-v2.5", "eleven-flash-v2.5"
255
+ ]
256
+ if reference_audio_path and supports_reference:
257
+ with open(reference_audio_path, "rb") as f:
258
+ audio_bytes = f.read()
259
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
260
+ # 不同模型参考音色字段不同
261
+ if model == "styletts2":
262
+ payload["reference_speaker"] = audio_b64
263
+ else: # elevenlabs 系列
264
+ payload["reference_audio"] = audio_b64
265
+
266
  result = requests.post(
267
  url,
268
  headers=headers,
269
+ data=json.dumps(payload),
 
 
 
 
 
 
270
  )
271
 
272
  response_json = result.json()