Spaces:
Paused
Paused
import gradio as gr | |
from pydub import AudioSegment | |
import json | |
import uuid | |
import edge_tts | |
import asyncio | |
import aiofiles | |
import os | |
import time | |
import mimetypes | |
import torch | |
import re | |
from typing import List, Dict | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
# Constants | |
MAX_FILE_SIZE_MB = 20 | |
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 | |
MODEL_ID = "unsloth/gemma-3-1b-pt" | |
# Initialize model with proper error handling | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
trust_remote_code=True | |
).eval() | |
# Configure generation parameters | |
generation_config = GenerationConfig( | |
max_new_tokens=1024, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
print(f"Model loaded successfully on device: {model.device}") | |
except Exception as e: | |
print(f"Model initialization error: {e}") | |
model = None | |
tokenizer = None | |
generation_config = None | |
class PodcastGenerator: | |
def __init__(self): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.generation_config = generation_config | |
def extract_json_from_text(self, text: str) -> Dict: | |
"""Extract JSON from model output using regex patterns""" | |
# Remove the input prompt from the output | |
# Look for JSON-like structures | |
json_patterns = [ | |
r'\{[^{}]*"topic"[^{}]*"podcast"[^{}]*\[.*?\]\s*\}', | |
r'\{.*?"topic".*?"podcast".*?\[.*?\].*?\}', | |
] | |
for pattern in json_patterns: | |
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) | |
for match in matches: | |
try: | |
# Clean up the match | |
cleaned_match = match.strip() | |
return json.loads(cleaned_match) | |
except json.JSONDecodeError: | |
continue | |
# If no valid JSON found, create a fallback structure | |
return self.create_fallback_podcast(text) | |
def create_fallback_podcast(self, text: str) -> Dict: | |
"""Create a basic podcast structure when JSON parsing fails""" | |
# Extract meaningful sentences from the text | |
sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 10] | |
if not sentences: | |
sentences = ["Let's discuss this interesting topic.", "That's a great point to consider."] | |
podcast_lines = [] | |
for i, sentence in enumerate(sentences[:10]): # Limit to 10 exchanges | |
speaker = (i % 2) + 1 | |
podcast_lines.append({ | |
"speaker": speaker, | |
"line": sentence + "." if not sentence.endswith('.') else sentence | |
}) | |
return { | |
"topic": "Generated Discussion", | |
"podcast": podcast_lines | |
} | |
async def generate_script(self, prompt: str, language: str, file_obj=None, progress=None) -> Dict: | |
if not self.model or not self.tokenizer: | |
raise Exception("Model not properly initialized. Please check model loading.") | |
example_json = { | |
"topic": "AGI", | |
"podcast": [ | |
{"speaker": 1, "line": "So, AGI, huh? Seems like everyone's talking about it these days."}, | |
{"speaker": 2, "line": "Yeah, it's definitely having a moment, isn't it?"}, | |
{"speaker": 1, "line": "It really is. What got you hooked on this topic?"}, | |
{"speaker": 2, "line": "The potential implications are fascinating and concerning at the same time."} | |
] | |
} | |
if language == "Auto Detect": | |
language_instruction = "Use the same language as the input text" | |
else: | |
language_instruction = f"Generate the podcast in {language} language" | |
# Simplified, more direct prompt | |
system_prompt = f"""Generate a podcast script as valid JSON. {language_instruction}. | |
Requirements: | |
- Exactly 2 speakers (speaker 1 and 2) | |
- Natural, engaging conversation | |
- JSON format only | |
Example format: | |
{json.dumps(example_json, indent=2)} | |
Input topic: {prompt} | |
Generate JSON:""" | |
try: | |
if progress: | |
progress(0.3, "Generating podcast script...") | |
# Tokenize with proper attention mask | |
inputs = self.tokenizer( | |
system_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=2048 | |
) | |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
# Generate with timeout | |
with torch.no_grad(): | |
output = self.model.generate( | |
**inputs, | |
generation_config=self.generation_config, | |
pad_token_id=self.tokenizer.pad_token_id, | |
) | |
# Decode only the new tokens | |
generated_text = self.tokenizer.decode( | |
output[0][inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True | |
) | |
print(f"Generated text: {generated_text[:500]}...") | |
if progress: | |
progress(0.4, "Processing generated script...") | |
# Extract JSON from the generated text | |
result = self.extract_json_from_text(generated_text) | |
if progress: | |
progress(0.5, "Script generated successfully!") | |
return result | |
except Exception as e: | |
print(f"Generation error: {e}") | |
# Return fallback podcast | |
return { | |
"topic": prompt or "Discussion", | |
"podcast": [ | |
{"speaker": 1, "line": f"Welcome to our discussion about {prompt or 'this topic'}."}, | |
{"speaker": 2, "line": "Thanks for having me. This is indeed an interesting subject."}, | |
{"speaker": 1, "line": "Let's dive into the key points and explore different perspectives."}, | |
{"speaker": 2, "line": "Absolutely. There's a lot to unpack here."}, | |
{"speaker": 1, "line": "What aspects do you find most compelling?"}, | |
{"speaker": 2, "line": "The implications and potential applications are fascinating."}, | |
{"speaker": 1, "line": "That's a great point. Thanks for the insightful discussion."}, | |
{"speaker": 2, "line": "Thank you. This has been a valuable conversation."} | |
] | |
} | |
async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str: | |
"""Generate TTS audio with improved error handling""" | |
voice = speaker1 if speaker == 1 else speaker2 | |
speech = edge_tts.Communicate(text, voice) | |
temp_filename = f"temp_audio_{uuid.uuid4()}.wav" | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
await asyncio.wait_for(speech.save(temp_filename), timeout=30) | |
if os.path.exists(temp_filename) and os.path.getsize(temp_filename) > 0: | |
return temp_filename | |
else: | |
raise Exception("Generated audio file is empty") | |
except asyncio.TimeoutError: | |
if os.path.exists(temp_filename): | |
os.remove(temp_filename) | |
if attempt == max_retries - 1: | |
raise Exception("TTS generation timed out after multiple attempts") | |
await asyncio.sleep(1) # Brief delay before retry | |
except Exception as e: | |
if os.path.exists(temp_filename): | |
os.remove(temp_filename) | |
if attempt == max_retries - 1: | |
raise Exception(f"TTS generation failed: {str(e)}") | |
await asyncio.sleep(1) | |
async def combine_audio_files(self, audio_files: List[str], progress=None) -> str: | |
"""Combine audio files with silence padding""" | |
if progress: | |
progress(0.9, "Combining audio files...") | |
try: | |
combined_audio = AudioSegment.empty() | |
silence_padding = AudioSegment.silent(duration=500) # 500ms silence | |
for i, audio_file in enumerate(audio_files): | |
try: | |
audio_segment = AudioSegment.from_file(audio_file) | |
combined_audio += audio_segment | |
# Add silence between speakers (except for the last file) | |
if i < len(audio_files) - 1: | |
combined_audio += silence_padding | |
except Exception as e: | |
print(f"Warning: Could not process audio file {audio_file}: {e}") | |
finally: | |
# Clean up temporary file | |
if os.path.exists(audio_file): | |
os.remove(audio_file) | |
if len(combined_audio) == 0: | |
raise Exception("No audio content generated") | |
output_filename = f"podcast_output_{uuid.uuid4()}.wav" | |
combined_audio.export(output_filename, format="wav") | |
if progress: | |
progress(1.0, "Podcast generated successfully!") | |
return output_filename | |
except Exception as e: | |
# Clean up any remaining temp files | |
for audio_file in audio_files: | |
if os.path.exists(audio_file): | |
os.remove(audio_file) | |
raise Exception(f"Audio combination failed: {str(e)}") | |
async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str: | |
"""Main podcast generation pipeline with improved error handling""" | |
try: | |
if progress: | |
progress(0.1, "Starting podcast generation...") | |
# Generate script | |
podcast_json = await self.generate_script(input_text, language, file_obj, progress) | |
if not podcast_json.get('podcast'): | |
raise Exception("No podcast content generated") | |
if progress: | |
progress(0.5, "Converting text to speech...") | |
# Generate TTS with sequential processing to avoid overload | |
audio_files = [] | |
total_lines = len(podcast_json['podcast']) | |
for i, item in enumerate(podcast_json['podcast']): | |
try: | |
audio_file = await self.tts_generate( | |
item['line'], | |
item['speaker'], | |
speaker1, | |
speaker2 | |
) | |
audio_files.append(audio_file) | |
# Update progress | |
if progress: | |
current_progress = 0.5 + (0.4 * (i + 1) / total_lines) | |
progress(current_progress, f"Generated speech {i + 1}/{total_lines}") | |
except Exception as e: | |
print(f"TTS error for line {i}: {e}") | |
# Continue with remaining lines | |
continue | |
if not audio_files: | |
raise Exception("No audio files generated successfully") | |
# Combine audio files | |
combined_audio = await self.combine_audio_files(audio_files, progress) | |
return combined_audio | |
except Exception as e: | |
raise Exception(f"Podcast generation failed: {str(e)}") | |
# Voice mapping | |
VOICE_MAPPING = { | |
"Andrew - English (United States)": "en-US-AndrewMultilingualNeural", | |
"Ava - English (United States)": "en-US-AvaMultilingualNeural", | |
"Brian - English (United States)": "en-US-BrianMultilingualNeural", | |
"Emma - English (United States)": "en-US-EmmaMultilingualNeural", | |
"Florian - German (Germany)": "de-DE-FlorianMultilingualNeural", | |
"Seraphina - German (Germany)": "de-DE-SeraphinaMultilingualNeural", | |
"Remy - French (France)": "fr-FR-RemyMultilingualNeural", | |
"Vivienne - French (France)": "fr-FR-VivienneMultilingualNeural" | |
} | |
async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, progress=None) -> str: | |
"""Process input and generate podcast""" | |
start_time = time.time() | |
try: | |
if progress: | |
progress(0.05, "Processing input...") | |
# Map speaker names to voice IDs | |
speaker1_voice = VOICE_MAPPING.get(speaker1, "en-US-AndrewMultilingualNeural") | |
speaker2_voice = VOICE_MAPPING.get(speaker2, "en-US-AvaMultilingualNeural") | |
# Validate input | |
if not input_text or input_text.strip() == "": | |
if input_file is None: | |
raise Exception("Please provide either text input or upload a file") | |
# TODO: Add file processing logic here if needed | |
podcast_generator = PodcastGenerator() | |
result = await podcast_generator.generate_podcast( | |
input_text, language, speaker1_voice, speaker2_voice, input_file, progress | |
) | |
end_time = time.time() | |
print(f"Total generation time: {end_time - start_time:.2f} seconds") | |
return result | |
except Exception as e: | |
error_msg = str(e) | |
print(f"Processing error: {error_msg}") | |
raise Exception(f"Generation failed: {error_msg}") | |
def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2): | |
"""Gradio interface function with proper error handling""" | |
try: | |
# Validate inputs | |
if not input_text and input_file is None: | |
return None | |
if input_text and len(input_text.strip()) == 0: | |
input_text = None | |
# Create a simple progress tracker | |
progress_history = [] | |
def progress_callback(value, text): | |
progress_history.append(f"{value:.1%}: {text}") | |
print(f"Progress: {value:.1%} - {text}") | |
# Run the async function | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
result = loop.run_until_complete( | |
process_input(input_text, input_file, language, speaker1, speaker2, progress_callback) | |
) | |
return result | |
finally: | |
loop.close() | |
except Exception as e: | |
print(f"Gradio function error: {e}") | |
raise gr.Error(f"Failed to generate podcast: {str(e)}") | |
def create_interface(): | |
"""Create the Gradio interface with proper component configuration""" | |
language_options = [ | |
"Auto Detect", "English", "German", "French", "Spanish", "Italian", | |
"Portuguese", "Dutch", "Russian", "Chinese", "Japanese", "Korean" | |
] | |
voice_options = list(VOICE_MAPPING.keys()) | |
with gr.Blocks( | |
title="PodcastGen 2🎙️", | |
theme=gr.themes.Soft(), | |
css=".gradio-container {max-width: 1200px; margin: auto;}" | |
) as demo: | |
gr.Markdown("# 🎙️ PodcastGen 2") | |
gr.Markdown("Generate professional 2-speaker podcasts from text input!") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_text = gr.Textbox( | |
label="Input Text", | |
lines=8, | |
placeholder="Enter your topic or text for podcast generation...", | |
info="Describe what you want the podcast to discuss" | |
) | |
with gr.Column(scale=1): | |
input_file = gr.File( | |
label="Upload File (Optional)", | |
file_types=[".pdf", ".txt"], | |
type="filepath", | |
info=f"Max size: {MAX_FILE_SIZE_MB}MB" | |
) | |
with gr.Row(): | |
language = gr.Dropdown( | |
label="Language", | |
choices=language_options, | |
value="Auto Detect", | |
info="Select output language" | |
) | |
speaker1 = gr.Dropdown( | |
label="Speaker 1 Voice", | |
choices=voice_options, | |
value="Andrew - English (United States)" | |
) | |
speaker2 = gr.Dropdown( | |
label="Speaker 2 Voice", | |
choices=voice_options, | |
value="Ava - English (United States)" | |
) | |
generate_btn = gr.Button( | |
"🎙️ Generate Podcast", | |
variant="primary", | |
size="lg" | |
) | |
output_audio = gr.Audio( | |
label="Generated Podcast", | |
type="filepath", | |
format="wav", | |
show_download_button=True | |
) | |
# Connect the interface | |
generate_btn.click( | |
fn=generate_podcast_gradio, | |
inputs=[input_text, input_file, language, speaker1, speaker2], | |
outputs=[output_audio], | |
show_progress=True | |
) | |
# Add usage instructions | |
with gr.Accordion("Usage Instructions", open=False): | |
gr.Markdown(""" | |
### How to use: | |
1. **Input**: Enter your topic or text in the text box, or upload a PDF/TXT file | |
2. **Language**: Choose the output language (Auto Detect recommended) | |
3. **Voices**: Select different voices for Speaker 1 and Speaker 2 | |
4. **Generate**: Click the button and wait for processing | |
### Tips: | |
- Provide clear, specific topics for better results | |
- The AI will create a natural conversation between two speakers | |
- Generation may take 1-3 minutes depending on text length | |
""") | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
share=False | |
) |