zach commited on
Commit
0e508c8
·
1 Parent(s): 83c6aee

Refactor tts integration functions to write audio to file and return file path, audio players to play mp3 file written to temp folder, fix audioplayer loading, remove unused imports

Browse files
.gitignore CHANGED
@@ -38,4 +38,4 @@ Thumbs.db
38
  *.cache
39
 
40
  # Temp files
41
- src/static/audio/
 
38
  *.cache
39
 
40
  # Temp files
41
+ static/audio/*
src/app.py CHANGED
@@ -11,13 +11,14 @@ Users can compare the outputs and vote for their favorite in an interactive UI.
11
  # Standard Library Imports
12
  from concurrent.futures import ThreadPoolExecutor
13
  import random
 
14
  from typing import Union, Tuple
15
 
16
  # Third-Party Library Imports
17
  import gradio as gr
18
 
19
  # Local Application Imports
20
- from src.config import logger
21
  from src.constants import (
22
  ELEVENLABS,
23
  HUME_AI,
@@ -27,7 +28,6 @@ from src.constants import (
27
  PROMPT_MIN_LENGTH,
28
  SAMPLE_PROMPTS,
29
  TROPHY_EMOJI,
30
- UNKNOWN_PROVIDER,
31
  VOTE_FOR_OPTION_A,
32
  VOTE_FOR_OPTION_B,
33
  )
@@ -41,7 +41,7 @@ from src.integrations import (
41
  )
42
  from src.theme import CustomTheme
43
  from src.types import OptionMap
44
- from src.utils import truncate_text, validate_prompt_length
45
 
46
 
47
  def generate_text(
@@ -130,13 +130,7 @@ def text_to_speech(
130
  audio_a = future_audio_a.result()
131
  audio_b = future_audio_b.result()
132
 
133
- logger.info(
134
- f"TTS generated: {provider_a}={len(audio_a)} bytes, {provider_b}={len(audio_b)} bytes"
135
- )
136
- options = [
137
- (audio_a, provider_a),
138
- (audio_b, provider_b),
139
- ]
140
  random.shuffle(options)
141
  option_a_audio, option_b_audio = options[0][0], options[1][0]
142
  options_map: OptionMap = {OPTION_A: options[0][1], OPTION_B: options[1][1]}
@@ -444,16 +438,11 @@ def build_gradio_interface() -> gr.Blocks:
444
  ],
445
  )
446
 
447
- # Auto-play second audio after first finishes (Workaround to play audio back-to-back)
448
- # Audio player A stop event handler chain:
449
- # 1. Clear the audio player A
450
- # 2. Load audio player A with audio and set auto play to True
451
  option_a_audio_player.stop(
452
- fn=lambda _: gr.update(value=None),
453
- inputs=[],
454
- outputs=[option_b_audio_player],
455
- ).then(
456
- fn=lambda audio: gr.update(value=audio, autoplay=True),
457
  inputs=[option_b_audio_state],
458
  outputs=[option_b_audio_player],
459
  )
@@ -476,4 +465,4 @@ def build_gradio_interface() -> gr.Blocks:
476
  if __name__ == "__main__":
477
  logger.info("Launching TTS Arena Gradio app...")
478
  demo = build_gradio_interface()
479
- demo.launch()
 
11
  # Standard Library Imports
12
  from concurrent.futures import ThreadPoolExecutor
13
  import random
14
+ import time
15
  from typing import Union, Tuple
16
 
17
  # Third-Party Library Imports
18
  import gradio as gr
19
 
20
  # Local Application Imports
21
+ from src.config import AUDIO_DIR, logger
22
  from src.constants import (
23
  ELEVENLABS,
24
  HUME_AI,
 
28
  PROMPT_MIN_LENGTH,
29
  SAMPLE_PROMPTS,
30
  TROPHY_EMOJI,
 
31
  VOTE_FOR_OPTION_A,
32
  VOTE_FOR_OPTION_B,
33
  )
 
41
  )
42
  from src.theme import CustomTheme
43
  from src.types import OptionMap
44
+ from src.utils import validate_prompt_length
45
 
46
 
47
  def generate_text(
 
130
  audio_a = future_audio_a.result()
131
  audio_b = future_audio_b.result()
132
 
133
+ options = [(audio_a, provider_a), (audio_b, provider_b)]
 
 
 
 
 
 
134
  random.shuffle(options)
135
  option_a_audio, option_b_audio = options[0][0], options[1][0]
136
  options_map: OptionMap = {OPTION_A: options[0][1], OPTION_B: options[1][1]}
 
438
  ],
439
  )
440
 
441
+ # Reload audio player B with audio and set autoplay to True (workaround to play audio back-to-back)
 
 
 
442
  option_a_audio_player.stop(
443
+ fn=lambda current_audio_path: gr.update(
444
+ value=f"{current_audio_path}?t={int(time.time())}", autoplay=True
445
+ ),
 
 
446
  inputs=[option_b_audio_state],
447
  outputs=[option_b_audio_player],
448
  )
 
465
  if __name__ == "__main__":
466
  logger.info("Launching TTS Arena Gradio app...")
467
  demo = build_gradio_interface()
468
+ demo.launch(allowed_paths=[AUDIO_DIR])
src/config.py CHANGED
@@ -35,6 +35,11 @@ logging.basicConfig(
35
  )
36
  logger: logging.Logger = logging.getLogger("tts_arena")
37
  logger.info(f'Debug mode is {"enabled" if DEBUG else "disabled"}.')
38
-
39
  if DEBUG:
40
  logger.debug(f"DEBUG mode enabled.")
 
 
 
 
 
 
 
35
  )
36
  logger: logging.Logger = logging.getLogger("tts_arena")
37
  logger.info(f'Debug mode is {"enabled" if DEBUG else "disabled"}.')
 
38
  if DEBUG:
39
  logger.debug(f"DEBUG mode enabled.")
40
+
41
+
42
+ # Define the directory for audio files relative to the project root
43
+ AUDIO_DIR = os.path.join(os.getcwd(), "static", "audio")
44
+ os.makedirs(AUDIO_DIR, exist_ok=True)
45
+ logger.info(f"Audio directory set to {AUDIO_DIR}")
src/constants.py CHANGED
@@ -11,7 +11,6 @@ from src.types import OptionKey, TTSProviderName
11
  # UI constants
12
  HUME_AI: TTSProviderName = "Hume AI"
13
  ELEVENLABS: TTSProviderName = "ElevenLabs"
14
- UNKNOWN_PROVIDER: TTSProviderName = "Unknown"
15
 
16
  PROMPT_MIN_LENGTH: int = 20
17
  PROMPT_MAX_LENGTH: int = 800
 
11
  # UI constants
12
  HUME_AI: TTSProviderName = "Hume AI"
13
  ELEVENLABS: TTSProviderName = "ElevenLabs"
 
14
 
15
  PROMPT_MIN_LENGTH: int = 20
16
  PROMPT_MAX_LENGTH: int = 800
src/integrations/anthropic_api.py CHANGED
@@ -40,10 +40,23 @@ class AnthropicConfig:
40
  api_key: str = validate_env_var("ANTHROPIC_API_KEY")
41
  model: ModelParam = "claude-3-5-sonnet-latest"
42
  max_tokens: int = 150
43
- system_prompt: str = f"""You are an expert at generating micro-content optimized for text-to-speech synthesis. Your absolute priority is delivering complete, untruncated responses within strict length limits.
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  CRITICAL LENGTH CONSTRAINTS:
45
 
46
- Maximum length: {max_tokens} tokens (approximately 400 characters)
47
  You MUST complete all thoughts and sentences
48
  Responses should be 25% shorter than you initially plan
49
  Never exceed 400 characters total
@@ -73,17 +86,7 @@ Resolution (75-100 characters)
73
 
74
  MANDATORY: If you find yourself reaching 300 characters, immediately begin your conclusion regardless of where you are in the narrative.
75
  Remember: A shorter, complete response is ALWAYS better than a longer, truncated one."""
76
-
77
- def __post_init__(self):
78
- # Validate that required attributes are set
79
- if not self.api_key:
80
- raise ValueError("Anthropic API key is not set.")
81
- if not self.model:
82
- raise ValueError("Anthropic Model is not set.")
83
- if not self.max_tokens:
84
- raise ValueError("Anthropic Max Tokens is not set.")
85
- if not self.system_prompt:
86
- raise ValueError("Anthropic System Prompt is not set.")
87
 
88
  @property
89
  def client(self) -> Anthropic:
 
40
  api_key: str = validate_env_var("ANTHROPIC_API_KEY")
41
  model: ModelParam = "claude-3-5-sonnet-latest"
42
  max_tokens: int = 150
43
+ system_prompt: Optional[str] = (
44
+ None # system prompt is set post initialization, since self.max_tokens is leveraged in the prompt.
45
+ )
46
+
47
+ def __post_init__(self):
48
+ # Validate that required attributes are set
49
+ if not self.api_key:
50
+ raise ValueError("Anthropic API key is not set.")
51
+ if not self.model:
52
+ raise ValueError("Anthropic Model is not set.")
53
+ if not self.max_tokens:
54
+ raise ValueError("Anthropic Max Tokens is not set.")
55
+ if self.system_prompt is None:
56
+ system_prompt: str = f"""You are an expert at generating micro-content optimized for text-to-speech synthesis. Your absolute priority is delivering complete, untruncated responses within strict length limits.
57
  CRITICAL LENGTH CONSTRAINTS:
58
 
59
+ Maximum length: {self.max_tokens} tokens (approximately 400 characters)
60
  You MUST complete all thoughts and sentences
61
  Responses should be 25% shorter than you initially plan
62
  Never exceed 400 characters total
 
86
 
87
  MANDATORY: If you find yourself reaching 300 characters, immediately begin your conclusion regardless of where you are in the narrative.
88
  Remember: A shorter, complete response is ALWAYS better than a longer, truncated one."""
89
+ object.__setattr__(self, "system_prompt", system_prompt)
 
 
 
 
 
 
 
 
 
 
90
 
91
  @property
92
  def client(self) -> Anthropic:
src/integrations/elevenlabs_api.py CHANGED
@@ -20,20 +20,18 @@ Functions:
20
  """
21
 
22
  # Standard Library Imports
23
- import base64
24
  from dataclasses import dataclass
25
- from enum import Enum
26
  import logging
27
  import random
28
- from typing import Literal, Optional, Tuple
29
 
30
  # Third-Party Library Imports
31
- from elevenlabs import ElevenLabs
32
  from tenacity import retry, stop_after_attempt, wait_fixed, before_log, after_log
33
 
34
  # Local Application Imports
35
  from src.config import logger
36
- from src.utils import validate_env_var
37
 
38
 
39
  @dataclass(frozen=True)
@@ -41,6 +39,7 @@ class ElevenLabsConfig:
41
  """Immutable configuration for interacting with the ElevenLabs TTS API."""
42
 
43
  api_key: str = validate_env_var("ELEVENLABS_API_KEY")
 
44
 
45
  def __post_init__(self):
46
  # Validate that required attributes are set
@@ -79,14 +78,14 @@ elevenlabs_config = ElevenLabsConfig()
79
  )
80
  def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
81
  """
82
- Synthesizes text to speech using the ElevenLabs TTS API.
83
 
84
  Args:
85
  prompt (str): The original user prompt used as the voice description.
86
  text (str): The text to be synthesized to speech.
87
 
88
  Returns:
89
- bytes: The raw binary audio data for playback.
90
 
91
  Raises:
92
  ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
@@ -102,6 +101,7 @@ def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
102
  response = elevenlabs_config.client.text_to_voice.create_previews(
103
  voice_description=prompt,
104
  text=text,
 
105
  )
106
 
107
  previews = response.previews
@@ -110,10 +110,14 @@ def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
110
  logger.error(msg)
111
  raise ElevenLabsError(message=msg)
112
 
 
113
  preview = random.choice(previews)
 
114
  base64_audio = preview.audio_base_64
115
- audio = base64.b64decode(base64_audio)
116
- return audio
 
 
117
 
118
  except Exception as e:
119
  logger.exception(f"Error synthesizing speech with ElevenLabs: {e}")
 
20
  """
21
 
22
  # Standard Library Imports
 
23
  from dataclasses import dataclass
 
24
  import logging
25
  import random
26
+ from typing import Optional
27
 
28
  # Third-Party Library Imports
29
+ from elevenlabs import ElevenLabs, TextToVoiceCreatePreviewsRequestOutputFormat
30
  from tenacity import retry, stop_after_attempt, wait_fixed, before_log, after_log
31
 
32
  # Local Application Imports
33
  from src.config import logger
34
+ from src.utils import save_base64_audio_to_file, validate_env_var
35
 
36
 
37
  @dataclass(frozen=True)
 
39
  """Immutable configuration for interacting with the ElevenLabs TTS API."""
40
 
41
  api_key: str = validate_env_var("ELEVENLABS_API_KEY")
42
+ output_format: TextToVoiceCreatePreviewsRequestOutputFormat = "mp3_44100_128"
43
 
44
  def __post_init__(self):
45
  # Validate that required attributes are set
 
78
  )
79
  def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
80
  """
81
+ Synthesizes text to speech using the ElevenLabs TTS API, processes audio data, and writes audio to a file.
82
 
83
  Args:
84
  prompt (str): The original user prompt used as the voice description.
85
  text (str): The text to be synthesized to speech.
86
 
87
  Returns:
88
+ str: The relative path for the file the synthesized audio was written to.
89
 
90
  Raises:
91
  ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
 
101
  response = elevenlabs_config.client.text_to_voice.create_previews(
102
  voice_description=prompt,
103
  text=text,
104
+ output_format=elevenlabs_config.output_format,
105
  )
106
 
107
  previews = response.previews
 
110
  logger.error(msg)
111
  raise ElevenLabsError(message=msg)
112
 
113
+ # Extract the base64 encoded audio and generated voice ID from the preview
114
  preview = random.choice(previews)
115
+ generated_voice_id = preview.generated_voice_id
116
  base64_audio = preview.audio_base_64
117
+ filename = f"{generated_voice_id}.mp3"
118
+
119
+ # Write audio to file and return the relative path
120
+ return save_base64_audio_to_file(base64_audio, filename)
121
 
122
  except Exception as e:
123
  logger.exception(f"Error synthesizing speech with ElevenLabs: {e}")
src/integrations/hume_api.py CHANGED
@@ -19,11 +19,11 @@ Functions:
19
  """
20
 
21
  # Standard Library Imports
22
- import base64
23
  from dataclasses import dataclass
24
  import logging
 
25
  import random
26
- from typing import List, Literal, Optional, Tuple
27
 
28
  # Third-Party Library Imports
29
  import requests
@@ -31,7 +31,11 @@ from tenacity import retry, stop_after_attempt, wait_fixed, before_log, after_lo
31
 
32
  # Local Application Imports
33
  from src.config import logger
34
- from src.utils import validate_env_var, truncate_text
 
 
 
 
35
 
36
 
37
  @dataclass(frozen=True)
@@ -41,6 +45,7 @@ class HumeConfig:
41
  api_key: str = validate_env_var("HUME_API_KEY")
42
  url: str = "https://test-api.hume.ai/v0/tts/octave"
43
  headers: dict = None
 
44
 
45
  def __post_init__(self):
46
  # Validate required attributes
@@ -48,6 +53,8 @@ class HumeConfig:
48
  raise ValueError("Hume API key is not set.")
49
  if not self.url:
50
  raise ValueError("Hume TTS endpoint URL is not set.")
 
 
51
 
52
  # Set headers dynamically after validation
53
  object.__setattr__(
@@ -81,14 +88,14 @@ hume_config = HumeConfig()
81
  )
82
  def text_to_speech_with_hume(prompt: str, text: str) -> bytes:
83
  """
84
- Synthesizes text to speech using the Hume TTS API and processes raw binary audio data.
85
 
86
  Args:
87
  prompt (str): The original user prompt to use as the description for generating the voice.
88
  text (str): The generated text to be converted to speech.
89
 
90
  Returns:
91
- bytes: The raw binary audio data for playback.
92
 
93
  Raises:
94
  HumeError: If there is an error communicating with the Hume TTS API or parsing the response.
@@ -108,24 +115,25 @@ def text_to_speech_with_hume(prompt: str, text: str) -> bytes:
108
  )
109
  response.raise_for_status()
110
  response_data = response.json()
111
- except requests.RequestException as re:
112
- request_error_msg = f"Error communicating with Hume TTS API: {re}"
113
- logger.exception(request_error_msg)
114
- raise HumeError(request_error_msg) from re
115
 
116
- try:
117
- # Safely extract the generation result from the response JSON
118
- generations = response_data.get("generations", [])
119
  if not generations:
120
- logger.error("Missing 'audio' data in the response.")
121
- raise HumeError("Missing audio data in response from Hume TTS API")
 
 
 
122
  generation = generations[0]
 
123
  base64_audio = generation.get("audio")
124
- # Decode base64 encoded audio
125
- audio = base64.b64decode(base64_audio)
126
- except (KeyError, TypeError, base64.binascii.Error) as ae:
127
- logger.exception(f"Error processing audio data: {ae}")
128
- raise HumeError(f"Error processing audio data from Hume TTS API: {ae}") from ae
129
-
130
- logger.info(f"Received audio data from Hume ({len(audio)} bytes).")
131
- return audio
 
 
 
 
19
  """
20
 
21
  # Standard Library Imports
 
22
  from dataclasses import dataclass
23
  import logging
24
+ import os
25
  import random
26
+ from typing import Literal, Optional
27
 
28
  # Third-Party Library Imports
29
  import requests
 
31
 
32
  # Local Application Imports
33
  from src.config import logger
34
+ from src.utils import save_base64_audio_to_file, validate_env_var
35
+
36
+
37
+ HumeSupportedFileFormat = Literal["mp3", "pcm", "wav"]
38
+ """ Support audio file formats for the Hume TTS API"""
39
 
40
 
41
  @dataclass(frozen=True)
 
45
  api_key: str = validate_env_var("HUME_API_KEY")
46
  url: str = "https://test-api.hume.ai/v0/tts/octave"
47
  headers: dict = None
48
+ file_format: HumeSupportedFileFormat = "mp3"
49
 
50
  def __post_init__(self):
51
  # Validate required attributes
 
53
  raise ValueError("Hume API key is not set.")
54
  if not self.url:
55
  raise ValueError("Hume TTS endpoint URL is not set.")
56
+ if not self.file_format:
57
+ raise ValueError("Hume TTS file format is not set.")
58
 
59
  # Set headers dynamically after validation
60
  object.__setattr__(
 
88
  )
89
  def text_to_speech_with_hume(prompt: str, text: str) -> bytes:
90
  """
91
+ Synthesizes text to speech using the Hume TTS API, processes audio data, and writes audio to a file.
92
 
93
  Args:
94
  prompt (str): The original user prompt to use as the description for generating the voice.
95
  text (str): The generated text to be converted to speech.
96
 
97
  Returns:
98
+ str: The relative path for the file the synthesized audio was written to.
99
 
100
  Raises:
101
  HumeError: If there is an error communicating with the Hume TTS API or parsing the response.
 
115
  )
116
  response.raise_for_status()
117
  response_data = response.json()
 
 
 
 
118
 
119
+ generations = response_data.get("generations")
 
 
120
  if not generations:
121
+ msg = "No generations returned by Hume API."
122
+ logger.error(msg)
123
+ raise HumeError(msg)
124
+
125
+ # Extract the base64 encoded audio and generation ID from the generation
126
  generation = generations[0]
127
+ generation_id = generation.get("generation_id")
128
  base64_audio = generation.get("audio")
129
+ filename = f"{generation_id}.mp3"
130
+
131
+ # Write audio to file and return the relative path
132
+ return save_base64_audio_to_file(base64_audio, filename)
133
+
134
+ except Exception as e:
135
+ logger.exception(f"Error synthesizing speech with Hume: {e}")
136
+ raise HumeError(
137
+ message=f"Failed to synthesize speech with Hume: {e}",
138
+ original_exception=e,
139
+ ) from e
src/types.py CHANGED
@@ -7,7 +7,7 @@ has a consistent structure including both the provider and the associated voice.
7
  """
8
 
9
  # Standard Library Imports
10
- from typing import TypedDict, Literal, Dict
11
 
12
 
13
  TTSProviderName = Literal["Hume AI", "ElevenLabs"]
 
7
  """
8
 
9
  # Standard Library Imports
10
+ from typing import Literal, Dict
11
 
12
 
13
  TTSProviderName = Literal["Hume AI", "ElevenLabs"]
src/utils.py CHANGED
@@ -11,10 +11,11 @@ Functions:
11
  """
12
 
13
  # Standard Library Imports
 
14
  import os
15
 
16
  # Local Application Imports
17
- from src.config import logger
18
 
19
 
20
  def truncate_text(text: str, max_length: int = 50) -> str:
@@ -116,3 +117,44 @@ def validate_prompt_length(prompt: str, max_length: int, min_length: int) -> Non
116
  logger.debug(
117
  f"Prompt length validation passed for prompt: {truncate_text(stripped_prompt)}"
118
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
 
13
  # Standard Library Imports
14
+ import base64
15
  import os
16
 
17
  # Local Application Imports
18
+ from src.config import AUDIO_DIR, logger
19
 
20
 
21
  def truncate_text(text: str, max_length: int = 50) -> str:
 
117
  logger.debug(
118
  f"Prompt length validation passed for prompt: {truncate_text(stripped_prompt)}"
119
  )
120
+
121
+
122
+ def save_base64_audio_to_file(base64_audio: str, filename: str) -> str:
123
+ """
124
+ Decode a base64-encoded audio string and write the resulting binary data to a file
125
+ within the preconfigured AUDIO_DIR directory. This function verifies the file was created,
126
+ logs the absolute and relative file paths, and returns a path relative to the current
127
+ working directory (which is what Gradio requires to serve static files).
128
+
129
+ Args:
130
+ base64_audio (str): The base64-encoded string representing the audio data.
131
+ filename (str): The name of the file (including extension, e.g.,
132
+ 'b4a335da-9786-483a-b0a5-37e6e4ad5fd1.mp3') where the decoded
133
+ audio will be saved.
134
+
135
+ Returns:
136
+ str: The relative file path to the saved audio file.
137
+
138
+ Raises:
139
+ Exception: Propagates any exceptions raised during the decoding or file I/O operations.
140
+ """
141
+ # Decode the base64-encoded audio into binary data.
142
+ audio_bytes = base64.b64decode(base64_audio)
143
+
144
+ # Construct the full absolute file path within the AUDIO_DIR directory.
145
+ file_path = os.path.join(AUDIO_DIR, filename)
146
+
147
+ # Write the binary audio data to the file.
148
+ with open(file_path, "wb") as audio_file:
149
+ audio_file.write(audio_bytes)
150
+
151
+ # Verify that the file was created.
152
+ if not os.path.exists(file_path):
153
+ raise FileNotFoundError(f"Audio file was not created at {file_path}")
154
+
155
+ # Compute a relative path for Gradio to serve (relative to the project root).
156
+ relative_path = os.path.relpath(file_path, os.getcwd())
157
+ logger.debug(f"Audio file absolute path: {file_path}")
158
+ logger.debug(f"Audio file relative path: {relative_path}")
159
+
160
+ return relative_path