Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from cached_path import cached_path | |
from f5_tts.infer.utils_infer import ( | |
infer_process, | |
load_model, | |
load_vocoder, | |
preprocess_ref_audio_text, | |
) | |
from f5_tts.model import DiT | |
vocoder = load_vocoder() | |
# common usage | |
v1_base_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
v1_small_cfg = dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4) | |
alg_vocab_path = str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/vocab.txt")) | |
tts_model_choice = "v1-base_zh-en" # default | |
tts_model_collections = { | |
"v1-base_zh-en": load_model( | |
DiT, | |
v1_base_cfg, | |
str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")), | |
vocab_file=str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt")), | |
), | |
"v1-small_alg-64h_300k": load_model( | |
DiT, | |
v1_small_cfg, | |
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_300000.safetensors")), | |
vocab_file=alg_vocab_path, | |
), | |
"v1-small_alg-64h_300k_no-ema": load_model( | |
DiT, | |
v1_small_cfg, | |
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_300000_no-ema.safetensors")), | |
vocab_file=alg_vocab_path, | |
), | |
"v1-small_alg-64h_200k": load_model( | |
DiT, | |
v1_small_cfg, | |
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_200000.safetensors")), | |
vocab_file=alg_vocab_path, | |
), | |
"v1-small_alg-64h_200k_no-ema": load_model( | |
DiT, | |
v1_small_cfg, | |
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_200000_no-ema.safetensors")), | |
vocab_file=alg_vocab_path, | |
), | |
} | |
def infer( | |
ref_audio_orig, | |
ref_text, | |
gen_text, | |
model, | |
seed, | |
show_info=gr.Info, | |
): | |
if not ref_audio_orig or not ref_text.strip() or not gen_text.strip(): | |
gr.Warning("Please ensure [Reference Audio] [Reference Text] [Text to Generate] are all provided.") | |
return gr.update(), ref_text, seed | |
if seed < 0 or seed > 2**31 - 1: | |
gr.Warning("Please set a seed in range 0 ~ 2**31 - 1.") | |
seed = np.random.randint(0, 2**31 - 1) | |
torch.manual_seed(seed) | |
used_seed = seed | |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) | |
final_wave, final_sample_rate, _ = infer_process( | |
ref_audio, | |
ref_text, | |
gen_text, | |
tts_model_collections[tts_model_choice], | |
vocoder, | |
show_info=show_info, | |
progress=gr.Progress(), | |
) | |
return (final_sample_rate, final_wave), ref_text, used_seed | |
with gr.Blocks() as app_basic_tts: | |
with gr.Row(): | |
with gr.Column(): | |
ref_wav_input = gr.Audio(label="Reference Audio", type="filepath") | |
ref_txt_input = gr.Textbox(label="Reference Text") | |
gen_txt_input = gr.Textbox(label="Text to Generate") | |
generate_btn = gr.Button("Synthesize", variant="primary") | |
with gr.Row(): | |
randomize_seed = gr.Checkbox( | |
label="Randomize Seed", | |
info="Check to use a random seed for each generation. Uncheck to use the seed specified.", | |
value=True, | |
scale=3, | |
) | |
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1) | |
audio_output = gr.Audio(label="Synthesized Audio") | |
def basic_tts( | |
ref_wav_input, | |
ref_txt_input, | |
gen_txt_input, | |
randomize_seed, | |
seed_input, | |
): | |
if randomize_seed: | |
seed_input = np.random.randint(0, 2**31 - 1) | |
audio_out, ref_text_out, used_seed = infer( | |
ref_wav_input, | |
ref_txt_input, | |
gen_txt_input, | |
tts_model_choice, | |
seed_input, | |
) | |
return audio_out, ref_text_out, used_seed | |
ref_wav_input.clear( | |
lambda: gr.update(), | |
None, | |
ref_txt_input, | |
) | |
generate_btn.click( | |
basic_tts, | |
inputs=[ | |
ref_wav_input, | |
ref_txt_input, | |
gen_txt_input, | |
randomize_seed, | |
seed_input, | |
], | |
outputs=[audio_output, ref_txt_input, seed_input], | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# 🗣️ F5-TTS Online Demo for Dev Test | |
Upload or record a reference voice, give its transcription text, then order the text to generate and have fun! | |
""" | |
) | |
def switch_tts_model(new_choice): | |
global tts_model_choice | |
tts_model_choice = new_choice | |
choose_tts_model = gr.Dropdown( | |
choices=[t for t in tts_model_collections], label="Choose TTS Model", value=tts_model_choice | |
) | |
choose_tts_model.change( | |
switch_tts_model, | |
inputs=[choose_tts_model], | |
) | |
gr.TabbedInterface( | |
[app_basic_tts], | |
["Basic-TTS"], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |