import gradio as gr import numpy as np import spaces import torch from cached_path import cached_path from f5_tts.infer.utils_infer import ( infer_process, load_model, load_vocoder, preprocess_ref_audio_text, ) from f5_tts.model import DiT vocoder = load_vocoder() # common usage v1_base_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) v1_small_cfg = dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4) alg_vocab_path = str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/vocab.txt")) tts_model_choice = "v1-base_zh-en" # default tts_model_collections = { "v1-base_zh-en": load_model( DiT, v1_base_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")), vocab_file=str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt")), ), "v1-small_alg-64h_300k": load_model( DiT, v1_small_cfg, str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_300000.safetensors")), vocab_file=alg_vocab_path, ), "v1-small_alg-64h_300k_no-ema": load_model( DiT, v1_small_cfg, str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_300000_no-ema.safetensors")), vocab_file=alg_vocab_path, ), "v1-small_alg-64h_200k": load_model( DiT, v1_small_cfg, str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_200000.safetensors")), vocab_file=alg_vocab_path, ), "v1-small_alg-64h_200k_no-ema": load_model( DiT, v1_small_cfg, str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_200000_no-ema.safetensors")), vocab_file=alg_vocab_path, ), } @spaces.GPU def infer( ref_audio_orig, ref_text, gen_text, model, seed, show_info=gr.Info, ): if not ref_audio_orig or not ref_text.strip() or not gen_text.strip(): gr.Warning("Please ensure [Reference Audio] [Reference Text] [Text to Generate] are all provided.") return gr.update(), ref_text, seed if seed < 0 or seed > 2**31 - 1: gr.Warning("Please set a seed in range 0 ~ 2**31 - 1.") seed = np.random.randint(0, 2**31 - 1) torch.manual_seed(seed) used_seed = seed ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) final_wave, final_sample_rate, _ = infer_process( ref_audio, ref_text, gen_text, tts_model_collections[tts_model_choice], vocoder, show_info=show_info, progress=gr.Progress(), ) return (final_sample_rate, final_wave), ref_text, used_seed with gr.Blocks() as app_basic_tts: with gr.Row(): with gr.Column(): ref_wav_input = gr.Audio(label="Reference Audio", type="filepath") ref_txt_input = gr.Textbox(label="Reference Text") gen_txt_input = gr.Textbox(label="Text to Generate") generate_btn = gr.Button("Synthesize", variant="primary") with gr.Row(): randomize_seed = gr.Checkbox( label="Randomize Seed", info="Check to use a random seed for each generation. Uncheck to use the seed specified.", value=True, scale=3, ) seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1) audio_output = gr.Audio(label="Synthesized Audio") def basic_tts( ref_wav_input, ref_txt_input, gen_txt_input, randomize_seed, seed_input, ): if randomize_seed: seed_input = np.random.randint(0, 2**31 - 1) audio_out, ref_text_out, used_seed = infer( ref_wav_input, ref_txt_input, gen_txt_input, tts_model_choice, seed_input, ) return audio_out, ref_text_out, used_seed ref_wav_input.clear( lambda: gr.update(), None, ref_txt_input, ) generate_btn.click( basic_tts, inputs=[ ref_wav_input, ref_txt_input, gen_txt_input, randomize_seed, seed_input, ], outputs=[audio_output, ref_txt_input, seed_input], ) with gr.Blocks() as demo: gr.Markdown( """ # 🗣️ F5-TTS Online Demo for Dev Test Upload or record a reference voice, give its transcription text, then order the text to generate and have fun! """ ) def switch_tts_model(new_choice): global tts_model_choice tts_model_choice = new_choice choose_tts_model = gr.Dropdown( choices=[t for t in tts_model_collections], label="Choose TTS Model", value=tts_model_choice ) choose_tts_model.change( switch_tts_model, inputs=[choose_tts_model], ) gr.TabbedInterface( [app_basic_tts], ["Basic-TTS"], ) if __name__ == "__main__": demo.launch()