Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Gradio UI for Text-to-Speech using HiggsAudioServeEngine | |
Enhanced with visual improvements and better user experience | |
""" | |
import argparse | |
import base64 | |
import os | |
import uuid | |
import json | |
from typing import Optional | |
import gradio as gr | |
from loguru import logger | |
import numpy as np | |
import time | |
from functools import lru_cache | |
import re | |
import spaces | |
import torch | |
# Import HiggsAudio components | |
from higgs_audio.serve.serve_engine import HiggsAudioServeEngine | |
from higgs_audio.data_types import ChatMLSample, AudioContent, Message | |
# Global engine instance | |
engine = None | |
VOICE_PRESETS = {} | |
# Default model configuration | |
DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base" | |
DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer" | |
SAMPLE_RATE = 24000 | |
DEFAULT_SYSTEM_PROMPT = ( | |
"Generate audio following instruction.\n\n" | |
"<|scene_desc_start|>\n" | |
"Audio is recorded from a quiet room.\n" | |
"Support for multiple languages including English, Chinese, Korean, Japanese, and more.\n" | |
"<|scene_desc_end|>" | |
) | |
DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"] | |
# Predefined examples for system and input messages | |
PREDEFINED_EXAMPLES = { | |
"voice-clone": { | |
"system_prompt": "", | |
"input_text": "Hey there! I'm your friendly voice twin in the making. Pick a voice preset below or upload your own audio - let's clone some vocals and bring your voice to life! ", | |
"description": "🎭 <b>Voice Clone</b> - Clone any voice with reference audio. Leave the system prompt empty for best results.", | |
"icon": "🎭", | |
"color": "#FF6B6B" | |
}, | |
"smart-voice": { | |
"system_prompt": DEFAULT_SYSTEM_PROMPT, | |
"input_text": "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.", | |
"description": "🧠 <b>Smart Voice</b> - Generate natural speech based on context", | |
"icon": "🧠", | |
"color": "#4ECDC4" | |
}, | |
"multispeaker-voice-description": { | |
"system_prompt": "You are an AI assistant designed to convert text into speech.\n" | |
"If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.\n" | |
"If no speaker tag is present, select a suitable voice on your own.\n\n" | |
"<|scene_desc_start|>\n" | |
"SPEAKER0: feminine\n" | |
"SPEAKER1: masculine\n" | |
"<|scene_desc_end|>", | |
"input_text": "[SPEAKER0] I can't believe you did that without even asking me first!\n" | |
"[SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this.\n" | |
"[SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion!\n" | |
"[SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act.", | |
"description": "👥 <b>Multi-Speaker</b> - Different voices for dialogue and conversations", | |
"icon": "👥", | |
"color": "#95E1D3" | |
}, | |
"single-speaker-voice-description": { | |
"system_prompt": "Generate audio following instruction.\n\n" | |
"<|scene_desc_start|>\n" | |
"SPEAKER0: He speaks with a clear British accent and a conversational, inquisitive tone. His delivery is articulate and at a moderate pace, and very clear audio.\n" | |
"<|scene_desc_end|>", | |
"input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n" | |
"It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n" | |
"And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n" | |
"\n" | |
"So here's the big question: Do you want to understand how deep learning works?\n", | |
"description": "🎙️ <b>Voice Description</b> - Generate speech with specific voice characteristics", | |
"icon": "🎙️", | |
"color": "#F38181" | |
}, | |
"single-speaker-zh": { | |
"system_prompt": "Generate audio following instruction.\n\n" | |
"<|scene_desc_start|>\n" | |
"Audio is recorded from a quiet room.\n" | |
"<|scene_desc_end|>", | |
"input_text": "大家好, 欢迎收听本期的跟李沐学AI. 今天沐哥在忙着洗数据, 所以由我, 希格斯主播代替他讲这期视频.\n" | |
"今天我们要聊的是一个你绝对不能忽视的话题: 多模态学习.\n" | |
"那么, 问题来了, 你真的了解多模态吗? 你知道如何自己动手构建多模态大模型吗.\n" | |
"或者说, 你能察觉到我其实是个机器人吗?", | |
"description": "🇨🇳 <b>Chinese Speech</b> - Generate natural Chinese speech", | |
"icon": "🇨🇳", | |
"color": "#AA96DA" | |
}, | |
"single-speaker-kr": { | |
"system_prompt": "Generate audio following instruction.\n\n" | |
"<|scene_desc_start|>\n" | |
"Audio is recorded from a quiet room.\n" | |
"<|scene_desc_end|>", | |
"input_text": "안녕하세요, 오늘은 인공지능의 미래에 대해 이야기해보겠습니다.\n" | |
"최근 AI 기술의 발전이 정말 놀라운데요,\n" | |
"특히 음성 합성 기술은 이제 사람과 구별하기 어려울 정도로 자연스러워졌습니다.\n" | |
"여러분은 제가 실제 사람인지 AI인지 구별할 수 있으신가요?", | |
"description": "🇰🇷 <b>Korean Speech</b> - Generate natural Korean speech", | |
"icon": "🇰🇷", | |
"color": "#FFB6C1" | |
}, | |
"single-speaker-bgm": { | |
"system_prompt": DEFAULT_SYSTEM_PROMPT, | |
"input_text": "[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it's due, so that defeat is not disgrace. And I hope I don't have to do it often. [music end]", | |
"description": "🎵 <b>Speech with BGM</b> - Add background music to your speech (experimental)", | |
"icon": "🎵", | |
"color": "#FCBAD3" | |
}, | |
} | |
def encode_audio_file(file_path): | |
"""Encode an audio file to base64.""" | |
with open(file_path, "rb") as audio_file: | |
return base64.b64encode(audio_file.read()).decode("utf-8") | |
def get_current_device(): | |
"""Get the current device.""" | |
return "cuda" if torch.cuda.is_available() else "cpu" | |
def load_voice_presets(): | |
"""Load the voice presets from the voice_examples directory.""" | |
try: | |
config_path = os.path.join(os.path.dirname(__file__), "voice_examples", "config.json") | |
# Check if directory exists | |
if not os.path.exists(os.path.dirname(config_path)): | |
logger.warning("Voice examples directory not found") | |
return {"EMPTY": "No reference voice"} | |
with open(config_path, "r") as f: | |
voice_dict = json.load(f) | |
voice_presets = {k: v["transcript"] for k, v in voice_dict.items()} | |
voice_presets["EMPTY"] = "No reference voice" | |
logger.info(f"Loaded voice presets: {list(voice_presets.keys())}") | |
return voice_presets | |
except FileNotFoundError: | |
logger.warning("Voice examples config file not found. Using empty voice presets.") | |
return {"EMPTY": "No reference voice"} | |
except Exception as e: | |
logger.error(f"Error loading voice presets: {e}") | |
return {"EMPTY": "No reference voice"} | |
def get_voice_preset(voice_preset): | |
"""Get the voice path and text for a given voice preset.""" | |
voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav") | |
if not os.path.exists(voice_path): | |
logger.warning(f"Voice preset file not found: {voice_path}") | |
return None, "Voice preset not found" | |
text = VOICE_PRESETS.get(voice_preset, "No transcript available") | |
return voice_path, text | |
def normalize_chinese_punctuation(text): | |
""" | |
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents. | |
""" | |
# Mapping of Chinese punctuation to English punctuation | |
chinese_to_english_punct = { | |
",": ", ", # comma | |
"。": ".", # period | |
":": ":", # colon | |
";": ";", # semicolon | |
"?": "?", # question mark | |
"!": "!", # exclamation mark | |
"(": "(", # left parenthesis | |
")": ")", # right parenthesis | |
"【": "[", # left square bracket | |
"】": "]", # right square bracket | |
"《": "<", # left angle quote | |
"》": ">", # right angle quote | |
""": '"', # left double quotation | |
""": '"', # right double quotation | |
"'": "'", # left single quotation | |
"'": "'", # right single quotation | |
"、": ",", # enumeration comma | |
"—": "-", # em dash | |
"…": "...", # ellipsis | |
"·": ".", # middle dot | |
"「": '"', # left corner bracket | |
"」": '"', # right corner bracket | |
"『": '"', # left double corner bracket | |
"』": '"', # right double corner bracket | |
} | |
# Replace each Chinese punctuation with its English counterpart | |
for zh_punct, en_punct in chinese_to_english_punct.items(): | |
text = text.replace(zh_punct, en_punct) | |
return text | |
def normalize_text(transcript: str): | |
# Skip normalization for Korean text to preserve it properly | |
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in transcript): | |
# Korean text detected - minimal normalization | |
transcript = transcript.strip() | |
if transcript and not any([transcript.endswith(c) for c in [".", "!", "?", "。", "!", "?"]]): | |
transcript += "." | |
return transcript | |
# Chinese punctuation normalization | |
transcript = normalize_chinese_punctuation(transcript) | |
# Other normalizations (e.g., parentheses and other symbols) | |
transcript = transcript.replace("(", " ") | |
transcript = transcript.replace(")", " ") | |
transcript = transcript.replace("°F", " degrees Fahrenheit") | |
transcript = transcript.replace("°C", " degrees Celsius") | |
for tag, replacement in [ | |
("[laugh]", "<SE>[Laughter]</SE>"), | |
("[humming start]", "<SE>[Humming]</SE>"), | |
("[humming end]", "<SE_e>[Humming]</SE_e>"), | |
("[music start]", "<SE_s>[Music]</SE_s>"), | |
("[music end]", "<SE_e>[Music]</SE_e>"), | |
("[music]", "<SE>[Music]</SE>"), | |
("[sing start]", "<SE_s>[Singing]</SE_s>"), | |
("[sing end]", "<SE_e>[Singing]</SE_e>"), | |
("[applause]", "<SE>[Applause]</SE>"), | |
("[cheering]", "<SE>[Cheering]</SE>"), | |
("[cough]", "<SE>[Cough]</SE>"), | |
]: | |
transcript = transcript.replace(tag, replacement) | |
lines = transcript.split("\n") | |
transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()]) | |
transcript = transcript.strip() | |
if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]): | |
transcript += "." | |
return transcript | |
def initialize_engine(model_path, audio_tokenizer_path) -> bool: | |
"""Initialize the HiggsAudioServeEngine.""" | |
global engine | |
try: | |
if engine is not None: | |
logger.info("Engine already initialized") | |
return True | |
logger.info(f"Initializing engine with model: {model_path} and audio tokenizer: {audio_tokenizer_path}") | |
engine = HiggsAudioServeEngine( | |
model_name_or_path=model_path, | |
audio_tokenizer_name_or_path=audio_tokenizer_path, | |
device=get_current_device(), | |
) | |
logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to initialize engine: {e}") | |
return False | |
def check_return_audio(audio_wv: np.ndarray): | |
# check if the audio returned is all silent | |
if np.all(audio_wv == 0): | |
logger.warning("Audio is silent, returning None") | |
def process_text_output(text_output: str): | |
# remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|> | |
text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output) | |
return text_output | |
def prepare_chatml_sample( | |
voice_preset: str, | |
text: str, | |
reference_audio: Optional[str] = None, | |
reference_text: Optional[str] = None, | |
system_prompt: str = DEFAULT_SYSTEM_PROMPT, | |
): | |
"""Prepare a ChatMLSample for the HiggsAudioServeEngine.""" | |
messages = [] | |
# Add system message if provided | |
if len(system_prompt) > 0: | |
messages.append(Message(role="system", content=system_prompt)) | |
# Add reference audio if provided | |
audio_base64 = None | |
ref_text = "" | |
if reference_audio: | |
# Custom reference audio | |
audio_base64 = encode_audio_file(reference_audio) | |
ref_text = reference_text or "" | |
elif voice_preset != "EMPTY": | |
# Voice preset | |
voice_path, ref_text = get_voice_preset(voice_preset) | |
if voice_path is None: | |
logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio") | |
else: | |
audio_base64 = encode_audio_file(voice_path) | |
# Only add reference audio if we have it | |
if audio_base64 is not None: | |
# Add user message with reference text | |
messages.append(Message(role="user", content=ref_text)) | |
# Add assistant message with audio content | |
audio_content = AudioContent(raw_audio=audio_base64, audio_url="") | |
messages.append(Message(role="assistant", content=[audio_content])) | |
# Add the main user message | |
text = normalize_text(text) | |
messages.append(Message(role="user", content=text)) | |
return ChatMLSample(messages=messages) | |
def text_to_speech( | |
text, | |
voice_preset, | |
reference_audio=None, | |
reference_text=None, | |
max_completion_tokens=1024, | |
temperature=1.0, | |
top_p=0.95, | |
top_k=50, | |
system_prompt=DEFAULT_SYSTEM_PROMPT, | |
stop_strings=None, | |
ras_win_len=7, | |
ras_win_max_num_repeat=2, | |
): | |
"""Convert text to speech using HiggsAudioServeEngine.""" | |
global engine | |
if engine is None: | |
if not initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH): | |
return "❌ Failed to initialize engine", None | |
try: | |
# Prepare ChatML sample | |
chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt) | |
# Convert stop strings format | |
if stop_strings is None: | |
stop_list = DEFAULT_STOP_STRINGS | |
else: | |
stop_list = [s for s in stop_strings["stops"] if s.strip()] | |
request_id = f"tts-playground-{str(uuid.uuid4())}" | |
logger.info( | |
f"{request_id}: Generating speech for text: {text[:100]}..., \n" | |
f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}, " | |
f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}" | |
) | |
start_time = time.time() | |
# Generate using the engine | |
response = engine.generate( | |
chat_ml_sample=chatml_sample, | |
max_new_tokens=max_completion_tokens, | |
temperature=temperature, | |
top_k=top_k if top_k > 0 else None, | |
top_p=top_p, | |
stop_strings=stop_list, | |
ras_win_len=ras_win_len if ras_win_len > 0 else None, | |
ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat), | |
) | |
generation_time = time.time() - start_time | |
logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds") | |
gr.Info(f"Generated audio in {generation_time:.3f} seconds") | |
# Process the response | |
text_output = process_text_output(response.generated_text) | |
if response.audio is not None: | |
# Convert to int16 for Gradio | |
audio_data = (response.audio * 32767).astype(np.int16) | |
check_return_audio(audio_data) | |
return text_output, (response.sampling_rate, audio_data) | |
else: | |
logger.warning("No audio generated") | |
return text_output, None | |
except Exception as e: | |
error_msg = f"Error generating speech: {e}" | |
logger.error(error_msg) | |
gr.Error(error_msg) | |
return f"❌ {error_msg}", None | |
def initialize_globals(): | |
"""Initialize global variables""" | |
global VOICE_PRESETS | |
VOICE_PRESETS = load_voice_presets() | |
def create_ui(): | |
# Try to load theme | |
try: | |
my_theme = gr.Theme.load("theme.json") | |
except Exception as e: | |
logger.warning(f"Failed to load theme.json: {e}, using default theme") | |
my_theme = gr.themes.Default() | |
# Enhanced CSS with animations and visual improvements | |
custom_css = """ | |
/* Remove focus highlighting */ | |
.gradio-container input:focus, | |
.gradio-container textarea:focus, | |
.gradio-container select:focus, | |
.gradio-container .gr-input:focus, | |
.gradio-container .gr-textarea:focus, | |
.gradio-container .gr-textbox:focus, | |
.gradio-container .gr-textbox:focus-within, | |
.gradio-container .gr-form:focus-within, | |
.gradio-container *:focus { | |
box-shadow: none !important; | |
border-color: var(--border-color-primary) !important; | |
outline: none !important; | |
background-color: var(--input-background-fill) !important; | |
} | |
/* Gradient background */ | |
.gradio-container { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
min-height: 100vh; | |
} | |
/* Main container styling */ | |
.container { | |
backdrop-filter: blur(10px); | |
background: rgba(255, 255, 255, 0.95); | |
border-radius: 20px; | |
box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37); | |
} | |
/* Fix dropdown visibility issues */ | |
.gr-dropdown { | |
position: relative !important; | |
z-index: 999 !important; | |
} | |
.gr-dropdown-container { | |
position: relative !important; | |
overflow: visible !important; | |
} | |
.gr-dropdown .gr-dropdown-list { | |
position: absolute !important; | |
z-index: 1000 !important; | |
background: white !important; | |
border: 1px solid #e0e0e0 !important; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; | |
max-height: 300px !important; | |
overflow-y: auto !important; | |
} | |
/* Ensure parent containers don't clip dropdown */ | |
.gr-form, .gr-box, .gr-group { | |
overflow: visible !important; | |
} | |
.template-selector { | |
position: relative !important; | |
z-index: 100 !important; | |
} | |
/* Main content area fix */ | |
.main-content { | |
overflow: visible !important; | |
position: relative; | |
z-index: 1; | |
} | |
.input-column { | |
overflow: visible !important; | |
position: relative; | |
} | |
/* Global overflow fix for dropdown visibility */ | |
.gr-panel { | |
overflow: visible !important; | |
} | |
/* Header styling */ | |
.header-container { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
padding: 2rem; | |
border-radius: 15px; | |
margin-bottom: 2rem; | |
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); | |
} | |
.header-title { | |
color: white; | |
font-size: 2.5rem; | |
font-weight: bold; | |
text-align: center; | |
margin: 0; | |
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2); | |
} | |
.header-subtitle { | |
color: rgba(255, 255, 255, 0.9); | |
text-align: center; | |
margin-top: 0.5rem; | |
font-size: 1.1rem; | |
} | |
/* Template cards */ | |
.template-card { | |
background: white; | |
border-radius: 12px; | |
padding: 1.5rem; | |
margin: 0.5rem; | |
border: 2px solid transparent; | |
transition: all 0.3s ease; | |
cursor: pointer; | |
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); | |
} | |
.template-card:hover { | |
transform: translateY(-3px); | |
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15); | |
border-color: var(--primary-500); | |
} | |
.template-card.selected { | |
border-color: var(--primary-500); | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
} | |
.template-icon { | |
font-size: 2rem; | |
margin-bottom: 0.5rem; | |
} | |
/* Voice preset cards */ | |
.voice-card { | |
background: white; | |
border-radius: 10px; | |
padding: 1rem; | |
margin: 0.5rem; | |
border: 2px solid #e0e0e0; | |
transition: all 0.3s ease; | |
cursor: pointer; | |
text-align: center; | |
} | |
.voice-card:hover { | |
border-color: var(--primary-500); | |
transform: scale(1.05); | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
} | |
.voice-card.selected { | |
border-color: var(--primary-500); | |
background: #f0f8ff; | |
} | |
/* Generate button animation */ | |
.generate-btn { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
font-size: 1.2rem; | |
font-weight: bold; | |
padding: 0.8rem 2rem; | |
border-radius: 30px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4); | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6); | |
} | |
.generate-btn:active { | |
transform: translateY(0); | |
} | |
/* Audio player styling */ | |
.audio-container { | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
padding: 2rem; | |
border-radius: 15px; | |
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); | |
} | |
/* Progress indicator */ | |
.progress-bar { | |
height: 4px; | |
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
border-radius: 2px; | |
animation: progress 2s ease-in-out infinite; | |
} | |
@keyframes progress { | |
0% { transform: translateX(-100%); } | |
100% { transform: translateX(100%); } | |
} | |
/* Accordion styling */ | |
.gr-accordion { | |
background: white; | |
border-radius: 10px; | |
border: 1px solid #e0e0e0; | |
margin-top: 1rem; | |
} | |
/* Info cards */ | |
.info-card { | |
background: #f8f9fa; | |
border-left: 4px solid var(--primary-500); | |
padding: 1rem; | |
margin: 1rem 0; | |
border-radius: 5px; | |
} | |
/* Tooltips */ | |
.tooltip { | |
position: relative; | |
display: inline-block; | |
border-bottom: 1px dotted black; | |
} | |
.tooltip .tooltiptext { | |
visibility: hidden; | |
width: 200px; | |
background-color: #555; | |
color: #fff; | |
text-align: center; | |
border-radius: 6px; | |
padding: 5px; | |
position: absolute; | |
z-index: 1; | |
bottom: 125%; | |
left: 50%; | |
margin-left: -100px; | |
opacity: 0; | |
transition: opacity 0.3s; | |
} | |
.tooltip:hover .tooltiptext { | |
visibility: visible; | |
opacity: 1; | |
} | |
/* Dropdown specific styling to ensure visibility */ | |
.template-selector { | |
min-width: 300px; | |
z-index: 1000; | |
} | |
.gr-dropdown { | |
position: relative; | |
} | |
.gr-dropdown .gr-dropdown-list { | |
max-height: 300px; | |
overflow-y: auto; | |
z-index: 1001; | |
} | |
@media (max-width: 768px) { | |
.header-title { | |
font-size: 2rem; | |
} | |
.template-card { | |
margin: 0.25rem; | |
padding: 1rem; | |
} | |
} | |
""" | |
default_template = "smart-voice" | |
"""Create the enhanced Gradio UI.""" | |
with gr.Blocks(theme=my_theme, css=custom_css, title="Higgs Audio TTS") as demo: | |
# Header with gradient background | |
gr.HTML(""" | |
<div class="header-container"> | |
<h1 class="header-title">🎙️ Higgs Audio Text-to-Speech</h1> | |
<p class="header-subtitle">Transform your text into natural, expressive speech with AI</p> | |
</div> | |
""") | |
# Main UI section with fixed overflow | |
with gr.Row(elem_classes=["main-content"]): | |
with gr.Column(scale=2, elem_classes=["input-column"]): | |
# Template selection with visual cards | |
gr.Markdown("### 🎯 Choose Your Template") | |
# Define available templates | |
available_templates = list(PREDEFINED_EXAMPLES.keys()) | |
# Use Radio instead of Dropdown for better visibility | |
template_dropdown = gr.Radio( | |
label="TTS Template", | |
choices=available_templates, | |
value=default_template, | |
info="Select a predefined template to get started quickly", | |
type="value" | |
) | |
# Template description with enhanced styling | |
template_description = gr.HTML( | |
value=f'<div class="info-card">{PREDEFINED_EXAMPLES[default_template]["description"]}</div>', | |
visible=True, | |
) | |
# System prompt with better styling | |
with gr.Group(): | |
gr.Markdown("### 🔧 System Configuration") | |
system_prompt = gr.TextArea( | |
label="System Prompt", | |
placeholder="Enter system prompt to guide the model...", | |
value=PREDEFINED_EXAMPLES[default_template]["system_prompt"], | |
lines=3, | |
elem_classes=["system-prompt"] | |
) | |
# Input text with character counter | |
with gr.Group(): | |
gr.Markdown("### ✍️ Your Text") | |
input_text = gr.TextArea( | |
label="Input Text", | |
placeholder="Type the text you want to convert to speech...", | |
value=PREDEFINED_EXAMPLES[default_template]["input_text"], | |
lines=6, | |
elem_classes=["input-text"] | |
) | |
char_count = gr.Markdown(f"Character count: {len(PREDEFINED_EXAMPLES[default_template]['input_text'])}") | |
# Voice selection section | |
with gr.Group(visible=False) as voice_section: | |
gr.Markdown("### 🎭 Voice Selection") | |
voice_preset = gr.Dropdown( | |
label="Voice Preset", | |
choices=list(VOICE_PRESETS.keys()), | |
value="EMPTY", | |
interactive=False, | |
visible=False, | |
elem_classes=["voice-preset"] | |
) | |
with gr.Accordion( | |
"🎤 Custom Reference Audio", open=False, visible=False | |
) as custom_reference_accordion: | |
reference_audio = gr.Audio( | |
label="Upload Reference Audio", | |
type="filepath", | |
elem_classes=["reference-audio"] | |
) | |
reference_text = gr.TextArea( | |
label="Reference Text (transcript of the reference audio)", | |
placeholder="Enter the transcript of your reference audio for better voice cloning...", | |
lines=3, | |
elem_classes=["reference-text"] | |
) | |
# Advanced parameters with better organization | |
with gr.Accordion("⚙️ Advanced Parameters", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
max_completion_tokens = gr.Slider( | |
minimum=128, | |
maximum=4096, | |
value=1024, | |
step=10, | |
label="Max Completion Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.5, | |
value=1.0, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness in generation" | |
) | |
with gr.Column(): | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top P", | |
info="Nucleus sampling parameter" | |
) | |
top_k = gr.Slider( | |
minimum=-1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Top K", | |
info="Top-k sampling parameter (-1 to disable)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
ras_win_len = gr.Slider( | |
minimum=0, | |
maximum=10, | |
value=7, | |
step=1, | |
label="RAS Window Length", | |
info="Window length for repetition avoidance sampling" | |
) | |
with gr.Column(): | |
ras_win_max_num_repeat = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=2, | |
step=1, | |
label="RAS Max Num Repeat", | |
info="Maximum repetitions allowed in the window" | |
) | |
# Stop strings with better UI | |
gr.Markdown("#### Stop Strings") | |
stop_strings = gr.Dataframe( | |
label="Stop Strings", | |
headers=["stops"], | |
datatype=["str"], | |
value=[[s] for s in DEFAULT_STOP_STRINGS], | |
interactive=True, | |
col_count=(1, "fixed"), | |
elem_classes=["stop-strings"] | |
) | |
# Generate button with enhanced styling | |
with gr.Row(): | |
submit_btn = gr.Button( | |
"🚀 Generate Speech", | |
variant="primary", | |
scale=1, | |
elem_classes=["generate-btn"] | |
) | |
# Output column with better organization | |
with gr.Column(scale=2): | |
# Status and progress section | |
with gr.Group(): | |
gr.Markdown("### 📊 Generation Status") | |
status_text = gr.Markdown("Ready to generate speech...", elem_classes=["status-text"]) | |
# Model response section | |
with gr.Group(): | |
gr.Markdown("### 💬 Model Response") | |
output_text = gr.TextArea( | |
label="Generated Text Output", | |
lines=3, | |
interactive=False, | |
elem_classes=["output-text"] | |
) | |
# Audio output with enhanced player | |
with gr.Group(): | |
gr.Markdown("### 🎵 Generated Audio") | |
output_audio = gr.Audio( | |
label="Audio Player", | |
interactive=False, | |
autoplay=True, | |
elem_classes=["audio-container"] | |
) | |
with gr.Row(): | |
stop_btn = gr.Button( | |
"⏹️ Stop Playback", | |
variant="secondary", | |
elem_classes=["stop-btn"] | |
) | |
download_btn = gr.Button( | |
"💾 Download Audio", | |
variant="secondary", | |
elem_classes=["download-btn"], | |
visible=False | |
) | |
# Quick tips section | |
gr.Markdown(""" | |
<div class="info-card"> | |
<h4>💡 Quick Tips:</h4> | |
<ul> | |
<li>For voice cloning, upload a clear 10-30 second audio sample</li> | |
<li>Use [music start] and [music end] tags for background music</li> | |
<li>Add [SPEAKER0] and [SPEAKER1] tags for multi-speaker dialogue</li> | |
<li>Experiment with temperature (0.8-1.2) for varied speech styles</li> | |
</ul> | |
</div> | |
""") | |
# Voice samples section with visual cards | |
with gr.Row(visible=False) as voice_samples_section: | |
gr.Markdown("### 🎧 Voice Samples Library") | |
voice_samples_table = gr.Dataframe( | |
headers=["Voice Preset", "Sample Text"], | |
datatype=["str", "str"], | |
value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"], | |
interactive=False, | |
elem_classes=["voice-samples-table"] | |
) | |
sample_audio = gr.Audio( | |
label="🔊 Preview Voice Sample", | |
elem_classes=["sample-audio"] | |
) | |
# Function to update character count | |
def update_char_count(text): | |
return f"Character count: {len(text)}" | |
# Function to play voice sample when clicking on a row | |
def play_voice_sample(evt: gr.SelectData): | |
try: | |
preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"] | |
if evt.index[0] < len(preset_names): | |
preset = preset_names[evt.index[0]] | |
voice_path, _ = get_voice_preset(preset) | |
if voice_path and os.path.exists(voice_path): | |
return voice_path | |
else: | |
gr.Warning(f"Voice sample file not found for preset: {preset}") | |
return None | |
else: | |
gr.Warning("Invalid voice preset selection") | |
return None | |
except Exception as e: | |
logger.error(f"Error playing voice sample: {e}") | |
gr.Error(f"Error playing voice sample: {e}") | |
return None | |
# Function to handle template selection | |
def apply_template(template_name): | |
if template_name in PREDEFINED_EXAMPLES: | |
template = PREDEFINED_EXAMPLES[template_name] | |
is_voice_clone = template_name == "voice-clone" | |
voice_preset_value = "belinda" if is_voice_clone else "EMPTY" | |
ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7 | |
description_html = f'<div class="info-card">{template["description"]}</div>' | |
return ( | |
template["system_prompt"], # system_prompt | |
template["input_text"], # input_text | |
description_html, # template_description | |
gr.update( | |
value=voice_preset_value, | |
interactive=is_voice_clone, | |
visible=is_voice_clone | |
), # voice_preset | |
gr.update(visible=is_voice_clone), # custom reference accordion | |
gr.update(visible=is_voice_clone), # voice samples section | |
ras_win_len_value, # ras_win_len | |
gr.update(visible=is_voice_clone), # voice_section | |
update_char_count(template["input_text"]), # char_count | |
) | |
return (gr.update(),) * 9 | |
# Enhanced text_to_speech wrapper with status updates | |
def text_to_speech_with_status( | |
text, voice_preset, reference_audio, reference_text, | |
max_completion_tokens, temperature, top_p, top_k, | |
system_prompt, stop_strings, ras_win_len, ras_win_max_num_repeat | |
): | |
# Update status | |
yield "🔄 Initializing model...", None, None, gr.update(visible=False) | |
# Call the actual TTS function | |
result_text, audio_result = text_to_speech( | |
text, voice_preset, reference_audio, reference_text, | |
max_completion_tokens, temperature, top_p, top_k, | |
system_prompt, stop_strings, ras_win_len, ras_win_max_num_repeat | |
) | |
if audio_result: | |
status = "✅ Speech generated successfully!" | |
download_visible = True | |
else: | |
status = "❌ Failed to generate speech" | |
download_visible = False | |
yield status, result_text, audio_result, gr.update(visible=download_visible) | |
# Set up event handlers | |
# Character count update | |
input_text.change( | |
fn=update_char_count, | |
inputs=[input_text], | |
outputs=[char_count] | |
) | |
# Template selection | |
template_dropdown.change( | |
fn=apply_template, | |
inputs=[template_dropdown], | |
outputs=[ | |
system_prompt, | |
input_text, | |
template_description, | |
voice_preset, | |
custom_reference_accordion, | |
voice_samples_section, | |
ras_win_len, | |
voice_section, | |
char_count, | |
], | |
) | |
# Voice sample preview | |
voice_samples_table.select( | |
fn=play_voice_sample, | |
outputs=[sample_audio] | |
) | |
# Generate button with status updates | |
submit_btn.click( | |
fn=text_to_speech_with_status, | |
inputs=[ | |
input_text, | |
voice_preset, | |
reference_audio, | |
reference_text, | |
max_completion_tokens, | |
temperature, | |
top_p, | |
top_k, | |
system_prompt, | |
stop_strings, | |
ras_win_len, | |
ras_win_max_num_repeat, | |
], | |
outputs=[status_text, output_text, output_audio, download_btn], | |
api_name="generate_speech", | |
) | |
# Stop button functionality | |
stop_btn.click( | |
fn=lambda: None, | |
inputs=[], | |
outputs=[output_audio], | |
js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}", | |
) | |
# Download button functionality | |
download_btn.click( | |
fn=lambda x: x, | |
inputs=[output_audio], | |
outputs=[], | |
js="(audio) => {if(audio) {const a = document.createElement('a'); a.href = audio.url; a.download = 'generated_speech.wav'; a.click();}}", | |
) | |
return demo | |
def main(): | |
"""Main function to parse arguments and launch the UI.""" | |
global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH | |
parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine") | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda", | |
choices=["cuda", "cpu"], | |
help="Device to run the model on.", | |
) | |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.") | |
parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.") | |
args = parser.parse_args() | |
# Initialize global variables | |
initialize_globals() | |
# Create and launch the UI | |
demo = create_ui() | |
demo.launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=False, | |
show_error=True | |
) | |
if __name__ == "__main__": | |
main() |