DMOSpeech2-demo / app.py
yl4579's picture
Update app.py
4241600 verified
import gradio as gr
import torch
import torchaudio
import numpy as np
import tempfile
import time
from pathlib import Path
from huggingface_hub import hf_hub_download
import os
import spaces
from transformers import pipeline
# Import the inference module
from infer import DMOInference
# Global variables
model = None
asr_pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize ASR pipeline
def initialize_asr_pipeline(device=device, dtype=None):
"""Initialize the ASR pipeline on startup."""
global asr_pipe
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.is_available()
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
print("Initializing ASR pipeline...")
try:
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=dtype,
device="cpu" # Keep ASR on CPU to save GPU memory
)
print("ASR pipeline initialized successfully")
except Exception as e:
print(f"Error initializing ASR pipeline: {e}")
asr_pipe = None
# Transcribe function
def transcribe(ref_audio, language=None):
"""Transcribe audio using the pre-loaded ASR pipeline."""
global asr_pipe
if asr_pipe is None:
return "" # Return empty string if ASR is not available
try:
result = asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)
return result["text"].strip()
except Exception as e:
print(f"Transcription error: {e}")
return ""
def download_models():
"""Download models from HuggingFace Hub."""
try:
print("Downloading models from HuggingFace...")
# Download student model
student_path = hf_hub_download(
repo_id="yl4579/DMOSpeech2",
filename="model_85000.pt",
cache_dir="./models"
)
# Download duration predictor
duration_path = hf_hub_download(
repo_id="yl4579/DMOSpeech2",
filename="model_1500.pt",
cache_dir="./models"
)
print(f"Student model: {student_path}")
print(f"Duration model: {duration_path}")
return student_path, duration_path
except Exception as e:
print(f"Error downloading models: {e}")
return None, None
def initialize_model():
"""Initialize the model on startup."""
global model
try:
# Download models
student_path, duration_path = download_models()
if not student_path or not duration_path:
return False, "Failed to download models from HuggingFace"
# Initialize model
model = DMOInference(
student_checkpoint_path=student_path,
duration_predictor_path=duration_path,
device=device,
model_type="F5TTS_Base"
)
return True, f"Model loaded successfully on {device.upper()}"
except Exception as e:
return False, f"Error initializing model: {str(e)}"
# Initialize models on startup
print("Initializing models...")
model_loaded, status_message = initialize_model()
initialize_asr_pipeline() # Initialize ASR pipeline
@spaces.GPU(duration=120) # Request GPU for up to 120 seconds
def generate_speech(
prompt_audio,
prompt_text,
target_text,
mode,
temperature,
custom_teacher_steps,
custom_teacher_stopping_time,
custom_student_start_step,
verbose
):
"""Generate speech with different configurations."""
if not model_loaded or model is None:
return None, "Model not loaded! Please refresh the page.", "", ""
if prompt_audio is None:
return None, "Please upload a reference audio!", "", ""
if not target_text:
return None, "Please enter text to generate!", "", ""
try:
# Auto-transcribe if prompt_text is empty
if not prompt_text and prompt_text != "":
print("Auto-transcribing reference audio...")
prompt_text = transcribe(prompt_audio)
print(f"Transcribed: {prompt_text}")
start_time = time.time()
# Configure parameters based on mode
if mode == "Student Only (4 steps)":
teacher_steps = 0
student_start_step = 0
teacher_stopping_time = 1.0
elif mode == "Teacher-Guided (8 steps)":
# Default configuration from the notebook
teacher_steps = 16
teacher_stopping_time = 0.07
student_start_step = 1
elif mode == "High Diversity (16 steps)":
teacher_steps = 24
teacher_stopping_time = 0.3
student_start_step = 2
else: # Custom
teacher_steps = custom_teacher_steps
teacher_stopping_time = custom_teacher_stopping_time
student_start_step = custom_student_start_step
# Generate speech
generated_audio = model.generate(
gen_text=target_text,
audio_path=prompt_audio,
prompt_text=prompt_text if prompt_text else None,
teacher_steps=teacher_steps,
teacher_stopping_time=teacher_stopping_time,
student_start_step=student_start_step,
temperature=temperature,
verbose=verbose
)
end_time = time.time()
# Calculate metrics
processing_time = end_time - start_time
audio_duration = generated_audio.shape[-1] / 24000
rtf = processing_time / audio_duration
# Save audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name
if isinstance(generated_audio, np.ndarray):
generated_audio = torch.from_numpy(generated_audio)
if generated_audio.dim() == 1:
generated_audio = generated_audio.unsqueeze(0)
torchaudio.save(output_path, generated_audio, 24000)
# Format metrics
metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}"
except Exception as e:
return None, f"Error: {str(e)}", "", ""
# Create Gradio interface
with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# πŸŽ™οΈ DMOSpeech 2: Zero-Shot Text-to-Speech
Generate natural speech in any voice with just a short reference audio!
**NOTE: The entire space was generated by Claude for demo purposes and may not demostrate the model's real performance because it can contain glitches/bugs**
This space will retire when a better and more cleaned up space comes up later.
""")
with gr.Row():
with gr.Column(scale=1):
# Reference audio input
prompt_audio = gr.Audio(
label="πŸ“Ž Reference Audio",
type="filepath",
sources=["upload", "microphone"]
)
prompt_text = gr.Textbox(
label="πŸ“ Reference Text (leave empty for auto-transcription)",
placeholder="The text spoken in the reference audio...",
lines=2
)
target_text = gr.Textbox(
label="✍️ Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=4
)
# Generation mode
mode = gr.Radio(
choices=[
"Student Only (4 steps)",
"Teacher-Guided (8 steps)",
"High Diversity (16 steps)",
"Custom"
],
value="Teacher-Guided (8 steps)",
label="πŸš€ Generation Mode",
info="Choose speed vs quality/diversity tradeoff"
)
# Advanced settings (collapsible)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Duration Temperature",
info="0 = deterministic, >0 = more variation in speech rhythm"
)
with gr.Group(visible=False) as custom_settings:
gr.Markdown("### Custom Mode Settings")
custom_teacher_steps = gr.Slider(
minimum=0,
maximum=32,
value=16,
step=1,
label="Teacher Steps",
info="More steps = higher quality"
)
custom_teacher_stopping_time = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.07,
step=0.01,
label="Teacher Stopping Time",
info="When to switch to student"
)
custom_student_start_step = gr.Slider(
minimum=0,
maximum=4,
value=1,
step=1,
label="Student Start Step",
info="Which student step to start from"
)
verbose = gr.Checkbox(
value=False,
label="Verbose Output",
info="Show detailed generation steps"
)
generate_btn = gr.Button("🎡 Generate Speech", variant="primary", size="lg")
with gr.Column(scale=1):
# Output
output_audio = gr.Audio(
label="πŸ”Š Generated Speech",
type="filepath",
autoplay=True
)
status = gr.Textbox(
label="Status",
interactive=False
)
metrics = gr.Textbox(
label="Performance Metrics",
interactive=False
)
info = gr.Textbox(
label="Generation Info",
interactive=False
)
# Tips
gr.Markdown("""
### πŸ’‘ Quick Tips:
- **Auto-transcription**: Leave reference text empty to auto-transcribe
- **Student Only**: Fastest (4 steps), good quality
- **Teacher-Guided**: Best balance (8 steps), recommended
- **High Diversity**: More natural prosody (16 steps)
- **Custom Mode**: Fine-tune all parameters
### πŸ“Š Expected RTF (Real-Time Factor):
- Student Only: ~0.05x (20x faster than real-time)
- Teacher-Guided: ~0.10x (10x faster)
- High Diversity: ~0.20x (5x faster)
""")
# Event handler
generate_btn.click(
generate_speech,
inputs=[
prompt_audio,
prompt_text,
target_text,
mode,
temperature,
custom_teacher_steps,
custom_teacher_stopping_time,
custom_student_start_step,
verbose
],
outputs=[output_audio, status, metrics, info]
)
# Update visibility of custom settings based on mode
def update_custom_visibility(mode):
is_custom = (mode == "Custom")
return gr.update(visible=is_custom)
mode.change(
update_custom_visibility,
inputs=[mode],
outputs=[custom_settings]
)
# Launch the app
if __name__ == "__main__":
if not model_loaded:
print(f"Warning: Model failed to load - {status_message}")
if not asr_pipe:
print("Warning: ASR pipeline not available - auto-transcription disabled")
demo.launch()