yaya-sy commited on
Commit
264bf64
Β·
verified Β·
1 Parent(s): f212225

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -21,7 +21,21 @@ def tts(text):
21
  audio_reference=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'),
22
  api_name="/predict"
23
  )
24
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  except Exception as e:
26
  print(f"TTS API Error: {e}")
27
  return None
@@ -157,7 +171,7 @@ def model_inference(input_dict, history):
157
  padding=True,
158
  ).to("cuda")
159
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
160
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
161
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
162
  thread.start()
163
  buffer = ""
@@ -197,8 +211,15 @@ with gr.Blocks() as demo:
197
 
198
  # Generate audio after streaming is complete
199
  try:
200
- audio_path = tts(bot_message)
201
- yield "", chat_history, audio_path
 
 
 
 
 
 
 
202
  except Exception as e:
203
  print(f"TTS Error: {e}")
204
  yield "", chat_history, None
 
21
  audio_reference=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'),
22
  api_name="/predict"
23
  )
24
+ print(f"TTS result: {result}") # Debug print to see what's returned
25
+
26
+ # Handle different possible return formats
27
+ if isinstance(result, tuple):
28
+ # If result is a tuple, the audio file might be in the first element
29
+ return result[0] if result else None
30
+ elif isinstance(result, str):
31
+ # If result is a string (file path)
32
+ return result
33
+ elif hasattr(result, 'name'):
34
+ # If result is a file object with a name attribute
35
+ return result.name
36
+ else:
37
+ # Try to return the result as-is
38
+ return result
39
  except Exception as e:
40
  print(f"TTS API Error: {e}")
41
  return None
 
171
  padding=True,
172
  ).to("cuda")
173
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
174
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=128)
175
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
176
  thread.start()
177
  buffer = ""
 
211
 
212
  # Generate audio after streaming is complete
213
  try:
214
+ if bot_message.strip(): # Only generate TTS if there's actual text
215
+ audio_path = tts(bot_message)
216
+ if audio_path:
217
+ yield "", chat_history, audio_path
218
+ else:
219
+ print("TTS returned None or empty result")
220
+ yield "", chat_history, None
221
+ else:
222
+ yield "", chat_history, None
223
  except Exception as e:
224
  print(f"TTS Error: {e}")
225
  yield "", chat_history, None