Spaces:
Running
Running
zach
commited on
Commit
·
0e508c8
1
Parent(s):
83c6aee
Refactor tts integration functions to write audio to file and return file path, audio players to play mp3 file written to temp folder, fix audioplayer loading, remove unused imports
Browse files- .gitignore +1 -1
- src/app.py +9 -20
- src/config.py +6 -1
- src/constants.py +0 -1
- src/integrations/anthropic_api.py +16 -13
- src/integrations/elevenlabs_api.py +13 -9
- src/integrations/hume_api.py +30 -22
- src/types.py +1 -1
- src/utils.py +43 -1
.gitignore
CHANGED
@@ -38,4 +38,4 @@ Thumbs.db
|
|
38 |
*.cache
|
39 |
|
40 |
# Temp files
|
41 |
-
|
|
|
38 |
*.cache
|
39 |
|
40 |
# Temp files
|
41 |
+
static/audio/*
|
src/app.py
CHANGED
@@ -11,13 +11,14 @@ Users can compare the outputs and vote for their favorite in an interactive UI.
|
|
11 |
# Standard Library Imports
|
12 |
from concurrent.futures import ThreadPoolExecutor
|
13 |
import random
|
|
|
14 |
from typing import Union, Tuple
|
15 |
|
16 |
# Third-Party Library Imports
|
17 |
import gradio as gr
|
18 |
|
19 |
# Local Application Imports
|
20 |
-
from src.config import logger
|
21 |
from src.constants import (
|
22 |
ELEVENLABS,
|
23 |
HUME_AI,
|
@@ -27,7 +28,6 @@ from src.constants import (
|
|
27 |
PROMPT_MIN_LENGTH,
|
28 |
SAMPLE_PROMPTS,
|
29 |
TROPHY_EMOJI,
|
30 |
-
UNKNOWN_PROVIDER,
|
31 |
VOTE_FOR_OPTION_A,
|
32 |
VOTE_FOR_OPTION_B,
|
33 |
)
|
@@ -41,7 +41,7 @@ from src.integrations import (
|
|
41 |
)
|
42 |
from src.theme import CustomTheme
|
43 |
from src.types import OptionMap
|
44 |
-
from src.utils import
|
45 |
|
46 |
|
47 |
def generate_text(
|
@@ -130,13 +130,7 @@ def text_to_speech(
|
|
130 |
audio_a = future_audio_a.result()
|
131 |
audio_b = future_audio_b.result()
|
132 |
|
133 |
-
|
134 |
-
f"TTS generated: {provider_a}={len(audio_a)} bytes, {provider_b}={len(audio_b)} bytes"
|
135 |
-
)
|
136 |
-
options = [
|
137 |
-
(audio_a, provider_a),
|
138 |
-
(audio_b, provider_b),
|
139 |
-
]
|
140 |
random.shuffle(options)
|
141 |
option_a_audio, option_b_audio = options[0][0], options[1][0]
|
142 |
options_map: OptionMap = {OPTION_A: options[0][1], OPTION_B: options[1][1]}
|
@@ -444,16 +438,11 @@ def build_gradio_interface() -> gr.Blocks:
|
|
444 |
],
|
445 |
)
|
446 |
|
447 |
-
#
|
448 |
-
# Audio player A stop event handler chain:
|
449 |
-
# 1. Clear the audio player A
|
450 |
-
# 2. Load audio player A with audio and set auto play to True
|
451 |
option_a_audio_player.stop(
|
452 |
-
fn=lambda
|
453 |
-
|
454 |
-
|
455 |
-
).then(
|
456 |
-
fn=lambda audio: gr.update(value=audio, autoplay=True),
|
457 |
inputs=[option_b_audio_state],
|
458 |
outputs=[option_b_audio_player],
|
459 |
)
|
@@ -476,4 +465,4 @@ def build_gradio_interface() -> gr.Blocks:
|
|
476 |
if __name__ == "__main__":
|
477 |
logger.info("Launching TTS Arena Gradio app...")
|
478 |
demo = build_gradio_interface()
|
479 |
-
demo.launch()
|
|
|
11 |
# Standard Library Imports
|
12 |
from concurrent.futures import ThreadPoolExecutor
|
13 |
import random
|
14 |
+
import time
|
15 |
from typing import Union, Tuple
|
16 |
|
17 |
# Third-Party Library Imports
|
18 |
import gradio as gr
|
19 |
|
20 |
# Local Application Imports
|
21 |
+
from src.config import AUDIO_DIR, logger
|
22 |
from src.constants import (
|
23 |
ELEVENLABS,
|
24 |
HUME_AI,
|
|
|
28 |
PROMPT_MIN_LENGTH,
|
29 |
SAMPLE_PROMPTS,
|
30 |
TROPHY_EMOJI,
|
|
|
31 |
VOTE_FOR_OPTION_A,
|
32 |
VOTE_FOR_OPTION_B,
|
33 |
)
|
|
|
41 |
)
|
42 |
from src.theme import CustomTheme
|
43 |
from src.types import OptionMap
|
44 |
+
from src.utils import validate_prompt_length
|
45 |
|
46 |
|
47 |
def generate_text(
|
|
|
130 |
audio_a = future_audio_a.result()
|
131 |
audio_b = future_audio_b.result()
|
132 |
|
133 |
+
options = [(audio_a, provider_a), (audio_b, provider_b)]
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
random.shuffle(options)
|
135 |
option_a_audio, option_b_audio = options[0][0], options[1][0]
|
136 |
options_map: OptionMap = {OPTION_A: options[0][1], OPTION_B: options[1][1]}
|
|
|
438 |
],
|
439 |
)
|
440 |
|
441 |
+
# Reload audio player B with audio and set autoplay to True (workaround to play audio back-to-back)
|
|
|
|
|
|
|
442 |
option_a_audio_player.stop(
|
443 |
+
fn=lambda current_audio_path: gr.update(
|
444 |
+
value=f"{current_audio_path}?t={int(time.time())}", autoplay=True
|
445 |
+
),
|
|
|
|
|
446 |
inputs=[option_b_audio_state],
|
447 |
outputs=[option_b_audio_player],
|
448 |
)
|
|
|
465 |
if __name__ == "__main__":
|
466 |
logger.info("Launching TTS Arena Gradio app...")
|
467 |
demo = build_gradio_interface()
|
468 |
+
demo.launch(allowed_paths=[AUDIO_DIR])
|
src/config.py
CHANGED
@@ -35,6 +35,11 @@ logging.basicConfig(
|
|
35 |
)
|
36 |
logger: logging.Logger = logging.getLogger("tts_arena")
|
37 |
logger.info(f'Debug mode is {"enabled" if DEBUG else "disabled"}.')
|
38 |
-
|
39 |
if DEBUG:
|
40 |
logger.debug(f"DEBUG mode enabled.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
)
|
36 |
logger: logging.Logger = logging.getLogger("tts_arena")
|
37 |
logger.info(f'Debug mode is {"enabled" if DEBUG else "disabled"}.')
|
|
|
38 |
if DEBUG:
|
39 |
logger.debug(f"DEBUG mode enabled.")
|
40 |
+
|
41 |
+
|
42 |
+
# Define the directory for audio files relative to the project root
|
43 |
+
AUDIO_DIR = os.path.join(os.getcwd(), "static", "audio")
|
44 |
+
os.makedirs(AUDIO_DIR, exist_ok=True)
|
45 |
+
logger.info(f"Audio directory set to {AUDIO_DIR}")
|
src/constants.py
CHANGED
@@ -11,7 +11,6 @@ from src.types import OptionKey, TTSProviderName
|
|
11 |
# UI constants
|
12 |
HUME_AI: TTSProviderName = "Hume AI"
|
13 |
ELEVENLABS: TTSProviderName = "ElevenLabs"
|
14 |
-
UNKNOWN_PROVIDER: TTSProviderName = "Unknown"
|
15 |
|
16 |
PROMPT_MIN_LENGTH: int = 20
|
17 |
PROMPT_MAX_LENGTH: int = 800
|
|
|
11 |
# UI constants
|
12 |
HUME_AI: TTSProviderName = "Hume AI"
|
13 |
ELEVENLABS: TTSProviderName = "ElevenLabs"
|
|
|
14 |
|
15 |
PROMPT_MIN_LENGTH: int = 20
|
16 |
PROMPT_MAX_LENGTH: int = 800
|
src/integrations/anthropic_api.py
CHANGED
@@ -40,10 +40,23 @@ class AnthropicConfig:
|
|
40 |
api_key: str = validate_env_var("ANTHROPIC_API_KEY")
|
41 |
model: ModelParam = "claude-3-5-sonnet-latest"
|
42 |
max_tokens: int = 150
|
43 |
-
system_prompt: str =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
CRITICAL LENGTH CONSTRAINTS:
|
45 |
|
46 |
-
Maximum length: {max_tokens} tokens (approximately 400 characters)
|
47 |
You MUST complete all thoughts and sentences
|
48 |
Responses should be 25% shorter than you initially plan
|
49 |
Never exceed 400 characters total
|
@@ -73,17 +86,7 @@ Resolution (75-100 characters)
|
|
73 |
|
74 |
MANDATORY: If you find yourself reaching 300 characters, immediately begin your conclusion regardless of where you are in the narrative.
|
75 |
Remember: A shorter, complete response is ALWAYS better than a longer, truncated one."""
|
76 |
-
|
77 |
-
def __post_init__(self):
|
78 |
-
# Validate that required attributes are set
|
79 |
-
if not self.api_key:
|
80 |
-
raise ValueError("Anthropic API key is not set.")
|
81 |
-
if not self.model:
|
82 |
-
raise ValueError("Anthropic Model is not set.")
|
83 |
-
if not self.max_tokens:
|
84 |
-
raise ValueError("Anthropic Max Tokens is not set.")
|
85 |
-
if not self.system_prompt:
|
86 |
-
raise ValueError("Anthropic System Prompt is not set.")
|
87 |
|
88 |
@property
|
89 |
def client(self) -> Anthropic:
|
|
|
40 |
api_key: str = validate_env_var("ANTHROPIC_API_KEY")
|
41 |
model: ModelParam = "claude-3-5-sonnet-latest"
|
42 |
max_tokens: int = 150
|
43 |
+
system_prompt: Optional[str] = (
|
44 |
+
None # system prompt is set post initialization, since self.max_tokens is leveraged in the prompt.
|
45 |
+
)
|
46 |
+
|
47 |
+
def __post_init__(self):
|
48 |
+
# Validate that required attributes are set
|
49 |
+
if not self.api_key:
|
50 |
+
raise ValueError("Anthropic API key is not set.")
|
51 |
+
if not self.model:
|
52 |
+
raise ValueError("Anthropic Model is not set.")
|
53 |
+
if not self.max_tokens:
|
54 |
+
raise ValueError("Anthropic Max Tokens is not set.")
|
55 |
+
if self.system_prompt is None:
|
56 |
+
system_prompt: str = f"""You are an expert at generating micro-content optimized for text-to-speech synthesis. Your absolute priority is delivering complete, untruncated responses within strict length limits.
|
57 |
CRITICAL LENGTH CONSTRAINTS:
|
58 |
|
59 |
+
Maximum length: {self.max_tokens} tokens (approximately 400 characters)
|
60 |
You MUST complete all thoughts and sentences
|
61 |
Responses should be 25% shorter than you initially plan
|
62 |
Never exceed 400 characters total
|
|
|
86 |
|
87 |
MANDATORY: If you find yourself reaching 300 characters, immediately begin your conclusion regardless of where you are in the narrative.
|
88 |
Remember: A shorter, complete response is ALWAYS better than a longer, truncated one."""
|
89 |
+
object.__setattr__(self, "system_prompt", system_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
@property
|
92 |
def client(self) -> Anthropic:
|
src/integrations/elevenlabs_api.py
CHANGED
@@ -20,20 +20,18 @@ Functions:
|
|
20 |
"""
|
21 |
|
22 |
# Standard Library Imports
|
23 |
-
import base64
|
24 |
from dataclasses import dataclass
|
25 |
-
from enum import Enum
|
26 |
import logging
|
27 |
import random
|
28 |
-
from typing import
|
29 |
|
30 |
# Third-Party Library Imports
|
31 |
-
from elevenlabs import ElevenLabs
|
32 |
from tenacity import retry, stop_after_attempt, wait_fixed, before_log, after_log
|
33 |
|
34 |
# Local Application Imports
|
35 |
from src.config import logger
|
36 |
-
from src.utils import validate_env_var
|
37 |
|
38 |
|
39 |
@dataclass(frozen=True)
|
@@ -41,6 +39,7 @@ class ElevenLabsConfig:
|
|
41 |
"""Immutable configuration for interacting with the ElevenLabs TTS API."""
|
42 |
|
43 |
api_key: str = validate_env_var("ELEVENLABS_API_KEY")
|
|
|
44 |
|
45 |
def __post_init__(self):
|
46 |
# Validate that required attributes are set
|
@@ -79,14 +78,14 @@ elevenlabs_config = ElevenLabsConfig()
|
|
79 |
)
|
80 |
def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
|
81 |
"""
|
82 |
-
Synthesizes text to speech using the ElevenLabs TTS API.
|
83 |
|
84 |
Args:
|
85 |
prompt (str): The original user prompt used as the voice description.
|
86 |
text (str): The text to be synthesized to speech.
|
87 |
|
88 |
Returns:
|
89 |
-
|
90 |
|
91 |
Raises:
|
92 |
ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
|
@@ -102,6 +101,7 @@ def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
|
|
102 |
response = elevenlabs_config.client.text_to_voice.create_previews(
|
103 |
voice_description=prompt,
|
104 |
text=text,
|
|
|
105 |
)
|
106 |
|
107 |
previews = response.previews
|
@@ -110,10 +110,14 @@ def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
|
|
110 |
logger.error(msg)
|
111 |
raise ElevenLabsError(message=msg)
|
112 |
|
|
|
113 |
preview = random.choice(previews)
|
|
|
114 |
base64_audio = preview.audio_base_64
|
115 |
-
|
116 |
-
|
|
|
|
|
117 |
|
118 |
except Exception as e:
|
119 |
logger.exception(f"Error synthesizing speech with ElevenLabs: {e}")
|
|
|
20 |
"""
|
21 |
|
22 |
# Standard Library Imports
|
|
|
23 |
from dataclasses import dataclass
|
|
|
24 |
import logging
|
25 |
import random
|
26 |
+
from typing import Optional
|
27 |
|
28 |
# Third-Party Library Imports
|
29 |
+
from elevenlabs import ElevenLabs, TextToVoiceCreatePreviewsRequestOutputFormat
|
30 |
from tenacity import retry, stop_after_attempt, wait_fixed, before_log, after_log
|
31 |
|
32 |
# Local Application Imports
|
33 |
from src.config import logger
|
34 |
+
from src.utils import save_base64_audio_to_file, validate_env_var
|
35 |
|
36 |
|
37 |
@dataclass(frozen=True)
|
|
|
39 |
"""Immutable configuration for interacting with the ElevenLabs TTS API."""
|
40 |
|
41 |
api_key: str = validate_env_var("ELEVENLABS_API_KEY")
|
42 |
+
output_format: TextToVoiceCreatePreviewsRequestOutputFormat = "mp3_44100_128"
|
43 |
|
44 |
def __post_init__(self):
|
45 |
# Validate that required attributes are set
|
|
|
78 |
)
|
79 |
def text_to_speech_with_elevenlabs(prompt: str, text: str) -> bytes:
|
80 |
"""
|
81 |
+
Synthesizes text to speech using the ElevenLabs TTS API, processes audio data, and writes audio to a file.
|
82 |
|
83 |
Args:
|
84 |
prompt (str): The original user prompt used as the voice description.
|
85 |
text (str): The text to be synthesized to speech.
|
86 |
|
87 |
Returns:
|
88 |
+
str: The relative path for the file the synthesized audio was written to.
|
89 |
|
90 |
Raises:
|
91 |
ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
|
|
|
101 |
response = elevenlabs_config.client.text_to_voice.create_previews(
|
102 |
voice_description=prompt,
|
103 |
text=text,
|
104 |
+
output_format=elevenlabs_config.output_format,
|
105 |
)
|
106 |
|
107 |
previews = response.previews
|
|
|
110 |
logger.error(msg)
|
111 |
raise ElevenLabsError(message=msg)
|
112 |
|
113 |
+
# Extract the base64 encoded audio and generated voice ID from the preview
|
114 |
preview = random.choice(previews)
|
115 |
+
generated_voice_id = preview.generated_voice_id
|
116 |
base64_audio = preview.audio_base_64
|
117 |
+
filename = f"{generated_voice_id}.mp3"
|
118 |
+
|
119 |
+
# Write audio to file and return the relative path
|
120 |
+
return save_base64_audio_to_file(base64_audio, filename)
|
121 |
|
122 |
except Exception as e:
|
123 |
logger.exception(f"Error synthesizing speech with ElevenLabs: {e}")
|
src/integrations/hume_api.py
CHANGED
@@ -19,11 +19,11 @@ Functions:
|
|
19 |
"""
|
20 |
|
21 |
# Standard Library Imports
|
22 |
-
import base64
|
23 |
from dataclasses import dataclass
|
24 |
import logging
|
|
|
25 |
import random
|
26 |
-
from typing import
|
27 |
|
28 |
# Third-Party Library Imports
|
29 |
import requests
|
@@ -31,7 +31,11 @@ from tenacity import retry, stop_after_attempt, wait_fixed, before_log, after_lo
|
|
31 |
|
32 |
# Local Application Imports
|
33 |
from src.config import logger
|
34 |
-
from src.utils import
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
@dataclass(frozen=True)
|
@@ -41,6 +45,7 @@ class HumeConfig:
|
|
41 |
api_key: str = validate_env_var("HUME_API_KEY")
|
42 |
url: str = "https://test-api.hume.ai/v0/tts/octave"
|
43 |
headers: dict = None
|
|
|
44 |
|
45 |
def __post_init__(self):
|
46 |
# Validate required attributes
|
@@ -48,6 +53,8 @@ class HumeConfig:
|
|
48 |
raise ValueError("Hume API key is not set.")
|
49 |
if not self.url:
|
50 |
raise ValueError("Hume TTS endpoint URL is not set.")
|
|
|
|
|
51 |
|
52 |
# Set headers dynamically after validation
|
53 |
object.__setattr__(
|
@@ -81,14 +88,14 @@ hume_config = HumeConfig()
|
|
81 |
)
|
82 |
def text_to_speech_with_hume(prompt: str, text: str) -> bytes:
|
83 |
"""
|
84 |
-
Synthesizes text to speech using the Hume TTS API
|
85 |
|
86 |
Args:
|
87 |
prompt (str): The original user prompt to use as the description for generating the voice.
|
88 |
text (str): The generated text to be converted to speech.
|
89 |
|
90 |
Returns:
|
91 |
-
|
92 |
|
93 |
Raises:
|
94 |
HumeError: If there is an error communicating with the Hume TTS API or parsing the response.
|
@@ -108,24 +115,25 @@ def text_to_speech_with_hume(prompt: str, text: str) -> bytes:
|
|
108 |
)
|
109 |
response.raise_for_status()
|
110 |
response_data = response.json()
|
111 |
-
except requests.RequestException as re:
|
112 |
-
request_error_msg = f"Error communicating with Hume TTS API: {re}"
|
113 |
-
logger.exception(request_error_msg)
|
114 |
-
raise HumeError(request_error_msg) from re
|
115 |
|
116 |
-
|
117 |
-
# Safely extract the generation result from the response JSON
|
118 |
-
generations = response_data.get("generations", [])
|
119 |
if not generations:
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
122 |
generation = generations[0]
|
|
|
123 |
base64_audio = generation.get("audio")
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
|
21 |
# Standard Library Imports
|
|
|
22 |
from dataclasses import dataclass
|
23 |
import logging
|
24 |
+
import os
|
25 |
import random
|
26 |
+
from typing import Literal, Optional
|
27 |
|
28 |
# Third-Party Library Imports
|
29 |
import requests
|
|
|
31 |
|
32 |
# Local Application Imports
|
33 |
from src.config import logger
|
34 |
+
from src.utils import save_base64_audio_to_file, validate_env_var
|
35 |
+
|
36 |
+
|
37 |
+
HumeSupportedFileFormat = Literal["mp3", "pcm", "wav"]
|
38 |
+
""" Support audio file formats for the Hume TTS API"""
|
39 |
|
40 |
|
41 |
@dataclass(frozen=True)
|
|
|
45 |
api_key: str = validate_env_var("HUME_API_KEY")
|
46 |
url: str = "https://test-api.hume.ai/v0/tts/octave"
|
47 |
headers: dict = None
|
48 |
+
file_format: HumeSupportedFileFormat = "mp3"
|
49 |
|
50 |
def __post_init__(self):
|
51 |
# Validate required attributes
|
|
|
53 |
raise ValueError("Hume API key is not set.")
|
54 |
if not self.url:
|
55 |
raise ValueError("Hume TTS endpoint URL is not set.")
|
56 |
+
if not self.file_format:
|
57 |
+
raise ValueError("Hume TTS file format is not set.")
|
58 |
|
59 |
# Set headers dynamically after validation
|
60 |
object.__setattr__(
|
|
|
88 |
)
|
89 |
def text_to_speech_with_hume(prompt: str, text: str) -> bytes:
|
90 |
"""
|
91 |
+
Synthesizes text to speech using the Hume TTS API, processes audio data, and writes audio to a file.
|
92 |
|
93 |
Args:
|
94 |
prompt (str): The original user prompt to use as the description for generating the voice.
|
95 |
text (str): The generated text to be converted to speech.
|
96 |
|
97 |
Returns:
|
98 |
+
str: The relative path for the file the synthesized audio was written to.
|
99 |
|
100 |
Raises:
|
101 |
HumeError: If there is an error communicating with the Hume TTS API or parsing the response.
|
|
|
115 |
)
|
116 |
response.raise_for_status()
|
117 |
response_data = response.json()
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
generations = response_data.get("generations")
|
|
|
|
|
120 |
if not generations:
|
121 |
+
msg = "No generations returned by Hume API."
|
122 |
+
logger.error(msg)
|
123 |
+
raise HumeError(msg)
|
124 |
+
|
125 |
+
# Extract the base64 encoded audio and generation ID from the generation
|
126 |
generation = generations[0]
|
127 |
+
generation_id = generation.get("generation_id")
|
128 |
base64_audio = generation.get("audio")
|
129 |
+
filename = f"{generation_id}.mp3"
|
130 |
+
|
131 |
+
# Write audio to file and return the relative path
|
132 |
+
return save_base64_audio_to_file(base64_audio, filename)
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
logger.exception(f"Error synthesizing speech with Hume: {e}")
|
136 |
+
raise HumeError(
|
137 |
+
message=f"Failed to synthesize speech with Hume: {e}",
|
138 |
+
original_exception=e,
|
139 |
+
) from e
|
src/types.py
CHANGED
@@ -7,7 +7,7 @@ has a consistent structure including both the provider and the associated voice.
|
|
7 |
"""
|
8 |
|
9 |
# Standard Library Imports
|
10 |
-
from typing import
|
11 |
|
12 |
|
13 |
TTSProviderName = Literal["Hume AI", "ElevenLabs"]
|
|
|
7 |
"""
|
8 |
|
9 |
# Standard Library Imports
|
10 |
+
from typing import Literal, Dict
|
11 |
|
12 |
|
13 |
TTSProviderName = Literal["Hume AI", "ElevenLabs"]
|
src/utils.py
CHANGED
@@ -11,10 +11,11 @@ Functions:
|
|
11 |
"""
|
12 |
|
13 |
# Standard Library Imports
|
|
|
14 |
import os
|
15 |
|
16 |
# Local Application Imports
|
17 |
-
from src.config import logger
|
18 |
|
19 |
|
20 |
def truncate_text(text: str, max_length: int = 50) -> str:
|
@@ -116,3 +117,44 @@ def validate_prompt_length(prompt: str, max_length: int, min_length: int) -> Non
|
|
116 |
logger.debug(
|
117 |
f"Prompt length validation passed for prompt: {truncate_text(stripped_prompt)}"
|
118 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
|
13 |
# Standard Library Imports
|
14 |
+
import base64
|
15 |
import os
|
16 |
|
17 |
# Local Application Imports
|
18 |
+
from src.config import AUDIO_DIR, logger
|
19 |
|
20 |
|
21 |
def truncate_text(text: str, max_length: int = 50) -> str:
|
|
|
117 |
logger.debug(
|
118 |
f"Prompt length validation passed for prompt: {truncate_text(stripped_prompt)}"
|
119 |
)
|
120 |
+
|
121 |
+
|
122 |
+
def save_base64_audio_to_file(base64_audio: str, filename: str) -> str:
|
123 |
+
"""
|
124 |
+
Decode a base64-encoded audio string and write the resulting binary data to a file
|
125 |
+
within the preconfigured AUDIO_DIR directory. This function verifies the file was created,
|
126 |
+
logs the absolute and relative file paths, and returns a path relative to the current
|
127 |
+
working directory (which is what Gradio requires to serve static files).
|
128 |
+
|
129 |
+
Args:
|
130 |
+
base64_audio (str): The base64-encoded string representing the audio data.
|
131 |
+
filename (str): The name of the file (including extension, e.g.,
|
132 |
+
'b4a335da-9786-483a-b0a5-37e6e4ad5fd1.mp3') where the decoded
|
133 |
+
audio will be saved.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
str: The relative file path to the saved audio file.
|
137 |
+
|
138 |
+
Raises:
|
139 |
+
Exception: Propagates any exceptions raised during the decoding or file I/O operations.
|
140 |
+
"""
|
141 |
+
# Decode the base64-encoded audio into binary data.
|
142 |
+
audio_bytes = base64.b64decode(base64_audio)
|
143 |
+
|
144 |
+
# Construct the full absolute file path within the AUDIO_DIR directory.
|
145 |
+
file_path = os.path.join(AUDIO_DIR, filename)
|
146 |
+
|
147 |
+
# Write the binary audio data to the file.
|
148 |
+
with open(file_path, "wb") as audio_file:
|
149 |
+
audio_file.write(audio_bytes)
|
150 |
+
|
151 |
+
# Verify that the file was created.
|
152 |
+
if not os.path.exists(file_path):
|
153 |
+
raise FileNotFoundError(f"Audio file was not created at {file_path}")
|
154 |
+
|
155 |
+
# Compute a relative path for Gradio to serve (relative to the project root).
|
156 |
+
relative_path = os.path.relpath(file_path, os.getcwd())
|
157 |
+
logger.debug(f"Audio file absolute path: {file_path}")
|
158 |
+
logger.debug(f"Audio file relative path: {relative_path}")
|
159 |
+
|
160 |
+
return relative_path
|