F5-TTS_Space / app.py
chenxie95's picture
Update app.py
fc78750 verified
import gradio as gr
import numpy as np
import spaces
import torch
from cached_path import cached_path
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
from f5_tts.model import DiT
vocoder = load_vocoder()
# common usage
v1_base_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
v1_small_cfg = dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)
alg_vocab_path = str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/vocab.txt"))
tts_model_choice = "v1-base_zh-en" # default
tts_model_collections = {
"v1-base_zh-en": load_model(
DiT,
v1_base_cfg,
str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")),
vocab_file=str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt")),
),
"v1-small_alg-64h_300k": load_model(
DiT,
v1_small_cfg,
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_300000.safetensors")),
vocab_file=alg_vocab_path,
),
"v1-small_alg-64h_300k_no-ema": load_model(
DiT,
v1_small_cfg,
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_300000_no-ema.safetensors")),
vocab_file=alg_vocab_path,
),
"v1-small_alg-64h_200k": load_model(
DiT,
v1_small_cfg,
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_200000.safetensors")),
vocab_file=alg_vocab_path,
),
"v1-small_alg-64h_200k_no-ema": load_model(
DiT,
v1_small_cfg,
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_200000_no-ema.safetensors")),
vocab_file=alg_vocab_path,
),
}
@spaces.GPU
def infer(
ref_audio_orig,
ref_text,
gen_text,
model,
seed,
show_info=gr.Info,
):
if not ref_audio_orig or not ref_text.strip() or not gen_text.strip():
gr.Warning("Please ensure [Reference Audio] [Reference Text] [Text to Generate] are all provided.")
return gr.update(), ref_text, seed
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Please set a seed in range 0 ~ 2**31 - 1.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
final_wave, final_sample_rate, _ = infer_process(
ref_audio,
ref_text,
gen_text,
tts_model_collections[tts_model_choice],
vocoder,
show_info=show_info,
progress=gr.Progress(),
)
return (final_sample_rate, final_wave), ref_text, used_seed
with gr.Blocks() as app_basic_tts:
with gr.Row():
with gr.Column():
ref_wav_input = gr.Audio(label="Reference Audio", type="filepath")
ref_txt_input = gr.Textbox(label="Reference Text")
gen_txt_input = gr.Textbox(label="Text to Generate")
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
value=True,
scale=3,
)
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
audio_output = gr.Audio(label="Synthesized Audio")
def basic_tts(
ref_wav_input,
ref_txt_input,
gen_txt_input,
randomize_seed,
seed_input,
):
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_out, ref_text_out, used_seed = infer(
ref_wav_input,
ref_txt_input,
gen_txt_input,
tts_model_choice,
seed_input,
)
return audio_out, ref_text_out, used_seed
ref_wav_input.clear(
lambda: gr.update(),
None,
ref_txt_input,
)
generate_btn.click(
basic_tts,
inputs=[
ref_wav_input,
ref_txt_input,
gen_txt_input,
randomize_seed,
seed_input,
],
outputs=[audio_output, ref_txt_input, seed_input],
)
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🗣️ F5-TTS Online Demo for Dev Test
Upload or record a reference voice, give its transcription text, then order the text to generate and have fun!
"""
)
def switch_tts_model(new_choice):
global tts_model_choice
tts_model_choice = new_choice
choose_tts_model = gr.Dropdown(
choices=[t for t in tts_model_collections], label="Choose TTS Model", value=tts_model_choice
)
choose_tts_model.change(
switch_tts_model,
inputs=[choose_tts_model],
)
gr.TabbedInterface(
[app_basic_tts],
["Basic-TTS"],
)
if __name__ == "__main__":
demo.launch()