Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from functools import lru_cache | |
import logging | |
from pathlib import Path | |
import platform | |
import zipfile | |
from typing import Tuple | |
from project_settings import log_directory | |
import log | |
log.setup_size_rotating(log_directory=log_directory) | |
import gradio as gr | |
from df.enhance import enhance as df_enhance | |
from df.enhance import init_df as df_init_df | |
from df.enhance import load_audio as df_load_audio | |
from libdf import DF | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from project_settings import project_path, environment, temp_directory | |
from toolbox.os.command import Command | |
main_logger = logging.getLogger("main") | |
def shell(cmd: str): | |
return Command.popen(cmd) | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--noise_suppression_examples_wav_dir", | |
default=(project_path / "data/examples").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--server_port", | |
default=environment.get("server_port", 7860), | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
def load_df_model(model_name: str) -> Tuple[nn.Module, DF]: | |
model_base_dir = temp_directory / "df" | |
model_dir = model_base_dir / model_name | |
main_logger.info("load model: {}".format(model_name)) | |
model_file = project_path / "trained_models/df/{}.zip".format(model_name) | |
if not model_dir.exists(): | |
with zipfile.ZipFile(model_file) as zf: | |
zf.extractall( | |
model_base_dir.as_posix() | |
) | |
model, df_state, _ = df_init_df( | |
model_base_dir=model_dir.as_posix(), | |
) | |
return model, df_state | |
def do_df_noise_suppression(filename: str, model_name: str) -> Tuple[int, np.ndarray]: | |
model, df_state = load_df_model(model_name) | |
main_logger.info("load audio: {}".format(filename)) | |
audio, _ = df_load_audio( | |
file=filename, | |
sr=df_state.sr() | |
) | |
main_logger.info("run enhance.") | |
enhanced: torch.Tensor = df_enhance( | |
model=model, | |
df_state=df_state, | |
audio=audio, | |
) | |
enhanced = enhanced[0].numpy() | |
enhanced = enhanced * (1 << 15) | |
enhanced = np.array(enhanced, dtype=np.int16) | |
return df_state.sr(), enhanced | |
def do_noise_suppression(filename: str, model_name: str): | |
if model_name in ( | |
"DeepFilterNet", | |
"DeepFilterNet2", | |
"DeepFilterNet3", | |
): | |
return do_df_noise_suppression(filename, model_name) | |
else: | |
raise AssertionError("invalid model name: {}".format(model_name)) | |
def main(): | |
args = get_args() | |
noise_suppression_examples_wav_dir = Path(args.noise_suppression_examples_wav_dir) | |
noise_suppression_examples = list() | |
for filename in noise_suppression_examples_wav_dir.glob("*/*.wav"): | |
name = filename.parts[-2] | |
model_name = "DeepFilterNet3" if name == "df" else name | |
noise_suppression_examples.append([ | |
filename.as_posix(), | |
model_name, | |
]) | |
title = "## Speech Enhancement and Noise Suppression." | |
# blocks | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=title) | |
with gr.Tabs(): | |
with gr.TabItem("SE"): | |
se_file = gr.Audio( | |
sources=["upload"], | |
type="filepath", | |
label="file", | |
) | |
se_model_name = gr.Dropdown( | |
choices=[ | |
"DeepFilterNet", "DeepFilterNet2", "DeepFilterNet3" | |
], | |
value="DeepFilterNet3", | |
label="model_name", | |
) | |
se_button = gr.Button("run") | |
se_enhanced = gr.Audio(type="numpy", label="enhanced") | |
gr.Examples( | |
examples=noise_suppression_examples, | |
inputs=[ | |
se_file, se_model_name | |
], | |
outputs=[ | |
se_enhanced | |
], | |
fn=do_df_noise_suppression | |
) | |
se_button.click( | |
do_noise_suppression, | |
inputs=[ | |
se_file, se_model_name | |
], | |
outputs=[ | |
se_enhanced | |
], | |
) | |
with gr.TabItem("shell"): | |
shell_text = gr.Textbox(label="cmd") | |
shell_button = gr.Button("run") | |
shell_output = gr.Textbox(label="output") | |
shell_button.click( | |
shell, | |
inputs=[ | |
shell_text, | |
], | |
outputs=[ | |
shell_output | |
], | |
) | |
blocks.queue().launch( | |
share=False if platform.system() == "Windows" else False, | |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
server_port=args.server_port | |
) | |
return | |
if __name__ == '__main__': | |
main() | |