cc_vad / main.py
HoneyTian's picture
update
5dd7349
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
import json
import logging
from pathlib import Path
import platform
import shutil
import tempfile
import time
from typing import Dict, Tuple
import uuid
import zipfile
import gradio as gr
import librosa
from huggingface_hub import snapshot_download
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import log
from project_settings import environment, project_path, log_directory, time_zone_info
from toolbox.os.command import Command
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
from toolbox.torchaudio.utils.visualization import process_speech_probs
from toolbox.vad.utils import PostProcess
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
logger = logging.getLogger("main")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/cc_vad",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
def save_input_audio(sample_rate: int, signal: np.ndarray) -> str:
if signal.dtype != np.int16:
raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}")
temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio"
temp_audio_dir.mkdir(parents=True, exist_ok=True)
filename = temp_audio_dir / f"{uuid.uuid4()}.wav"
filename = filename.as_posix()
wavfile.write(
filename,
sample_rate, signal
)
return filename
def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int):
filename = save_input_audio(sample_rate, signal)
signal, _ = librosa.load(filename, sr=target_sample_rate)
signal = np.array(signal * (1 << 15), dtype=np.int16)
return signal
def shell(cmd: str):
return Command.popen(cmd)
def get_infer_cls_by_model_name(model_name: str):
if model_name.__contains__("fsmn"):
infer_cls = InferenceFSMNVadOnnx
elif model_name.__contains__("silero"):
infer_cls = InferenceSileroVad
else:
raise AssertionError
return infer_cls
vad_engines: Dict[str, dict] = None
@lru_cache(maxsize=1)
def load_vad_model(infer_cls, **kwargs):
infer_engine = infer_cls(**kwargs)
return infer_engine
def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""):
duration = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(duration, signal, color='b')
plt.plot(duration, speech_probs, color='gray')
plt.title(title)
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
plt.savefig(temp_file.name, bbox_inches="tight")
plt.close()
return temp_file.name
def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
ring_max_length: int = 10,
min_silence_length: int = 2,
max_speech_length: int = 10000, min_speech_length: int = 10,
engine: str = None,
):
if audio_file_t is None and audio_microphone_t is None:
raise gr.Error(f"audio file and microphone is null.")
if audio_file_t is not None and audio_microphone_t is not None:
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
audio_t: Tuple = audio_file_t or audio_microphone_t
sample_rate, signal = audio_t
if sample_rate != 8000:
signal = convert_sample_rate(signal, sample_rate, 8000)
sample_rate = 8000
audio_duration = signal.shape[-1] // sample_rate
audio = np.array(signal / (1 << 15), dtype=np.float32)
infer_engine_param = vad_engines.get(engine)
if infer_engine_param is None:
raise gr.Error(f"invalid denoise engine: {engine}.")
try:
infer_cls = infer_engine_param["infer_cls"]
kwargs = infer_engine_param["kwargs"]
infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs)
begin = time.time()
vad_info = infer_engine.infer(audio)
time_cost = time.time() - begin
probs = vad_info["probs"]
lsnr = vad_info["lsnr"]
# lsnr = lsnr / np.max(np.abs(lsnr))
lsnr = lsnr / 30
frame_step = infer_engine.config.hop_size
# post process
vad_post_process = PostProcess(
start_ring_rate=start_ring_rate,
end_ring_rate=end_ring_rate,
ring_max_length=ring_max_length,
min_silence_length=min_silence_length,
max_speech_length=max_speech_length,
min_speech_length=min_speech_length
)
vad_segments = vad_post_process.get_vad_segments(probs)
vad_flags = vad_post_process.get_vad_flags(probs, vad_segments)
# vad_image
vad_ = process_speech_probs(audio, vad_flags, frame_step)
vad_image = generate_image(audio, vad_)
# probs_image
probs_ = process_speech_probs(audio, probs, frame_step)
probs_image = generate_image(audio, probs_)
# lsnr_image
lsnr_ = process_speech_probs(audio, lsnr, frame_step)
lsnr_image = generate_image(audio, lsnr_)
# vad segment
vad_segments = [
[
v[0] * frame_step / sample_rate,
v[1] * frame_step / sample_rate
] for v in vad_segments
]
# message
rtf = time_cost / audio_duration
info = {
"vad_segments": vad_segments,
"time_cost": round(time_cost, 4),
"duration": round(audio_duration, 4),
"rtf": round(rtf, 4)
}
message = json.dumps(info, ensure_ascii=False, indent=4)
except Exception as e:
raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
return vad_image, probs_image, lsnr_image, message
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# engines
global vad_engines
vad_engines = {
filename.stem: {
"infer_cls": get_infer_cls_by_model_name(filename.stem),
"kwargs": {
"pretrained_model_path_or_zip_file": filename.as_posix()
}
}
for filename in (project_path / "trained_models").glob("*.zip")
if filename.name not in (
"cnn-vad-by-webrtcvad-nx-dns3.zip",
"fsmn-vad-by-webrtcvad-nx-dns3.zip",
"examples.zip",
"sound-2-ch32.zip",
"sound-3-ch32.zip",
"sound-4-ch32.zip",
"sound-8-ch32.zip",
)
}
# choices
vad_engine_choices = list(vad_engines.keys())
# examples
if not examples_dir.exists():
example_zip_file = trained_model_dir / "examples.zip"
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
out_root = examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# examples
examples = list()
for filename in examples_dir.glob("**/*.wav"):
examples.append([
filename.as_posix(),
None,
vad_engine_choices[0],
])
# ui
with gr.Blocks() as blocks:
gr.Markdown(value="vad.")
with gr.Tabs():
with gr.TabItem("vad"):
with gr.Row():
with gr.Column(variant="panel", scale=5):
with gr.Tabs():
with gr.TabItem("file"):
vad_audio_file = gr.Audio(label="audio")
with gr.TabItem("microphone"):
vad_audio_microphone = gr.Audio(sources="microphone", label="audio")
with gr.Row():
vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
with gr.Row():
vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)")
vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)")
with gr.Row():
vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)")
vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)")
vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
vad_button = gr.Button(variant="primary")
with gr.Column(variant="panel", scale=5):
vad_vad_image = gr.Image(label="vad")
vad_prob_image = gr.Image(label="prob")
vad_lsnr_image = gr.Image(label="lsnr")
vad_message = gr.Textbox(lines=1, max_lines=20, label="message")
vad_button.click(
when_click_vad_button,
inputs=[
vad_audio_file, vad_audio_microphone,
vad_start_ring_rate, vad_end_ring_rate,
vad_ring_max_length,
vad_min_silence_length,
vad_max_speech_length, vad_min_speech_length,
vad_engine,
],
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
)
gr.Examples(
examples=examples,
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
fn=when_click_vad_button,
# cache_examples=True,
# cache_mode="lazy",
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[shell_text,],
outputs=[shell_output],
)
# http://127.0.0.1:7866/
# http://10.75.27.247:7866/
blocks.queue().launch(
# share=True,
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port,
show_error=True
)
return
if __name__ == "__main__":
main()