cc_vad / toolbox /vad /vad.py
HoneyTian's picture
update
5703a24
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import collections
from functools import lru_cache
import os
from pathlib import Path
import shutil
import tempfile
import zipfile
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import torch
import webrtcvad
from project_settings import project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
class FrameVoiceClassifier(object):
def predict(self, chunk: np.ndarray) -> float:
raise NotImplementedError
class WebRTCVoiceClassifier(FrameVoiceClassifier):
def __init__(self,
agg: int = 3,
sample_rate: int = 8000
):
self.agg = agg
self.sample_rate = sample_rate
self.model = webrtcvad.Vad(mode=agg)
def predict(self, chunk: np.ndarray) -> float:
if chunk.dtype != np.int16:
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
audio_bytes = bytes(chunk)
is_speech = self.model.is_speech(audio_bytes, self.sample_rate)
return 1.0 if is_speech else 0.0
class SileroVoiceClassifier(FrameVoiceClassifier):
def __init__(self,
model_path: str,
sample_rate: int = 8000):
self.model_path = model_path
self.sample_rate = sample_rate
with open(self.model_path, "rb") as f:
model = torch.jit.load(f, map_location="cpu")
self.model = model
self.model.reset_states()
def predict(self, chunk: np.ndarray) -> float:
if self.sample_rate / len(chunk) > 31.25:
raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25))
if chunk.dtype != np.int16:
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
num_samples = len(chunk)
if self.sample_rate == 8000 and num_samples != 256:
raise AssertionError(f"win size must be 32 ms for silero vad. ")
if self.sample_rate == 16000 and num_samples != 512:
raise AssertionError(f"win size must be 32 ms for silero vad. ")
chunk = chunk / 32768
chunk = torch.tensor(chunk, dtype=torch.float32)
speech_prob = self.model(chunk, self.sample_rate).item()
return float(speech_prob)
class CCSoundsClassifier(FrameVoiceClassifier):
def __init__(self,
model_path: str,
sample_rate: int = 8000):
self.model_path = model_path
self.sample_rate = sample_rate
d = self.load_model(Path(model_path))
model = d["model"]
vocabulary = d["vocabulary"]
self.model = model
self.vocabulary = vocabulary
@staticmethod
@lru_cache(maxsize=100)
def load_model(model_file: Path):
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
jit_model_file = tgt_path / "trace_model.zip"
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
with open(jit_model_file.as_posix(), "rb") as f:
model = torch.jit.load(f)
model.eval()
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
def predict(self, chunk: np.ndarray) -> float:
if chunk.dtype != np.int16:
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
chunk = chunk / (1 << 15)
inputs = torch.tensor(chunk, dtype=torch.float32)
inputs = torch.unsqueeze(inputs, dim=0)
with torch.no_grad():
logits = self.model(inputs)
probs = torch.nn.functional.softmax(logits, dim=-1)
voice_idx = self.vocabulary.get_token_index(token="voice", namespace="labels")
probs = probs.cpu()
voice_prob = probs[0][voice_idx]
return float(voice_prob)
class Frame(object):
def __init__(self, signal: np.ndarray, timestamp_s: float):
self.signal = signal
self.timestamp_s = timestamp_s
class RingVad(object):
def __init__(self,
model: FrameVoiceClassifier,
start_ring_rate: float = 0.5,
end_ring_rate: float = 0.5,
frame_size_ms: int = 30,
frame_step_ms: int = 30,
padding_length_ms: int = 300,
max_silence_length_ms: int = 300,
max_speech_length_s: float = 2.0,
min_speech_length_s: float = 0.3,
sample_rate: int = 8000
):
self.model = model
self.start_ring_rate = start_ring_rate
self.end_ring_rate = end_ring_rate
self.frame_size_ms = frame_size_ms
self.frame_step_ms = frame_step_ms
self.padding_length_ms = padding_length_ms
self.max_silence_length_ms = max_silence_length_ms
self.max_speech_length_s = max_speech_length_s
self.min_speech_length_s = min_speech_length_s
self.sample_rate = sample_rate
# frames
self.frame_size = int(sample_rate * (frame_size_ms / 1000.0))
self.frame_step = int(sample_rate * (frame_step_ms / 1000.0))
self.frame_timestamp_s = 0.0
self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16)
# segments
self.num_padding_frames = int(padding_length_ms / frame_step_ms)
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames: List[Frame] = list()
self.segments = list()
# vad segments
self.is_first_segment = True
self.timestamp_start_s = 0.0
self.timestamp_end_s = 0.0
# speech probs
self.speech_probs: List[float] = list()
def reset(self):
# frames
self.frame_size = int(self.sample_rate * (self.frame_size_ms / 1000.0))
self.frame_step = int(self.sample_rate * (self.frame_step_ms / 1000.0))
self.frame_timestamp_s = 0.0
self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16)
# segments
self.num_padding_frames = int(self.padding_length_ms / self.frame_step_ms)
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames: List[Frame] = list()
self.segments = list()
# vad segments
self.is_first_segment = True
self.timestamp_start_s = 0.0
self.timestamp_end_s = 0.0
# speech probs
self.speech_probs: List[float] = list()
def signal_to_frames(self, signal: np.ndarray):
frames = list()
l = len(signal)
duration_s = float(self.frame_step) / self.sample_rate
for offset in range(0, l - self.frame_size + 1, self.frame_step):
sub_signal = signal[offset:offset+self.frame_size]
frame = Frame(sub_signal, self.frame_timestamp_s)
self.frame_timestamp_s += duration_s
frames.append(frame)
return frames
def segments_generator(self, signal: np.ndarray):
# signal rounding
if self.signal_cache is not None:
signal = np.concatenate([self.signal_cache, signal])
# rest
rest = (len(signal) - self.frame_size) % self.frame_step
if rest == 0:
self.signal_cache = None
signal_ = signal
else:
self.signal_cache = signal[-rest:]
signal_ = signal[:-rest]
# frames
frames = self.signal_to_frames(signal_)
for frame in frames:
speech_prob = self.model.predict(frame.signal)
self.speech_probs.append(speech_prob)
if not self.triggered:
self.ring_buffer.append((frame, speech_prob))
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
self.triggered = True
for f, _ in self.ring_buffer:
self.voiced_frames.append(f)
continue
self.voiced_frames.append(frame)
self.ring_buffer.append((frame, speech_prob))
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
segment = [
np.concatenate([f.signal for f in self.voiced_frames]),
self.voiced_frames[0].timestamp_s,
self.voiced_frames[-1].timestamp_s,
]
yield segment
self.triggered = False
self.ring_buffer.clear()
self.voiced_frames = []
continue
def vad_segments_generator(self, segments_generator):
segments = list(segments_generator)
for i, segment in enumerate(segments):
start = round(segment[1], 4)
end = round(segment[2], 4)
if self.timestamp_start_s is None and self.timestamp_end_s is None:
self.timestamp_start_s = start
self.timestamp_end_s = end
continue
if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s:
end_ = self.timestamp_start_s + self.max_speech_length_s
vad_segment = [self.timestamp_start_s, end_]
yield vad_segment
self.timestamp_start_s = end_
silence_length_ms = (start - self.timestamp_end_s) * 1000
if silence_length_ms < self.max_silence_length_ms:
self.timestamp_end_s = end
continue
if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s:
self.timestamp_start_s = start
self.timestamp_end_s = end
continue
vad_segment = [self.timestamp_start_s, self.timestamp_end_s]
yield vad_segment
self.timestamp_start_s = start
self.timestamp_end_s = end
def vad(self, signal: np.ndarray) -> List[list]:
segments = self.segments_generator(signal)
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
return vad_segments
def last_vad_segments(self) -> List[list]:
# last segments
if len(self.voiced_frames) == 0:
segments = []
else:
segment = [
np.concatenate([f.signal for f in self.voiced_frames]),
self.voiced_frames[0].timestamp_s,
self.voiced_frames[-1].timestamp_s
]
segments = [segment]
# last vad segments
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5:
vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]]
return vad_segments
def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
speech_probs_ = list()
for p in speech_probs[1:]:
speech_probs_.extend([p] * frame_step)
pad = (signal.shape[0] - len(speech_probs_))
speech_probs_ = speech_probs_ + [0.0] * pad
speech_probs_ = np.array(speech_probs_, dtype=np.float32)
if len(speech_probs_) != len(signal):
raise AssertionError
return speech_probs_
def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list):
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color='b')
plt.plot(time, speech_probs, color='gray')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点")
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点")
plt.show()
return
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
# default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
# default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/en-PH/2025-05-15/active_media_r_0617d225-f396-4011-a86e-eaf68cdda5a8_3.wav",
type=str,
)
parser.add_argument(
"--model_path",
default=(project_path / "trained_models/silero_vad.jit").as_posix(),
type=str,
)
args = parser.parse_args()
return args
SAMPLE_RATE = 8000
def main():
args = get_args()
sample_rate, signal = wavfile.read(args.wav_file)
if SAMPLE_RATE != sample_rate:
raise AssertionError
# model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE)
model = WebRTCVoiceClassifier(agg=3, sample_rate=SAMPLE_RATE)
# model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
# silero vad
ring_vad = RingVad(model=model,
start_ring_rate=0.2,
end_ring_rate=0.1,
frame_size_ms=32,
frame_step_ms=32,
padding_length_ms=320,
max_silence_length_ms=320,
max_speech_length_s=100,
min_speech_length_s=0.1,
sample_rate=SAMPLE_RATE,
)
# webrtcvad
ring_vad = RingVad(model=model,
start_ring_rate=0.9,
end_ring_rate=0.1,
frame_size_ms=30,
frame_step_ms=30,
padding_length_ms=30,
max_silence_length_ms=0,
max_speech_length_s=100,
min_speech_length_s=0.1,
sample_rate=SAMPLE_RATE,
)
print(ring_vad)
vad_segments = list()
segments = ring_vad.vad(signal)
vad_segments += segments
for segment in segments:
print(segment)
# last vad segment
segments = ring_vad.last_vad_segments()
vad_segments += segments
for segment in segments:
print(segment)
print(ring_vad.speech_probs)
print(len(ring_vad.speech_probs))
# speech_probs
speech_probs = process_speech_probs(
signal=signal,
speech_probs=ring_vad.speech_probs,
frame_step=ring_vad.frame_step,
)
# plot
make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
return
if __name__ == "__main__":
main()