Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import numpy as np | |
import gradio as gr | |
import librosa | |
import soundfile as sf | |
import torch | |
import traceback | |
import threading | |
from spaces import GPU | |
from datetime import datetime | |
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference | |
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor | |
from vibevoice.modular.streamer import AudioStreamer | |
from transformers.utils import logging | |
from transformers import set_seed | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
class VibeVoiceDemo: | |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): | |
self.model_path = model_path | |
self.device = device | |
self.inference_steps = inference_steps | |
self.is_generating = False | |
self.processor = None | |
self.model = None | |
self.available_voices = {} | |
self.load_model() | |
self.setup_voice_presets() | |
self.load_example_scripts() | |
def load_model(self): | |
print(f"Loading processor & model from {self.model_path}") | |
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) | |
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
self.model_path, | |
torch_dtype=torch.bfloat16 | |
) | |
# self.model.eval() | |
# self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
def setup_voice_presets(self): | |
voices_dir = os.path.join(os.path.dirname(__file__), "voices") | |
if not os.path.exists(voices_dir): | |
print(f"Warning: Voices directory not found at {voices_dir}") | |
return | |
wav_files = [f for f in os.listdir(voices_dir) | |
if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac'))] | |
for wav_file in wav_files: | |
name = os.path.splitext(wav_file)[0] | |
self.available_voices[name] = os.path.join(voices_dir, wav_file) | |
print(f"Voices loaded: {list(self.available_voices.keys())}") | |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: | |
try: | |
wav, sr = sf.read(audio_path) | |
if len(wav.shape) > 1: | |
wav = np.mean(wav, axis=1) | |
if sr != target_sr: | |
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) | |
return wav | |
except Exception as e: | |
print(f"Error reading audio {audio_path}: {e}") | |
return np.array([]) | |
def generate_podcast(self, | |
num_speakers: int, | |
script: str, | |
speaker_1: str = None, | |
speaker_2: str = None, | |
speaker_3: str = None, | |
speaker_4: str = None, | |
cfg_scale: float = 1.3): | |
""" | |
Generates a podcast as a single audio file from a script and saves it. | |
This is a non-streaming function. | |
""" | |
try: | |
self.model = self.model.to(self.device) | |
print(f"Model successfully moved to device: {self.device.upper()}") | |
# Step 3: Continue with the rest of your setup. | |
self.model.eval() | |
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
# 1. Set generating state and validate inputs | |
self.is_generating = True | |
if not script.strip(): | |
raise gr.Error("Error: Please provide a script.") | |
# Defend against common mistake with apostrophes | |
script = script.replace("β", "'") | |
if not 1 <= num_speakers <= 4: | |
raise gr.Error("Error: Number of speakers must be between 1 and 4.") | |
# 2. Collect and validate selected speakers | |
selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] | |
for i, speaker_name in enumerate(selected_speakers): | |
if not speaker_name or speaker_name not in self.available_voices: | |
raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.") | |
# 3. Build initial log | |
log = f"ποΈ Generating podcast with {num_speakers} speakers\n" | |
log += f"π Parameters: CFG Scale={cfg_scale}\n" | |
log += f"π Speakers: {', '.join(selected_speakers)}\n" | |
# 4. Load voice samples | |
voice_samples = [] | |
for speaker_name in selected_speakers: | |
audio_path = self.available_voices[speaker_name] | |
# Assuming self.read_audio is a method in your class that returns audio data | |
audio_data = self.read_audio(audio_path) | |
if len(audio_data) == 0: | |
raise gr.Error(f"Error: Failed to load audio for {speaker_name}") | |
voice_samples.append(audio_data) | |
log += f"β Loaded {len(voice_samples)} voice samples\n" | |
# 5. Parse and format the script | |
lines = script.strip().split('\n') | |
formatted_script_lines = [] | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
# Check if line already has speaker format (e.g., "Speaker 1: ...") | |
if line.startswith('Speaker ') and ':' in line: | |
formatted_script_lines.append(line) | |
else: | |
# Auto-assign speakers in rotation | |
speaker_id = len(formatted_script_lines) % num_speakers | |
formatted_script_lines.append(f"Speaker {speaker_id}: {line}") | |
formatted_script = '\n'.join(formatted_script_lines) | |
log += f"π Formatted script with {len(formatted_script_lines)} turns\n" | |
log += "π Processing with VibeVoice...\n" | |
# 6. Prepare inputs for the model | |
# Assuming self.processor is an object available in your class | |
inputs = self.processor( | |
text=[formatted_script], | |
voice_samples=[voice_samples], | |
padding=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
# 7. Generate audio | |
start_time = time.time() | |
# Assuming self.model is an object available in your class | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=None, | |
cfg_scale=cfg_scale, | |
tokenizer=self.processor.tokenizer, | |
generation_config={'do_sample': False}, | |
verbose=False, # Verbose is off for cleaner logs | |
) | |
generation_time = time.time() - start_time | |
# 8. Extract audio output | |
# The generated audio is often in speech_outputs or a similar attribute | |
if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None: | |
audio_tensor = outputs.speech_outputs[0] | |
audio = audio_tensor.cpu().float().numpy() | |
else: | |
raise gr.Error("β Error: No audio was generated by the model. Please try again.") | |
# Ensure audio is a 1D array | |
if audio.ndim > 1: | |
audio = audio.squeeze() | |
sample_rate = 24000 # Standard sample rate for this model | |
# 9. Save the audio file | |
output_dir = "outputs" | |
os.makedirs(output_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
file_path = os.path.join(output_dir, f"podcast_{timestamp}.wav") | |
# Write the NumPy array to a WAV file | |
sf.write(file_path, audio, sample_rate) | |
print(f"πΎ Podcast saved to {file_path}") | |
# 10. Finalize log and return | |
total_duration = len(audio) / sample_rate | |
log += f"β±οΈ Generation completed in {generation_time:.2f} seconds\n" | |
log += f"π΅ Final audio duration: {total_duration:.2f} seconds\n" | |
log += f"β Successfully saved podcast to: {file_path}\n" | |
self.is_generating = False | |
return (sample_rate, audio), log | |
except gr.Error as e: | |
# Handle Gradio-specific errors (for user feedback) | |
self.is_generating = False | |
error_msg = f"β Input Error: {str(e)}" | |
print(error_msg) | |
# In Gradio, you would typically return an update to the UI | |
# For a pure function, we re-raise or handle it as needed. | |
# This return signature matches the success case but with error info. | |
return None, error_msg | |
except Exception as e: | |
# Handle all other unexpected errors | |
self.is_generating = False | |
error_msg = f"β An unexpected error occurred: {str(e)}" | |
print(error_msg) | |
import traceback | |
traceback.print_exc() | |
return None, error_msg | |
def _infer_num_speakers_from_script(script: str) -> int: | |
""" | |
Infer number of speakers by counting distinct 'Speaker X:' tags in the script. | |
Robust to 0- or 1-indexed labels and repeated turns. | |
Falls back to 1 if none found. | |
""" | |
import re | |
ids = re.findall(r'(?mi)^\s*Speaker\s+(\d+)\s*:', script) | |
return len({int(x) for x in ids}) if ids else 1 | |
def load_example_scripts(self): | |
examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") | |
self.example_scripts = [] | |
if not os.path.exists(examples_dir): | |
return | |
txt_files = sorted( | |
[f for f in os.listdir(examples_dir) if f.lower().endswith('.txt')] | |
) | |
for txt_file in txt_files: | |
try: | |
with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f: | |
script_content = f.read().strip() | |
if script_content: | |
num_speakers = self._infer_num_speakers_from_script(script_content) | |
self.example_scripts.append([num_speakers, script_content]) | |
except Exception as e: | |
print(f"Error loading {txt_file}: {e}") | |
def convert_to_16_bit_wav(data): | |
if torch.is_tensor(data): | |
data = data.detach().cpu().numpy() | |
data = np.array(data) | |
if np.max(np.abs(data)) > 1.0: | |
data = data / np.max(np.abs(data)) | |
return (data * 32767).astype(np.int16) | |
def create_demo_interface(demo_instance: VibeVoiceDemo): | |
"""Create the Gradio interface (final audio only, no streaming).""" | |
# Custom CSS for high-end aesthetics | |
custom_css = """ ... """ # (keep your CSS unchanged) | |
with gr.Blocks( | |
title="VibeVoice - AI Podcast Generator", | |
css=custom_css, | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="purple", | |
neutral_hue="slate", | |
) | |
) as interface: | |
# Header | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>ποΈ Vibe Podcasting</h1> | |
<p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p> | |
</div> | |
""") | |
with gr.Row(): | |
# Left column - Settings | |
with gr.Column(scale=1, elem_classes="settings-card"): | |
gr.Markdown("### ποΈ **Podcast Settings**") | |
num_speakers = gr.Slider( | |
minimum=1, maximum=4, value=2, step=1, | |
label="Number of Speakers", | |
elem_classes="slider-container" | |
) | |
gr.Markdown("### π **Speaker Selection**") | |
available_speaker_names = list(demo_instance.available_voices.keys()) | |
default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman'] | |
speaker_selections = [] | |
for i in range(4): | |
default_value = default_speakers[i] if i < len(default_speakers) else None | |
speaker = gr.Dropdown( | |
choices=available_speaker_names, | |
value=default_value, | |
label=f"Speaker {i+1}", | |
visible=(i < 2), | |
elem_classes="speaker-item" | |
) | |
speaker_selections.append(speaker) | |
gr.Markdown("### βοΈ **Advanced Settings**") | |
with gr.Accordion("Generation Parameters", open=False): | |
cfg_scale = gr.Slider( | |
minimum=1.0, maximum=2.0, value=1.3, step=0.05, | |
label="CFG Scale (Guidance Strength)", | |
elem_classes="slider-container" | |
) | |
# Right column - Generation | |
with gr.Column(scale=2, elem_classes="generation-card"): | |
gr.Markdown("### π **Script Input**") | |
script_input = gr.Textbox( | |
label="Conversation Script", | |
placeholder="Enter your podcast script here...", | |
lines=12, | |
max_lines=20, | |
elem_classes="script-input" | |
) | |
with gr.Row(): | |
random_example_btn = gr.Button( | |
"π² Random Example", size="lg", | |
variant="secondary", elem_classes="random-btn", scale=1 | |
) | |
generate_btn = gr.Button( | |
"π Generate Podcast", size="lg", | |
variant="primary", elem_classes="generate-btn", scale=2 | |
) | |
# Output section | |
gr.Markdown("### π΅ **Generated Podcast**") | |
complete_audio_output = gr.Audio( | |
label="Complete Podcast (Download)", | |
type="numpy", | |
elem_classes="audio-output complete-audio-section", | |
autoplay=False, | |
show_download_button=True, | |
visible=True | |
) | |
log_output = gr.Textbox( | |
label="Generation Log", | |
lines=8, max_lines=15, | |
interactive=False, | |
elem_classes="log-output" | |
) | |
# === logic === | |
def update_speaker_visibility(num_speakers): | |
return [gr.update(visible=(i < num_speakers)) for i in range(4)] | |
num_speakers.change( | |
fn=update_speaker_visibility, | |
inputs=[num_speakers], | |
outputs=speaker_selections | |
) | |
def generate_podcast_wrapper(num_speakers, script, *speakers_and_params): | |
try: | |
speakers = speakers_and_params[:4] | |
cfg_scale = speakers_and_params[4] | |
audio, log = demo_instance.generate_podcast( | |
num_speakers=int(num_speakers), | |
script=script, | |
speaker_1=speakers[0], | |
speaker_2=speakers[1], | |
speaker_3=speakers[2], | |
speaker_4=speakers[3], | |
cfg_scale=cfg_scale | |
) | |
return audio, log | |
except Exception as e: | |
traceback.print_exc() | |
return None, f"β Error: {str(e)}" | |
generate_btn.click( | |
fn=generate_podcast_wrapper, | |
inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale], | |
outputs=[complete_audio_output, log_output], | |
queue=True | |
) | |
def load_random_example(): | |
import random | |
examples = getattr(demo_instance, "example_scripts", []) | |
if not examples: | |
examples = [ | |
[2, "Speaker 0: Welcome to our AI podcast demo!\nSpeaker 1: Thanks, excited to be here!"] | |
] | |
num_speakers_value, script_value = random.choice(examples) | |
return num_speakers_value, script_value | |
random_example_btn.click( | |
fn=load_random_example, | |
inputs=[], | |
outputs=[num_speakers, script_input], | |
queue=False | |
) | |
gr.Markdown("### π **Example Scripts**") | |
examples = getattr(demo_instance, "example_scripts", []) or [ | |
[1, "Speaker 1: Welcome to our AI podcast demo. This is a sample script."] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[num_speakers, script_input], | |
label="Try these example scripts:" | |
) | |
return interface | |
def run_demo( | |
model_path: str = "aoi-ot/VibeVoice-Large", | |
device: str = "cuda", | |
inference_steps: int = 5, | |
share: bool = True, | |
): | |
set_seed(42) | |
demo_instance = VibeVoiceDemo(model_path, device, inference_steps) | |
interface = create_demo_interface(demo_instance) | |
interface.queue().launch( | |
share=share, | |
server_name="0.0.0.0" if share else "127.0.0.1", | |
show_error=True, | |
show_api=False | |
) | |
if __name__ == "__main__": | |
run_demo() | |