HoneyTian's picture
update
8ae9295
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
import logging
from pathlib import Path
import platform
import zipfile
from typing import Tuple
from project_settings import log_directory
import log
log.setup_size_rotating(log_directory=log_directory)
import gradio as gr
from df.enhance import enhance as df_enhance
from df.enhance import init_df as df_init_df
from df.enhance import load_audio as df_load_audio
from libdf import DF
import numpy as np
import torch
import torch.nn as nn
from project_settings import project_path, environment, temp_directory
from toolbox.os.command import Command
main_logger = logging.getLogger("main")
def shell(cmd: str):
return Command.popen(cmd)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--noise_suppression_examples_wav_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
@lru_cache(maxsize=10)
def load_df_model(model_name: str) -> Tuple[nn.Module, DF]:
model_base_dir = temp_directory / "df"
model_dir = model_base_dir / model_name
main_logger.info("load model: {}".format(model_name))
model_file = project_path / "trained_models/df/{}.zip".format(model_name)
if not model_dir.exists():
with zipfile.ZipFile(model_file) as zf:
zf.extractall(
model_base_dir.as_posix()
)
model, df_state, _ = df_init_df(
model_base_dir=model_dir.as_posix(),
)
return model, df_state
def do_df_noise_suppression(filename: str, model_name: str) -> Tuple[int, np.ndarray]:
model, df_state = load_df_model(model_name)
main_logger.info("load audio: {}".format(filename))
audio, _ = df_load_audio(
file=filename,
sr=df_state.sr()
)
main_logger.info("run enhance.")
enhanced: torch.Tensor = df_enhance(
model=model,
df_state=df_state,
audio=audio,
)
enhanced = enhanced[0].numpy()
enhanced = enhanced * (1 << 15)
enhanced = np.array(enhanced, dtype=np.int16)
return df_state.sr(), enhanced
def do_noise_suppression(filename: str, model_name: str):
if model_name in (
"DeepFilterNet",
"DeepFilterNet2",
"DeepFilterNet3",
):
return do_df_noise_suppression(filename, model_name)
else:
raise AssertionError("invalid model name: {}".format(model_name))
def main():
args = get_args()
noise_suppression_examples_wav_dir = Path(args.noise_suppression_examples_wav_dir)
noise_suppression_examples = list()
for filename in noise_suppression_examples_wav_dir.glob("*/*.wav"):
name = filename.parts[-2]
model_name = "DeepFilterNet3" if name == "df" else name
noise_suppression_examples.append([
filename.as_posix(),
model_name,
])
title = "## Speech Enhancement and Noise Suppression."
# blocks
with gr.Blocks() as blocks:
gr.Markdown(value=title)
with gr.Tabs():
with gr.TabItem("SE"):
se_file = gr.Audio(
sources=["upload"],
type="filepath",
label="file",
)
se_model_name = gr.Dropdown(
choices=[
"DeepFilterNet", "DeepFilterNet2", "DeepFilterNet3"
],
value="DeepFilterNet3",
label="model_name",
)
se_button = gr.Button("run")
se_enhanced = gr.Audio(type="numpy", label="enhanced")
gr.Examples(
examples=noise_suppression_examples,
inputs=[
se_file, se_model_name
],
outputs=[
se_enhanced
],
fn=do_df_noise_suppression
)
se_button.click(
do_noise_suppression,
inputs=[
se_file, se_model_name
],
outputs=[
se_enhanced
],
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[
shell_text,
],
outputs=[
shell_output
],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == '__main__':
main()