jordand's picture
Update app.py
46f334b verified
# The code in this file is almost entirely written by LLMs and is much, much, much messier than it needs to be (at this point it's not clear to what extent it is even human-modifiable). We'd hope to improve this for any future local gradio release(s).
import tempfile
import os
import json
import time
import secrets
import logging
from pathlib import Path
from typing import Tuple, Any
from functools import partial
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
import warnings
# Suppress torchaudio TorchCodec parameter warnings
warnings.filterwarnings('ignore', message='.*encoding.*parameter is not fully supported by TorchCodec')
warnings.filterwarnings('ignore', message='.*bits_per_sample.*parameter is not directly supported by TorchCodec')
warnings.filterwarnings('ignore', message='.* is not used by TorchCodec AudioEncoder. Format is determined by the file extension.')
import gradio as gr
import numpy as np
import torch
import torchaudio
from huggingface_hub import snapshot_download
import spaces
from inference import (
load_model_from_hf,
load_fish_ae_from_hf,
load_pca_state_from_hf,
load_audio,
ae_reconstruct,
sample_pipeline
)
from samplers import sample_euler_cfg_any, GuidanceMode
import tarfile
# --------------------------------------------------------------------
### Configuration
MODEL_DTYPE = torch.bfloat16
FISH_AE_DTYPE = torch.float32
# FISH_AE_DTYPE = torch.bfloat16 # MAYBE SLIGHTLY WORSE QUALITY, IF YOU HAVE ROOM, MAYBE USE FLOAT32
USE_16_BIT_WAV = True # Save WAV files as 16-bit PCM instead of 32-bit float
# Audio Prompt Library for Custom Audio Panel (included in repo)
AUDIO_PROMPT_FOLDER = Path("./prompt_audio")
# If not on Zero GPU, compile fish_ae encoder/decoder on initialization
COMPILE_FISH_IF_NOT_ON_ZERO_GPU = True
# Silentcipher watermarking configuration
USE_SILENTCIPHER = True # Enable/disable audio watermarking
SILENTCIPHER_MESSAGE = [91, 57, 81, 60, 83] # Watermark message (list of integers)
SILENTCIPHER_SDR = 47 # Message SDR in dB (higher = less perceptible but less robust)
# Get HF token from environment for private model access
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# --------------------------------------------------------------------
# Check if running on Zero GPU (compile incompatible with Zero GPU)
IS_ZEROGPU = os.environ.get("SPACES_ZERO_GPU") is not None
# print("FISH_AE_DTYPE:", FISH_AE_DTYPE)
# print("IS_ZEROGPU:", IS_ZEROGPU)
# if IS_ZEROGPU:
# print("Running on Zero GPU - model compilation disabled")
# else:
# print("Not on Zero GPU - model compilation available")
def _safe_members(tf, prefix):
if not prefix.endswith('/'):
prefix += '/'
for m in tf.getmembers():
if not m.name.startswith(prefix):
continue
p = Path(m.name)
if any(part == '..' for part in p.parts) or p.is_absolute():
continue
yield m
def ensure_tar_tree(repo_id: str, root: str, *, token: str | None = None, max_workers: int = 4):
os.environ.setdefault('HF_HUB_ENABLE_HF_TRANSFER', '1')
from huggingface_hub import snapshot_download
base = Path(snapshot_download(repo_id=repo_id, repo_type='dataset',
allow_patterns=[f'{root}.tar', 'index.jsonl', 'README.md', 'LICENSE'],
resume_download=True, token=token, max_workers=max_workers))
root_dir = base / root
if root_dir.exists():
return root_dir
tar_path = base / f'{root}.tar'
if not tar_path.exists():
raise FileNotFoundError(f'Expected {tar_path} in snapshot')
with tarfile.open(tar_path, 'r') as tf:
tf.extractall(base, members=_safe_members(tf, root))
return root_dir
EARS_PATH = ensure_tar_tree(repo_id="jordand/echo-embeddings-ears-tar", root="EARS", token=HF_TOKEN)
VCTK_PATH = ensure_tar_tree(repo_id="jordand/echo-embeddings-vctk-tar", root="VCTK", token=HF_TOKEN)
EXPRESSO_PATH = ensure_tar_tree(repo_id="jordand/echo-embeddings-expresso-tar", root="Expresso", token=HF_TOKEN)
from huggingface_hub import snapshot_download
HF_CUSTOM_PATH = Path(snapshot_download(
repo_id="jordand/echo-embeddings-custom",
repo_type="dataset",
allow_patterns=[
"HF-Custom/**/speaker_latent.safetensors",
"HF-Custom/**/metadata.json",
"HF-Custom/**/audio.mp3",
],
token=HF_TOKEN,
) + "/HF-Custom")
TEMP_AUDIO_DIR = Path('./temp_gradio_audio')
TEMP_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
# Helper functions for unique filenames and cleanup
def make_stem(prefix: str, user_id: str | None = None) -> str:
"""Create unique filename stem: prefix__user__timestamp_random or prefix__timestamp_random if no user_id."""
ts = int(time.time() * 1000)
rand = secrets.token_hex(4)
if user_id:
return f"{prefix}__{user_id}__{ts}_{rand}"
return f"{prefix}__{ts}_{rand}"
def cleanup_temp_audio(dir_: Path, user_id: str | None, max_age_sec: int = 60 * 5):
"""Remove old files globally and all previous files for this user."""
now = time.time()
# 1) Global TTL: remove any file older than max_age_sec
for p in dir_.glob("*"):
try:
if p.is_file() and (now - p.stat().st_mtime) > max_age_sec:
p.unlink(missing_ok=True)
except Exception:
pass
# 2) Per-user: remove ALL previous files for this user (we don't need to keep any)
if user_id:
for p in dir_.glob(f"*__{user_id}__*"):
try:
if p.is_file():
p.unlink(missing_ok=True)
except Exception:
pass
TEXT_PRESETS_PATH = Path('./text_presets.txt')
SAMPLER_PRESETS_PATH = Path('./sampler_presets.json')
# Global model variables (loaded lazily for Zero GPU)
model = None
model_compiled = None # Separate compiled model for toggling
fish_ae = None
pca_state = None
silentcipher_model = None # Silentcipher watermarking model
_model_compiled = False
def load_models():
"""Lazy load models on first use (required for Zero GPU)."""
global model, model_compiled, fish_ae, pca_state, silentcipher_model
if model is None:
# print("Loading models from HuggingFace...")
model = load_model_from_hf(dtype=MODEL_DTYPE, compile=False, token=HF_TOKEN)
fish_ae = load_fish_ae_from_hf(compile=(COMPILE_FISH_IF_NOT_ON_ZERO_GPU and not IS_ZEROGPU), dtype=FISH_AE_DTYPE, token=HF_TOKEN)
pca_state = load_pca_state_from_hf(token=HF_TOKEN)
# Load silentcipher model if enabled
if USE_SILENTCIPHER:
try:
import silentcipher
# print("Loading silentcipher watermarking model...")
silentcipher_model = silentcipher.get_model(model_type='44.1k', device='cuda')
# print("Silentcipher model loaded successfully!")
except Exception as e:
print(f"Warning: Failed to load silentcipher model: {e}")
print("Continuing without watermarking...")
# print("Models loaded successfully!")
# if not IS_ZEROGPU:
# print("Note: model_compiled will be created when you check 'Compile Model'")
def compile_model(should_compile):
"""Compile the model for faster inference."""
global model, model_compiled, _model_compiled
# If on Zero GPU, compilation is not supported
if IS_ZEROGPU:
return gr.update(value=False, interactive=False), gr.update(value="⚠️ Compile disabled on Zero GPU", visible=True)
if not should_compile:
# User unchecked - clear status and allow toggling
return gr.update(value=False, interactive=True), gr.update(value="", visible=False)
if _model_compiled:
# Already compiled - just show status
return gr.update(value=True, interactive=True), gr.update(value="✓ Model already compiled", visible=True)
# Need to compile - disable checkbox temporarily and show status
return gr.update(value=True, interactive=False), gr.update(value="⏳ Compiling... (1-3 minutes)", visible=True)
def do_compile():
"""Actually perform the compilation by creating a separate compiled model."""
global model, model_compiled, _model_compiled
# Skip if on Zero GPU
if IS_ZEROGPU:
return gr.update(value="⚠️ Compile disabled on Zero GPU", visible=True), gr.update(interactive=False)
if _model_compiled:
return gr.update(value="", visible=False), gr.update(interactive=True)
try:
# Load models first if not already loaded (needed for compilation)
# Since Zero GPU can't compile, we can safely load eagerly here
load_models()
# print("Compiling model... This will take 1-3 minutes on first run.")
# print("Creating a separate compiled model for toggling...")
# Create a compiled version of the model
model_compiled = torch.compile(model)
model_compiled.get_kv_cache = torch.compile(model.get_kv_cache)
model_compiled.get_kv_cache_from_precomputed_speaker_state = torch.compile(model.get_kv_cache_from_precomputed_speaker_state)
_model_compiled = True
# print("Compilation complete! You can now toggle between compiled/uncompiled.")
return gr.update(value="", visible=False), gr.update(interactive=True)
except Exception as e:
print(f"Compilation failed: {str(e)}")
return gr.update(value=f"✗ Compilation failed: {str(e)}", visible=True), gr.update(interactive=True)
def save_audio_with_format(audio_tensor: torch.Tensor, base_path: Path, filename: str, sample_rate: int, audio_format: str) -> Path:
"""Save audio in specified format, fallback to WAV if MP3 encoding fails."""
if audio_format == "mp3":
try:
output_path = base_path / f"{filename}.mp3"
# Try to save as MP3
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
format="mp3",
encoding="mp3",
bits_per_sample=None
)
# print(f"Successfully saved as MP3: {output_path}")
return output_path
except Exception as e:
print(f"MP3 encoding failed: {e}, falling back to WAV")
# Fallback to WAV
output_path = base_path / f"{filename}.wav"
if USE_16_BIT_WAV:
torchaudio.save(str(output_path), audio_tensor, sample_rate, encoding="PCM_S", bits_per_sample=16)
else:
torchaudio.save(str(output_path), audio_tensor, sample_rate)
return output_path
else:
# Save as WAV
output_path = base_path / f"{filename}.wav"
if USE_16_BIT_WAV:
torchaudio.save(str(output_path), audio_tensor, sample_rate, encoding="PCM_S", bits_per_sample=16)
else:
torchaudio.save(str(output_path), audio_tensor, sample_rate)
return output_path
@spaces.GPU
def generate_audio(
text_prompt: str,
speaker_st_path: str,
speaker_audio_path: str,
# Sampling parameters
num_steps: int,
rng_seed: int,
cfg_mode: str,
cfg_scale_text: float,
cfg_scale_speaker: float,
cfg_min_t: float,
cfg_max_t: float,
truncation_factor: float,
rescale_k: float,
rescale_sigma: float,
speaker_k_enable: bool,
speaker_k_scale: float,
speaker_k_min_t: float,
speaker_k_max_layers: int,
apg_eta_text: float,
apg_eta_speaker: float,
apg_momentum_text: float,
apg_momentum_speaker: float,
apg_norm_text: str,
apg_norm_speaker: str,
reconstruct_first_30_seconds: bool,
use_custom_shapes: bool,
max_text_byte_length: str,
max_speaker_latent_length: str,
sample_latent_len: str,
audio_format: str,
use_compile: bool,
show_original_audio: bool,
session_id: str,
) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any]:
"""Generate audio using the model from the notebook."""
# Load models on first use (required for Zero GPU)
load_models()
# Choose which model to use based on compile setting
global model, model_compiled
active_model = model_compiled if (use_compile and model_compiled is not None) else model
if use_compile and model_compiled is None:
print("Warning: Compile requested but model not yet compiled. Using uncompiled model.")
# Cleanup old temp files globally and remove ALL previous files for this user
cleanup_temp_audio(TEMP_AUDIO_DIR, session_id)
# Check if speaker is provided (now optional for zero conditioning)
use_zero_speaker = not speaker_audio_path or speaker_audio_path == ""
if use_zero_speaker:
speaker_audio_path = None
start_time = time.time()
# Parse parameters (most are already numeric from gr.Number)
num_steps_int = min(max(int(num_steps), 1), 80) # Clamp to [1, 80]
rng_seed_int = int(rng_seed) if rng_seed is not None else 0
cfg_scale_text_val = float(cfg_scale_text)
cfg_min_t_val = float(cfg_min_t)
cfg_max_t_val = float(cfg_max_t)
truncation_factor_val = float(truncation_factor)
rescale_k_val = float(rescale_k) if rescale_k != 1.0 else None # 1.0 means "off"
rescale_sigma_val = float(rescale_sigma)
# Determine guidance mode
if cfg_mode == "independent":
guidance_mode = GuidanceMode.INDEPENDENT
cfg_scale_speaker_val = float(cfg_scale_speaker) if cfg_scale_speaker is not None else None
apg_eta_text_val = None
apg_eta_speaker_val = None
apg_momentum_text_val = None
apg_momentum_speaker_val = None
apg_norm_text_val = None
apg_norm_speaker_val = None
elif cfg_mode == "alternating":
guidance_mode = GuidanceMode.ALTERNATING
cfg_scale_speaker_val = float(cfg_scale_speaker) if cfg_scale_speaker is not None else None
apg_eta_text_val = None
apg_eta_speaker_val = None
apg_momentum_text_val = None
apg_momentum_speaker_val = None
apg_norm_text_val = None
apg_norm_speaker_val = None
elif cfg_mode == "apg-independent":
guidance_mode = GuidanceMode.APG
cfg_scale_speaker_val = float(cfg_scale_speaker) if cfg_scale_speaker is not None else None
apg_eta_text_val = float(apg_eta_text) if apg_eta_text is not None else None
apg_eta_speaker_val = float(apg_eta_speaker) if apg_eta_speaker is not None else None
apg_momentum_text_val = float(apg_momentum_text) if apg_momentum_text is not None else None
apg_momentum_speaker_val = float(apg_momentum_speaker) if apg_momentum_speaker is not None else None
apg_norm_text_val = float(apg_norm_text) if apg_norm_text.strip() else None
apg_norm_speaker_val = float(apg_norm_speaker) if apg_norm_speaker.strip() else None
else: # "joint-unconditional"
guidance_mode = GuidanceMode.JOINT
# For unconditional, speaker scale must be None
cfg_scale_speaker_val = None
apg_eta_text_val = None
apg_eta_speaker_val = None
apg_momentum_text_val = None
apg_momentum_speaker_val = None
apg_norm_text_val = None
apg_norm_speaker_val = None
# Parse speaker K scale parameters (available for all modes)
if speaker_k_enable:
speaker_k_scale_val = float(speaker_k_scale) if speaker_k_scale is not None else None
speaker_k_min_t_val = float(speaker_k_min_t) if speaker_k_min_t is not None else None
speaker_k_max_layers_val = int(speaker_k_max_layers) if speaker_k_max_layers is not None else None
else:
speaker_k_scale_val = None
speaker_k_min_t_val = None
speaker_k_max_layers_val = None
# Parse custom shapes if enabled
if use_custom_shapes:
# Allow blank/empty values for first two fields (will use None)
pad_to_max_text_seq_len = int(max_text_byte_length) if max_text_byte_length.strip() else None
pad_to_max_speaker_latent_len = int(max_speaker_latent_length) if max_speaker_latent_length.strip() else None
sample_latent_len_val = int(sample_latent_len) if sample_latent_len.strip() else 640
else:
pad_to_max_text_seq_len = 768
pad_to_max_speaker_latent_len = 2560
sample_latent_len_val = 640
# Create sample function with parameters
sample_fn = partial(
sample_euler_cfg_any,
num_steps=num_steps_int,
guidance_mode=guidance_mode,
cfg_scale_text=cfg_scale_text_val,
cfg_scale_speaker=cfg_scale_speaker_val,
cfg_min_t=cfg_min_t_val,
cfg_max_t=cfg_max_t_val,
truncation_factor=truncation_factor_val,
rescale_k=rescale_k_val,
rescale_sigma=rescale_sigma_val,
speaker_k_scale=speaker_k_scale_val,
speaker_k_min_t=speaker_k_min_t_val,
speaker_k_max_layers=speaker_k_max_layers_val,
apg_eta_text=apg_eta_text_val,
apg_eta_speaker=apg_eta_speaker_val,
apg_momentum_text=apg_momentum_text_val,
apg_momentum_speaker=apg_momentum_speaker_val,
apg_norm_text=apg_norm_text_val,
apg_norm_speaker=apg_norm_speaker_val,
block_size=sample_latent_len_val
)
# Load speaker audio if provided
if speaker_audio_path is not None:
speaker_audio = load_audio(speaker_audio_path).cuda()
else:
speaker_audio = None
# Generate audio using raw audio (with selected model - compiled or not)
audio_out = sample_pipeline(
model=active_model,
fish_ae=fish_ae,
pca_state=pca_state,
sample_fn=sample_fn,
text_prompt=text_prompt,
speaker_audio=speaker_audio,
rng_seed=rng_seed_int,
pad_to_max_text_seq_len=pad_to_max_text_seq_len,
pad_to_max_speaker_latent_len=pad_to_max_speaker_latent_len,
)
# Apply silentcipher watermarking if enabled
audio_to_save = audio_out[0].cpu()
if USE_SILENTCIPHER and silentcipher_model is not None:
try:
# print("Applying silentcipher watermark...")
audio_numpy = audio_to_save.squeeze(0).numpy()
encoded_audio, sdr = silentcipher_model.encode_wav(
audio_numpy,
44100,
SILENTCIPHER_MESSAGE,
message_sdr=SILENTCIPHER_SDR
)
audio_to_save = torch.tensor(encoded_audio).unsqueeze(0)
# print(f"Watermark applied successfully! SDR: {sdr:.2f} dB")
except Exception as e:
print(f"Warning: Watermarking failed: {e}")
print("Saving audio without watermark...")
# Save generated audio with format selection (unique filename per session)
stem = make_stem("generated", session_id)
output_path = save_audio_with_format(
audio_to_save,
TEMP_AUDIO_DIR,
stem,
44100,
audio_format
)
# Calculate generation time
generation_time = time.time() - start_time
time_str = f"⏱️ Total generation time: {generation_time:.2f}s"
# Format text prompt for display
text_display = f"**Text Prompt:**\n\n{text_prompt}"
# Prepare reconstruction and original audio based on checkboxes
recon_output_path = None
original_output_path = None
# Optionally reconstruct first 30 seconds for reference
if reconstruct_first_30_seconds and speaker_audio_path:
audio_recon = ae_reconstruct(
fish_ae=fish_ae,
pca_state=pca_state,
audio=torch.nn.functional.pad(
speaker_audio[..., :2048 * 640],
(0, max(0, 2048 * 640 - speaker_audio.shape[-1]))
)[None],
)[..., :speaker_audio.shape[-1]]
# Save reconstruction with same format (unique filename per session)
recon_stem = make_stem("speaker_recon", session_id)
recon_output_path = save_audio_with_format(
audio_recon.cpu()[0],
TEMP_AUDIO_DIR,
recon_stem,
44100,
audio_format
)
# Optionally show original audio (2-minute cropped mono)
if show_original_audio and speaker_audio_path:
# Save original audio with same format (unique filename per session)
original_stem = make_stem("original_audio", session_id)
original_output_path = save_audio_with_format(
speaker_audio.cpu(),
TEMP_AUDIO_DIR,
original_stem,
44100,
audio_format
)
# Return results with visibility control for accordions
show_reference_section = (show_original_audio or reconstruct_first_30_seconds) and speaker_audio_path is not None
return (
gr.update(),
gr.update(value=str(output_path), visible=True),
gr.update(value=text_display, visible=True),
gr.update(value=str(original_output_path) if original_output_path else None, visible=True),
gr.update(value=time_str, visible=True),
gr.update(value=str(recon_output_path) if recon_output_path else None, visible=True),
gr.update(visible=(show_original_audio and speaker_audio_path is not None)), # original_accordion visibility
gr.update(visible=(reconstruct_first_30_seconds and speaker_audio_path is not None)), # reference_accordion visibility
gr.update(visible=show_reference_section) # reference_audio_header visibility
)
# UI Helper Functions
def load_speaker_metadata(speaker_id):
"""Load metadata for a speaker from any of their voice folders."""
if not EARS_PATH.exists():
return None
# Find any subfolder for this speaker and load its metadata
for subdir in EARS_PATH.iterdir():
if subdir.is_dir() and subdir.name.startswith(f"{speaker_id}_"):
metadata_path = subdir / "metadata.json"
if metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
return data.get("speaker_metadata", {})
except Exception:
continue
return None
def get_speakers():
"""Get list of unique speakers with their metadata."""
if not EARS_PATH.exists():
return []
speakers_dict = {}
for subdir in sorted(EARS_PATH.iterdir()):
if subdir.is_dir():
# Extract speaker ID (pXXX)
name = subdir.name
if name.startswith('p') and '_' in name:
speaker_id = name.split('_')[0]
if speaker_id not in speakers_dict:
speakers_dict[speaker_id] = None
# Load metadata for each speaker
speakers_with_metadata = []
for speaker_id in sorted(speakers_dict.keys()):
metadata = load_speaker_metadata(speaker_id)
if metadata:
speakers_with_metadata.append({
'id': speaker_id,
'gender': metadata.get('gender', 'unknown'),
'age': metadata.get('age', 'unknown'),
'ethnicity': metadata.get('ethnicity', 'unknown'),
'native_language': metadata.get('native language', 'unknown'),
})
else:
speakers_with_metadata.append({
'id': speaker_id,
'gender': 'unknown',
'age': 'unknown',
'ethnicity': 'unknown',
'native_language': 'unknown',
})
return speakers_with_metadata
def get_speakers_table(search_query=""):
"""Get speakers as table data for Gradio, optionally filtered by search query."""
speakers = get_speakers()
result = []
for s in speakers:
# Abbreviate gender
gender = s['gender']
if gender.lower() == 'male':
gender = 'M'
elif gender.lower() == 'female':
gender = 'F'
else:
gender = gender[0].upper() if gender else '?'
# Apply search filter if provided
if search_query:
search_lower = search_query.lower()
searchable_text = f"{s['id']} {gender} {s['age']} {s['ethnicity']} {s['native_language']}".lower()
if search_lower not in searchable_text:
continue
result.append([s['id'], gender, s['age'], s['ethnicity'], s['native_language']])
return result
def get_audio_length_from_metadata(voice_dir):
"""Get audio length from metadata.json file."""
metadata_path = voice_dir / "metadata.json"
if metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
length = data.get("audio_length_seconds", 0)
return f"{length:.1f}s"
except Exception:
return "N/A"
return "N/A"
def get_freeform_table(speaker_id):
"""Get freeform table for a speaker (single row if exists)."""
if not EARS_PATH.exists() or not speaker_id:
return []
freeform_dir = EARS_PATH / f"{speaker_id}_freeform"
if freeform_dir.exists():
audio_path = freeform_dir / "audio.mp3"
st_path = freeform_dir / "speaker_latent.safetensors"
if audio_path.exists() and st_path.exists():
audio_length = get_audio_length_from_metadata(freeform_dir)
return [["Freeform", audio_length]]
return []
def get_emotions_for_speaker(speaker_id):
"""Get list of emotions with audio lengths available for a given speaker (excluding _joint_)."""
if not EARS_PATH.exists() or not speaker_id:
return []
emotions = []
for subdir in sorted(EARS_PATH.iterdir()):
if subdir.is_dir():
name = subdir.name
# Match pattern: p{speaker_id}_emo_{emotion} (but not _emo_joint_)
if name.startswith(f"{speaker_id}_emo_") and "_joint_" not in name:
# Extract emotion part
parts = name.split('_emo_')
if len(parts) == 2:
emotion = parts[1]
# Verify files exist
audio_path = subdir / "audio.mp3"
st_path = subdir / "speaker_latent.safetensors"
if audio_path.exists() and st_path.exists():
audio_length = get_audio_length_from_metadata(subdir)
emotions.append((emotion, audio_length))
return emotions
def get_emotions_table(speaker_id):
"""Get emotions table for a speaker with audio lengths."""
if not speaker_id:
return []
emotions = get_emotions_for_speaker(speaker_id)
return [[emotion, length] for emotion, length in emotions]
# VCTK Helper Functions
def get_vctk_speakers():
"""Get list of VCTK speakers with their metadata."""
if not VCTK_PATH.exists():
return []
speakers_with_metadata = []
for subdir in sorted(VCTK_PATH.iterdir()):
if subdir.is_dir() and subdir.name.startswith('p'):
speaker_id = subdir.name
audio_path = subdir / "audio.mp3"
st_path = subdir / "speaker_latent.safetensors"
metadata_path = subdir / "metadata.json"
if audio_path.exists() and st_path.exists() and metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
speaker_info = data.get("speaker_info", {})
audio_length = data.get("total_audio_length_seconds", 0)
speakers_with_metadata.append({
'id': speaker_info.get('id', speaker_id),
'gender': speaker_info.get('gender', 'unknown'),
'age': speaker_info.get('age', 'unknown'),
'details': speaker_info.get('details', 'unknown'),
'audio_length': f"{audio_length:.1f}s"
})
except Exception:
continue
return speakers_with_metadata
def get_vctk_speakers_table(search_query=""):
"""Get VCTK speakers as table data for Gradio, optionally filtered by search query."""
speakers = get_vctk_speakers()
result = []
for s in speakers:
# Abbreviate gender
gender = s['gender']
if gender.lower() == 'male' or gender == 'M':
gender = 'M'
elif gender.lower() == 'female' or gender == 'F':
gender = 'F'
else:
gender = gender[0].upper() if gender else '?'
# Apply search filter if provided
if search_query:
search_lower = search_query.lower()
searchable_text = f"{s['id']} {gender} {s['age']} {s['details']} {s['audio_length']}".lower()
if search_lower not in searchable_text:
continue
result.append([s['id'], gender, s['age'], s['details'], s['audio_length']])
return result
def load_text_presets():
"""Load text presets from file with category and word count."""
if TEXT_PRESETS_PATH.exists():
with open(TEXT_PRESETS_PATH, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]
result = []
for line in lines:
# Split on first " | " to separate category from text
if " | " in line:
parts = line.split(" | ", 1)
category = parts[0]
text = parts[1]
else:
# Fallback if no category
category = "Uncategorized"
text = line
# Calculate word count
word_count = len(text.split())
result.append([category, str(word_count), text])
return result
return []
def search_speakers(search_query):
"""Filter speakers table based on search query."""
filtered_data = get_speakers_table(search_query)
return gr.update(value=filtered_data)
def select_speaker_from_table(evt: gr.SelectData, table_data):
"""Handle speaker selection - populate freeform and emotions tables."""
if evt.value and table_data is not None:
# evt.index is a tuple/list (row, col), we need the row to get the speaker ID
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
row_index = evt.index[0]
else:
row_index = evt.index
# Use the actual displayed (filtered) table data (pandas DataFrame)
if isinstance(row_index, int) and row_index < len(table_data):
speaker_row = table_data.iloc[row_index]
speaker_id = speaker_row.iloc[0] # First column is the ID
# Format selection display - clean and simple
gender_full = "Male" if speaker_row.iloc[1] == "M" else "Female" if speaker_row.iloc[1] == "F" else speaker_row.iloc[1]
selection_text = f"Selected Speaker: {speaker_id}\n{gender_full}{speaker_row.iloc[2]}{speaker_row.iloc[3]}"
# Get freeform and emotions data
freeform_data = get_freeform_table(speaker_id)
emotions_data = get_emotions_table(speaker_id)
return (
gr.update(value=selection_text, visible=True), # Show speaker selection
gr.update(value=freeform_data, visible=True), # Update freeform table
gr.update(value=emotions_data, visible=True), # Update emotions table
gr.update(value=speaker_id), # Store speaker ID
gr.update(value=None), # Clear audio preview
gr.update(value=""), # Clear safetensors path
gr.update(value=""), # Clear audio path
gr.update(value="", visible=False) # Clear voice selection display
)
return (
gr.update(value="", visible=False),
gr.update(value=[], visible=True),
gr.update(value=[], visible=True),
gr.update(value=""),
gr.update(value=None),
gr.update(value=""),
gr.update(value=""),
gr.update(value="", visible=False)
)
def select_freeform_from_table(evt: gr.SelectData, speaker_id: str):
"""Handle freeform selection from table - load freeform voice files."""
if speaker_id:
voice_name = f"{speaker_id}_freeform"
voice_dir = EARS_PATH / voice_name
audio_path = str(voice_dir / "audio.mp3")
st_path = str(voice_dir / "speaker_latent.safetensors")
if voice_dir.exists():
# Format freeform display
freeform_display = f"Selected: Freeform\n{speaker_id}_freeform"
return (
gr.update(value=freeform_display, visible=True), # Show freeform selection
gr.update(value=audio_path), # Update audio player
gr.update(value=st_path), # Update safetensors path
gr.update(value=audio_path) # Update audio path for reconstruction
)
return gr.update(value="", visible=False), gr.update(value=None), gr.update(value=""), gr.update(value="")
def select_emotion_from_table(evt: gr.SelectData, speaker_id: str):
"""Handle emotion selection - load voice files."""
if evt.value and speaker_id:
# evt.index is (row, col) - get the row to extract emotion from first column
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
row_index = evt.index[0]
else:
row_index = 0
# Get emotions data and extract the emotion name from first column
emotions_data = get_emotions_table(speaker_id)
if isinstance(row_index, int) and row_index < len(emotions_data):
emotion = emotions_data[row_index][0] # First column is emotion name
voice_name = f"{speaker_id}_emo_{emotion}"
voice_dir = EARS_PATH / voice_name
audio_path = str(voice_dir / "audio.mp3")
st_path = str(voice_dir / "speaker_latent.safetensors")
if voice_dir.exists():
# Format emotion display - clean and simple
emotion_display = f"Selected Emotion: {emotion.title()}\n{speaker_id}_emo_{emotion}"
return (
gr.update(value=emotion_display, visible=True), # Show emotion selection
gr.update(value=audio_path), # Update audio player
gr.update(value=st_path), # Update safetensors path
gr.update(value=audio_path) # Update audio path for reconstruction
)
return gr.update(value="", visible=False), gr.update(value=None), gr.update(value=""), gr.update(value="")
def select_vctk_speaker_from_table(evt: gr.SelectData, table_data):
"""Handle VCTK speaker selection - load voice files directly."""
if evt.value and table_data is not None:
# evt.index is a tuple/list (row, col), we need the row to get the speaker ID
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
row_index = evt.index[0]
else:
row_index = evt.index
# Use the actual displayed (filtered) table data (pandas DataFrame)
if isinstance(row_index, int) and row_index < len(table_data):
speaker_row = table_data.iloc[row_index]
speaker_id = speaker_row.iloc[0] # First column is the ID
# Load voice files from VCTK
voice_dir = VCTK_PATH / speaker_id
audio_path = str(voice_dir / "audio.mp3")
st_path = str(voice_dir / "speaker_latent.safetensors")
if voice_dir.exists():
# Format selection display
gender_full = "Male" if speaker_row.iloc[1] == "M" else "Female" if speaker_row.iloc[1] == "F" else speaker_row.iloc[1]
selection_text = f"Selected Speaker: {speaker_id}\n{gender_full}{speaker_row.iloc[2]}{speaker_row.iloc[3]}"
return (
gr.update(value=selection_text, visible=True), # Show speaker selection
gr.update(value=speaker_id), # Store speaker ID
gr.update(value=audio_path), # Update audio player
gr.update(value=st_path), # Update safetensors path
gr.update(value=audio_path) # Update audio path for reconstruction
)
return (
gr.update(value="", visible=False),
gr.update(value=""),
gr.update(value=None),
gr.update(value=""),
gr.update(value="")
)
def search_vctk_speakers(search_query):
"""Filter VCTK speakers table based on search query."""
filtered_data = get_vctk_speakers_table(search_query)
return gr.update(value=filtered_data)
# Expresso Helper Functions
def get_expresso_speakers():
"""Get list of all Expresso speakers with their metadata."""
if not EXPRESSO_PATH.exists():
return []
speakers_with_metadata = []
for subdir in sorted(EXPRESSO_PATH.iterdir()):
if subdir.is_dir() and subdir.name.startswith('expresso_'):
speaker_id = subdir.name
audio_path = subdir / "audio.mp3"
st_path = subdir / "speaker_latent.safetensors"
metadata_path = subdir / "metadata.json"
if audio_path.exists() and st_path.exists() and metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
audio_length = data.get("audio_length_seconds", 0)
speakers_with_metadata.append({
'id': speaker_id,
'type': data.get('type', 'unknown'),
'speakers': data.get('speakers', 'unknown'),
'style': data.get('style', 'unknown'),
'audio_length': f"{audio_length:.1f}s"
})
except Exception:
continue
return speakers_with_metadata
def get_expresso_speakers_table(search_query=""):
"""Get Expresso speakers as table data for Gradio, optionally filtered by search query."""
speakers = get_expresso_speakers()
result = []
for s in speakers:
# Apply search filter if provided
if search_query:
search_lower = search_query.lower()
# Search in all fields
if not any(search_lower in str(v).lower() for v in [s['id'], s['type'], s['speakers'], s['style']]):
continue
result.append([
s['id'],
s['type'],
s['speakers'],
s['style'],
s['audio_length']
])
return result
def select_expresso_speaker_from_table(evt: gr.SelectData, table_data):
"""Handle Expresso speaker selection - load voice files directly."""
if evt.value and table_data is not None:
# evt.index is a tuple/list (row, col), we need the row to get the speaker ID
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
row_index = evt.index[0]
else:
row_index = evt.index
# Use the actual displayed (filtered) table data (pandas DataFrame)
if isinstance(row_index, int) and row_index < len(table_data):
speaker_row = table_data.iloc[row_index]
speaker_id = speaker_row.iloc[0] # First column is the ID
# Load voice files from Expresso
voice_dir = EXPRESSO_PATH / speaker_id
audio_path = str(voice_dir / "audio.mp3")
st_path = str(voice_dir / "speaker_latent.safetensors")
if voice_dir.exists():
# Format selection display
selection_text = f"Selected Voice: {speaker_id}\nType: {speaker_row.iloc[1]} • Speakers: {speaker_row.iloc[2]} • Style: {speaker_row.iloc[3]}"
return (
gr.update(value=selection_text, visible=True), # Show speaker selection
gr.update(value=speaker_id), # Store speaker ID
gr.update(value=audio_path), # Update audio player
gr.update(value=st_path), # Update safetensors path
gr.update(value=audio_path) # Update audio path for reconstruction
)
return (
gr.update(value="", visible=False),
gr.update(value=""),
gr.update(value=None),
gr.update(value=""),
gr.update(value="")
)
def search_expresso_speakers(search_query):
"""Filter Expresso speakers table based on search query."""
filtered_data = get_expresso_speakers_table(search_query)
return gr.update(value=filtered_data)
# HF-Custom Helper Functions
def get_hf_custom_speakers():
"""Get list of all HF-Custom speakers with their metadata."""
if not HF_CUSTOM_PATH.exists():
return []
speakers_with_metadata = []
for subdir in sorted(HF_CUSTOM_PATH.iterdir()):
if subdir.is_dir():
speaker_name = subdir.name
audio_path = subdir / "audio.mp3"
st_path = subdir / "speaker_latent.safetensors"
metadata_path = subdir / "metadata.json"
if audio_path.exists() and st_path.exists() and metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
data = json.load(f)
audio_length = data.get("audio_duration_seconds", 0)
speakers_with_metadata.append({
'name': data.get('speaker_name', speaker_name),
'dataset': data.get('dataset_name', ''),
'description': data.get('speaker_description', ''),
'audio_length': f"{audio_length:.1f}s"
})
except Exception:
continue
return speakers_with_metadata
def get_hf_custom_speakers_table(search_query=""):
"""Get HF-Custom speakers as table data for Gradio, optionally filtered by search query."""
speakers = get_hf_custom_speakers()
result = []
for s in speakers:
# Apply search filter if provided
if search_query:
search_lower = search_query.lower()
# Search in all fields
if not any(search_lower in str(v).lower() for v in [s['name'], s['dataset'], s['description']]):
continue
result.append([
s['name'],
s['dataset'],
s['description'],
s['audio_length']
])
return result
def select_hf_custom_speaker_from_table(evt: gr.SelectData, table_data):
"""Handle HF-Custom speaker selection - load voice files directly."""
if evt.value and table_data is not None:
# evt.index is a tuple/list (row, col), we need the row to get the speaker name
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
row_index = evt.index[0]
else:
row_index = evt.index
# Use the actual displayed (filtered) table data (pandas DataFrame)
if isinstance(row_index, int) and row_index < len(table_data):
speaker_row = table_data.iloc[row_index]
speaker_name = speaker_row.iloc[0] # First column is the name
# Load voice files from HF-Custom
voice_dir = HF_CUSTOM_PATH / speaker_name
audio_path = str(voice_dir / "audio.mp3")
st_path = str(voice_dir / "speaker_latent.safetensors")
if voice_dir.exists():
# Format selection display
dataset_info = f" • {speaker_row.iloc[1]}" if speaker_row.iloc[1] else ""
selection_text = f"Selected Voice: {speaker_name}{dataset_info}\n{speaker_row.iloc[2]}"
return (
gr.update(value=selection_text, visible=True), # Show speaker selection
gr.update(value=speaker_name), # Store speaker name
gr.update(value=audio_path), # Update audio player
gr.update(value=st_path), # Update safetensors path
gr.update(value=audio_path) # Update audio path for reconstruction
)
return (
gr.update(value="", visible=False),
gr.update(value=""),
gr.update(value=None),
gr.update(value=""),
gr.update(value="")
)
def search_hf_custom_speakers(search_query):
"""Filter HF-Custom speakers table based on search query."""
filtered_data = get_hf_custom_speakers_table(search_query)
return gr.update(value=filtered_data)
# Audio Prompt Library functions
AUDIO_EXTS = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".webm", ".aac", ".opus"}
def get_audio_prompt_files():
"""Get list of audio files from the audio prompt folder."""
if AUDIO_PROMPT_FOLDER is None or not AUDIO_PROMPT_FOLDER.exists():
return []
files = sorted([
f.name for f in AUDIO_PROMPT_FOLDER.iterdir()
if f.is_file() and f.suffix.lower() in AUDIO_EXTS
], key=str.lower)
return [[file] for file in files]
def select_audio_prompt_file(evt: gr.SelectData):
"""Handle audio prompt file selection from table."""
if evt.value and AUDIO_PROMPT_FOLDER is not None:
file_path = AUDIO_PROMPT_FOLDER / evt.value
if file_path.exists():
return gr.update(value=str(file_path))
return gr.update()
def switch_dataset(dataset_name):
"""Switch between Custom Audio Panel, EARS, VCTK, Expresso, and HF-Custom datasets."""
if dataset_name == "Custom Audio Panel":
# Show Custom Audio Panel only, hide all voicebank UI
return (
gr.update(value="", visible=False), # dataset_license_info
gr.update(visible=True), # custom_audio_row
gr.update(visible=False), # voicebank_row
gr.update(visible=False), # voice_type_column
gr.update(visible=True), # ears_column (within voicebank_row)
gr.update(visible=False), # vctk_column
gr.update(visible=False), # expresso_column
gr.update(visible=False), # hf_custom_column
# Clear selections
gr.update(value="", visible=False), # selected_speaker_display
gr.update(value=[]), # freeform_table
gr.update(value=[]), # emotions_table
gr.update(value="", visible=False), # selected_voice_display
gr.update(value="", visible=False), # vctk_speaker_display
gr.update(value="", visible=False), # expresso_speaker_display
gr.update(value="", visible=False), # hf_custom_speaker_display
gr.update(value=""), # selected_speaker_state
gr.update(value=None), # audio_preview
gr.update(value=""), # speaker_st_path_state
gr.update(value="") # speaker_audio_path_state
)
elif dataset_name == "EARS":
# Show EARS UI, hide others, show Voice Type column
license_text = "**EARS Dataset License:** Creative Commons Attribution 4.0 International ([CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/))"
return (
gr.update(value=license_text, visible=True), # dataset_license_info
gr.update(visible=False), # custom_audio_row
gr.update(visible=True), # voicebank_row
gr.update(visible=True), # voice_type_column (show for EARS)
gr.update(visible=True), # ears_column
gr.update(visible=False), # vctk_column
gr.update(visible=False), # expresso_column
gr.update(visible=False), # hf_custom_column
gr.update(value=""), # selected_speaker_display
gr.update(value=[], visible=True), # freeform_table
gr.update(value=[], visible=True), # emotions_table
gr.update(value="", visible=False), # selected_voice_display
gr.update(value="", visible=False), # vctk_speaker_display
gr.update(value="", visible=False), # expresso_speaker_display
gr.update(value="", visible=False), # hf_custom_speaker_display
gr.update(value=""), # selected_speaker_state
gr.update(value=None), # audio_preview
gr.update(value=""), # speaker_st_path_state
gr.update(value="") # speaker_audio_path_state
)
elif dataset_name == "VCTK":
# Show VCTK UI, hide others, hide Voice Type column
license_text = "**VCTK Dataset License:** Creative Commons Attribution 4.0 International ([CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/))"
return (
gr.update(value=license_text, visible=True), # dataset_license_info
gr.update(visible=False), # custom_audio_row
gr.update(visible=True), # voicebank_row
gr.update(visible=False), # voice_type_column
gr.update(visible=False), # ears_column
gr.update(visible=True), # vctk_column
gr.update(visible=False), # expresso_column
gr.update(visible=False), # hf_custom_column (hide for VCTK)
gr.update(value=""), # selected_speaker_display
gr.update(value=[], visible=True), # freeform_table
gr.update(value=[], visible=True), # emotions_table
gr.update(value="", visible=False), # selected_voice_display
gr.update(value="", visible=False), # vctk_speaker_display
gr.update(value="", visible=False), # expresso_speaker_display
gr.update(value="", visible=False), # hf_custom_speaker_display
gr.update(value=""), # selected_speaker_state
gr.update(value=None), # audio_preview
gr.update(value=""), # speaker_st_path_state
gr.update(value="") # speaker_audio_path_state
)
elif dataset_name == "Expresso":
# Show Expresso UI, hide others, hide Voice Type column
license_text = "**Expresso Dataset License:** Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International ([CC-BY-NC-SA-4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/))"
return (
gr.update(value=license_text, visible=True), # dataset_license_info
gr.update(visible=False), # custom_audio_row
gr.update(visible=True), # voicebank_row
gr.update(visible=False), # voice_type_column
gr.update(visible=False), # ears_column
gr.update(visible=False), # vctk_column
gr.update(visible=True), # expresso_column
gr.update(visible=False), # hf_custom_column (hide for Expresso)
gr.update(value=""), # selected_speaker_display
gr.update(value=[], visible=True), # freeform_table
gr.update(value=[], visible=True), # emotions_table
gr.update(value="", visible=False), # selected_voice_display
gr.update(value="", visible=False), # vctk_speaker_display
gr.update(value="", visible=False), # expresso_speaker_display
gr.update(value="", visible=False), # hf_custom_speaker_display
gr.update(value=""), # selected_speaker_state
gr.update(value=None), # audio_preview
gr.update(value=""), # speaker_st_path_state
gr.update(value="") # speaker_audio_path_state
)
else: # HF-Custom
# Show HF-Custom UI, hide others, hide Voice Type column
license_text = "**HF-Custom Voices:** Available in dataset cache (information in metadata.json per voice). Also view dataset at [jordand/echo-embeddings-custom](https://huggingface.co/datasets/jordand/echo-embeddings-custom)"
return (
gr.update(value=license_text, visible=True), # dataset_license_info
gr.update(visible=False), # custom_audio_row
gr.update(visible=True), # voicebank_row
gr.update(visible=False), # voice_type_column
gr.update(visible=False), # ears_column
gr.update(visible=False), # vctk_column
gr.update(visible=False), # expresso_column
gr.update(visible=True), # hf_custom_column
gr.update(value=""), # selected_speaker_display
gr.update(value=[], visible=True), # freeform_table
gr.update(value=[], visible=True), # emotions_table
gr.update(value="", visible=False), # selected_voice_display
gr.update(value="", visible=False), # vctk_speaker_display
gr.update(value="", visible=False), # expresso_speaker_display
gr.update(value="", visible=False), # hf_custom_speaker_display
gr.update(value=""), # selected_speaker_state
gr.update(value=None), # audio_preview
gr.update(value=""), # speaker_st_path_state
gr.update(value="") # speaker_audio_path_state
)
def select_text_preset(evt: gr.SelectData):
"""Handle text preset selection - extract text from the row."""
if evt.value:
# Get the row index from the selected cell
if isinstance(evt.index, (tuple, list)) and len(evt.index) >= 2:
row_index = evt.index[0]
else:
row_index = evt.index
# Get all presets and extract the text (column 2) from the selected row
presets_data = load_text_presets()
if isinstance(row_index, int) and row_index < len(presets_data):
text = presets_data[row_index][2] # Column 2 is the text
return gr.update(value=text)
return gr.update()
def update_cfg_visibility(cfg_mode):
"""Update visibility of CFG parameters based on selected mode."""
if cfg_mode == "joint-unconditional":
return (
gr.update(label="Text/Speaker CFG Scale", info="Guidance strength for text and speaker (joint)"),
gr.update(visible=False),
gr.update(visible=False)
)
elif cfg_mode == "apg-independent":
return (
gr.update(label="Text CFG Scale", info="Guidance strength for text"),
gr.update(visible=True),
gr.update(visible=True)
)
else: # independent or alternating
return (
gr.update(label="Text CFG Scale", info="Guidance strength for text"),
gr.update(visible=True),
gr.update(visible=False)
)
def toggle_speaker_k_fields(enabled):
"""Toggle visibility of speaker K row. Hidden components preserve their values automatically."""
return gr.update(visible=enabled)
def toggle_custom_shapes_fields(enabled):
"""Toggle visibility of custom shapes row and reset to defaults if disabled."""
if enabled:
return gr.update(visible=True)
else:
# When disabled, hide the row and reset fields to defaults
return gr.update(visible=False)
def toggle_mode(mode, speaker_k_enable_val, speaker_kv_simple_val):
"""Toggle between simple and advanced modes and sync speaker KV state."""
if mode == "Simple Mode":
# Sync simple checkbox with advanced mode's speaker_k_enable value
return (
gr.update(visible=True), # simple_mode_row (speaker KV checkbox)
gr.update(visible=False), # advanced_mode_compile_column
gr.update(visible=False), # advanced_mode_column (all other parameters)
gr.update(value=speaker_k_enable_val), # sync simple checkbox with advanced
gr.update(value=speaker_k_enable_val), # also update speaker_k_enable (keep same)
)
else: # Advanced Mode
# Sync advanced mode's speaker_k_enable with simple checkbox value
return (
gr.update(visible=False), # simple_mode_row (speaker KV checkbox)
gr.update(visible=True), # advanced_mode_compile_column
gr.update(visible=True), # advanced_mode_column (all other parameters)
gr.update(value=speaker_kv_simple_val), # sync simple checkbox (keep same)
gr.update(value=speaker_kv_simple_val), # sync advanced with simple checkbox
)
def sync_simple_to_advanced(simple_enabled):
"""Sync simple mode speaker KV checkbox to advanced mode controls."""
if simple_enabled:
return (
gr.update(value=True), # speaker_k_enable
gr.update(visible=True), # speaker_k_row
gr.update(value=1.5), # speaker_k_scale
gr.update(value=0.9), # speaker_k_min_t
gr.update(value=24), # speaker_k_max_layers
)
else:
return (
gr.update(value=False), # speaker_k_enable
gr.update(visible=False), # speaker_k_row
gr.update(), # speaker_k_scale (no change)
gr.update(), # speaker_k_min_t (no change)
gr.update(), # speaker_k_max_layers (no change)
)
def apply_core_preset(preset_name):
"""Apply core sampling parameters preset."""
if preset_name == "default":
return [
gr.update(value=0), # rng_seed
gr.update(value=40), # num_steps
gr.update(value="independent"), # cfg_mode
gr.update(value="Custom"), # Set main preset to Custom
]
return [gr.update()] * 4
def apply_cfg_preset(preset_name):
"""Apply CFG guidance preset."""
presets = {
"default": (3.0, 5.0, 0.5, 1.0),
"higher speaker": (3.0, 8.0, 0.5, 1.0),
"large guidances": (8.0, 8.0, 0.5, 1.0),
}
if preset_name not in presets:
return [gr.update()] * 5
text_scale, speaker_scale, min_t, max_t = presets[preset_name]
return [
gr.update(value=text_scale), # cfg_scale_text
gr.update(value=speaker_scale), # cfg_scale_speaker
gr.update(value=min_t), # cfg_min_t
gr.update(value=max_t), # cfg_max_t
gr.update(value="Custom"), # Set main preset to Custom
]
def apply_speaker_kv_preset(preset_name):
"""Apply speaker KV attention control preset."""
if preset_name == "enable":
return [
gr.update(value=True), # speaker_k_enable
gr.update(visible=True), # speaker_k_row
gr.update(value="Custom"), # Set main preset to Custom
]
elif preset_name == "off":
return [
gr.update(value=False), # speaker_k_enable
gr.update(visible=False), # speaker_k_row
gr.update(value="Custom"), # Set main preset to Custom
]
return [gr.update()] * 3
def apply_truncation_preset(preset_name):
"""Apply truncation & temporal rescaling preset."""
presets = {
"flat": (0.8, 1.2, 3.0),
"sharp": (0.9, 0.96, 3.0),
"baseline(sharp)": (1.0, 1.0, 3.0),
}
if preset_name == "custom" or preset_name not in presets:
return [gr.update()] * 4 # Return no changes for custom
truncation, rescale_k, rescale_sigma = presets[preset_name]
return [
gr.update(value=truncation),
gr.update(value=rescale_k),
gr.update(value=rescale_sigma),
gr.update(value="Custom"), # Set main preset to Custom
]
def apply_apg_preset(preset_name):
"""Apply APG parameters preset."""
presets = {
"default": (0.5, 0.5, -0.25, -0.25, "", ""), # default: -0.25 momentum
"no momentum": (0.0, 0.0, 0.0, 0.0, "", ""), # no momentum: 0 momentum
"norms": (0.5, 0.5, -0.25, -0.25, "7.5", "7.5"), # norms: default + 7.5 norms
"no eta": (0.0, 0.0, -0.25, -0.25, "", ""), # no eta: 0 eta
}
if preset_name not in presets:
return [gr.update()] * 7
eta_text, eta_speaker, momentum_text, momentum_speaker, norm_text, norm_speaker = presets[preset_name]
return [
gr.update(value=eta_text), # apg_eta_text
gr.update(value=eta_speaker), # apg_eta_speaker
gr.update(value=momentum_text), # apg_momentum_text
gr.update(value=momentum_speaker), # apg_momentum_speaker
gr.update(value=norm_text), # apg_norm_text
gr.update(value=norm_speaker), # apg_norm_speaker
gr.update(value="Custom"), # Set main preset to Custom
]
def load_sampler_presets():
"""Load sampler presets from JSON file."""
if SAMPLER_PRESETS_PATH.exists():
with open(SAMPLER_PRESETS_PATH, 'r') as f:
return json.load(f)
else:
# Create default presets (will use existing JSON file if it exists)
default_presets = {
"Flat (Independent)": {
"num_steps": "30",
"cfg_mode": "independent",
"cfg_scale_text": "3.0",
"cfg_scale_speaker": "5.0",
"cfg_min_t": "0.5",
"cfg_max_t": "1.0",
"truncation_factor": "0.8",
"rescale_k": "1.2",
"rescale_sigma": "3.0"
},
"Sharp (Independent)": {
"num_steps": "30",
"cfg_mode": "independent",
"cfg_scale_text": "3.0",
"cfg_scale_speaker": "5.0",
"cfg_min_t": "0.5",
"cfg_max_t": "1.0",
"truncation_factor": "0.9",
"rescale_k": "0.96",
"rescale_sigma": "3.0"
},
}
with open(SAMPLER_PRESETS_PATH, 'w') as f:
json.dump(default_presets, f, indent=2)
return default_presets
def apply_sampler_preset(preset_name):
"""Apply a sampler preset to all fields."""
presets = load_sampler_presets()
if preset_name == "Custom" or preset_name not in presets:
return [gr.update()] * 20 # Return no changes for custom
preset = presets[preset_name]
# Determine visibility based on cfg_mode
cfg_mode_value = preset["cfg_mode"]
speaker_visible = (cfg_mode_value != "joint-unconditional")
apg_visible = (cfg_mode_value == "apg-independent")
speaker_k_enabled = preset.get("speaker_k_enable", False)
# Convert string values to numeric where appropriate
def to_num(val, default):
try:
return float(val) if isinstance(val, str) else val
except (ValueError, TypeError):
return default
return [
gr.update(value=int(to_num(preset["num_steps"], 40))),
gr.update(value=preset["cfg_mode"]),
gr.update(value=to_num(preset["cfg_scale_text"], 3.0)),
gr.update(value=to_num(preset["cfg_scale_speaker"], 5.0), visible=speaker_visible),
gr.update(value=to_num(preset["cfg_min_t"], 0.5)),
gr.update(value=to_num(preset["cfg_max_t"], 1.0)),
gr.update(value=to_num(preset["truncation_factor"], 0.8)),
gr.update(value=to_num(preset["rescale_k"], 1.2)), # Now numeric
gr.update(value=to_num(preset["rescale_sigma"], 3.0)),
gr.update(value=speaker_k_enabled),
gr.update(visible=speaker_k_enabled), # speaker_k_row
gr.update(value=to_num(preset.get("speaker_k_scale", "1.5"), 1.5)),
gr.update(value=to_num(preset.get("speaker_k_min_t", "0.9"), 0.9)),
gr.update(value=int(to_num(preset.get("speaker_k_max_layers", "24"), 24))),
gr.update(value=to_num(preset.get("apg_eta_text", "0.0"), 0.0)),
gr.update(value=to_num(preset.get("apg_eta_speaker", "0.0"), 0.0)),
gr.update(value=to_num(preset.get("apg_momentum_text", "0.0"), 0.0)),
gr.update(value=to_num(preset.get("apg_momentum_speaker", "0.0"), 0.0)),
gr.update(value=preset.get("apg_norm_text", "")), # Keep as string (can be empty)
gr.update(value=preset.get("apg_norm_speaker", "")), # Keep as string (can be empty)
]
# Build Gradio Interface
LINK_CSS = """
.preset-inline { display:flex; align-items:baseline; gap:6px; margin-top:-4px; margin-bottom:-12px; }
.preset-inline .title { font-weight:600; font-size:.95rem; }
.preset-inline .dim { color:#666; margin:0 4px; }
/* blue, linky */
a.preset-link { color: #0a5bd8; text-decoration: underline; cursor: pointer; font-weight: 400; }
a.preset-link:hover { text-decoration: none; opacity: 0.8; }
/* Dark mode support for preset links */
.dark a.preset-link,
[data-theme="dark"] a.preset-link {
color: #60a5fa !important;
}
.dark a.preset-link:hover,
[data-theme="dark"] a.preset-link:hover {
color: #93c5fd !important;
}
.dark .preset-inline .dim,
[data-theme="dark"] .preset-inline .dim {
color: #9ca3af !important;
}
/* keep proxy buttons in DOM but invisible */
.proxy-btn { position:absolute; width:0; height:0; overflow:hidden; padding:0 !important; margin:0 !important; border:0 !important; opacity:0; pointer-events:none; }
/* Better contrast for parameter group boxes */
.gr-group {
border: 1px solid #d1d5db !important;
background: #f3f4f6 !important;
}
.dark .gr-group,
[data-theme="dark"] .gr-group {
border: 1px solid #4b5563 !important;
background: #1f2937 !important;
}
/* Highlight generated audio */
.generated-audio-player {
border: 3px solid #667eea !important;
border-radius: 12px !important;
padding: 20px !important;
background: linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.05) 100%) !important;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.2) !important;
margin: 1rem 0 !important;
}
.generated-audio-player > div {
background: transparent !important;
}
/* Make Parameter Mode selector more prominent */
#component-mode-selector {
text-align: center;
padding: 1rem 0;
}
#component-mode-selector label {
font-size: 1.1rem !important;
font-weight: 600 !important;
margin-bottom: 0.5rem !important;
}
#component-mode-selector .wrap {
justify-content: center !important;
}
#component-mode-selector fieldset {
border: 2px solid #e5e7eb !important;
border-radius: 8px !important;
padding: 1rem !important;
background: #f9fafb !important;
}
.dark #component-mode-selector fieldset,
[data-theme="dark"] #component-mode-selector fieldset {
border: 2px solid #4b5563 !important;
background: #1f2937 !important;
}
/* Stronger section separators */
.section-separator {
height: 3px !important;
background: linear-gradient(90deg, transparent 0%, #667eea 20%, #764ba2 80%, transparent 100%) !important;
border: none !important;
margin: 2rem 0 !important;
}
.dark .section-separator,
[data-theme="dark"] .section-separator {
background: linear-gradient(90deg, transparent 0%, #667eea 20%, #764ba2 80%, transparent 100%) !important;
}
/* Section headers styling */
.gradio-container h1,
.gradio-container h2 {
font-weight: 700 !important;
margin-top: 1.5rem !important;
margin-bottom: 1rem !important;
}
/* Highlighted tip box */
.tip-box {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%) !important;
border-left: 4px solid #f59e0b !important;
border-radius: 8px !important;
padding: 1rem 1.5rem !important;
margin: 1rem 0 !important;
box-shadow: 0 2px 4px rgba(245, 158, 11, 0.1) !important;
}
.tip-box strong {
color: #92400e !important;
}
.dark .tip-box,
[data-theme="dark"] .tip-box {
background: linear-gradient(135deg, #451a03 0%, #78350f 100%) !important;
border-left: 4px solid #f59e0b !important;
}
.dark .tip-box strong,
[data-theme="dark"] .tip-box strong {
color: #fbbf24 !important;
}
"""
JS_CODE = r"""
function () {
// Get a queryable root, regardless of Shadow DOM
const appEl = document.querySelector("gradio-app");
const root = appEl && appEl.shadowRoot ? appEl.shadowRoot : document;
function clickHiddenButtonById(id) {
if (!id) return;
const host = root.getElementById(id);
if (!host) return;
const realBtn = host.querySelector("button, [role='button']") || host;
realBtn.click();
}
// Delegate clicks from any <a class="preset-link" data-fire="...">
root.addEventListener("click", (ev) => {
const a = ev.target.closest("a.preset-link");
if (!a) return;
ev.preventDefault();
ev.stopPropagation();
ev.stopImmediatePropagation();
clickHiddenButtonById(a.getAttribute("data-fire"));
return false;
}, true);
}
"""
def init_session():
"""Initialize session ID for this browser tab/session."""
return secrets.token_hex(8)
def init_and_compile():
"""Initialize session and trigger compilation on page load."""
session_id = secrets.token_hex(8)
# Trigger compilation automatically on page load if not on Zero GPU
# This ensures Simple mode (which defaults compile=True) gets compiled
if not IS_ZEROGPU:
# Just call do_compile directly - it will load models and compile
# Status updates will be visible in Advanced mode, hidden in Simple mode
status_update, checkbox_update = do_compile()
return session_id, status_update, checkbox_update
else:
# On Zero GPU, don't try to compile
return session_id, gr.update(), gr.update()
with gr.Blocks(title="Echo-TTS", css=LINK_CSS, js=JS_CODE) as demo:
gr.Markdown("# Echo-TTS")
gr.Markdown("*Jordan Darefsky, 2025. See technical details [here](https://jordandarefsky.com/blog/2025/echo/)*")
# License notice for Fish Speech autoencoder
gr.Markdown("**License Notice:** All audio outputs are subject to non-commercial use only due to the [Fish Speech S1-DAC autoencoder](https://github.com/fishaudio/fish-speech) being licensed under [CC-BY-NC-SA-4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/).")
# Silentcipher watermarking notice
if USE_SILENTCIPHER:
gr.Markdown(f"*Audio output is watermarked with [silentcipher](https://github.com/sony/silentcipher) using message `{SILENTCIPHER_MESSAGE}`*")
# Instructions for Simple Mode
with gr.Accordion("📖 Quick Start Instructions", open=True):
gr.Markdown("""
### Simple Mode (Recommended for Beginners)
1. **Pick or upload a voice** - Choose from the voicebank or upload your own audio (up to 2 minutes)
2. **Choose a text prompt preset or enter your own prompt** - What you want the voice to say (the presets are a good guide for format/style)
3. **Select a Sampling preset** - The default preset "Independent (High Speaker CFG)" is usually good to start
4. **Click Generate Audio** - Wait for the model to generate your audio
<div class="tip-box">
💡 **Tip:** If the generated voice doesn't match the reference speaker at all, enable "Speaker KV Attention Scaling" and click Generate Audio again.
</div>
### Advanced Mode
Switch to Advanced mode for full control over all generation parameters including CFG scales, sampling steps, truncation, and more.
### Other tips
High CFG settings are recommended but may lead to oversaturation; APG might help with this. Flat settings tend to reduce "impulse" artifacts but might result in worse (blunted/compressed/artifact-y) laughter, breathing, etc. generation.
Echo will try to fit the entire text-prompt into (<=) 30 seconds of audio. If your prompt is very long, the generated speech may be too quick (this is not an issue for shorter text-prompts). For disfluent, single-speaker speech, we recommend trying the reference text beginning with "[S1] ... explore how we can design" as a starting point.
""")
# Session state for per-user file management
session_id_state = gr.State(None)
# Hidden state variables to store paths and selection
selected_speaker_state = gr.Textbox(visible=False, value="")
speaker_st_path_state = gr.Textbox(visible=False, value="")
speaker_audio_path_state = gr.Textbox(visible=False, value="")
gr.Markdown("# Voice Selection")
# Dataset selector
dataset_selector = gr.Radio(
choices=["Custom Audio Panel", "EARS", "VCTK", "Expresso", "HF-Custom"],
value="Custom Audio Panel",
label="Select Dataset",
info="Choose which voicebank to use"
)
dataset_license_info = gr.Markdown(
"",
visible=False
)
# Custom Audio Panel UI (visible by default, takes full width)
with gr.Row(visible=True) as custom_audio_row:
# Optional: Audio prompt library table (only shown if AUDIO_PROMPT_FOLDER is configured)
if AUDIO_PROMPT_FOLDER is not None and AUDIO_PROMPT_FOLDER.exists():
with gr.Column(scale=1, min_width=200):
gr.Markdown("#### Audio Library (favorite examples from voicebank datasets)")
audio_prompt_table = gr.Dataframe(
value=get_audio_prompt_files(),
headers=["Filename"],
datatype=["str"],
row_count=(10, "dynamic"),
col_count=(1, "fixed"),
interactive=False,
label="Click to select (or upload your own audio file directly on the right)"
)
with gr.Column(scale=2):
custom_audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Speaker Reference Audio (only first two minutes will be used; leave empty for zero speaker conditioning)",
max_length=600 # Maximum duration in seconds (10 minutes)
)
with gr.Row(visible=False) as voicebank_row:
# Voice selection UI for all voicebank datasets
# EARS UI (visible by default when voicebank_row is shown)
with gr.Column(scale=2, visible=True) as ears_column:
gr.Markdown("### 1. Speakers (EARS)")
selected_speaker_display = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
lines=2,
max_lines=2
)
speaker_search = gr.Textbox(
placeholder="Search speakers (by ID, gender, age, ethnicity, language)...",
label="",
show_label=False,
container=False
)
speakers_table = gr.Dataframe(
value=get_speakers_table(),
headers=["ID", "G", "Age", "Ethnicity", "Native Lang"],
datatype=["str", "str", "str", "str", "str"],
row_count=(8, "dynamic"),
col_count=(5, "fixed"),
interactive=False,
label="Click any cell to select",
column_widths=["10%", "8%", "15%", "30%", "37%"]
)
# VCTK UI (hidden by default)
with gr.Column(scale=2, visible=False) as vctk_column:
gr.Markdown("### 1. Speakers (VCTK)")
vctk_speaker_display = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
lines=2,
max_lines=2
)
vctk_speaker_search = gr.Textbox(
placeholder="Search speakers (by ID, gender, age, details)...",
label="",
show_label=False,
container=False
)
vctk_speakers_table = gr.Dataframe(
value=get_vctk_speakers_table(),
headers=["ID", "G", "Age", "Details", "Length"],
datatype=["str", "str", "str", "str", "str"],
row_count=(8, "dynamic"),
col_count=(5, "fixed"),
interactive=False,
label="Click any cell to select",
column_widths=["10%", "8%", "12%", "50%", "20%"]
)
# Expresso UI (hidden by default)
with gr.Column(scale=2, visible=False) as expresso_column:
gr.Markdown("### 1. Voices (Expresso)")
expresso_speaker_display = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
lines=2,
max_lines=2
)
expresso_speaker_search = gr.Textbox(
placeholder="Search voices (by ID, type, speakers, style)...",
label="",
show_label=False,
container=False
)
expresso_speakers_table = gr.Dataframe(
value=get_expresso_speakers_table(),
headers=["ID", "Type", "Speakers", "Style", "Length"],
datatype=["str", "str", "str", "str", "str"],
row_count=(8, "dynamic"),
col_count=(5, "fixed"),
interactive=False,
label="Click any cell to select",
column_widths=["35%", "15%", "15%", "15%", "20%"]
)
# HF-Custom UI (hidden by default)
with gr.Column(scale=2, visible=False) as hf_custom_column:
gr.Markdown("### 1. Voices (HF-Custom)")
hf_custom_speaker_display = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
lines=2,
max_lines=2
)
hf_custom_speaker_search = gr.Textbox(
placeholder="Search voices (by name, dataset, description)...",
label="",
show_label=False,
container=False
)
hf_custom_speakers_table = gr.Dataframe(
value=get_hf_custom_speakers_table(),
headers=["Name", "Dataset", "Description", "Length"],
datatype=["str", "str", "str", "str"],
row_count=(8, "dynamic"),
col_count=(4, "fixed"),
interactive=False,
label="Click any cell to select",
column_widths=["15%", "15%", "50%", "20%"]
)
with gr.Column(scale=1, visible=True) as voice_type_column:
gr.Markdown("### 2. Voice Type")
selected_voice_display = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
lines=2,
max_lines=2
)
freeform_table = gr.Dataframe(
value=[],
headers=["Type", "Length"],
datatype=["str", "str"],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
interactive=False,
label="Freeform voice",
visible=True,
column_widths=["60%", "40%"]
)
gr.Markdown("**Emotions:**")
emotions_table = gr.Dataframe(
value=[],
headers=["Emotion", "Length"],
datatype=["str", "str"],
row_count=(8, "dynamic"),
col_count=(2, "fixed"),
interactive=False,
visible=True,
column_widths=["60%", "40%"]
)
with gr.Column(scale=1):
gr.Markdown("### 3. Audio Preview")
audio_preview = gr.Audio(label="Voice Sample", type="filepath", interactive=False)
gr.HTML('<hr class="section-separator">')
gr.Markdown("# Text Prompt")
with gr.Accordion("Text Presets", open=True):
text_presets_table = gr.Dataframe(
value=load_text_presets(),
headers=["Category", "Words", "Preset Text"],
datatype=["str", "str", "str"],
row_count=(3, "dynamic"),
col_count=(3, "fixed"),
interactive=False,
column_widths=["12%", "6%", "82%"]
)
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="[S1] Enter your text prompt here...",
lines=4
)
gr.HTML('<hr class="section-separator">')
gr.Markdown("# Generation")
# Mode selector: Simple or Advanced (outside the accordion, centered and prominent)
with gr.Row():
with gr.Column(scale=1):
pass # Empty column for spacing
with gr.Column(scale=2):
mode_selector = gr.Radio(
choices=["Simple Mode", "Advanced Mode"],
value="Simple Mode",
label="",
info=None,
elem_id="component-mode-selector"
)
with gr.Column(scale=1):
pass # Empty column for spacing
with gr.Accordion("⚙️ Generation Parameters", open=True):
with gr.Row():
presets = load_sampler_presets()
preset_keys = list(presets.keys())
first_preset = preset_keys[0] if preset_keys else "Custom"
preset_dropdown = gr.Dropdown(
choices=["Custom"] + preset_keys,
value=first_preset, # Default to first preset instead of Custom
label="Sampler Preset",
info="Load preset configurations",
scale=2
)
rng_seed = gr.Number(
label="RNG Seed",
value=0,
info="Random seed for starting noise",
precision=0,
scale=1
)
# Simple mode: Speaker KV checkbox on same row (visible by default)
with gr.Column(scale=1, visible=True) as simple_mode_row:
speaker_kv_simple_checkbox = gr.Checkbox(
label="\"Force Speaker\" (Enable Speaker KV Attention Scaling)",
value=False,
info="Enable if generation does not match reference voice (otherwise leave off)"
)
# Advanced mode: Compile and custom shapes checkboxes (hidden by default)
with gr.Column(scale=1, visible=False) as advanced_mode_compile_column:
compile_checkbox = gr.Checkbox(
label="Compile Model",
value=True, # Default to True in simple mode
interactive=not IS_ZEROGPU,
info="Compile disabled on Zero GPU" if IS_ZEROGPU else "~20-30% faster after initial compilation"
)
compile_status = gr.Markdown(
value="⚠️ Compile disabled on Zero GPU" if IS_ZEROGPU else "",
visible=IS_ZEROGPU
)
use_custom_shapes_checkbox = gr.Checkbox(
label="Use Custom Shapes (Advanced)",
value=False,
info="Override default sequence lengths for text, speaker, and sample"
)
# Advanced mode controls (hidden by default)
with gr.Column(visible=False) as advanced_mode_column:
with gr.Row(visible=False) as custom_shapes_row:
max_text_byte_length = gr.Textbox(
label="Max Text Byte Length (padded)",
value="768",
info="Maximum text utf-8 byte sequence length (blank -> no padding)",
scale=1
)
max_speaker_latent_length = gr.Textbox(
label="Max Speaker Latent Length (padded)",
value="2560",
info="Maximum (unpatched)speaker latent length (blank -> no padding), default 2560 = ~30s",
scale=1
)
sample_latent_len = gr.Textbox(
label="Sample Latent Length",
value="640",
info="Maximum sample latent length (EXPERIMENTAL!!! ONLY TRAINED WITH 640 BUT SOMEHOW WORKS WITH < 640 TO GENERATE PREFIXES)",
scale=1
)
with gr.Row():
# Left column: Core Sampling Parameters
with gr.Column(scale=1):
with gr.Group():
gr.HTML("""
<div class="preset-inline">
<span class="title">Core Sampling Parameters</span><span class="dim">(</span>
<a href="javascript:void(0)" class="preset-link" data-fire="core_default">default</a>
<span class="dim">)</span>
</div>
""")
core_preset_default = gr.Button("", elem_id="core_default", elem_classes=["proxy-btn"])
num_steps = gr.Number(label="Number of Steps", value=40, info="Number of sampling steps (consider 20 - 80) (capped at 80)", precision=0, minimum=1, step=5, maximum=80)
cfg_mode = gr.Radio(
choices=[
"independent",
"apg-independent",
"alternating",
"joint-unconditional"
],
value="independent",
label="CFG Mode",
info="Independent (3 NFE), Adaptive Projected Guidance (3 NFE, see https://arxiv.org/abs/2410.02416), Alternating (2 NFE), Joint-Unconditional (2 NFE)"
)
with gr.Group():
gr.HTML("""
<div class="preset-inline">
<span class="title">CFG Guidance</span><span class="dim">(</span>
<a href="javascript:void(0)" class="preset-link" data-fire="cfg_default">default</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="cfg_higher">higher speaker</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="cfg_large">large guidances(works with apg)</a>
<span class="dim">)</span>
</div>
""")
cfg_preset_default = gr.Button("", elem_id="cfg_default", elem_classes=["proxy-btn"])
cfg_preset_higher_speaker = gr.Button("", elem_id="cfg_higher", elem_classes=["proxy-btn"])
cfg_preset_large_guidances = gr.Button("", elem_id="cfg_large", elem_classes=["proxy-btn"])
with gr.Row():
cfg_scale_text = gr.Number(label="Text CFG Scale", value=3.0, info="Guidance strength for text", minimum=0, step=0.5)
cfg_scale_speaker = gr.Number(label="Speaker CFG Scale", value=5.0, info="Guidance strength for speaker", minimum=0, step=0.5)
with gr.Row():
cfg_min_t = gr.Number(label="CFG Min t", value=0.5, info="(0-1), CFG applied when t >= val", minimum=0, maximum=1, step=0.05)
cfg_max_t = gr.Number(label="CFG Max t", value=1.0, info="(0-1), CFG applied when t <= val", minimum=0, maximum=1, step=0.05)
# Right column: Speaker KV, Truncation + APG
with gr.Column(scale=1):
with gr.Group():
gr.HTML("""
<div class="preset-inline">
<span class="title">Speaker KV Attention Scaling</span><span class="dim">(</span>
<a href="javascript:void(0)" class="preset-link" data-fire="spk_kv_enable">enable if generation does not match reference</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="spk_kv_off">off</a>
<span class="dim">)</span>
</div>
""")
spk_kv_preset_enable = gr.Button("", elem_id="spk_kv_enable", elem_classes=["proxy-btn"])
spk_kv_preset_off = gr.Button("", elem_id="spk_kv_off", elem_classes=["proxy-btn"])
speaker_k_enable = gr.Checkbox(label="Enable Speaker KV Scaling", value=False, info="Scale speaker attention key-values; useful when the model-generated audio does not at all match the reference audio (i.e. ignores speaker-reference)")
with gr.Row(visible=False) as speaker_k_row:
speaker_k_scale = gr.Number(label="KV Scale", value=1.5, info="Scale factor", minimum=0, step=0.1)
speaker_k_min_t = gr.Number(label="KV Min t", value=0.9, info="(0-1), scale applied from steps t=1. to val", minimum=0, maximum=1, step=0.05)
speaker_k_max_layers = gr.Number(label="Max Layers", value=24, info="(0-24), scale applied in first N layers", precision=0, minimum=0, maximum=24)
with gr.Group():
gr.HTML("""
<div class="preset-inline">
<span class="title">Truncation &amp; Temporal Rescaling</span><span class="dim">(</span>
<a href="javascript:void(0)" class="preset-link" data-fire="trunc_flat">flat</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="trunc_sharp">sharp</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="trunc_baseline">baseline(sharp)</a>
<span class="dim">)</span>
</div>
""")
trunc_preset_flat = gr.Button("", elem_id="trunc_flat", elem_classes=["proxy-btn"])
trunc_preset_sharp = gr.Button("", elem_id="trunc_sharp", elem_classes=["proxy-btn"])
trunc_preset_baseline = gr.Button("", elem_id="trunc_baseline", elem_classes=["proxy-btn"])
with gr.Row():
truncation_factor = gr.Number(label="Truncation Factor", value=0.8, info="Multiply initial noise (<1 helps artifacts)", minimum=0, step=0.05)
rescale_k = gr.Number(label="Rescale k", value=1.2, info="<1=sharpen, >1=flatten, 1=off", minimum=0, step=0.05)
rescale_sigma = gr.Number(label="Rescale σ", value=3.0, info="Sigma parameter", minimum=0, step=0.1)
with gr.Group(visible=False) as apg_row:
gr.HTML("""
<div class="preset-inline">
<span class="title">APG Parameters</span><span class="dim">(</span>
<a href="javascript:void(0)" class="preset-link" data-fire="apg_default">default</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="apg_no_momentum">no momentum</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="apg_norms">norms</a>
<span class="dim">,</span>
<a href="javascript:void(0)" class="preset-link" data-fire="apg_no_eta">no eta</a>
<span class="dim">)</span>
</div>
""")
apg_preset_default = gr.Button("", elem_id="apg_default", elem_classes=["proxy-btn"])
apg_preset_no_momentum = gr.Button("", elem_id="apg_no_momentum", elem_classes=["proxy-btn"])
apg_preset_norms = gr.Button("", elem_id="apg_norms", elem_classes=["proxy-btn"])
apg_preset_no_eta = gr.Button("", elem_id="apg_no_eta", elem_classes=["proxy-btn"])
with gr.Row():
apg_eta_text = gr.Number(label="APG η (text)", value=0.5, info="Eta for text projection (0-1, higher -> more like CFG)", minimum=0, maximum=1, step=0.25)
apg_eta_speaker = gr.Number(label="APG η (speaker)", value=0.5, info="Eta for speaker projection (0-1, higher -> more like CFG)", minimum=0, maximum=1, step=0.25)
with gr.Row() as apg_row2:
apg_momentum_text = gr.Number(label="APG Momentum (text)", value=-0.25, info="Text momentum (can try 0., -.25, -0.5, -0.75...)", step=0.25)
apg_momentum_speaker = gr.Number(label="APG Momentum (speaker)", value=-0.25, info="Speaker momentum (can try 0., -.25, -0.5, -0.75...)", step=0.25)
with gr.Row():
apg_norm_text = gr.Textbox(label="APG Norm (text)", value="", info="Text norm clip (leave blank to disable, can try 7.5, 15.0)")
apg_norm_speaker = gr.Textbox(label="APG Norm (speaker)", value="", info="Speaker norm clip (leave blank to disable, can try 7.5, 15.0)")
# End of advanced_mode_column
with gr.Row(equal_height=True):
audio_format = gr.Radio(
choices=["wav", "mp3"],
value="wav",
label="Format",
scale=1,
min_width=90
)
generate_btn = gr.Button("Generate Audio", variant="primary", size="lg", scale=10)
with gr.Column(scale=1):
show_original_audio = gr.Checkbox(
label="Re-display original audio (full 2-minute cropped mono)",
value=False
)
reconstruct_first_30_seconds = gr.Checkbox(
label="Show Autoencoder Reconstruction (only first 30s of reference)",
value=False
)
gr.HTML('<hr class="section-separator">')
with gr.Accordion("Generated Audio", open=True, visible=True) as generated_section:
generation_time_display = gr.Markdown("", visible=False)
with gr.Group(elem_classes=["generated-audio-player"]):
generated_audio = gr.Audio(label="Generated Audio", visible=True)
text_prompt_display = gr.Markdown("", visible=False)
gr.Markdown("---")
reference_audio_header = gr.Markdown("#### Reference Audio", visible=False)
with gr.Accordion("Original Audio (2 min Cropped Mono)", open=False, visible=False) as original_accordion:
original_audio = gr.Audio(label="Original Reference Audio (2 min)", visible=True)
with gr.Accordion("Autoencoder Reconstruction of First 30s of Reference", open=False, visible=False) as reference_accordion:
reference_audio = gr.Audio(label="Decoded Reference Audio (30s)", visible=True)
# Event handlers
# Custom Audio Panel - handle audio change to update speaker_audio_path_state
custom_audio_input.change(
lambda audio: gr.update(value=audio if audio else ""),
inputs=[custom_audio_input],
outputs=[speaker_audio_path_state]
)
# Audio prompt library table selection (only if configured)
if AUDIO_PROMPT_FOLDER is not None and AUDIO_PROMPT_FOLDER.exists():
audio_prompt_table.select(
select_audio_prompt_file,
outputs=[custom_audio_input]
)
# Dataset selector: switch between Custom Audio Panel, EARS, VCTK, Expresso, and HF-Custom
dataset_selector.change(
switch_dataset,
inputs=[dataset_selector],
outputs=[
dataset_license_info, custom_audio_row, voicebank_row, voice_type_column,
ears_column, vctk_column, expresso_column, hf_custom_column,
selected_speaker_display, freeform_table, emotions_table,
selected_voice_display, vctk_speaker_display, expresso_speaker_display, hf_custom_speaker_display, selected_speaker_state,
audio_preview, speaker_st_path_state, speaker_audio_path_state
]
)
# EARS: Speaker search
speaker_search.change(
search_speakers,
inputs=[speaker_search],
outputs=[speakers_table]
)
# EARS: Speaker selection - populate freeform and emotions tables
speakers_table.select(
select_speaker_from_table,
inputs=[speakers_table],
outputs=[selected_speaker_display, freeform_table, emotions_table, selected_speaker_state, audio_preview, speaker_st_path_state, speaker_audio_path_state, selected_voice_display]
)
# VCTK: Speaker search
vctk_speaker_search.change(
search_vctk_speakers,
inputs=[vctk_speaker_search],
outputs=[vctk_speakers_table]
)
# VCTK: Speaker selection - load voice files directly
vctk_speakers_table.select(
select_vctk_speaker_from_table,
inputs=[vctk_speakers_table],
outputs=[vctk_speaker_display, selected_speaker_state, audio_preview, speaker_st_path_state, speaker_audio_path_state]
)
# Expresso: Speaker search
expresso_speaker_search.change(
search_expresso_speakers,
inputs=[expresso_speaker_search],
outputs=[expresso_speakers_table]
)
# Expresso: Speaker selection - load voice files directly
expresso_speakers_table.select(
select_expresso_speaker_from_table,
inputs=[expresso_speakers_table],
outputs=[expresso_speaker_display, selected_speaker_state, audio_preview, speaker_st_path_state, speaker_audio_path_state]
)
# HF-Custom: Speaker search
hf_custom_speaker_search.change(
search_hf_custom_speakers,
inputs=[hf_custom_speaker_search],
outputs=[hf_custom_speakers_table]
)
# HF-Custom: Speaker selection - load voice files directly
hf_custom_speakers_table.select(
select_hf_custom_speaker_from_table,
inputs=[hf_custom_speakers_table],
outputs=[hf_custom_speaker_display, selected_speaker_state, audio_preview, speaker_st_path_state, speaker_audio_path_state]
)
# Freeform selection: load freeform voice files
freeform_table.select(
select_freeform_from_table,
inputs=[selected_speaker_state],
outputs=[selected_voice_display, audio_preview, speaker_st_path_state, speaker_audio_path_state]
)
# Emotion selection: load voice files
emotions_table.select(
select_emotion_from_table,
inputs=[selected_speaker_state],
outputs=[selected_voice_display, audio_preview, speaker_st_path_state, speaker_audio_path_state]
)
text_presets_table.select(select_text_preset, outputs=text_prompt)
# Mode selector handler
mode_selector.change(
toggle_mode,
inputs=[mode_selector, speaker_k_enable, speaker_kv_simple_checkbox],
outputs=[simple_mode_row, advanced_mode_compile_column, advanced_mode_column, speaker_kv_simple_checkbox, speaker_k_enable]
).then(
# Sync the row visibility and values after mode switch
lambda enabled: (gr.update(visible=enabled), gr.update(value=1.5 if enabled else 1.5), gr.update(value=0.9 if enabled else 0.9), gr.update(value=24 if enabled else 24)),
inputs=[speaker_k_enable],
outputs=[speaker_k_row, speaker_k_scale, speaker_k_min_t, speaker_k_max_layers]
)
# Simple mode speaker KV checkbox handler
speaker_kv_simple_checkbox.change(
sync_simple_to_advanced,
inputs=[speaker_kv_simple_checkbox],
outputs=[speaker_k_enable, speaker_k_row, speaker_k_scale, speaker_k_min_t, speaker_k_max_layers]
)
cfg_mode.change(update_cfg_visibility, inputs=cfg_mode, outputs=[cfg_scale_text, cfg_scale_speaker, apg_row])
# Speaker K enable handler - toggle row visibility and sync with simple mode
speaker_k_enable.change(
lambda enabled: (gr.update(visible=enabled), gr.update(value=enabled)),
inputs=[speaker_k_enable],
outputs=[speaker_k_row, speaker_kv_simple_checkbox]
)
# Custom shapes enable handler - toggle row visibility and reset to defaults on disable
def toggle_custom_shapes(enabled):
if enabled:
return (
gr.update(visible=True),
gr.update(),
gr.update(),
gr.update(),
)
else:
return (
gr.update(visible=False),
gr.update(value="768"),
gr.update(value="2560"),
gr.update(value="640"),
)
use_custom_shapes_checkbox.change(
toggle_custom_shapes,
inputs=[use_custom_shapes_checkbox],
outputs=[custom_shapes_row, max_text_byte_length, max_speaker_latent_length, sample_latent_len]
)
# Core preset handler
core_preset_default.click(
lambda: apply_core_preset("default"),
outputs=[rng_seed, num_steps, cfg_mode, preset_dropdown]
)
# CFG preset handlers
cfg_preset_default.click(
lambda: apply_cfg_preset("default"),
outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
)
cfg_preset_higher_speaker.click(
lambda: apply_cfg_preset("higher speaker"),
outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
)
cfg_preset_large_guidances.click(
lambda: apply_cfg_preset("large guidances"),
outputs=[cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, preset_dropdown]
)
# Speaker KV preset handlers
spk_kv_preset_enable.click(
lambda: apply_speaker_kv_preset("enable"),
outputs=[speaker_k_enable, speaker_k_row, preset_dropdown]
)
spk_kv_preset_off.click(
lambda: apply_speaker_kv_preset("off"),
outputs=[speaker_k_enable, speaker_k_row, preset_dropdown]
)
# Truncation preset handlers
trunc_preset_flat.click(
lambda: apply_truncation_preset("flat"),
outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown]
)
trunc_preset_sharp.click(
lambda: apply_truncation_preset("sharp"),
outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown]
)
trunc_preset_baseline.click(
lambda: apply_truncation_preset("baseline(sharp)"),
outputs=[truncation_factor, rescale_k, rescale_sigma, preset_dropdown]
)
# APG preset handlers
apg_preset_default.click(
lambda: apply_apg_preset("default"),
outputs=[apg_eta_text, apg_eta_speaker, apg_momentum_text, apg_momentum_speaker, apg_norm_text, apg_norm_speaker, preset_dropdown]
)
apg_preset_no_momentum.click(
lambda: apply_apg_preset("no momentum"),
outputs=[apg_eta_text, apg_eta_speaker, apg_momentum_text, apg_momentum_speaker, apg_norm_text, apg_norm_speaker, preset_dropdown]
)
apg_preset_norms.click(
lambda: apply_apg_preset("norms"),
outputs=[apg_eta_text, apg_eta_speaker, apg_momentum_text, apg_momentum_speaker, apg_norm_text, apg_norm_speaker, preset_dropdown]
)
apg_preset_no_eta.click(
lambda: apply_apg_preset("no eta"),
outputs=[apg_eta_text, apg_eta_speaker, apg_momentum_text, apg_momentum_speaker, apg_norm_text, apg_norm_speaker, preset_dropdown]
)
# Preset handler
preset_dropdown.change(
apply_sampler_preset,
inputs=preset_dropdown,
outputs=[num_steps, cfg_mode, cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, truncation_factor, rescale_k, rescale_sigma, speaker_k_enable, speaker_k_row, speaker_k_scale, speaker_k_min_t, speaker_k_max_layers, apg_eta_text, apg_eta_speaker, apg_momentum_text, apg_momentum_speaker, apg_norm_text, apg_norm_speaker]
)
# Compile handler
compile_checkbox.change(
compile_model,
inputs=compile_checkbox,
outputs=[compile_checkbox, compile_status]
).then(
do_compile,
outputs=[compile_status, compile_checkbox]
)
generate_btn.click(
generate_audio,
inputs=[
text_prompt,
speaker_st_path_state,
speaker_audio_path_state,
num_steps,
rng_seed,
cfg_mode,
cfg_scale_text,
cfg_scale_speaker,
cfg_min_t,
cfg_max_t,
truncation_factor,
rescale_k,
rescale_sigma,
speaker_k_enable,
speaker_k_scale,
speaker_k_min_t,
speaker_k_max_layers,
apg_eta_text,
apg_eta_speaker,
apg_momentum_text,
apg_momentum_speaker,
apg_norm_text,
apg_norm_speaker,
reconstruct_first_30_seconds,
use_custom_shapes_checkbox,
max_text_byte_length,
max_speaker_latent_length,
sample_latent_len,
audio_format,
compile_checkbox, # Pass compile state to choose model
show_original_audio,
session_id_state,
],
outputs=[generated_section, generated_audio, text_prompt_display, original_audio, generation_time_display, reference_audio, original_accordion, reference_accordion, reference_audio_header]
)
# Initialize session ID and trigger compilation when the page loads
demo.load(init_and_compile, outputs=[session_id_state, compile_status, compile_checkbox]).then(
# Apply the first preset on load
lambda: apply_sampler_preset(list(load_sampler_presets().keys())[0]),
outputs=[num_steps, cfg_mode, cfg_scale_text, cfg_scale_speaker, cfg_min_t, cfg_max_t, truncation_factor, rescale_k, rescale_sigma, speaker_k_enable, speaker_k_row, speaker_k_scale, speaker_k_min_t, speaker_k_max_layers, apg_eta_text, apg_eta_speaker, apg_momentum_text, apg_momentum_speaker, apg_norm_text, apg_norm_speaker]
)
if __name__ == "__main__":
# For HF-Custom, allow the entire dataset cache directory to handle subdirectories
hf_custom_cache = HF_CUSTOM_PATH.parent.parent.parent
allowed_paths = [
str(EARS_PATH),
str(VCTK_PATH),
str(EXPRESSO_PATH),
str(hf_custom_cache),
str(TEMP_AUDIO_DIR),
str(AUDIO_PROMPT_FOLDER)
]
# Enable queue for better handling of concurrent requests on HF Spaces
demo.queue(max_size=20)
demo.launch(allowed_paths=allowed_paths)