Spaces:
Running
Running
zach
commited on
Commit
·
104737f
1
Parent(s):
80026d8
Update API integration code to be async
Browse files- pyproject.toml +1 -1
- src/app.py +31 -78
- src/constants.py +1 -1
- src/database/database.py +15 -14
- src/integrations/anthropic_api.py +4 -1
- src/integrations/elevenlabs_api.py +18 -24
- src/integrations/hume_api.py +58 -73
- src/main.py +12 -2
- uv.lock +0 -2
pyproject.toml
CHANGED
@@ -12,7 +12,6 @@ dependencies = [
|
|
12 |
"greenlet>=2.0.0",
|
13 |
"httpx>=0.24.1",
|
14 |
"python-dotenv>=1.0.1",
|
15 |
-
"requests>=2.32.3",
|
16 |
"sqlalchemy>=2.0.0",
|
17 |
"tenacity>=9.0.0",
|
18 |
]
|
@@ -45,6 +44,7 @@ ignore = [
|
|
45 |
"PLR0912",
|
46 |
"PLR0913",
|
47 |
"PLR2004",
|
|
|
48 |
"TD002",
|
49 |
"TD003",
|
50 |
]
|
|
|
12 |
"greenlet>=2.0.0",
|
13 |
"httpx>=0.24.1",
|
14 |
"python-dotenv>=1.0.1",
|
|
|
15 |
"sqlalchemy>=2.0.0",
|
16 |
"tenacity>=9.0.0",
|
17 |
]
|
|
|
44 |
"PLR0912",
|
45 |
"PLR0913",
|
46 |
"PLR2004",
|
47 |
+
"RUF006",
|
48 |
"TD002",
|
49 |
"TD003",
|
50 |
]
|
src/app.py
CHANGED
@@ -10,9 +10,7 @@ Users can compare the outputs and vote for their favorite in an interactive UI.
|
|
10 |
|
11 |
# Standard Library Imports
|
12 |
import asyncio
|
13 |
-
import threading
|
14 |
import time
|
15 |
-
from concurrent.futures import ThreadPoolExecutor
|
16 |
from typing import Tuple
|
17 |
|
18 |
# Third-Party Library Imports
|
@@ -83,7 +81,7 @@ class App:
|
|
83 |
logger.error(f"Unexpected error while generating text: {e}")
|
84 |
raise gr.Error("Failed to generate text. Please try again later.")
|
85 |
|
86 |
-
def _synthesize_speech(
|
87 |
self,
|
88 |
character_description: str,
|
89 |
text: str,
|
@@ -130,38 +128,34 @@ class App:
|
|
130 |
if provider_b == constants.HUME_AI:
|
131 |
num_generations = 2
|
132 |
# If generating 2 Hume outputs, do so in a single API call.
|
133 |
-
result = text_to_speech_with_hume(character_description, text, num_generations, self.config)
|
134 |
# Enforce that 4 values are returned.
|
135 |
if not (isinstance(result, tuple) and len(result) == 4):
|
136 |
raise ValueError("Expected 4 values from Hume TTS call when generating 2 outputs")
|
137 |
generation_id_a, audio_a, generation_id_b, audio_b = result
|
138 |
else:
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
if isinstance(result_b, tuple) and len(result_b) >= 2:
|
162 |
-
generation_id_b, audio_b = result_b[0], result_b[1] # type: ignore
|
163 |
-
else:
|
164 |
-
raise ValueError("Unexpected return from text_to_speech_with_elevenlabs")
|
165 |
|
166 |
# Shuffle options so that placement of options in the UI will always be random.
|
167 |
option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
|
@@ -190,47 +184,7 @@ class App:
|
|
190 |
raise gr.Error("An unexpected error occurred. Please try again later.")
|
191 |
|
192 |
|
193 |
-
def
|
194 |
-
self,
|
195 |
-
option_map: OptionMap,
|
196 |
-
selected_option: constants.OptionKey,
|
197 |
-
text_modified: bool,
|
198 |
-
character_description: str,
|
199 |
-
text: str,
|
200 |
-
) -> None:
|
201 |
-
"""
|
202 |
-
Runs the vote submission in a background thread.
|
203 |
-
Creates a new event loop and runs the async submit_voting_results function in it.
|
204 |
-
|
205 |
-
Args:
|
206 |
-
Same as submit_voting_results
|
207 |
-
|
208 |
-
Returns:
|
209 |
-
None
|
210 |
-
"""
|
211 |
-
try:
|
212 |
-
# Create a new event loop for this thread
|
213 |
-
loop = asyncio.new_event_loop()
|
214 |
-
asyncio.set_event_loop(loop)
|
215 |
-
|
216 |
-
# Run the async function in the new loop
|
217 |
-
loop.run_until_complete(submit_voting_results(
|
218 |
-
option_map,
|
219 |
-
selected_option,
|
220 |
-
text_modified,
|
221 |
-
character_description,
|
222 |
-
text,
|
223 |
-
self.db_session_maker,
|
224 |
-
self.config,
|
225 |
-
))
|
226 |
-
except Exception as e:
|
227 |
-
logger.error(f"Error in background vote submission thread: {e}", exc_info=True)
|
228 |
-
finally:
|
229 |
-
# Close the loop when done
|
230 |
-
loop.close()
|
231 |
-
|
232 |
-
|
233 |
-
def _vote(
|
234 |
self,
|
235 |
vote_submitted: bool,
|
236 |
option_map: OptionMap,
|
@@ -261,19 +215,18 @@ class App:
|
|
261 |
selected_provider = option_map[selected_option]["provider"]
|
262 |
other_provider = option_map[other_option]["provider"]
|
263 |
|
264 |
-
#
|
265 |
-
|
266 |
-
|
267 |
-
args=(
|
268 |
option_map,
|
269 |
selected_option,
|
270 |
text_modified,
|
271 |
character_description,
|
272 |
text,
|
273 |
-
|
274 |
-
|
|
|
275 |
)
|
276 |
-
thread.start()
|
277 |
|
278 |
# Build button text, displaying the provider and voice name, appending the trophy emoji to the selected option.
|
279 |
selected_label = f"{selected_provider} {constants.TROPHY_EMOJI}"
|
|
|
10 |
|
11 |
# Standard Library Imports
|
12 |
import asyncio
|
|
|
13 |
import time
|
|
|
14 |
from typing import Tuple
|
15 |
|
16 |
# Third-Party Library Imports
|
|
|
81 |
logger.error(f"Unexpected error while generating text: {e}")
|
82 |
raise gr.Error("Failed to generate text. Please try again later.")
|
83 |
|
84 |
+
async def _synthesize_speech(
|
85 |
self,
|
86 |
character_description: str,
|
87 |
text: str,
|
|
|
128 |
if provider_b == constants.HUME_AI:
|
129 |
num_generations = 2
|
130 |
# If generating 2 Hume outputs, do so in a single API call.
|
131 |
+
result = await text_to_speech_with_hume(character_description, text, num_generations, self.config)
|
132 |
# Enforce that 4 values are returned.
|
133 |
if not (isinstance(result, tuple) and len(result) == 4):
|
134 |
raise ValueError("Expected 4 values from Hume TTS call when generating 2 outputs")
|
135 |
generation_id_a, audio_a, generation_id_b, audio_b = result
|
136 |
else:
|
137 |
+
num_generations = 1
|
138 |
+
# Run both API calls concurrently using asyncio
|
139 |
+
tasks = []
|
140 |
+
# Generate a single Hume output
|
141 |
+
tasks.append(text_to_speech_with_hume(character_description, text, num_generations, self.config))
|
142 |
+
|
143 |
+
# Generate a second TTS output from the second provider
|
144 |
+
match provider_b:
|
145 |
+
case constants.ELEVENLABS:
|
146 |
+
tasks.append(text_to_speech_with_elevenlabs(character_description, text, self.config))
|
147 |
+
case _:
|
148 |
+
# Additional TTS Providers can be added here.
|
149 |
+
raise ValueError(f"Unsupported provider: {provider_b}")
|
150 |
+
|
151 |
+
# Await both tasks concurrently
|
152 |
+
result_a, result_b = await asyncio.gather(*tasks)
|
153 |
+
|
154 |
+
if not isinstance(result_a, tuple) or len(result_a) != 2:
|
155 |
+
raise ValueError("Expected 2 values from Hume TTS call when generating 1 output")
|
156 |
+
|
157 |
+
generation_id_a, audio_a = result_a[0], result_a[1]
|
158 |
+
generation_id_b, audio_b = result_b[0], result_b[1]
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# Shuffle options so that placement of options in the UI will always be random.
|
161 |
option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
|
|
|
184 |
raise gr.Error("An unexpected error occurred. Please try again later.")
|
185 |
|
186 |
|
187 |
+
async def _vote(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
self,
|
189 |
vote_submitted: bool,
|
190 |
option_map: OptionMap,
|
|
|
215 |
selected_provider = option_map[selected_option]["provider"]
|
216 |
other_provider = option_map[other_option]["provider"]
|
217 |
|
218 |
+
# Process vote in the background without blocking the UI
|
219 |
+
asyncio.create_task(
|
220 |
+
submit_voting_results(
|
|
|
221 |
option_map,
|
222 |
selected_option,
|
223 |
text_modified,
|
224 |
character_description,
|
225 |
text,
|
226 |
+
self.db_session_maker,
|
227 |
+
self.config,
|
228 |
+
)
|
229 |
)
|
|
|
230 |
|
231 |
# Build button text, displaying the provider and voice name, appending the trophy emoji to the selected option.
|
232 |
selected_label = f"{selected_provider} {constants.TROPHY_EMOJI}"
|
src/constants.py
CHANGED
@@ -59,7 +59,7 @@ SAMPLE_CHARACTER_DESCRIPTIONS: dict = {
|
|
59 |
"building tension through perfectly timed pauses and haunting inflections."
|
60 |
),
|
61 |
"🌿 British Naturalist": (
|
62 |
-
"A passionate nature documentarian with a voice that brings the wild to life—crisp, refined "
|
63 |
"tones brimming with wonder and expertise. It shifts seamlessly from hushed observation to "
|
64 |
"animated excitement, painting vivid pictures of the natural world's endless marvels."
|
65 |
),
|
|
|
59 |
"building tension through perfectly timed pauses and haunting inflections."
|
60 |
),
|
61 |
"🌿 British Naturalist": (
|
62 |
+
"A passionate, British nature documentarian with a voice that brings the wild to life—crisp, refined "
|
63 |
"tones brimming with wonder and expertise. It shifts seamlessly from hushed observation to "
|
64 |
"animated excitement, painting vivid pictures of the natural world's endless marvels."
|
65 |
),
|
src/database/database.py
CHANGED
@@ -9,25 +9,22 @@ If no DATABASE_URL environment variable is set, then create a dummy database to
|
|
9 |
"""
|
10 |
|
11 |
# Standard Library Imports
|
12 |
-
from typing import Callable, Optional
|
13 |
|
14 |
# Third-Party Library Imports
|
15 |
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
16 |
from sqlalchemy.orm import DeclarativeBase
|
17 |
|
18 |
# Local Application Imports
|
19 |
-
from src.config import Config
|
20 |
|
21 |
|
22 |
-
# Define the SQLAlchemy Base
|
23 |
class Base(DeclarativeBase):
|
24 |
pass
|
25 |
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
class AsyncDummySession:
|
31 |
is_dummy = True # Flag to indicate this is a dummy session.
|
32 |
|
33 |
async def __enter__(self):
|
@@ -42,11 +39,11 @@ class AsyncDummySession:
|
|
42 |
|
43 |
async def commit(self):
|
44 |
# Raise an exception to simulate failure when attempting a write.
|
45 |
-
raise RuntimeError("
|
46 |
|
47 |
async def refresh(self, _instance):
|
48 |
# Raise an exception to simulate failure when attempting to refresh.
|
49 |
-
raise RuntimeError("
|
50 |
|
51 |
async def rollback(self):
|
52 |
# No-op: there's nothing to roll back.
|
@@ -57,8 +54,8 @@ class AsyncDummySession:
|
|
57 |
pass
|
58 |
|
59 |
|
60 |
-
|
61 |
-
|
62 |
|
63 |
|
64 |
def init_db(config: Config) -> AsyncDBSessionMaker:
|
@@ -88,21 +85,25 @@ def init_db(config: Config) -> AsyncDBSessionMaker:
|
|
88 |
# In production, a valid DATABASE_URL is required.
|
89 |
if not config.database_url:
|
90 |
raise ValueError("DATABASE_URL must be set in production!")
|
|
|
91 |
async_db_url = convert_to_async_url(config.database_url)
|
92 |
engine = create_async_engine(async_db_url)
|
|
|
93 |
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
|
94 |
|
95 |
# In development, if a DATABASE_URL is provided, use it.
|
96 |
if config.database_url:
|
97 |
async_db_url = convert_to_async_url(config.database_url)
|
98 |
engine = create_async_engine(async_db_url)
|
|
|
99 |
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
|
100 |
|
101 |
-
# No DATABASE_URL is provided; use a
|
102 |
engine = None
|
|
|
103 |
|
104 |
-
def async_dummy_session_factory() ->
|
105 |
-
return
|
106 |
|
107 |
return async_dummy_session_factory
|
108 |
|
|
|
9 |
"""
|
10 |
|
11 |
# Standard Library Imports
|
12 |
+
from typing import Callable, Optional, TypeAlias, Union
|
13 |
|
14 |
# Third-Party Library Imports
|
15 |
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
16 |
from sqlalchemy.orm import DeclarativeBase
|
17 |
|
18 |
# Local Application Imports
|
19 |
+
from src.config import Config, logger
|
20 |
|
21 |
|
22 |
+
# Define the SQLAlchemy Base
|
23 |
class Base(DeclarativeBase):
|
24 |
pass
|
25 |
|
26 |
|
27 |
+
class DummyAsyncSession:
|
|
|
|
|
|
|
28 |
is_dummy = True # Flag to indicate this is a dummy session.
|
29 |
|
30 |
async def __enter__(self):
|
|
|
39 |
|
40 |
async def commit(self):
|
41 |
# Raise an exception to simulate failure when attempting a write.
|
42 |
+
raise RuntimeError("DummyAsyncSession does not support commit operations.")
|
43 |
|
44 |
async def refresh(self, _instance):
|
45 |
# Raise an exception to simulate failure when attempting to refresh.
|
46 |
+
raise RuntimeError("DummyAsyncSession does not support refresh operations.")
|
47 |
|
48 |
async def rollback(self):
|
49 |
# No-op: there's nothing to roll back.
|
|
|
54 |
pass
|
55 |
|
56 |
|
57 |
+
AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
|
58 |
+
engine: Optional[AsyncEngine] = None
|
59 |
|
60 |
|
61 |
def init_db(config: Config) -> AsyncDBSessionMaker:
|
|
|
85 |
# In production, a valid DATABASE_URL is required.
|
86 |
if not config.database_url:
|
87 |
raise ValueError("DATABASE_URL must be set in production!")
|
88 |
+
|
89 |
async_db_url = convert_to_async_url(config.database_url)
|
90 |
engine = create_async_engine(async_db_url)
|
91 |
+
|
92 |
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
|
93 |
|
94 |
# In development, if a DATABASE_URL is provided, use it.
|
95 |
if config.database_url:
|
96 |
async_db_url = convert_to_async_url(config.database_url)
|
97 |
engine = create_async_engine(async_db_url)
|
98 |
+
|
99 |
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
|
100 |
|
101 |
+
# No DATABASE_URL is provided; use a DummyAsyncSession that does nothing.
|
102 |
engine = None
|
103 |
+
logger.warning("No DATABASE_URL provided - database operations will use DummyAsyncSession")
|
104 |
|
105 |
+
def async_dummy_session_factory() -> DummyAsyncSession:
|
106 |
+
return DummyAsyncSession()
|
107 |
|
108 |
return async_dummy_session_factory
|
109 |
|
src/integrations/anthropic_api.py
CHANGED
@@ -84,7 +84,8 @@ class AnthropicConfig:
|
|
84 |
from anthropic import AsyncAnthropic # Import the async client from Anthropic SDK
|
85 |
return AsyncAnthropic(api_key=self.api_key)
|
86 |
|
87 |
-
|
|
|
88 |
"""
|
89 |
Constructs and returns a prompt based solely on the provided character description.
|
90 |
The returned prompt is intended to instruct Claude to generate expressive text from a character,
|
@@ -120,6 +121,8 @@ class UnretryableAnthropicError(AnthropicError):
|
|
120 |
|
121 |
def __init__(self, message: str, original_exception: Optional[Exception] = None) -> None:
|
122 |
super().__init__(message, original_exception)
|
|
|
|
|
123 |
|
124 |
|
125 |
@retry(
|
|
|
84 |
from anthropic import AsyncAnthropic # Import the async client from Anthropic SDK
|
85 |
return AsyncAnthropic(api_key=self.api_key)
|
86 |
|
87 |
+
@staticmethod
|
88 |
+
def build_expressive_prompt(character_description: str) -> str:
|
89 |
"""
|
90 |
Constructs and returns a prompt based solely on the provided character description.
|
91 |
The returned prompt is intended to instruct Claude to generate expressive text from a character,
|
|
|
121 |
|
122 |
def __init__(self, message: str, original_exception: Optional[Exception] = None) -> None:
|
123 |
super().__init__(message, original_exception)
|
124 |
+
self.original_exception = original_exception
|
125 |
+
self.message = message
|
126 |
|
127 |
|
128 |
@retry(
|
src/integrations/elevenlabs_api.py
CHANGED
@@ -10,13 +10,6 @@ Key Features:
|
|
10 |
- Handles received audio and processes it for playback on the web.
|
11 |
- Provides detailed logging for debugging and error tracking.
|
12 |
- Utilizes robust error handling (EAFP) to validate API responses.
|
13 |
-
|
14 |
-
Classes:
|
15 |
-
- ElevenLabsConfig: Immutable configuration for interacting with ElevenLabs' TTS API.
|
16 |
-
- ElevenLabsError: Custom exception for ElevenLabs API-related errors.
|
17 |
-
|
18 |
-
Functions:
|
19 |
-
- text_to_speech_with_elevenlabs: Synthesizes speech from text using ElevenLabs' TTS API.
|
20 |
"""
|
21 |
|
22 |
# Standard Library Imports
|
@@ -26,9 +19,9 @@ from dataclasses import dataclass, field
|
|
26 |
from typing import Optional, Tuple
|
27 |
|
28 |
# Third-Party Library Imports
|
29 |
-
from elevenlabs import
|
30 |
from elevenlabs.core import ApiError
|
31 |
-
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
32 |
|
33 |
# Local Application Imports
|
34 |
from src.config import Config, logger
|
@@ -48,19 +41,18 @@ class ElevenLabsConfig:
|
|
48 |
if not self.output_format:
|
49 |
raise ValueError("ElevenLabs TTS API output format is not set.")
|
50 |
|
51 |
-
# Compute the API key from the environment.
|
52 |
computed_key = validate_env_var("ELEVENLABS_API_KEY")
|
53 |
object.__setattr__(self, "api_key", computed_key)
|
54 |
|
55 |
@property
|
56 |
-
def client(self) ->
|
57 |
"""
|
58 |
-
Lazy initialization of the ElevenLabs client.
|
59 |
|
60 |
Returns:
|
61 |
-
|
62 |
"""
|
63 |
-
return
|
64 |
|
65 |
|
66 |
class ElevenLabsError(Exception):
|
@@ -77,42 +69,43 @@ class UnretryableElevenLabsError(ElevenLabsError):
|
|
77 |
|
78 |
def __init__(self, message: str, original_exception: Optional[Exception] = None):
|
79 |
super().__init__(message, original_exception)
|
|
|
|
|
80 |
|
81 |
|
82 |
@retry(
|
|
|
83 |
stop=stop_after_attempt(3),
|
84 |
wait=wait_fixed(2),
|
85 |
before=before_log(logger, logging.DEBUG),
|
86 |
after=after_log(logger, logging.DEBUG),
|
87 |
reraise=True,
|
88 |
)
|
89 |
-
def text_to_speech_with_elevenlabs(
|
90 |
character_description: str, text: str, config: Config
|
91 |
) -> Tuple[None, str]:
|
92 |
"""
|
93 |
-
|
94 |
|
95 |
Args:
|
96 |
-
character_description (str): The character description used
|
97 |
text (str): The text to be synthesized into speech.
|
|
|
98 |
|
99 |
Returns:
|
100 |
Tuple[None, str]: A tuple containing:
|
101 |
-
- generation_id (None):
|
102 |
-
|
103 |
-
- file_path (str): The relative file path to the audio file where the synthesized speech was saved.
|
104 |
|
105 |
Raises:
|
106 |
ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
|
107 |
"""
|
108 |
-
|
109 |
logger.debug(f"Synthesizing speech with ElevenLabs. Text length: {len(text)} characters.")
|
110 |
-
|
111 |
elevenlabs_config = config.elevenlabs_config
|
112 |
|
113 |
try:
|
114 |
# Synthesize speech using the ElevenLabs SDK
|
115 |
-
response = elevenlabs_config.client.text_to_voice.create_previews(
|
116 |
voice_description=character_description,
|
117 |
text=text,
|
118 |
output_format=elevenlabs_config.output_format,
|
@@ -129,9 +122,10 @@ def text_to_speech_with_elevenlabs(
|
|
129 |
generated_voice_id = preview.generated_voice_id
|
130 |
base64_audio = preview.audio_base_64
|
131 |
filename = f"{generated_voice_id}.mp3"
|
132 |
-
audio_file_path = save_base64_audio_to_file(base64_audio, filename, config)
|
133 |
|
134 |
# Write audio to file and return the relative path
|
|
|
|
|
135 |
return None, audio_file_path
|
136 |
|
137 |
except Exception as e:
|
|
|
10 |
- Handles received audio and processes it for playback on the web.
|
11 |
- Provides detailed logging for debugging and error tracking.
|
12 |
- Utilizes robust error handling (EAFP) to validate API responses.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"""
|
14 |
|
15 |
# Standard Library Imports
|
|
|
19 |
from typing import Optional, Tuple
|
20 |
|
21 |
# Third-Party Library Imports
|
22 |
+
from elevenlabs import AsyncElevenLabs, TextToVoiceCreatePreviewsRequestOutputFormat
|
23 |
from elevenlabs.core import ApiError
|
24 |
+
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
25 |
|
26 |
# Local Application Imports
|
27 |
from src.config import Config, logger
|
|
|
41 |
if not self.output_format:
|
42 |
raise ValueError("ElevenLabs TTS API output format is not set.")
|
43 |
|
|
|
44 |
computed_key = validate_env_var("ELEVENLABS_API_KEY")
|
45 |
object.__setattr__(self, "api_key", computed_key)
|
46 |
|
47 |
@property
|
48 |
+
def client(self) -> AsyncElevenLabs:
|
49 |
"""
|
50 |
+
Lazy initialization of the asynchronous ElevenLabs client.
|
51 |
|
52 |
Returns:
|
53 |
+
AsyncElevenLabs: Configured async client instance.
|
54 |
"""
|
55 |
+
return AsyncElevenLabs(api_key=self.api_key)
|
56 |
|
57 |
|
58 |
class ElevenLabsError(Exception):
|
|
|
69 |
|
70 |
def __init__(self, message: str, original_exception: Optional[Exception] = None):
|
71 |
super().__init__(message, original_exception)
|
72 |
+
self.original_exception = original_exception
|
73 |
+
self.message = message
|
74 |
|
75 |
|
76 |
@retry(
|
77 |
+
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableElevenLabsError)),
|
78 |
stop=stop_after_attempt(3),
|
79 |
wait=wait_fixed(2),
|
80 |
before=before_log(logger, logging.DEBUG),
|
81 |
after=after_log(logger, logging.DEBUG),
|
82 |
reraise=True,
|
83 |
)
|
84 |
+
async def text_to_speech_with_elevenlabs(
|
85 |
character_description: str, text: str, config: Config
|
86 |
) -> Tuple[None, str]:
|
87 |
"""
|
88 |
+
Asynchronously synthesizes speech using the ElevenLabs TTS API, processes the audio data, and writes it to a file.
|
89 |
|
90 |
Args:
|
91 |
+
character_description (str): The character description used for voice synthesis.
|
92 |
text (str): The text to be synthesized into speech.
|
93 |
+
config (Config): Application configuration containing ElevenLabs API settings.
|
94 |
|
95 |
Returns:
|
96 |
Tuple[None, str]: A tuple containing:
|
97 |
+
- generation_id (None): A placeholder (no generation ID is returned).
|
98 |
+
- file_path (str): The relative file path to the saved audio file.
|
|
|
99 |
|
100 |
Raises:
|
101 |
ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
|
102 |
"""
|
|
|
103 |
logger.debug(f"Synthesizing speech with ElevenLabs. Text length: {len(text)} characters.")
|
|
|
104 |
elevenlabs_config = config.elevenlabs_config
|
105 |
|
106 |
try:
|
107 |
# Synthesize speech using the ElevenLabs SDK
|
108 |
+
response = await elevenlabs_config.client.text_to_voice.create_previews(
|
109 |
voice_description=character_description,
|
110 |
text=text,
|
111 |
output_format=elevenlabs_config.output_format,
|
|
|
122 |
generated_voice_id = preview.generated_voice_id
|
123 |
base64_audio = preview.audio_base_64
|
124 |
filename = f"{generated_voice_id}.mp3"
|
|
|
125 |
|
126 |
# Write audio to file and return the relative path
|
127 |
+
audio_file_path = save_base64_audio_to_file(base64_audio, filename, config)
|
128 |
+
|
129 |
return None, audio_file_path
|
130 |
|
131 |
except Exception as e:
|
src/integrations/hume_api.py
CHANGED
@@ -9,13 +9,6 @@ Key Features:
|
|
9 |
- Implements retry logic for handling transient API errors.
|
10 |
- Handles received audio and processes it for playback on the web.
|
11 |
- Provides detailed logging for debugging and error tracking.
|
12 |
-
|
13 |
-
Classes:
|
14 |
-
- HumeConfig: Immutable configuration for interacting with Hume's TTS API.
|
15 |
-
- HumeError: Custom exception for Hume API-related errors.
|
16 |
-
|
17 |
-
Functions:
|
18 |
-
- text_to_speech_with_hume: Synthesizes speech from text using Hume's TTS API.
|
19 |
"""
|
20 |
|
21 |
# Standard Library Imports
|
@@ -24,9 +17,8 @@ from dataclasses import dataclass, field
|
|
24 |
from typing import Any, Dict, Literal, Tuple, Union
|
25 |
|
26 |
# Third-Party Library Imports
|
27 |
-
import
|
28 |
-
from
|
29 |
-
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
30 |
|
31 |
# Local Application Imports
|
32 |
from src.config import Config, logger
|
@@ -41,12 +33,9 @@ HumeSupportedFileFormat = Literal["mp3", "pcm", "wav"]
|
|
41 |
class HumeConfig:
|
42 |
"""Immutable configuration for interacting with the Hume TTS API."""
|
43 |
|
44 |
-
# Computed fields.
|
45 |
api_key: str = field(init=False)
|
46 |
headers: Dict[str, str] = field(init=False)
|
47 |
-
|
48 |
-
# Provided fields.
|
49 |
-
url: str = "https://test-api.hume.ai/v0/tts/octave"
|
50 |
file_format: HumeSupportedFileFormat = "mp3"
|
51 |
|
52 |
def __post_init__(self) -> None:
|
@@ -56,11 +45,8 @@ class HumeConfig:
|
|
56 |
if not self.file_format:
|
57 |
raise ValueError("Hume TTS file format is not set.")
|
58 |
|
59 |
-
# Compute the API key from the environment.
|
60 |
computed_api_key = validate_env_var("HUME_API_KEY")
|
61 |
object.__setattr__(self, "api_key", computed_api_key)
|
62 |
-
|
63 |
-
# Compute the headers.
|
64 |
computed_headers = {
|
65 |
"X-Hume-Api-Key": f"{computed_api_key}",
|
66 |
"Content-Type": "application/json",
|
@@ -83,38 +69,36 @@ class UnretryableHumeError(HumeError):
|
|
83 |
def __init__(self, message: str, original_exception: Union[Exception, None] = None):
|
84 |
super().__init__(message, original_exception)
|
85 |
self.original_exception = original_exception
|
|
|
86 |
|
87 |
|
88 |
@retry(
|
|
|
89 |
stop=stop_after_attempt(3),
|
90 |
wait=wait_fixed(2),
|
91 |
before=before_log(logger, logging.DEBUG),
|
92 |
after=after_log(logger, logging.DEBUG),
|
93 |
reraise=True,
|
94 |
)
|
95 |
-
def text_to_speech_with_hume(
|
96 |
character_description: str,
|
97 |
text: str,
|
98 |
num_generations: int,
|
99 |
config: Config,
|
100 |
) -> Union[Tuple[str, str], Tuple[str, str, str, str]]:
|
101 |
"""
|
102 |
-
|
103 |
|
104 |
-
This function sends a POST request to the Hume TTS API with a character description and text
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
and returns the relevant details.
|
109 |
|
110 |
Args:
|
111 |
-
character_description (str):
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
Allowed values are 1 or 2. If 1, only a single generation is processed; if 2, a second
|
116 |
-
generation is expected in the API response.
|
117 |
-
config (Config): The application configuration containing Hume API settings.
|
118 |
|
119 |
Returns:
|
120 |
Union[Tuple[str, str], Tuple[str, str, str, str]]:
|
@@ -123,15 +107,13 @@ def text_to_speech_with_hume(
|
|
123 |
|
124 |
Raises:
|
125 |
ValueError: If num_generations is not 1 or 2.
|
126 |
-
HumeError:
|
127 |
-
UnretryableHumeError:
|
128 |
-
Exception: Any other exceptions raised during the request or processing will be wrapped and
|
129 |
-
re-raised as HumeError.
|
130 |
"""
|
131 |
-
|
132 |
logger.debug(
|
133 |
-
|
134 |
-
f"
|
|
|
135 |
)
|
136 |
|
137 |
if num_generations < 1 or num_generations > 2:
|
@@ -145,14 +127,15 @@ def text_to_speech_with_hume(
|
|
145 |
}
|
146 |
|
147 |
try:
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
156 |
|
157 |
generations = response_data.get("generations")
|
158 |
if not generations:
|
@@ -160,7 +143,6 @@ def text_to_speech_with_hume(
|
|
160 |
logger.error(msg)
|
161 |
raise HumeError(msg)
|
162 |
|
163 |
-
# Extract the base64 encoded audio and generation ID from the generation.
|
164 |
generation_a = generations[0]
|
165 |
generation_a_id, audio_a_path = _parse_hume_tts_generation(generation_a, config)
|
166 |
|
@@ -171,48 +153,51 @@ def text_to_speech_with_hume(
|
|
171 |
generation_b_id, audio_b_path = _parse_hume_tts_generation(generation_b, config)
|
172 |
return (generation_a_id, audio_a_path, generation_b_id, audio_b_path)
|
173 |
|
174 |
-
except
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
)
|
|
|
|
|
|
|
|
|
|
|
180 |
raise UnretryableHumeError(
|
181 |
-
message=
|
182 |
original_exception=e,
|
183 |
) from e
|
184 |
-
|
|
|
185 |
raise HumeError(
|
186 |
-
message=
|
187 |
original_exception=e,
|
188 |
) from e
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def _parse_hume_tts_generation(generation: Dict[str, Any], config: Config) -> Tuple[str, str]:
|
192 |
"""
|
193 |
-
|
194 |
-
|
195 |
-
This function extracts the generation ID and the base64-encoded audio from the provided
|
196 |
-
dictionary. It then decodes and saves the audio data to an MP3 file, naming the file using
|
197 |
-
the generation ID. Finally, it returns a tuple containing the generation ID and the file path
|
198 |
-
of the saved audio.
|
199 |
|
200 |
Args:
|
201 |
-
generation (Dict[str, Any]):
|
202 |
-
|
203 |
-
- "generation_id" (str): A unique identifier for the generated audio.
|
204 |
-
- "audio" (str): A base64 encoded string of the audio data.
|
205 |
-
config (Config): The application configuration used for saving the audio file.
|
206 |
|
207 |
Returns:
|
208 |
-
Tuple[str, str]:
|
209 |
-
- generation_id (str): The unique identifier for the audio generation.
|
210 |
-
- audio_path (str): The filesystem path where the audio file was saved.
|
211 |
|
212 |
Raises:
|
213 |
-
KeyError: If
|
214 |
-
Exception: Propagates
|
215 |
-
the decoding or file saving process.
|
216 |
"""
|
217 |
generation_id = generation.get("generation_id")
|
218 |
if generation_id is None:
|
|
|
9 |
- Implements retry logic for handling transient API errors.
|
10 |
- Handles received audio and processes it for playback on the web.
|
11 |
- Provides detailed logging for debugging and error tracking.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
|
14 |
# Standard Library Imports
|
|
|
17 |
from typing import Any, Dict, Literal, Tuple, Union
|
18 |
|
19 |
# Third-Party Library Imports
|
20 |
+
import httpx
|
21 |
+
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
|
|
22 |
|
23 |
# Local Application Imports
|
24 |
from src.config import Config, logger
|
|
|
33 |
class HumeConfig:
|
34 |
"""Immutable configuration for interacting with the Hume TTS API."""
|
35 |
|
|
|
36 |
api_key: str = field(init=False)
|
37 |
headers: Dict[str, str] = field(init=False)
|
38 |
+
url: str = "https://api.hume.ai/v0/tts/octave"
|
|
|
|
|
39 |
file_format: HumeSupportedFileFormat = "mp3"
|
40 |
|
41 |
def __post_init__(self) -> None:
|
|
|
45 |
if not self.file_format:
|
46 |
raise ValueError("Hume TTS file format is not set.")
|
47 |
|
|
|
48 |
computed_api_key = validate_env_var("HUME_API_KEY")
|
49 |
object.__setattr__(self, "api_key", computed_api_key)
|
|
|
|
|
50 |
computed_headers = {
|
51 |
"X-Hume-Api-Key": f"{computed_api_key}",
|
52 |
"Content-Type": "application/json",
|
|
|
69 |
def __init__(self, message: str, original_exception: Union[Exception, None] = None):
|
70 |
super().__init__(message, original_exception)
|
71 |
self.original_exception = original_exception
|
72 |
+
self.message = message
|
73 |
|
74 |
|
75 |
@retry(
|
76 |
+
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableHumeError)),
|
77 |
stop=stop_after_attempt(3),
|
78 |
wait=wait_fixed(2),
|
79 |
before=before_log(logger, logging.DEBUG),
|
80 |
after=after_log(logger, logging.DEBUG),
|
81 |
reraise=True,
|
82 |
)
|
83 |
+
async def text_to_speech_with_hume(
|
84 |
character_description: str,
|
85 |
text: str,
|
86 |
num_generations: int,
|
87 |
config: Config,
|
88 |
) -> Union[Tuple[str, str], Tuple[str, str, str, str]]:
|
89 |
"""
|
90 |
+
Asynchronously synthesizes speech using the Hume TTS API, processes audio data, and writes audio to a file.
|
91 |
|
92 |
+
This function sends a POST request to the Hume TTS API with a character description and text to be converted to
|
93 |
+
speech. Depending on the specified number of generations (1 or 2), the API returns one or two generations.
|
94 |
+
For each generation, the function extracts the base64-encoded audio and generation ID, saves the audio as an MP3
|
95 |
+
file, and returns the relevant details.
|
|
|
96 |
|
97 |
Args:
|
98 |
+
character_description (str): Description used for voice synthesis.
|
99 |
+
text (str): Text to be converted to speech.
|
100 |
+
num_generations (int): Number of audio generations to request (1 or 2).
|
101 |
+
config (Config): Application configuration containing Hume API settings.
|
|
|
|
|
|
|
102 |
|
103 |
Returns:
|
104 |
Union[Tuple[str, str], Tuple[str, str, str, str]]:
|
|
|
107 |
|
108 |
Raises:
|
109 |
ValueError: If num_generations is not 1 or 2.
|
110 |
+
HumeError: For errors communicating with the Hume API.
|
111 |
+
UnretryableHumeError: For client-side HTTP errors (status code 4xx).
|
|
|
|
|
112 |
"""
|
|
|
113 |
logger.debug(
|
114 |
+
"Processing TTS with Hume. "
|
115 |
+
f"Character description length: {len(character_description)}. "
|
116 |
+
f"Text length: {len(text)}."
|
117 |
)
|
118 |
|
119 |
if num_generations < 1 or num_generations > 2:
|
|
|
127 |
}
|
128 |
|
129 |
try:
|
130 |
+
async with httpx.AsyncClient() as client:
|
131 |
+
response = await client.post(
|
132 |
+
url=hume_config.url,
|
133 |
+
headers=hume_config.headers,
|
134 |
+
json=request_body,
|
135 |
+
timeout=30.0,
|
136 |
+
)
|
137 |
+
response.raise_for_status()
|
138 |
+
response_data = response.json()
|
139 |
|
140 |
generations = response_data.get("generations")
|
141 |
if not generations:
|
|
|
143 |
logger.error(msg)
|
144 |
raise HumeError(msg)
|
145 |
|
|
|
146 |
generation_a = generations[0]
|
147 |
generation_a_id, audio_a_path = _parse_hume_tts_generation(generation_a, config)
|
148 |
|
|
|
153 |
generation_b_id, audio_b_path = _parse_hume_tts_generation(generation_b, config)
|
154 |
return (generation_a_id, audio_a_path, generation_b_id, audio_b_path)
|
155 |
|
156 |
+
except httpx.ReadTimeout as e:
|
157 |
+
# Handle timeout specifically
|
158 |
+
raise HumeError(
|
159 |
+
message="Request to Hume API timed out. Please try again later.",
|
160 |
+
original_exception=e,
|
161 |
+
) from e
|
162 |
+
|
163 |
+
except httpx.HTTPStatusError as e:
|
164 |
+
if e.response is not None and CLIENT_ERROR_CODE <= e.response.status_code < SERVER_ERROR_CODE:
|
165 |
+
error_message = f"HTTP Error {e.response.status_code}: {e.response.text}"
|
166 |
+
logger.error(error_message)
|
167 |
raise UnretryableHumeError(
|
168 |
+
message=error_message,
|
169 |
original_exception=e,
|
170 |
) from e
|
171 |
+
error_message = f"HTTP Error {e.response.status_code if e.response else 'unknown'}"
|
172 |
+
logger.error(error_message)
|
173 |
raise HumeError(
|
174 |
+
message=error_message,
|
175 |
original_exception=e,
|
176 |
) from e
|
177 |
|
178 |
+
except Exception as e:
|
179 |
+
error_type = type(e).__name__
|
180 |
+
error_message = str(e) if str(e) else f"An error of type {error_type} occurred"
|
181 |
+
logger.error("Error during Hume API call: %s - %s", error_type, error_message)
|
182 |
+
raise HumeError(
|
183 |
+
message=error_message,
|
184 |
+
original_exception=e,
|
185 |
+
) from e
|
186 |
|
187 |
def _parse_hume_tts_generation(generation: Dict[str, Any], config: Config) -> Tuple[str, str]:
|
188 |
"""
|
189 |
+
Parses a Hume TTS generation response and saves the decoded audio as an MP3 file.
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
Args:
|
192 |
+
generation (Dict[str, Any]): TTS generation response containing 'generation_id' and 'audio'.
|
193 |
+
config (Config): Application configuration for saving the audio file.
|
|
|
|
|
|
|
194 |
|
195 |
Returns:
|
196 |
+
Tuple[str, str]: (generation_id, audio_path)
|
|
|
|
|
197 |
|
198 |
Raises:
|
199 |
+
KeyError: If expected keys are missing.
|
200 |
+
Exception: Propagates exceptions from saving the audio file.
|
|
|
201 |
"""
|
202 |
generation_id = generation.get("generation_id")
|
203 |
if generation_id is None:
|
src/main.py
CHANGED
@@ -4,16 +4,26 @@ main.py
|
|
4 |
This module is the entry point for the app. It loads configuration and starts the Gradio app.
|
5 |
"""
|
6 |
|
|
|
|
|
|
|
7 |
# Local Application Imports
|
8 |
from src.app import App
|
9 |
from src.config import Config, logger
|
10 |
from src.database import init_db
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
config = Config.get()
|
14 |
logger.info("Launching TTS Arena Gradio app...")
|
15 |
db_session_maker = init_db(config)
|
16 |
app = App(config, db_session_maker)
|
17 |
demo = app.build_gradio_interface()
|
18 |
-
init_db(config)
|
19 |
demo.launch(server_name="0.0.0.0", allowed_paths=[str(config.audio_dir)])
|
|
|
|
|
|
|
|
|
|
4 |
This module is the entry point for the app. It loads configuration and starts the Gradio app.
|
5 |
"""
|
6 |
|
7 |
+
# Standard Library Imports
|
8 |
+
import asyncio
|
9 |
+
|
10 |
# Local Application Imports
|
11 |
from src.app import App
|
12 |
from src.config import Config, logger
|
13 |
from src.database import init_db
|
14 |
|
15 |
+
|
16 |
+
async def main():
|
17 |
+
"""
|
18 |
+
Asynchronous main function to initialize the application.
|
19 |
+
"""
|
20 |
config = Config.get()
|
21 |
logger.info("Launching TTS Arena Gradio app...")
|
22 |
db_session_maker = init_db(config)
|
23 |
app = App(config, db_session_maker)
|
24 |
demo = app.build_gradio_interface()
|
|
|
25 |
demo.launch(server_name="0.0.0.0", allowed_paths=[str(config.audio_dir)])
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
asyncio.run(main())
|
uv.lock
CHANGED
@@ -262,7 +262,6 @@ dependencies = [
|
|
262 |
{ name = "greenlet" },
|
263 |
{ name = "httpx" },
|
264 |
{ name = "python-dotenv" },
|
265 |
-
{ name = "requests" },
|
266 |
{ name = "sqlalchemy" },
|
267 |
{ name = "tenacity" },
|
268 |
]
|
@@ -287,7 +286,6 @@ requires-dist = [
|
|
287 |
{ name = "greenlet", specifier = ">=2.0.0" },
|
288 |
{ name = "httpx", specifier = ">=0.24.1" },
|
289 |
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
290 |
-
{ name = "requests", specifier = ">=2.32.3" },
|
291 |
{ name = "sqlalchemy", specifier = ">=2.0.0" },
|
292 |
{ name = "tenacity", specifier = ">=9.0.0" },
|
293 |
]
|
|
|
262 |
{ name = "greenlet" },
|
263 |
{ name = "httpx" },
|
264 |
{ name = "python-dotenv" },
|
|
|
265 |
{ name = "sqlalchemy" },
|
266 |
{ name = "tenacity" },
|
267 |
]
|
|
|
286 |
{ name = "greenlet", specifier = ">=2.0.0" },
|
287 |
{ name = "httpx", specifier = ">=0.24.1" },
|
288 |
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
|
|
289 |
{ name = "sqlalchemy", specifier = ">=2.0.0" },
|
290 |
{ name = "tenacity", specifier = ">=9.0.0" },
|
291 |
]
|