Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,21 @@ def tts(text):
|
|
| 21 |
audio_reference=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'),
|
| 22 |
api_name="/predict"
|
| 23 |
)
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
except Exception as e:
|
| 26 |
print(f"TTS API Error: {e}")
|
| 27 |
return None
|
|
@@ -157,7 +171,7 @@ def model_inference(input_dict, history):
|
|
| 157 |
padding=True,
|
| 158 |
).to("cuda")
|
| 159 |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
| 160 |
-
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=
|
| 161 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 162 |
thread.start()
|
| 163 |
buffer = ""
|
|
@@ -197,8 +211,15 @@ with gr.Blocks() as demo:
|
|
| 197 |
|
| 198 |
# Generate audio after streaming is complete
|
| 199 |
try:
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
except Exception as e:
|
| 203 |
print(f"TTS Error: {e}")
|
| 204 |
yield "", chat_history, None
|
|
|
|
| 21 |
audio_reference=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'),
|
| 22 |
api_name="/predict"
|
| 23 |
)
|
| 24 |
+
print(f"TTS result: {result}") # Debug print to see what's returned
|
| 25 |
+
|
| 26 |
+
# Handle different possible return formats
|
| 27 |
+
if isinstance(result, tuple):
|
| 28 |
+
# If result is a tuple, the audio file might be in the first element
|
| 29 |
+
return result[0] if result else None
|
| 30 |
+
elif isinstance(result, str):
|
| 31 |
+
# If result is a string (file path)
|
| 32 |
+
return result
|
| 33 |
+
elif hasattr(result, 'name'):
|
| 34 |
+
# If result is a file object with a name attribute
|
| 35 |
+
return result.name
|
| 36 |
+
else:
|
| 37 |
+
# Try to return the result as-is
|
| 38 |
+
return result
|
| 39 |
except Exception as e:
|
| 40 |
print(f"TTS API Error: {e}")
|
| 41 |
return None
|
|
|
|
| 171 |
padding=True,
|
| 172 |
).to("cuda")
|
| 173 |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
| 174 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=128)
|
| 175 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 176 |
thread.start()
|
| 177 |
buffer = ""
|
|
|
|
| 211 |
|
| 212 |
# Generate audio after streaming is complete
|
| 213 |
try:
|
| 214 |
+
if bot_message.strip(): # Only generate TTS if there's actual text
|
| 215 |
+
audio_path = tts(bot_message)
|
| 216 |
+
if audio_path:
|
| 217 |
+
yield "", chat_history, audio_path
|
| 218 |
+
else:
|
| 219 |
+
print("TTS returned None or empty result")
|
| 220 |
+
yield "", chat_history, None
|
| 221 |
+
else:
|
| 222 |
+
yield "", chat_history, None
|
| 223 |
except Exception as e:
|
| 224 |
print(f"TTS Error: {e}")
|
| 225 |
yield "", chat_history, None
|