|
|
|
|
|
import argparse |
|
from functools import lru_cache |
|
import json |
|
import logging |
|
from pathlib import Path |
|
import platform |
|
import shutil |
|
import tempfile |
|
import time |
|
from typing import Dict, Tuple |
|
import uuid |
|
import zipfile |
|
|
|
import gradio as gr |
|
import librosa |
|
from huggingface_hub import snapshot_download |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.io import wavfile |
|
|
|
import log |
|
from project_settings import environment, project_path, log_directory, time_zone_info |
|
from toolbox.os.command import Command |
|
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx |
|
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad |
|
from toolbox.torchaudio.utils.visualization import process_speech_probs |
|
from toolbox.vad.utils import PostProcess |
|
|
|
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info) |
|
|
|
logger = logging.getLogger("main") |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--examples_dir", |
|
|
|
default=(project_path / "data/examples").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--models_repo_id", |
|
default="qgyd2021/cc_vad", |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--trained_model_dir", |
|
default=(project_path / "trained_models").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--hf_token", |
|
default=environment.get("hf_token"), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--server_port", |
|
default=environment.get("server_port", 7860), |
|
type=int |
|
) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def save_input_audio(sample_rate: int, signal: np.ndarray) -> str: |
|
if signal.dtype != np.int16: |
|
raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}") |
|
temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio" |
|
temp_audio_dir.mkdir(parents=True, exist_ok=True) |
|
filename = temp_audio_dir / f"{uuid.uuid4()}.wav" |
|
filename = filename.as_posix() |
|
wavfile.write( |
|
filename, |
|
sample_rate, signal |
|
) |
|
return filename |
|
|
|
|
|
def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int): |
|
filename = save_input_audio(sample_rate, signal) |
|
|
|
signal, _ = librosa.load(filename, sr=target_sample_rate) |
|
signal = np.array(signal * (1 << 15), dtype=np.int16) |
|
return signal |
|
|
|
|
|
def shell(cmd: str): |
|
return Command.popen(cmd) |
|
|
|
|
|
def get_infer_cls_by_model_name(model_name: str): |
|
if model_name.__contains__("fsmn"): |
|
infer_cls = InferenceFSMNVadOnnx |
|
elif model_name.__contains__("silero"): |
|
infer_cls = InferenceSileroVad |
|
else: |
|
raise AssertionError |
|
return infer_cls |
|
|
|
|
|
vad_engines: Dict[str, dict] = None |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_vad_model(infer_cls, **kwargs): |
|
infer_engine = infer_cls(**kwargs) |
|
return infer_engine |
|
|
|
|
|
def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""): |
|
duration = np.arange(0, len(signal)) / sample_rate |
|
plt.figure(figsize=(12, 5)) |
|
plt.plot(duration, signal, color='b') |
|
plt.plot(duration, speech_probs, color='gray') |
|
plt.title(title) |
|
|
|
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
|
plt.savefig(temp_file.name, bbox_inches="tight") |
|
plt.close() |
|
return temp_file.name |
|
|
|
|
|
def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, |
|
start_ring_rate: float = 0.5, end_ring_rate: float = 0.3, |
|
ring_max_length: int = 10, |
|
min_silence_length: int = 2, |
|
max_speech_length: int = 10000, min_speech_length: int = 10, |
|
engine: str = None, |
|
): |
|
if audio_file_t is None and audio_microphone_t is None: |
|
raise gr.Error(f"audio file and microphone is null.") |
|
if audio_file_t is not None and audio_microphone_t is not None: |
|
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.") |
|
audio_t: Tuple = audio_file_t or audio_microphone_t |
|
|
|
sample_rate, signal = audio_t |
|
if sample_rate != 8000: |
|
signal = convert_sample_rate(signal, sample_rate, 8000) |
|
sample_rate = 8000 |
|
|
|
audio_duration = signal.shape[-1] // sample_rate |
|
audio = np.array(signal / (1 << 15), dtype=np.float32) |
|
|
|
infer_engine_param = vad_engines.get(engine) |
|
if infer_engine_param is None: |
|
raise gr.Error(f"invalid denoise engine: {engine}.") |
|
|
|
try: |
|
infer_cls = infer_engine_param["infer_cls"] |
|
kwargs = infer_engine_param["kwargs"] |
|
infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs) |
|
|
|
begin = time.time() |
|
vad_info = infer_engine.infer(audio) |
|
time_cost = time.time() - begin |
|
|
|
probs = vad_info["probs"] |
|
lsnr = vad_info["lsnr"] |
|
|
|
lsnr = lsnr / 30 |
|
|
|
frame_step = infer_engine.config.hop_size |
|
|
|
|
|
vad_post_process = PostProcess( |
|
start_ring_rate=start_ring_rate, |
|
end_ring_rate=end_ring_rate, |
|
ring_max_length=ring_max_length, |
|
min_silence_length=min_silence_length, |
|
max_speech_length=max_speech_length, |
|
min_speech_length=min_speech_length |
|
) |
|
vad_segments = vad_post_process.get_vad_segments(probs) |
|
vad_flags = vad_post_process.get_vad_flags(probs, vad_segments) |
|
|
|
|
|
vad_ = process_speech_probs(audio, vad_flags, frame_step) |
|
vad_image = generate_image(audio, vad_) |
|
|
|
|
|
probs_ = process_speech_probs(audio, probs, frame_step) |
|
probs_image = generate_image(audio, probs_) |
|
|
|
|
|
lsnr_ = process_speech_probs(audio, lsnr, frame_step) |
|
lsnr_image = generate_image(audio, lsnr_) |
|
|
|
|
|
vad_segments = [ |
|
[ |
|
v[0] * frame_step / sample_rate, |
|
v[1] * frame_step / sample_rate |
|
] for v in vad_segments |
|
] |
|
|
|
|
|
rtf = time_cost / audio_duration |
|
info = { |
|
"vad_segments": vad_segments, |
|
"time_cost": round(time_cost, 4), |
|
"duration": round(audio_duration, 4), |
|
"rtf": round(rtf, 4) |
|
} |
|
message = json.dumps(info, ensure_ascii=False, indent=4) |
|
|
|
except Exception as e: |
|
raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.") |
|
|
|
return vad_image, probs_image, lsnr_image, message |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
examples_dir = Path(args.examples_dir) |
|
trained_model_dir = Path(args.trained_model_dir) |
|
|
|
|
|
if not trained_model_dir.exists(): |
|
trained_model_dir.mkdir(parents=True, exist_ok=True) |
|
_ = snapshot_download( |
|
repo_id=args.models_repo_id, |
|
local_dir=trained_model_dir.as_posix(), |
|
token=args.hf_token, |
|
) |
|
|
|
|
|
global vad_engines |
|
vad_engines = { |
|
filename.stem: { |
|
"infer_cls": get_infer_cls_by_model_name(filename.stem), |
|
"kwargs": { |
|
"pretrained_model_path_or_zip_file": filename.as_posix() |
|
} |
|
} |
|
for filename in (project_path / "trained_models").glob("*.zip") |
|
if filename.name not in ( |
|
"cnn-vad-by-webrtcvad-nx-dns3.zip", |
|
"fsmn-vad-by-webrtcvad-nx-dns3.zip", |
|
"examples.zip", |
|
"sound-2-ch32.zip", |
|
"sound-3-ch32.zip", |
|
"sound-4-ch32.zip", |
|
"sound-8-ch32.zip", |
|
) |
|
} |
|
|
|
|
|
vad_engine_choices = list(vad_engines.keys()) |
|
|
|
|
|
if not examples_dir.exists(): |
|
example_zip_file = trained_model_dir / "examples.zip" |
|
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: |
|
out_root = examples_dir |
|
if out_root.exists(): |
|
shutil.rmtree(out_root.as_posix()) |
|
out_root.mkdir(parents=True, exist_ok=True) |
|
f_zip.extractall(path=out_root) |
|
|
|
|
|
examples = list() |
|
for filename in examples_dir.glob("**/*.wav"): |
|
examples.append([ |
|
filename.as_posix(), |
|
None, |
|
vad_engine_choices[0], |
|
]) |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
gr.Markdown(value="vad.") |
|
with gr.Tabs(): |
|
with gr.TabItem("vad"): |
|
with gr.Row(): |
|
with gr.Column(variant="panel", scale=5): |
|
with gr.Tabs(): |
|
with gr.TabItem("file"): |
|
vad_audio_file = gr.Audio(label="audio") |
|
with gr.TabItem("microphone"): |
|
vad_audio_microphone = gr.Audio(sources="microphone", label="audio") |
|
|
|
with gr.Row(): |
|
vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate") |
|
vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate") |
|
with gr.Row(): |
|
vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)") |
|
vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)") |
|
with gr.Row(): |
|
vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)") |
|
vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)") |
|
vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine") |
|
vad_button = gr.Button(variant="primary") |
|
with gr.Column(variant="panel", scale=5): |
|
vad_vad_image = gr.Image(label="vad") |
|
vad_prob_image = gr.Image(label="prob") |
|
vad_lsnr_image = gr.Image(label="lsnr") |
|
vad_message = gr.Textbox(lines=1, max_lines=20, label="message") |
|
|
|
vad_button.click( |
|
when_click_vad_button, |
|
inputs=[ |
|
vad_audio_file, vad_audio_microphone, |
|
vad_start_ring_rate, vad_end_ring_rate, |
|
vad_ring_max_length, |
|
vad_min_silence_length, |
|
vad_max_speech_length, vad_min_speech_length, |
|
vad_engine, |
|
], |
|
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message], |
|
) |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[vad_audio_file, vad_audio_microphone, vad_engine], |
|
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message], |
|
fn=when_click_vad_button, |
|
|
|
|
|
) |
|
with gr.TabItem("shell"): |
|
shell_text = gr.Textbox(label="cmd") |
|
shell_button = gr.Button("run") |
|
shell_output = gr.Textbox(label="output") |
|
|
|
shell_button.click( |
|
shell, |
|
inputs=[shell_text,], |
|
outputs=[shell_output], |
|
) |
|
|
|
|
|
|
|
blocks.queue().launch( |
|
|
|
share=False if platform.system() == "Windows" else False, |
|
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
|
server_port=args.server_port, |
|
show_error=True |
|
) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|