Spaces:
Sleeping
Sleeping
File size: 5,170 Bytes
054b62d f50d71c 8ae9295 054b62d f50d71c 054b62d f50d71c 054b62d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
import logging
from pathlib import Path
import platform
import zipfile
from typing import Tuple
from project_settings import log_directory
import log
log.setup_size_rotating(log_directory=log_directory)
import gradio as gr
from df.enhance import enhance as df_enhance
from df.enhance import init_df as df_init_df
from df.enhance import load_audio as df_load_audio
from libdf import DF
import numpy as np
import torch
import torch.nn as nn
from project_settings import project_path, environment, temp_directory
from toolbox.os.command import Command
main_logger = logging.getLogger("main")
def shell(cmd: str):
return Command.popen(cmd)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--noise_suppression_examples_wav_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
@lru_cache(maxsize=10)
def load_df_model(model_name: str) -> Tuple[nn.Module, DF]:
model_base_dir = temp_directory / "df"
model_dir = model_base_dir / model_name
main_logger.info("load model: {}".format(model_name))
model_file = project_path / "trained_models/df/{}.zip".format(model_name)
if not model_dir.exists():
with zipfile.ZipFile(model_file) as zf:
zf.extractall(
model_base_dir.as_posix()
)
model, df_state, _ = df_init_df(
model_base_dir=model_dir.as_posix(),
)
return model, df_state
def do_df_noise_suppression(filename: str, model_name: str) -> Tuple[int, np.ndarray]:
model, df_state = load_df_model(model_name)
main_logger.info("load audio: {}".format(filename))
audio, _ = df_load_audio(
file=filename,
sr=df_state.sr()
)
main_logger.info("run enhance.")
enhanced: torch.Tensor = df_enhance(
model=model,
df_state=df_state,
audio=audio,
)
enhanced = enhanced[0].numpy()
enhanced = enhanced * (1 << 15)
enhanced = np.array(enhanced, dtype=np.int16)
return df_state.sr(), enhanced
def do_noise_suppression(filename: str, model_name: str):
if model_name in (
"DeepFilterNet",
"DeepFilterNet2",
"DeepFilterNet3",
):
return do_df_noise_suppression(filename, model_name)
else:
raise AssertionError("invalid model name: {}".format(model_name))
def main():
args = get_args()
noise_suppression_examples_wav_dir = Path(args.noise_suppression_examples_wav_dir)
noise_suppression_examples = list()
for filename in noise_suppression_examples_wav_dir.glob("*/*.wav"):
name = filename.parts[-2]
model_name = "DeepFilterNet3" if name == "df" else name
noise_suppression_examples.append([
filename.as_posix(),
model_name,
])
title = "## Speech Enhancement and Noise Suppression."
# blocks
with gr.Blocks() as blocks:
gr.Markdown(value=title)
with gr.Tabs():
with gr.TabItem("SE"):
se_file = gr.Audio(
sources=["upload"],
type="filepath",
label="file",
)
se_model_name = gr.Dropdown(
choices=[
"DeepFilterNet", "DeepFilterNet2", "DeepFilterNet3"
],
value="DeepFilterNet3",
label="model_name",
)
se_button = gr.Button("run")
se_enhanced = gr.Audio(type="numpy", label="enhanced")
gr.Examples(
examples=noise_suppression_examples,
inputs=[
se_file, se_model_name
],
outputs=[
se_enhanced
],
fn=do_df_noise_suppression
)
se_button.click(
do_noise_suppression,
inputs=[
se_file, se_model_name
],
outputs=[
se_enhanced
],
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[
shell_text,
],
outputs=[
shell_output
],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == '__main__':
main()
|