#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import collections from functools import lru_cache import os from pathlib import Path import shutil import tempfile import zipfile from typing import List import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile import torch import webrtcvad from project_settings import project_path from toolbox.torch.utils.data.vocabulary import Vocabulary class FrameVoiceClassifier(object): def predict(self, chunk: np.ndarray) -> float: raise NotImplementedError class WebRTCVoiceClassifier(FrameVoiceClassifier): def __init__(self, agg: int = 3, sample_rate: int = 8000 ): self.agg = agg self.sample_rate = sample_rate self.model = webrtcvad.Vad(mode=agg) def predict(self, chunk: np.ndarray) -> float: if chunk.dtype != np.int16: raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) audio_bytes = bytes(chunk) is_speech = self.model.is_speech(audio_bytes, self.sample_rate) return 1.0 if is_speech else 0.0 class SileroVoiceClassifier(FrameVoiceClassifier): def __init__(self, model_path: str, sample_rate: int = 8000): self.model_path = model_path self.sample_rate = sample_rate with open(self.model_path, "rb") as f: model = torch.jit.load(f, map_location="cpu") self.model = model self.model.reset_states() def predict(self, chunk: np.ndarray) -> float: if self.sample_rate / len(chunk) > 31.25: raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25)) if chunk.dtype != np.int16: raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) num_samples = len(chunk) if self.sample_rate == 8000 and num_samples != 256: raise AssertionError(f"win size must be 32 ms for silero vad. ") if self.sample_rate == 16000 and num_samples != 512: raise AssertionError(f"win size must be 32 ms for silero vad. ") chunk = chunk / 32768 chunk = torch.tensor(chunk, dtype=torch.float32) speech_prob = self.model(chunk, self.sample_rate).item() return float(speech_prob) class CCSoundsClassifier(FrameVoiceClassifier): def __init__(self, model_path: str, sample_rate: int = 8000): self.model_path = model_path self.sample_rate = sample_rate d = self.load_model(Path(model_path)) model = d["model"] vocabulary = d["vocabulary"] self.model = model self.vocabulary = vocabulary @staticmethod @lru_cache(maxsize=100) def load_model(model_file: Path): with zipfile.ZipFile(model_file, "r") as f_zip: out_root = Path(tempfile.gettempdir()) / "cc_audio_8" if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) tgt_path = out_root / model_file.stem jit_model_file = tgt_path / "trace_model.zip" vocab_path = tgt_path / "vocabulary" vocabulary = Vocabulary.from_files(vocab_path.as_posix()) with open(jit_model_file.as_posix(), "rb") as f: model = torch.jit.load(f) model.eval() shutil.rmtree(tgt_path) d = { "model": model, "vocabulary": vocabulary } return d def predict(self, chunk: np.ndarray) -> float: if chunk.dtype != np.int16: raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) chunk = chunk / (1 << 15) inputs = torch.tensor(chunk, dtype=torch.float32) inputs = torch.unsqueeze(inputs, dim=0) with torch.no_grad(): logits = self.model(inputs) probs = torch.nn.functional.softmax(logits, dim=-1) voice_idx = self.vocabulary.get_token_index(token="voice", namespace="labels") probs = probs.cpu() voice_prob = probs[0][voice_idx] return float(voice_prob) class Frame(object): def __init__(self, signal: np.ndarray, timestamp_s: float): self.signal = signal self.timestamp_s = timestamp_s class RingVad(object): def __init__(self, model: FrameVoiceClassifier, start_ring_rate: float = 0.5, end_ring_rate: float = 0.5, frame_size_ms: int = 30, frame_step_ms: int = 30, padding_length_ms: int = 300, max_silence_length_ms: int = 300, max_speech_length_s: float = 2.0, min_speech_length_s: float = 0.3, sample_rate: int = 8000 ): self.model = model self.start_ring_rate = start_ring_rate self.end_ring_rate = end_ring_rate self.frame_size_ms = frame_size_ms self.frame_step_ms = frame_step_ms self.padding_length_ms = padding_length_ms self.max_silence_length_ms = max_silence_length_ms self.max_speech_length_s = max_speech_length_s self.min_speech_length_s = min_speech_length_s self.sample_rate = sample_rate # frames self.frame_size = int(sample_rate * (frame_size_ms / 1000.0)) self.frame_step = int(sample_rate * (frame_step_ms / 1000.0)) self.frame_timestamp_s = 0.0 self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16) # segments self.num_padding_frames = int(padding_length_ms / frame_step_ms) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames: List[Frame] = list() self.segments = list() # vad segments self.is_first_segment = True self.timestamp_start_s = 0.0 self.timestamp_end_s = 0.0 # speech probs self.speech_probs: List[float] = list() def reset(self): # frames self.frame_size = int(self.sample_rate * (self.frame_size_ms / 1000.0)) self.frame_step = int(self.sample_rate * (self.frame_step_ms / 1000.0)) self.frame_timestamp_s = 0.0 self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16) # segments self.num_padding_frames = int(self.padding_length_ms / self.frame_step_ms) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames: List[Frame] = list() self.segments = list() # vad segments self.is_first_segment = True self.timestamp_start_s = 0.0 self.timestamp_end_s = 0.0 # speech probs self.speech_probs: List[float] = list() def signal_to_frames(self, signal: np.ndarray): frames = list() l = len(signal) duration_s = float(self.frame_step) / self.sample_rate for offset in range(0, l - self.frame_size + 1, self.frame_step): sub_signal = signal[offset:offset+self.frame_size] frame = Frame(sub_signal, self.frame_timestamp_s) self.frame_timestamp_s += duration_s frames.append(frame) return frames def segments_generator(self, signal: np.ndarray): # signal rounding if self.signal_cache is not None: signal = np.concatenate([self.signal_cache, signal]) # rest rest = (len(signal) - self.frame_size) % self.frame_step if rest == 0: self.signal_cache = None signal_ = signal else: self.signal_cache = signal[-rest:] signal_ = signal[:-rest] # frames frames = self.signal_to_frames(signal_) for frame in frames: speech_prob = self.model.predict(frame.signal) self.speech_probs.append(speech_prob) if not self.triggered: self.ring_buffer.append((frame, speech_prob)) num_voiced = sum([p for _, p in self.ring_buffer]) if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen: self.triggered = True for f, _ in self.ring_buffer: self.voiced_frames.append(f) continue self.voiced_frames.append(frame) self.ring_buffer.append((frame, speech_prob)) num_voiced = sum([p for _, p in self.ring_buffer]) if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen: segment = [ np.concatenate([f.signal for f in self.voiced_frames]), self.voiced_frames[0].timestamp_s, self.voiced_frames[-1].timestamp_s, ] yield segment self.triggered = False self.ring_buffer.clear() self.voiced_frames = [] continue def vad_segments_generator(self, segments_generator): segments = list(segments_generator) for i, segment in enumerate(segments): start = round(segment[1], 4) end = round(segment[2], 4) if self.timestamp_start_s is None and self.timestamp_end_s is None: self.timestamp_start_s = start self.timestamp_end_s = end continue if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s: end_ = self.timestamp_start_s + self.max_speech_length_s vad_segment = [self.timestamp_start_s, end_] yield vad_segment self.timestamp_start_s = end_ silence_length_ms = (start - self.timestamp_end_s) * 1000 if silence_length_ms < self.max_silence_length_ms: self.timestamp_end_s = end continue if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s: self.timestamp_start_s = start self.timestamp_end_s = end continue vad_segment = [self.timestamp_start_s, self.timestamp_end_s] yield vad_segment self.timestamp_start_s = start self.timestamp_end_s = end def vad(self, signal: np.ndarray) -> List[list]: segments = self.segments_generator(signal) vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) return vad_segments def last_vad_segments(self) -> List[list]: # last segments if len(self.voiced_frames) == 0: segments = [] else: segment = [ np.concatenate([f.signal for f in self.voiced_frames]), self.voiced_frames[0].timestamp_s, self.voiced_frames[-1].timestamp_s ] segments = [segment] # last vad segments vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5: vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]] return vad_segments def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray: speech_probs_ = list() for p in speech_probs[1:]: speech_probs_.extend([p] * frame_step) pad = (signal.shape[0] - len(speech_probs_)) speech_probs_ = speech_probs_ + [0.0] * pad speech_probs_ = np.array(speech_probs_, dtype=np.float32) if len(speech_probs_) != len(signal): raise AssertionError return speech_probs_ def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list): time = np.arange(0, len(signal)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(time, signal / 32768, color='b') plt.plot(time, speech_probs, color='gray') for start, end in vad_segments: plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点") plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点") plt.show() return def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--wav_file", # default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(), # default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav", default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/en-PH/2025-05-15/active_media_r_0617d225-f396-4011-a86e-eaf68cdda5a8_3.wav", type=str, ) parser.add_argument( "--model_path", default=(project_path / "trained_models/silero_vad.jit").as_posix(), type=str, ) args = parser.parse_args() return args SAMPLE_RATE = 8000 def main(): args = get_args() sample_rate, signal = wavfile.read(args.wav_file) if SAMPLE_RATE != sample_rate: raise AssertionError # model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE) model = WebRTCVoiceClassifier(agg=3, sample_rate=SAMPLE_RATE) # model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix()) # silero vad ring_vad = RingVad(model=model, start_ring_rate=0.2, end_ring_rate=0.1, frame_size_ms=32, frame_step_ms=32, padding_length_ms=320, max_silence_length_ms=320, max_speech_length_s=100, min_speech_length_s=0.1, sample_rate=SAMPLE_RATE, ) # webrtcvad ring_vad = RingVad(model=model, start_ring_rate=0.9, end_ring_rate=0.1, frame_size_ms=30, frame_step_ms=30, padding_length_ms=30, max_silence_length_ms=0, max_speech_length_s=100, min_speech_length_s=0.1, sample_rate=SAMPLE_RATE, ) print(ring_vad) vad_segments = list() segments = ring_vad.vad(signal) vad_segments += segments for segment in segments: print(segment) # last vad segment segments = ring_vad.last_vad_segments() vad_segments += segments for segment in segments: print(segment) print(ring_vad.speech_probs) print(len(ring_vad.speech_probs)) # speech_probs speech_probs = process_speech_probs( signal=signal, speech_probs=ring_vad.speech_probs, frame_step=ring_vad.frame_step, ) # plot make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments) return if __name__ == "__main__": main()