#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from functools import lru_cache import logging from pathlib import Path import platform import zipfile from typing import Tuple from project_settings import log_directory import log log.setup_size_rotating(log_directory=log_directory) import gradio as gr from df.enhance import enhance as df_enhance from df.enhance import init_df as df_init_df from df.enhance import load_audio as df_load_audio from libdf import DF import numpy as np import torch import torch.nn as nn from project_settings import project_path, environment, temp_directory from toolbox.os.command import Command main_logger = logging.getLogger("main") def shell(cmd: str): return Command.popen(cmd) def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--noise_suppression_examples_wav_dir", default=(project_path / "data/examples").as_posix(), type=str ) parser.add_argument( "--server_port", default=environment.get("server_port", 7860), type=int ) args = parser.parse_args() return args @lru_cache(maxsize=10) def load_df_model(model_name: str) -> Tuple[nn.Module, DF]: model_base_dir = temp_directory / "df" model_dir = model_base_dir / model_name main_logger.info("load model: {}".format(model_name)) model_file = project_path / "trained_models/df/{}.zip".format(model_name) if not model_dir.exists(): with zipfile.ZipFile(model_file) as zf: zf.extractall( model_base_dir.as_posix() ) model, df_state, _ = df_init_df( model_base_dir=model_dir.as_posix(), ) return model, df_state def do_df_noise_suppression(filename: str, model_name: str) -> Tuple[int, np.ndarray]: model, df_state = load_df_model(model_name) main_logger.info("load audio: {}".format(filename)) audio, _ = df_load_audio( file=filename, sr=df_state.sr() ) main_logger.info("run enhance.") enhanced: torch.Tensor = df_enhance( model=model, df_state=df_state, audio=audio, ) enhanced = enhanced[0].numpy() enhanced = enhanced * (1 << 15) enhanced = np.array(enhanced, dtype=np.int16) return df_state.sr(), enhanced def do_noise_suppression(filename: str, model_name: str): if model_name in ( "DeepFilterNet", "DeepFilterNet2", "DeepFilterNet3", ): return do_df_noise_suppression(filename, model_name) else: raise AssertionError("invalid model name: {}".format(model_name)) def main(): args = get_args() noise_suppression_examples_wav_dir = Path(args.noise_suppression_examples_wav_dir) noise_suppression_examples = list() for filename in noise_suppression_examples_wav_dir.glob("*/*.wav"): name = filename.parts[-2] model_name = "DeepFilterNet3" if name == "df" else name noise_suppression_examples.append([ filename.as_posix(), model_name, ]) title = "## Speech Enhancement and Noise Suppression." # blocks with gr.Blocks() as blocks: gr.Markdown(value=title) with gr.Tabs(): with gr.TabItem("SE"): se_file = gr.Audio( sources=["upload"], type="filepath", label="file", ) se_model_name = gr.Dropdown( choices=[ "DeepFilterNet", "DeepFilterNet2", "DeepFilterNet3" ], value="DeepFilterNet3", label="model_name", ) se_button = gr.Button("run") se_enhanced = gr.Audio(type="numpy", label="enhanced") gr.Examples( examples=noise_suppression_examples, inputs=[ se_file, se_model_name ], outputs=[ se_enhanced ], fn=do_df_noise_suppression ) se_button.click( do_noise_suppression, inputs=[ se_file, se_model_name ], outputs=[ se_enhanced ], ) with gr.TabItem("shell"): shell_text = gr.Textbox(label="cmd") shell_button = gr.Button("run") shell_output = gr.Textbox(label="output") shell_button.click( shell, inputs=[ shell_text, ], outputs=[ shell_output ], ) blocks.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=args.server_port ) return if __name__ == '__main__': main()