import gradio as gr import torch import torchaudio import numpy as np import tempfile import time from pathlib import Path from huggingface_hub import hf_hub_download import os import spaces from transformers import pipeline # Import the inference module from infer import DMOInference # Global variables model = None asr_pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize ASR pipeline def initialize_asr_pipeline(device=device, dtype=None): """Initialize the ASR pipeline on startup.""" global asr_pipe if dtype is None: dtype = ( torch.float16 if "cuda" in device and torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 7 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) print("Initializing ASR pipeline...") try: asr_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=dtype, device="cpu" # Keep ASR on CPU to save GPU memory ) print("ASR pipeline initialized successfully") except Exception as e: print(f"Error initializing ASR pipeline: {e}") asr_pipe = None # Transcribe function def transcribe(ref_audio, language=None): """Transcribe audio using the pre-loaded ASR pipeline.""" global asr_pipe if asr_pipe is None: return "" # Return empty string if ASR is not available try: result = asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, return_timestamps=False, ) return result["text"].strip() except Exception as e: print(f"Transcription error: {e}") return "" def download_models(): """Download models from HuggingFace Hub.""" try: print("Downloading models from HuggingFace...") # Download student model student_path = hf_hub_download( repo_id="yl4579/DMOSpeech2", filename="model_85000.pt", cache_dir="./models" ) # Download duration predictor duration_path = hf_hub_download( repo_id="yl4579/DMOSpeech2", filename="model_1500.pt", cache_dir="./models" ) print(f"Student model: {student_path}") print(f"Duration model: {duration_path}") return student_path, duration_path except Exception as e: print(f"Error downloading models: {e}") return None, None def initialize_model(): """Initialize the model on startup.""" global model try: # Download models student_path, duration_path = download_models() if not student_path or not duration_path: return False, "Failed to download models from HuggingFace" # Initialize model model = DMOInference( student_checkpoint_path=student_path, duration_predictor_path=duration_path, device=device, model_type="F5TTS_Base" ) return True, f"Model loaded successfully on {device.upper()}" except Exception as e: return False, f"Error initializing model: {str(e)}" # Initialize models on startup print("Initializing models...") model_loaded, status_message = initialize_model() initialize_asr_pipeline() # Initialize ASR pipeline @spaces.GPU(duration=120) # Request GPU for up to 120 seconds def generate_speech( prompt_audio, prompt_text, target_text, mode, temperature, custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step, verbose ): """Generate speech with different configurations.""" if not model_loaded or model is None: return None, "Model not loaded! Please refresh the page.", "", "" if prompt_audio is None: return None, "Please upload a reference audio!", "", "" if not target_text: return None, "Please enter text to generate!", "", "" try: # Auto-transcribe if prompt_text is empty if not prompt_text and prompt_text != "": print("Auto-transcribing reference audio...") prompt_text = transcribe(prompt_audio) print(f"Transcribed: {prompt_text}") start_time = time.time() # Configure parameters based on mode if mode == "Student Only (4 steps)": teacher_steps = 0 student_start_step = 0 teacher_stopping_time = 1.0 elif mode == "Teacher-Guided (8 steps)": # Default configuration from the notebook teacher_steps = 16 teacher_stopping_time = 0.07 student_start_step = 1 elif mode == "High Diversity (16 steps)": teacher_steps = 24 teacher_stopping_time = 0.3 student_start_step = 2 else: # Custom teacher_steps = custom_teacher_steps teacher_stopping_time = custom_teacher_stopping_time student_start_step = custom_student_start_step # Generate speech generated_audio = model.generate( gen_text=target_text, audio_path=prompt_audio, prompt_text=prompt_text if prompt_text else None, teacher_steps=teacher_steps, teacher_stopping_time=teacher_stopping_time, student_start_step=student_start_step, temperature=temperature, verbose=verbose ) end_time = time.time() # Calculate metrics processing_time = end_time - start_time audio_duration = generated_audio.shape[-1] / 24000 rtf = processing_time / audio_duration # Save audio with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: output_path = tmp_file.name if isinstance(generated_audio, np.ndarray): generated_audio = torch.from_numpy(generated_audio) if generated_audio.dim() == 1: generated_audio = generated_audio.unsqueeze(0) torchaudio.save(output_path, generated_audio, 24000) # Format metrics metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio" return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}" except Exception as e: return None, f"Error: {str(e)}", "", "" # Create Gradio interface with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo: gr.Markdown(f""" # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech Generate natural speech in any voice with just a short reference audio! **NOTE: The entire space was generated by Claude for demo purposes and may not demostrate the model's real performance because it can contain glitches/bugs** This space will retire when a better and more cleaned up space comes up later. """) with gr.Row(): with gr.Column(scale=1): # Reference audio input prompt_audio = gr.Audio( label="📎 Reference Audio", type="filepath", sources=["upload", "microphone"] ) prompt_text = gr.Textbox( label="📝 Reference Text (leave empty for auto-transcription)", placeholder="The text spoken in the reference audio...", lines=2 ) target_text = gr.Textbox( label="✍️ Text to Generate", placeholder="Enter the text you want to synthesize...", lines=4 ) # Generation mode mode = gr.Radio( choices=[ "Student Only (4 steps)", "Teacher-Guided (8 steps)", "High Diversity (16 steps)", "Custom" ], value="Teacher-Guided (8 steps)", label="🚀 Generation Mode", info="Choose speed vs quality/diversity tradeoff" ) # Advanced settings (collapsible) with gr.Accordion("⚙️ Advanced Settings", open=False): temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Duration Temperature", info="0 = deterministic, >0 = more variation in speech rhythm" ) with gr.Group(visible=False) as custom_settings: gr.Markdown("### Custom Mode Settings") custom_teacher_steps = gr.Slider( minimum=0, maximum=32, value=16, step=1, label="Teacher Steps", info="More steps = higher quality" ) custom_teacher_stopping_time = gr.Slider( minimum=0.0, maximum=1.0, value=0.07, step=0.01, label="Teacher Stopping Time", info="When to switch to student" ) custom_student_start_step = gr.Slider( minimum=0, maximum=4, value=1, step=1, label="Student Start Step", info="Which student step to start from" ) verbose = gr.Checkbox( value=False, label="Verbose Output", info="Show detailed generation steps" ) generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") with gr.Column(scale=1): # Output output_audio = gr.Audio( label="🔊 Generated Speech", type="filepath", autoplay=True ) status = gr.Textbox( label="Status", interactive=False ) metrics = gr.Textbox( label="Performance Metrics", interactive=False ) info = gr.Textbox( label="Generation Info", interactive=False ) # Tips gr.Markdown(""" ### 💡 Quick Tips: - **Auto-transcription**: Leave reference text empty to auto-transcribe - **Student Only**: Fastest (4 steps), good quality - **Teacher-Guided**: Best balance (8 steps), recommended - **High Diversity**: More natural prosody (16 steps) - **Custom Mode**: Fine-tune all parameters ### 📊 Expected RTF (Real-Time Factor): - Student Only: ~0.05x (20x faster than real-time) - Teacher-Guided: ~0.10x (10x faster) - High Diversity: ~0.20x (5x faster) """) # Event handler generate_btn.click( generate_speech, inputs=[ prompt_audio, prompt_text, target_text, mode, temperature, custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step, verbose ], outputs=[output_audio, status, metrics, info] ) # Update visibility of custom settings based on mode def update_custom_visibility(mode): is_custom = (mode == "Custom") return gr.update(visible=is_custom) mode.change( update_custom_visibility, inputs=[mode], outputs=[custom_settings] ) # Launch the app if __name__ == "__main__": if not model_loaded: print(f"Warning: Model failed to load - {status_message}") if not asr_pipe: print("Warning: ASR pipeline not available - auto-transcription disabled") demo.launch()