zach commited on
Commit
104737f
·
1 Parent(s): 80026d8

Update API integration code to be async

Browse files
pyproject.toml CHANGED
@@ -12,7 +12,6 @@ dependencies = [
12
  "greenlet>=2.0.0",
13
  "httpx>=0.24.1",
14
  "python-dotenv>=1.0.1",
15
- "requests>=2.32.3",
16
  "sqlalchemy>=2.0.0",
17
  "tenacity>=9.0.0",
18
  ]
@@ -45,6 +44,7 @@ ignore = [
45
  "PLR0912",
46
  "PLR0913",
47
  "PLR2004",
 
48
  "TD002",
49
  "TD003",
50
  ]
 
12
  "greenlet>=2.0.0",
13
  "httpx>=0.24.1",
14
  "python-dotenv>=1.0.1",
 
15
  "sqlalchemy>=2.0.0",
16
  "tenacity>=9.0.0",
17
  ]
 
44
  "PLR0912",
45
  "PLR0913",
46
  "PLR2004",
47
+ "RUF006",
48
  "TD002",
49
  "TD003",
50
  ]
src/app.py CHANGED
@@ -10,9 +10,7 @@ Users can compare the outputs and vote for their favorite in an interactive UI.
10
 
11
  # Standard Library Imports
12
  import asyncio
13
- import threading
14
  import time
15
- from concurrent.futures import ThreadPoolExecutor
16
  from typing import Tuple
17
 
18
  # Third-Party Library Imports
@@ -83,7 +81,7 @@ class App:
83
  logger.error(f"Unexpected error while generating text: {e}")
84
  raise gr.Error("Failed to generate text. Please try again later.")
85
 
86
- def _synthesize_speech(
87
  self,
88
  character_description: str,
89
  text: str,
@@ -130,38 +128,34 @@ class App:
130
  if provider_b == constants.HUME_AI:
131
  num_generations = 2
132
  # If generating 2 Hume outputs, do so in a single API call.
133
- result = text_to_speech_with_hume(character_description, text, num_generations, self.config)
134
  # Enforce that 4 values are returned.
135
  if not (isinstance(result, tuple) and len(result) == 4):
136
  raise ValueError("Expected 4 values from Hume TTS call when generating 2 outputs")
137
  generation_id_a, audio_a, generation_id_b, audio_b = result
138
  else:
139
- with ThreadPoolExecutor(max_workers=2) as executor:
140
- num_generations = 1
141
- # Generate a single Hume output.
142
- future_audio_a = executor.submit(
143
- text_to_speech_with_hume, character_description, text, num_generations, self.config
144
- )
145
- # Generate a second TTS output from the second provider.
146
- match provider_b:
147
- case constants.ELEVENLABS:
148
- future_audio_b = executor.submit(
149
- text_to_speech_with_elevenlabs, character_description, text, self.config
150
- )
151
- case _:
152
- # Additional TTS Providers can be added here.
153
- raise ValueError(f"Unsupported provider: {provider_b}")
154
-
155
- result_a = future_audio_a.result()
156
- result_b = future_audio_b.result()
157
- if isinstance(result_a, tuple) and len(result_a) >= 2:
158
- generation_id_a, audio_a = result_a[0], result_a[1]
159
- else:
160
- raise ValueError("Unexpected return from text_to_speech_with_hume")
161
- if isinstance(result_b, tuple) and len(result_b) >= 2:
162
- generation_id_b, audio_b = result_b[0], result_b[1] # type: ignore
163
- else:
164
- raise ValueError("Unexpected return from text_to_speech_with_elevenlabs")
165
 
166
  # Shuffle options so that placement of options in the UI will always be random.
167
  option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
@@ -190,47 +184,7 @@ class App:
190
  raise gr.Error("An unexpected error occurred. Please try again later.")
191
 
192
 
193
- def _background_submit_vote(
194
- self,
195
- option_map: OptionMap,
196
- selected_option: constants.OptionKey,
197
- text_modified: bool,
198
- character_description: str,
199
- text: str,
200
- ) -> None:
201
- """
202
- Runs the vote submission in a background thread.
203
- Creates a new event loop and runs the async submit_voting_results function in it.
204
-
205
- Args:
206
- Same as submit_voting_results
207
-
208
- Returns:
209
- None
210
- """
211
- try:
212
- # Create a new event loop for this thread
213
- loop = asyncio.new_event_loop()
214
- asyncio.set_event_loop(loop)
215
-
216
- # Run the async function in the new loop
217
- loop.run_until_complete(submit_voting_results(
218
- option_map,
219
- selected_option,
220
- text_modified,
221
- character_description,
222
- text,
223
- self.db_session_maker,
224
- self.config,
225
- ))
226
- except Exception as e:
227
- logger.error(f"Error in background vote submission thread: {e}", exc_info=True)
228
- finally:
229
- # Close the loop when done
230
- loop.close()
231
-
232
-
233
- def _vote(
234
  self,
235
  vote_submitted: bool,
236
  option_map: OptionMap,
@@ -261,19 +215,18 @@ class App:
261
  selected_provider = option_map[selected_option]["provider"]
262
  other_provider = option_map[other_option]["provider"]
263
 
264
- # Start a background thread for the database operation
265
- thread = threading.Thread(
266
- target=self._background_submit_vote,
267
- args=(
268
  option_map,
269
  selected_option,
270
  text_modified,
271
  character_description,
272
  text,
273
- ),
274
- daemon=True
 
275
  )
276
- thread.start()
277
 
278
  # Build button text, displaying the provider and voice name, appending the trophy emoji to the selected option.
279
  selected_label = f"{selected_provider} {constants.TROPHY_EMOJI}"
 
10
 
11
  # Standard Library Imports
12
  import asyncio
 
13
  import time
 
14
  from typing import Tuple
15
 
16
  # Third-Party Library Imports
 
81
  logger.error(f"Unexpected error while generating text: {e}")
82
  raise gr.Error("Failed to generate text. Please try again later.")
83
 
84
+ async def _synthesize_speech(
85
  self,
86
  character_description: str,
87
  text: str,
 
128
  if provider_b == constants.HUME_AI:
129
  num_generations = 2
130
  # If generating 2 Hume outputs, do so in a single API call.
131
+ result = await text_to_speech_with_hume(character_description, text, num_generations, self.config)
132
  # Enforce that 4 values are returned.
133
  if not (isinstance(result, tuple) and len(result) == 4):
134
  raise ValueError("Expected 4 values from Hume TTS call when generating 2 outputs")
135
  generation_id_a, audio_a, generation_id_b, audio_b = result
136
  else:
137
+ num_generations = 1
138
+ # Run both API calls concurrently using asyncio
139
+ tasks = []
140
+ # Generate a single Hume output
141
+ tasks.append(text_to_speech_with_hume(character_description, text, num_generations, self.config))
142
+
143
+ # Generate a second TTS output from the second provider
144
+ match provider_b:
145
+ case constants.ELEVENLABS:
146
+ tasks.append(text_to_speech_with_elevenlabs(character_description, text, self.config))
147
+ case _:
148
+ # Additional TTS Providers can be added here.
149
+ raise ValueError(f"Unsupported provider: {provider_b}")
150
+
151
+ # Await both tasks concurrently
152
+ result_a, result_b = await asyncio.gather(*tasks)
153
+
154
+ if not isinstance(result_a, tuple) or len(result_a) != 2:
155
+ raise ValueError("Expected 2 values from Hume TTS call when generating 1 output")
156
+
157
+ generation_id_a, audio_a = result_a[0], result_a[1]
158
+ generation_id_b, audio_b = result_b[0], result_b[1]
 
 
 
 
159
 
160
  # Shuffle options so that placement of options in the UI will always be random.
161
  option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
 
184
  raise gr.Error("An unexpected error occurred. Please try again later.")
185
 
186
 
187
+ async def _vote(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  self,
189
  vote_submitted: bool,
190
  option_map: OptionMap,
 
215
  selected_provider = option_map[selected_option]["provider"]
216
  other_provider = option_map[other_option]["provider"]
217
 
218
+ # Process vote in the background without blocking the UI
219
+ asyncio.create_task(
220
+ submit_voting_results(
 
221
  option_map,
222
  selected_option,
223
  text_modified,
224
  character_description,
225
  text,
226
+ self.db_session_maker,
227
+ self.config,
228
+ )
229
  )
 
230
 
231
  # Build button text, displaying the provider and voice name, appending the trophy emoji to the selected option.
232
  selected_label = f"{selected_provider} {constants.TROPHY_EMOJI}"
src/constants.py CHANGED
@@ -59,7 +59,7 @@ SAMPLE_CHARACTER_DESCRIPTIONS: dict = {
59
  "building tension through perfectly timed pauses and haunting inflections."
60
  ),
61
  "🌿 British Naturalist": (
62
- "A passionate nature documentarian with a voice that brings the wild to life—crisp, refined "
63
  "tones brimming with wonder and expertise. It shifts seamlessly from hushed observation to "
64
  "animated excitement, painting vivid pictures of the natural world's endless marvels."
65
  ),
 
59
  "building tension through perfectly timed pauses and haunting inflections."
60
  ),
61
  "🌿 British Naturalist": (
62
+ "A passionate, British nature documentarian with a voice that brings the wild to life—crisp, refined "
63
  "tones brimming with wonder and expertise. It shifts seamlessly from hushed observation to "
64
  "animated excitement, painting vivid pictures of the natural world's endless marvels."
65
  ),
src/database/database.py CHANGED
@@ -9,25 +9,22 @@ If no DATABASE_URL environment variable is set, then create a dummy database to
9
  """
10
 
11
  # Standard Library Imports
12
- from typing import Callable, Optional
13
 
14
  # Third-Party Library Imports
15
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
16
  from sqlalchemy.orm import DeclarativeBase
17
 
18
  # Local Application Imports
19
- from src.config import Config
20
 
21
 
22
- # Define the SQLAlchemy Base using SQLAlchemy 2.0 style.
23
  class Base(DeclarativeBase):
24
  pass
25
 
26
 
27
- engine: Optional[AsyncEngine] = None
28
-
29
-
30
- class AsyncDummySession:
31
  is_dummy = True # Flag to indicate this is a dummy session.
32
 
33
  async def __enter__(self):
@@ -42,11 +39,11 @@ class AsyncDummySession:
42
 
43
  async def commit(self):
44
  # Raise an exception to simulate failure when attempting a write.
45
- raise RuntimeError("DummySession does not support commit operations.")
46
 
47
  async def refresh(self, _instance):
48
  # Raise an exception to simulate failure when attempting to refresh.
49
- raise RuntimeError("DummySession does not support refresh operations.")
50
 
51
  async def rollback(self):
52
  # No-op: there's nothing to roll back.
@@ -57,8 +54,8 @@ class AsyncDummySession:
57
  pass
58
 
59
 
60
- # AsyncDBSessionMaker is either a async_sessionmaker instance or a callable that returns a AsyncDummySession.
61
- AsyncDBSessionMaker = async_sessionmaker | Callable[[], AsyncDummySession]
62
 
63
 
64
  def init_db(config: Config) -> AsyncDBSessionMaker:
@@ -88,21 +85,25 @@ def init_db(config: Config) -> AsyncDBSessionMaker:
88
  # In production, a valid DATABASE_URL is required.
89
  if not config.database_url:
90
  raise ValueError("DATABASE_URL must be set in production!")
 
91
  async_db_url = convert_to_async_url(config.database_url)
92
  engine = create_async_engine(async_db_url)
 
93
  return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
94
 
95
  # In development, if a DATABASE_URL is provided, use it.
96
  if config.database_url:
97
  async_db_url = convert_to_async_url(config.database_url)
98
  engine = create_async_engine(async_db_url)
 
99
  return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
100
 
101
- # No DATABASE_URL is provided; use a DummySession that does nothing.
102
  engine = None
 
103
 
104
- def async_dummy_session_factory() -> AsyncDummySession:
105
- return AsyncDummySession()
106
 
107
  return async_dummy_session_factory
108
 
 
9
  """
10
 
11
  # Standard Library Imports
12
+ from typing import Callable, Optional, TypeAlias, Union
13
 
14
  # Third-Party Library Imports
15
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
16
  from sqlalchemy.orm import DeclarativeBase
17
 
18
  # Local Application Imports
19
+ from src.config import Config, logger
20
 
21
 
22
+ # Define the SQLAlchemy Base
23
  class Base(DeclarativeBase):
24
  pass
25
 
26
 
27
+ class DummyAsyncSession:
 
 
 
28
  is_dummy = True # Flag to indicate this is a dummy session.
29
 
30
  async def __enter__(self):
 
39
 
40
  async def commit(self):
41
  # Raise an exception to simulate failure when attempting a write.
42
+ raise RuntimeError("DummyAsyncSession does not support commit operations.")
43
 
44
  async def refresh(self, _instance):
45
  # Raise an exception to simulate failure when attempting to refresh.
46
+ raise RuntimeError("DummyAsyncSession does not support refresh operations.")
47
 
48
  async def rollback(self):
49
  # No-op: there's nothing to roll back.
 
54
  pass
55
 
56
 
57
+ AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
58
+ engine: Optional[AsyncEngine] = None
59
 
60
 
61
  def init_db(config: Config) -> AsyncDBSessionMaker:
 
85
  # In production, a valid DATABASE_URL is required.
86
  if not config.database_url:
87
  raise ValueError("DATABASE_URL must be set in production!")
88
+
89
  async_db_url = convert_to_async_url(config.database_url)
90
  engine = create_async_engine(async_db_url)
91
+
92
  return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
93
 
94
  # In development, if a DATABASE_URL is provided, use it.
95
  if config.database_url:
96
  async_db_url = convert_to_async_url(config.database_url)
97
  engine = create_async_engine(async_db_url)
98
+
99
  return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
100
 
101
+ # No DATABASE_URL is provided; use a DummyAsyncSession that does nothing.
102
  engine = None
103
+ logger.warning("No DATABASE_URL provided - database operations will use DummyAsyncSession")
104
 
105
+ def async_dummy_session_factory() -> DummyAsyncSession:
106
+ return DummyAsyncSession()
107
 
108
  return async_dummy_session_factory
109
 
src/integrations/anthropic_api.py CHANGED
@@ -84,7 +84,8 @@ class AnthropicConfig:
84
  from anthropic import AsyncAnthropic # Import the async client from Anthropic SDK
85
  return AsyncAnthropic(api_key=self.api_key)
86
 
87
- def build_expressive_prompt(self, character_description: str) -> str:
 
88
  """
89
  Constructs and returns a prompt based solely on the provided character description.
90
  The returned prompt is intended to instruct Claude to generate expressive text from a character,
@@ -120,6 +121,8 @@ class UnretryableAnthropicError(AnthropicError):
120
 
121
  def __init__(self, message: str, original_exception: Optional[Exception] = None) -> None:
122
  super().__init__(message, original_exception)
 
 
123
 
124
 
125
  @retry(
 
84
  from anthropic import AsyncAnthropic # Import the async client from Anthropic SDK
85
  return AsyncAnthropic(api_key=self.api_key)
86
 
87
+ @staticmethod
88
+ def build_expressive_prompt(character_description: str) -> str:
89
  """
90
  Constructs and returns a prompt based solely on the provided character description.
91
  The returned prompt is intended to instruct Claude to generate expressive text from a character,
 
121
 
122
  def __init__(self, message: str, original_exception: Optional[Exception] = None) -> None:
123
  super().__init__(message, original_exception)
124
+ self.original_exception = original_exception
125
+ self.message = message
126
 
127
 
128
  @retry(
src/integrations/elevenlabs_api.py CHANGED
@@ -10,13 +10,6 @@ Key Features:
10
  - Handles received audio and processes it for playback on the web.
11
  - Provides detailed logging for debugging and error tracking.
12
  - Utilizes robust error handling (EAFP) to validate API responses.
13
-
14
- Classes:
15
- - ElevenLabsConfig: Immutable configuration for interacting with ElevenLabs' TTS API.
16
- - ElevenLabsError: Custom exception for ElevenLabs API-related errors.
17
-
18
- Functions:
19
- - text_to_speech_with_elevenlabs: Synthesizes speech from text using ElevenLabs' TTS API.
20
  """
21
 
22
  # Standard Library Imports
@@ -26,9 +19,9 @@ from dataclasses import dataclass, field
26
  from typing import Optional, Tuple
27
 
28
  # Third-Party Library Imports
29
- from elevenlabs import ElevenLabs, TextToVoiceCreatePreviewsRequestOutputFormat
30
  from elevenlabs.core import ApiError
31
- from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
32
 
33
  # Local Application Imports
34
  from src.config import Config, logger
@@ -48,19 +41,18 @@ class ElevenLabsConfig:
48
  if not self.output_format:
49
  raise ValueError("ElevenLabs TTS API output format is not set.")
50
 
51
- # Compute the API key from the environment.
52
  computed_key = validate_env_var("ELEVENLABS_API_KEY")
53
  object.__setattr__(self, "api_key", computed_key)
54
 
55
  @property
56
- def client(self) -> ElevenLabs:
57
  """
58
- Lazy initialization of the ElevenLabs client.
59
 
60
  Returns:
61
- ElevenLabs: Configured client instance.
62
  """
63
- return ElevenLabs(api_key=self.api_key)
64
 
65
 
66
  class ElevenLabsError(Exception):
@@ -77,42 +69,43 @@ class UnretryableElevenLabsError(ElevenLabsError):
77
 
78
  def __init__(self, message: str, original_exception: Optional[Exception] = None):
79
  super().__init__(message, original_exception)
 
 
80
 
81
 
82
  @retry(
 
83
  stop=stop_after_attempt(3),
84
  wait=wait_fixed(2),
85
  before=before_log(logger, logging.DEBUG),
86
  after=after_log(logger, logging.DEBUG),
87
  reraise=True,
88
  )
89
- def text_to_speech_with_elevenlabs(
90
  character_description: str, text: str, config: Config
91
  ) -> Tuple[None, str]:
92
  """
93
- Synthesizes text to speech using the ElevenLabs TTS API, processes the audio data, and writes it to a file.
94
 
95
  Args:
96
- character_description (str): The character description used as the voice description.
97
  text (str): The text to be synthesized into speech.
 
98
 
99
  Returns:
100
  Tuple[None, str]: A tuple containing:
101
- - generation_id (None): We do not record the generation ID for ElevenLabs, but return None for uniformity
102
- across TTS integrations.
103
- - file_path (str): The relative file path to the audio file where the synthesized speech was saved.
104
 
105
  Raises:
106
  ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
107
  """
108
-
109
  logger.debug(f"Synthesizing speech with ElevenLabs. Text length: {len(text)} characters.")
110
-
111
  elevenlabs_config = config.elevenlabs_config
112
 
113
  try:
114
  # Synthesize speech using the ElevenLabs SDK
115
- response = elevenlabs_config.client.text_to_voice.create_previews(
116
  voice_description=character_description,
117
  text=text,
118
  output_format=elevenlabs_config.output_format,
@@ -129,9 +122,10 @@ def text_to_speech_with_elevenlabs(
129
  generated_voice_id = preview.generated_voice_id
130
  base64_audio = preview.audio_base_64
131
  filename = f"{generated_voice_id}.mp3"
132
- audio_file_path = save_base64_audio_to_file(base64_audio, filename, config)
133
 
134
  # Write audio to file and return the relative path
 
 
135
  return None, audio_file_path
136
 
137
  except Exception as e:
 
10
  - Handles received audio and processes it for playback on the web.
11
  - Provides detailed logging for debugging and error tracking.
12
  - Utilizes robust error handling (EAFP) to validate API responses.
 
 
 
 
 
 
 
13
  """
14
 
15
  # Standard Library Imports
 
19
  from typing import Optional, Tuple
20
 
21
  # Third-Party Library Imports
22
+ from elevenlabs import AsyncElevenLabs, TextToVoiceCreatePreviewsRequestOutputFormat
23
  from elevenlabs.core import ApiError
24
+ from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
25
 
26
  # Local Application Imports
27
  from src.config import Config, logger
 
41
  if not self.output_format:
42
  raise ValueError("ElevenLabs TTS API output format is not set.")
43
 
 
44
  computed_key = validate_env_var("ELEVENLABS_API_KEY")
45
  object.__setattr__(self, "api_key", computed_key)
46
 
47
  @property
48
+ def client(self) -> AsyncElevenLabs:
49
  """
50
+ Lazy initialization of the asynchronous ElevenLabs client.
51
 
52
  Returns:
53
+ AsyncElevenLabs: Configured async client instance.
54
  """
55
+ return AsyncElevenLabs(api_key=self.api_key)
56
 
57
 
58
  class ElevenLabsError(Exception):
 
69
 
70
  def __init__(self, message: str, original_exception: Optional[Exception] = None):
71
  super().__init__(message, original_exception)
72
+ self.original_exception = original_exception
73
+ self.message = message
74
 
75
 
76
  @retry(
77
+ retry=retry_if_exception(lambda e: not isinstance(e, UnretryableElevenLabsError)),
78
  stop=stop_after_attempt(3),
79
  wait=wait_fixed(2),
80
  before=before_log(logger, logging.DEBUG),
81
  after=after_log(logger, logging.DEBUG),
82
  reraise=True,
83
  )
84
+ async def text_to_speech_with_elevenlabs(
85
  character_description: str, text: str, config: Config
86
  ) -> Tuple[None, str]:
87
  """
88
+ Asynchronously synthesizes speech using the ElevenLabs TTS API, processes the audio data, and writes it to a file.
89
 
90
  Args:
91
+ character_description (str): The character description used for voice synthesis.
92
  text (str): The text to be synthesized into speech.
93
+ config (Config): Application configuration containing ElevenLabs API settings.
94
 
95
  Returns:
96
  Tuple[None, str]: A tuple containing:
97
+ - generation_id (None): A placeholder (no generation ID is returned).
98
+ - file_path (str): The relative file path to the saved audio file.
 
99
 
100
  Raises:
101
  ElevenLabsError: If there is an error communicating with the ElevenLabs API or processing the response.
102
  """
 
103
  logger.debug(f"Synthesizing speech with ElevenLabs. Text length: {len(text)} characters.")
 
104
  elevenlabs_config = config.elevenlabs_config
105
 
106
  try:
107
  # Synthesize speech using the ElevenLabs SDK
108
+ response = await elevenlabs_config.client.text_to_voice.create_previews(
109
  voice_description=character_description,
110
  text=text,
111
  output_format=elevenlabs_config.output_format,
 
122
  generated_voice_id = preview.generated_voice_id
123
  base64_audio = preview.audio_base_64
124
  filename = f"{generated_voice_id}.mp3"
 
125
 
126
  # Write audio to file and return the relative path
127
+ audio_file_path = save_base64_audio_to_file(base64_audio, filename, config)
128
+
129
  return None, audio_file_path
130
 
131
  except Exception as e:
src/integrations/hume_api.py CHANGED
@@ -9,13 +9,6 @@ Key Features:
9
  - Implements retry logic for handling transient API errors.
10
  - Handles received audio and processes it for playback on the web.
11
  - Provides detailed logging for debugging and error tracking.
12
-
13
- Classes:
14
- - HumeConfig: Immutable configuration for interacting with Hume's TTS API.
15
- - HumeError: Custom exception for Hume API-related errors.
16
-
17
- Functions:
18
- - text_to_speech_with_hume: Synthesizes speech from text using Hume's TTS API.
19
  """
20
 
21
  # Standard Library Imports
@@ -24,9 +17,8 @@ from dataclasses import dataclass, field
24
  from typing import Any, Dict, Literal, Tuple, Union
25
 
26
  # Third-Party Library Imports
27
- import requests
28
- from requests.exceptions import HTTPError
29
- from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
30
 
31
  # Local Application Imports
32
  from src.config import Config, logger
@@ -41,12 +33,9 @@ HumeSupportedFileFormat = Literal["mp3", "pcm", "wav"]
41
  class HumeConfig:
42
  """Immutable configuration for interacting with the Hume TTS API."""
43
 
44
- # Computed fields.
45
  api_key: str = field(init=False)
46
  headers: Dict[str, str] = field(init=False)
47
-
48
- # Provided fields.
49
- url: str = "https://test-api.hume.ai/v0/tts/octave"
50
  file_format: HumeSupportedFileFormat = "mp3"
51
 
52
  def __post_init__(self) -> None:
@@ -56,11 +45,8 @@ class HumeConfig:
56
  if not self.file_format:
57
  raise ValueError("Hume TTS file format is not set.")
58
 
59
- # Compute the API key from the environment.
60
  computed_api_key = validate_env_var("HUME_API_KEY")
61
  object.__setattr__(self, "api_key", computed_api_key)
62
-
63
- # Compute the headers.
64
  computed_headers = {
65
  "X-Hume-Api-Key": f"{computed_api_key}",
66
  "Content-Type": "application/json",
@@ -83,38 +69,36 @@ class UnretryableHumeError(HumeError):
83
  def __init__(self, message: str, original_exception: Union[Exception, None] = None):
84
  super().__init__(message, original_exception)
85
  self.original_exception = original_exception
 
86
 
87
 
88
  @retry(
 
89
  stop=stop_after_attempt(3),
90
  wait=wait_fixed(2),
91
  before=before_log(logger, logging.DEBUG),
92
  after=after_log(logger, logging.DEBUG),
93
  reraise=True,
94
  )
95
- def text_to_speech_with_hume(
96
  character_description: str,
97
  text: str,
98
  num_generations: int,
99
  config: Config,
100
  ) -> Union[Tuple[str, str], Tuple[str, str, str, str]]:
101
  """
102
- Synthesizes text to speech using the Hume TTS API, processes audio data, and writes audio to a file.
103
 
104
- This function sends a POST request to the Hume TTS API with a character description and text
105
- to be converted to speech. Depending on the specified number of generations (allowed values: 1 or 2),
106
- the API returns one or two generations. For each generation, the function extracts the base64-encoded
107
- audio and the generation ID, saves the audio as an MP3 file via the `save_base64_audio_to_file` helper,
108
- and returns the relevant details.
109
 
110
  Args:
111
- character_description (str): A description of the character, which is used as contextual input
112
- for generating the voice.
113
- text (str): The text to be converted to speech.
114
- num_generations (int): The number of audio generations to request from the API.
115
- Allowed values are 1 or 2. If 1, only a single generation is processed; if 2, a second
116
- generation is expected in the API response.
117
- config (Config): The application configuration containing Hume API settings.
118
 
119
  Returns:
120
  Union[Tuple[str, str], Tuple[str, str, str, str]]:
@@ -123,15 +107,13 @@ def text_to_speech_with_hume(
123
 
124
  Raises:
125
  ValueError: If num_generations is not 1 or 2.
126
- HumeError: If there is an error communicating with the Hume TTS API or parsing its response.
127
- UnretryableHumeError: If a client-side HTTP error (status code in the 4xx range) is encountered.
128
- Exception: Any other exceptions raised during the request or processing will be wrapped and
129
- re-raised as HumeError.
130
  """
131
-
132
  logger.debug(
133
- f"Processing TTS with Hume. Prompt length: {len(character_description)} characters. "
134
- f"Text length: {len(text)} characters."
 
135
  )
136
 
137
  if num_generations < 1 or num_generations > 2:
@@ -145,14 +127,15 @@ def text_to_speech_with_hume(
145
  }
146
 
147
  try:
148
- # Synthesize speech using the Hume TTS API
149
- response = requests.post(
150
- url=hume_config.url,
151
- headers=hume_config.headers,
152
- json=request_body,
153
- )
154
- response.raise_for_status()
155
- response_data = response.json()
 
156
 
157
  generations = response_data.get("generations")
158
  if not generations:
@@ -160,7 +143,6 @@ def text_to_speech_with_hume(
160
  logger.error(msg)
161
  raise HumeError(msg)
162
 
163
- # Extract the base64 encoded audio and generation ID from the generation.
164
  generation_a = generations[0]
165
  generation_a_id, audio_a_path = _parse_hume_tts_generation(generation_a, config)
166
 
@@ -171,48 +153,51 @@ def text_to_speech_with_hume(
171
  generation_b_id, audio_b_path = _parse_hume_tts_generation(generation_b, config)
172
  return (generation_a_id, audio_a_path, generation_b_id, audio_b_path)
173
 
174
- except Exception as e:
175
- if (
176
- isinstance(e, HTTPError)
177
- and e.response is not None
178
- and CLIENT_ERROR_CODE <= e.response.status_code < SERVER_ERROR_CODE
179
- ):
 
 
 
 
 
180
  raise UnretryableHumeError(
181
- message=f"{e.response.text}",
182
  original_exception=e,
183
  ) from e
184
-
 
185
  raise HumeError(
186
- message=f"{e}",
187
  original_exception=e,
188
  ) from e
189
 
 
 
 
 
 
 
 
 
190
 
191
  def _parse_hume_tts_generation(generation: Dict[str, Any], config: Config) -> Tuple[str, str]:
192
  """
193
- Parse a Hume TTS generation response and save the decoded audio as an MP3 file.
194
-
195
- This function extracts the generation ID and the base64-encoded audio from the provided
196
- dictionary. It then decodes and saves the audio data to an MP3 file, naming the file using
197
- the generation ID. Finally, it returns a tuple containing the generation ID and the file path
198
- of the saved audio.
199
 
200
  Args:
201
- generation (Dict[str, Any]): A dictionary representing the TTS generation response from Hume.
202
- Expected keys are:
203
- - "generation_id" (str): A unique identifier for the generated audio.
204
- - "audio" (str): A base64 encoded string of the audio data.
205
- config (Config): The application configuration used for saving the audio file.
206
 
207
  Returns:
208
- Tuple[str, str]: A tuple containing:
209
- - generation_id (str): The unique identifier for the audio generation.
210
- - audio_path (str): The filesystem path where the audio file was saved.
211
 
212
  Raises:
213
- KeyError: If the "generation_id" or "audio" key is missing from the generation dictionary.
214
- Exception: Propagates any exceptions raised by save_base64_audio_to_file, such as errors during
215
- the decoding or file saving process.
216
  """
217
  generation_id = generation.get("generation_id")
218
  if generation_id is None:
 
9
  - Implements retry logic for handling transient API errors.
10
  - Handles received audio and processes it for playback on the web.
11
  - Provides detailed logging for debugging and error tracking.
 
 
 
 
 
 
 
12
  """
13
 
14
  # Standard Library Imports
 
17
  from typing import Any, Dict, Literal, Tuple, Union
18
 
19
  # Third-Party Library Imports
20
+ import httpx
21
+ from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
 
22
 
23
  # Local Application Imports
24
  from src.config import Config, logger
 
33
  class HumeConfig:
34
  """Immutable configuration for interacting with the Hume TTS API."""
35
 
 
36
  api_key: str = field(init=False)
37
  headers: Dict[str, str] = field(init=False)
38
+ url: str = "https://api.hume.ai/v0/tts/octave"
 
 
39
  file_format: HumeSupportedFileFormat = "mp3"
40
 
41
  def __post_init__(self) -> None:
 
45
  if not self.file_format:
46
  raise ValueError("Hume TTS file format is not set.")
47
 
 
48
  computed_api_key = validate_env_var("HUME_API_KEY")
49
  object.__setattr__(self, "api_key", computed_api_key)
 
 
50
  computed_headers = {
51
  "X-Hume-Api-Key": f"{computed_api_key}",
52
  "Content-Type": "application/json",
 
69
  def __init__(self, message: str, original_exception: Union[Exception, None] = None):
70
  super().__init__(message, original_exception)
71
  self.original_exception = original_exception
72
+ self.message = message
73
 
74
 
75
  @retry(
76
+ retry=retry_if_exception(lambda e: not isinstance(e, UnretryableHumeError)),
77
  stop=stop_after_attempt(3),
78
  wait=wait_fixed(2),
79
  before=before_log(logger, logging.DEBUG),
80
  after=after_log(logger, logging.DEBUG),
81
  reraise=True,
82
  )
83
+ async def text_to_speech_with_hume(
84
  character_description: str,
85
  text: str,
86
  num_generations: int,
87
  config: Config,
88
  ) -> Union[Tuple[str, str], Tuple[str, str, str, str]]:
89
  """
90
+ Asynchronously synthesizes speech using the Hume TTS API, processes audio data, and writes audio to a file.
91
 
92
+ This function sends a POST request to the Hume TTS API with a character description and text to be converted to
93
+ speech. Depending on the specified number of generations (1 or 2), the API returns one or two generations.
94
+ For each generation, the function extracts the base64-encoded audio and generation ID, saves the audio as an MP3
95
+ file, and returns the relevant details.
 
96
 
97
  Args:
98
+ character_description (str): Description used for voice synthesis.
99
+ text (str): Text to be converted to speech.
100
+ num_generations (int): Number of audio generations to request (1 or 2).
101
+ config (Config): Application configuration containing Hume API settings.
 
 
 
102
 
103
  Returns:
104
  Union[Tuple[str, str], Tuple[str, str, str, str]]:
 
107
 
108
  Raises:
109
  ValueError: If num_generations is not 1 or 2.
110
+ HumeError: For errors communicating with the Hume API.
111
+ UnretryableHumeError: For client-side HTTP errors (status code 4xx).
 
 
112
  """
 
113
  logger.debug(
114
+ "Processing TTS with Hume. "
115
+ f"Character description length: {len(character_description)}. "
116
+ f"Text length: {len(text)}."
117
  )
118
 
119
  if num_generations < 1 or num_generations > 2:
 
127
  }
128
 
129
  try:
130
+ async with httpx.AsyncClient() as client:
131
+ response = await client.post(
132
+ url=hume_config.url,
133
+ headers=hume_config.headers,
134
+ json=request_body,
135
+ timeout=30.0,
136
+ )
137
+ response.raise_for_status()
138
+ response_data = response.json()
139
 
140
  generations = response_data.get("generations")
141
  if not generations:
 
143
  logger.error(msg)
144
  raise HumeError(msg)
145
 
 
146
  generation_a = generations[0]
147
  generation_a_id, audio_a_path = _parse_hume_tts_generation(generation_a, config)
148
 
 
153
  generation_b_id, audio_b_path = _parse_hume_tts_generation(generation_b, config)
154
  return (generation_a_id, audio_a_path, generation_b_id, audio_b_path)
155
 
156
+ except httpx.ReadTimeout as e:
157
+ # Handle timeout specifically
158
+ raise HumeError(
159
+ message="Request to Hume API timed out. Please try again later.",
160
+ original_exception=e,
161
+ ) from e
162
+
163
+ except httpx.HTTPStatusError as e:
164
+ if e.response is not None and CLIENT_ERROR_CODE <= e.response.status_code < SERVER_ERROR_CODE:
165
+ error_message = f"HTTP Error {e.response.status_code}: {e.response.text}"
166
+ logger.error(error_message)
167
  raise UnretryableHumeError(
168
+ message=error_message,
169
  original_exception=e,
170
  ) from e
171
+ error_message = f"HTTP Error {e.response.status_code if e.response else 'unknown'}"
172
+ logger.error(error_message)
173
  raise HumeError(
174
+ message=error_message,
175
  original_exception=e,
176
  ) from e
177
 
178
+ except Exception as e:
179
+ error_type = type(e).__name__
180
+ error_message = str(e) if str(e) else f"An error of type {error_type} occurred"
181
+ logger.error("Error during Hume API call: %s - %s", error_type, error_message)
182
+ raise HumeError(
183
+ message=error_message,
184
+ original_exception=e,
185
+ ) from e
186
 
187
  def _parse_hume_tts_generation(generation: Dict[str, Any], config: Config) -> Tuple[str, str]:
188
  """
189
+ Parses a Hume TTS generation response and saves the decoded audio as an MP3 file.
 
 
 
 
 
190
 
191
  Args:
192
+ generation (Dict[str, Any]): TTS generation response containing 'generation_id' and 'audio'.
193
+ config (Config): Application configuration for saving the audio file.
 
 
 
194
 
195
  Returns:
196
+ Tuple[str, str]: (generation_id, audio_path)
 
 
197
 
198
  Raises:
199
+ KeyError: If expected keys are missing.
200
+ Exception: Propagates exceptions from saving the audio file.
 
201
  """
202
  generation_id = generation.get("generation_id")
203
  if generation_id is None:
src/main.py CHANGED
@@ -4,16 +4,26 @@ main.py
4
  This module is the entry point for the app. It loads configuration and starts the Gradio app.
5
  """
6
 
 
 
 
7
  # Local Application Imports
8
  from src.app import App
9
  from src.config import Config, logger
10
  from src.database import init_db
11
 
12
- if __name__ == "__main__":
 
 
 
 
13
  config = Config.get()
14
  logger.info("Launching TTS Arena Gradio app...")
15
  db_session_maker = init_db(config)
16
  app = App(config, db_session_maker)
17
  demo = app.build_gradio_interface()
18
- init_db(config)
19
  demo.launch(server_name="0.0.0.0", allowed_paths=[str(config.audio_dir)])
 
 
 
 
 
4
  This module is the entry point for the app. It loads configuration and starts the Gradio app.
5
  """
6
 
7
+ # Standard Library Imports
8
+ import asyncio
9
+
10
  # Local Application Imports
11
  from src.app import App
12
  from src.config import Config, logger
13
  from src.database import init_db
14
 
15
+
16
+ async def main():
17
+ """
18
+ Asynchronous main function to initialize the application.
19
+ """
20
  config = Config.get()
21
  logger.info("Launching TTS Arena Gradio app...")
22
  db_session_maker = init_db(config)
23
  app = App(config, db_session_maker)
24
  demo = app.build_gradio_interface()
 
25
  demo.launch(server_name="0.0.0.0", allowed_paths=[str(config.audio_dir)])
26
+
27
+
28
+ if __name__ == "__main__":
29
+ asyncio.run(main())
uv.lock CHANGED
@@ -262,7 +262,6 @@ dependencies = [
262
  { name = "greenlet" },
263
  { name = "httpx" },
264
  { name = "python-dotenv" },
265
- { name = "requests" },
266
  { name = "sqlalchemy" },
267
  { name = "tenacity" },
268
  ]
@@ -287,7 +286,6 @@ requires-dist = [
287
  { name = "greenlet", specifier = ">=2.0.0" },
288
  { name = "httpx", specifier = ">=0.24.1" },
289
  { name = "python-dotenv", specifier = ">=1.0.1" },
290
- { name = "requests", specifier = ">=2.32.3" },
291
  { name = "sqlalchemy", specifier = ">=2.0.0" },
292
  { name = "tenacity", specifier = ">=9.0.0" },
293
  ]
 
262
  { name = "greenlet" },
263
  { name = "httpx" },
264
  { name = "python-dotenv" },
 
265
  { name = "sqlalchemy" },
266
  { name = "tenacity" },
267
  ]
 
286
  { name = "greenlet", specifier = ">=2.0.0" },
287
  { name = "httpx", specifier = ">=0.24.1" },
288
  { name = "python-dotenv", specifier = ">=1.0.1" },
 
289
  { name = "sqlalchemy", specifier = ">=2.0.0" },
290
  { name = "tenacity", specifier = ">=9.0.0" },
291
  ]