podcastgen / app.py
Rausda6's picture
Update app.py
52f0f4a verified
raw
history blame
18.5 kB
import gradio as gr
from pydub import AudioSegment
import json
import uuid
import edge_tts
import asyncio
import aiofiles
import os
import time
import mimetypes
import torch
import re
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
# Constants
MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
MODEL_ID = "unsloth/gemma-3-1b-pt"
# Initialize model with proper error handling
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
).eval()
# Configure generation parameters
generation_config = GenerationConfig(
max_new_tokens=1024,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print(f"Model loaded successfully on device: {model.device}")
except Exception as e:
print(f"Model initialization error: {e}")
model = None
tokenizer = None
generation_config = None
class PodcastGenerator:
def __init__(self):
self.model = model
self.tokenizer = tokenizer
self.generation_config = generation_config
def extract_json_from_text(self, text: str) -> Dict:
"""Extract JSON from model output using regex patterns"""
# Remove the input prompt from the output
# Look for JSON-like structures
json_patterns = [
r'\{[^{}]*"topic"[^{}]*"podcast"[^{}]*\[.*?\]\s*\}',
r'\{.*?"topic".*?"podcast".*?\[.*?\].*?\}',
]
for pattern in json_patterns:
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
for match in matches:
try:
# Clean up the match
cleaned_match = match.strip()
return json.loads(cleaned_match)
except json.JSONDecodeError:
continue
# If no valid JSON found, create a fallback structure
return self.create_fallback_podcast(text)
def create_fallback_podcast(self, text: str) -> Dict:
"""Create a basic podcast structure when JSON parsing fails"""
# Extract meaningful sentences from the text
sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 10]
if not sentences:
sentences = ["Let's discuss this interesting topic.", "That's a great point to consider."]
podcast_lines = []
for i, sentence in enumerate(sentences[:10]): # Limit to 10 exchanges
speaker = (i % 2) + 1
podcast_lines.append({
"speaker": speaker,
"line": sentence + "." if not sentence.endswith('.') else sentence
})
return {
"topic": "Generated Discussion",
"podcast": podcast_lines
}
async def generate_script(self, prompt: str, language: str, file_obj=None, progress=None) -> Dict:
if not self.model or not self.tokenizer:
raise Exception("Model not properly initialized. Please check model loading.")
example_json = {
"topic": "AGI",
"podcast": [
{"speaker": 1, "line": "So, AGI, huh? Seems like everyone's talking about it these days."},
{"speaker": 2, "line": "Yeah, it's definitely having a moment, isn't it?"},
{"speaker": 1, "line": "It really is. What got you hooked on this topic?"},
{"speaker": 2, "line": "The potential implications are fascinating and concerning at the same time."}
]
}
if language == "Auto Detect":
language_instruction = "Use the same language as the input text"
else:
language_instruction = f"Generate the podcast in {language} language"
# Simplified, more direct prompt
system_prompt = f"""Generate a podcast script as valid JSON. {language_instruction}.
Requirements:
- Exactly 2 speakers (speaker 1 and 2)
- Natural, engaging conversation
- JSON format only
Example format:
{json.dumps(example_json, indent=2)}
Input topic: {prompt}
Generate JSON:"""
try:
if progress:
progress(0.3, "Generating podcast script...")
# Tokenize with proper attention mask
inputs = self.tokenizer(
system_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate with timeout
with torch.no_grad():
output = self.model.generate(
**inputs,
generation_config=self.generation_config,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode only the new tokens
generated_text = self.tokenizer.decode(
output[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
print(f"Generated text: {generated_text[:500]}...")
if progress:
progress(0.4, "Processing generated script...")
# Extract JSON from the generated text
result = self.extract_json_from_text(generated_text)
if progress:
progress(0.5, "Script generated successfully!")
return result
except Exception as e:
print(f"Generation error: {e}")
# Return fallback podcast
return {
"topic": prompt or "Discussion",
"podcast": [
{"speaker": 1, "line": f"Welcome to our discussion about {prompt or 'this topic'}."},
{"speaker": 2, "line": "Thanks for having me. This is indeed an interesting subject."},
{"speaker": 1, "line": "Let's dive into the key points and explore different perspectives."},
{"speaker": 2, "line": "Absolutely. There's a lot to unpack here."},
{"speaker": 1, "line": "What aspects do you find most compelling?"},
{"speaker": 2, "line": "The implications and potential applications are fascinating."},
{"speaker": 1, "line": "That's a great point. Thanks for the insightful discussion."},
{"speaker": 2, "line": "Thank you. This has been a valuable conversation."}
]
}
async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
"""Generate TTS audio with improved error handling"""
voice = speaker1 if speaker == 1 else speaker2
speech = edge_tts.Communicate(text, voice)
temp_filename = f"temp_audio_{uuid.uuid4()}.wav"
max_retries = 3
for attempt in range(max_retries):
try:
await asyncio.wait_for(speech.save(temp_filename), timeout=30)
if os.path.exists(temp_filename) and os.path.getsize(temp_filename) > 0:
return temp_filename
else:
raise Exception("Generated audio file is empty")
except asyncio.TimeoutError:
if os.path.exists(temp_filename):
os.remove(temp_filename)
if attempt == max_retries - 1:
raise Exception("TTS generation timed out after multiple attempts")
await asyncio.sleep(1) # Brief delay before retry
except Exception as e:
if os.path.exists(temp_filename):
os.remove(temp_filename)
if attempt == max_retries - 1:
raise Exception(f"TTS generation failed: {str(e)}")
await asyncio.sleep(1)
async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
"""Combine audio files with silence padding"""
if progress:
progress(0.9, "Combining audio files...")
try:
combined_audio = AudioSegment.empty()
silence_padding = AudioSegment.silent(duration=500) # 500ms silence
for i, audio_file in enumerate(audio_files):
try:
audio_segment = AudioSegment.from_file(audio_file)
combined_audio += audio_segment
# Add silence between speakers (except for the last file)
if i < len(audio_files) - 1:
combined_audio += silence_padding
except Exception as e:
print(f"Warning: Could not process audio file {audio_file}: {e}")
finally:
# Clean up temporary file
if os.path.exists(audio_file):
os.remove(audio_file)
if len(combined_audio) == 0:
raise Exception("No audio content generated")
output_filename = f"podcast_output_{uuid.uuid4()}.wav"
combined_audio.export(output_filename, format="wav")
if progress:
progress(1.0, "Podcast generated successfully!")
return output_filename
except Exception as e:
# Clean up any remaining temp files
for audio_file in audio_files:
if os.path.exists(audio_file):
os.remove(audio_file)
raise Exception(f"Audio combination failed: {str(e)}")
async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str:
"""Main podcast generation pipeline with improved error handling"""
try:
if progress:
progress(0.1, "Starting podcast generation...")
# Generate script
podcast_json = await self.generate_script(input_text, language, file_obj, progress)
if not podcast_json.get('podcast'):
raise Exception("No podcast content generated")
if progress:
progress(0.5, "Converting text to speech...")
# Generate TTS with sequential processing to avoid overload
audio_files = []
total_lines = len(podcast_json['podcast'])
for i, item in enumerate(podcast_json['podcast']):
try:
audio_file = await self.tts_generate(
item['line'],
item['speaker'],
speaker1,
speaker2
)
audio_files.append(audio_file)
# Update progress
if progress:
current_progress = 0.5 + (0.4 * (i + 1) / total_lines)
progress(current_progress, f"Generated speech {i + 1}/{total_lines}")
except Exception as e:
print(f"TTS error for line {i}: {e}")
# Continue with remaining lines
continue
if not audio_files:
raise Exception("No audio files generated successfully")
# Combine audio files
combined_audio = await self.combine_audio_files(audio_files, progress)
return combined_audio
except Exception as e:
raise Exception(f"Podcast generation failed: {str(e)}")
# Voice mapping
VOICE_MAPPING = {
"Andrew - English (United States)": "en-US-AndrewMultilingualNeural",
"Ava - English (United States)": "en-US-AvaMultilingualNeural",
"Brian - English (United States)": "en-US-BrianMultilingualNeural",
"Emma - English (United States)": "en-US-EmmaMultilingualNeural",
"Florian - German (Germany)": "de-DE-FlorianMultilingualNeural",
"Seraphina - German (Germany)": "de-DE-SeraphinaMultilingualNeural",
"Remy - French (France)": "fr-FR-RemyMultilingualNeural",
"Vivienne - French (France)": "fr-FR-VivienneMultilingualNeural"
}
async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, progress=None) -> str:
"""Process input and generate podcast"""
start_time = time.time()
try:
if progress:
progress(0.05, "Processing input...")
# Map speaker names to voice IDs
speaker1_voice = VOICE_MAPPING.get(speaker1, "en-US-AndrewMultilingualNeural")
speaker2_voice = VOICE_MAPPING.get(speaker2, "en-US-AvaMultilingualNeural")
# Validate input
if not input_text or input_text.strip() == "":
if input_file is None:
raise Exception("Please provide either text input or upload a file")
# TODO: Add file processing logic here if needed
podcast_generator = PodcastGenerator()
result = await podcast_generator.generate_podcast(
input_text, language, speaker1_voice, speaker2_voice, input_file, progress
)
end_time = time.time()
print(f"Total generation time: {end_time - start_time:.2f} seconds")
return result
except Exception as e:
error_msg = str(e)
print(f"Processing error: {error_msg}")
raise Exception(f"Generation failed: {error_msg}")
def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2):
"""Gradio interface function with proper error handling"""
try:
# Validate inputs
if not input_text and input_file is None:
return None
if input_text and len(input_text.strip()) == 0:
input_text = None
# Create a simple progress tracker
progress_history = []
def progress_callback(value, text):
progress_history.append(f"{value:.1%}: {text}")
print(f"Progress: {value:.1%} - {text}")
# Run the async function
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
process_input(input_text, input_file, language, speaker1, speaker2, progress_callback)
)
return result
finally:
loop.close()
except Exception as e:
print(f"Gradio function error: {e}")
raise gr.Error(f"Failed to generate podcast: {str(e)}")
def create_interface():
"""Create the Gradio interface with proper component configuration"""
language_options = [
"Auto Detect", "English", "German", "French", "Spanish", "Italian",
"Portuguese", "Dutch", "Russian", "Chinese", "Japanese", "Korean"
]
voice_options = list(VOICE_MAPPING.keys())
with gr.Blocks(
title="PodcastGen 2🎙️",
theme=gr.themes.Soft(),
css=".gradio-container {max-width: 1200px; margin: auto;}"
) as demo:
gr.Markdown("# 🎙️ PodcastGen 2")
gr.Markdown("Generate professional 2-speaker podcasts from text input!")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Input Text",
lines=8,
placeholder="Enter your topic or text for podcast generation...",
info="Describe what you want the podcast to discuss"
)
with gr.Column(scale=1):
input_file = gr.File(
label="Upload File (Optional)",
file_types=[".pdf", ".txt"],
type="filepath",
info=f"Max size: {MAX_FILE_SIZE_MB}MB"
)
with gr.Row():
language = gr.Dropdown(
label="Language",
choices=language_options,
value="Auto Detect",
info="Select output language"
)
speaker1 = gr.Dropdown(
label="Speaker 1 Voice",
choices=voice_options,
value="Andrew - English (United States)"
)
speaker2 = gr.Dropdown(
label="Speaker 2 Voice",
choices=voice_options,
value="Ava - English (United States)"
)
generate_btn = gr.Button(
"🎙️ Generate Podcast",
variant="primary",
size="lg"
)
output_audio = gr.Audio(
label="Generated Podcast",
type="filepath",
format="wav",
show_download_button=True
)
# Connect the interface
generate_btn.click(
fn=generate_podcast_gradio,
inputs=[input_text, input_file, language, speaker1, speaker2],
outputs=[output_audio],
show_progress=True
)
# Add usage instructions
with gr.Accordion("Usage Instructions", open=False):
gr.Markdown("""
### How to use:
1. **Input**: Enter your topic or text in the text box, or upload a PDF/TXT file
2. **Language**: Choose the output language (Auto Detect recommended)
3. **Voices**: Select different voices for Speaker 1 and Speaker 2
4. **Generate**: Click the button and wait for processing
### Tips:
- Provide clear, specific topics for better results
- The AI will create a natural conversation between two speakers
- Generation may take 1-3 minutes depending on text length
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)