#!/usr/bin/python3 # -*- coding: utf-8 -*- import collections from typing import List, Tuple class PostProcess(object): def __init__(self, start_ring_rate: float = 0.5, end_ring_rate: float = 0.5, ring_max_length: int = 10, min_silence_length: int = 6, max_speech_length: float = 100000, min_speech_length: float = 15, ): self.start_ring_rate = start_ring_rate self.end_ring_rate = end_ring_rate self.ring_max_length = ring_max_length self.max_speech_length = max_speech_length self.min_speech_length = min_speech_length self.min_silence_length = min_silence_length # segments self.ring_buffer = collections.deque(maxlen=self.ring_max_length) self.triggered = False # vad segments self.is_first_segment = True self.start_idx: int = -1 self.end_idx: int = -1 # speech probs self.voiced_frames: List[Tuple[int, float]] = list() def segments_generator(self, probs: List[float]): for idx, prob in enumerate(probs): if not self.triggered: self.ring_buffer.append((idx, prob)) num_voiced = sum([p for _, p in self.ring_buffer]) if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen: self.triggered = True for idx_prob_t in self.ring_buffer: self.voiced_frames.append(idx_prob_t) continue idx_prob_t = (idx, prob) self.voiced_frames.append(idx_prob_t) self.ring_buffer.append(idx_prob_t) num_voiced = sum([p for _, p in self.ring_buffer]) if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen: segment = [ self.voiced_frames[0][0], self.voiced_frames[-1][0], ] yield segment self.triggered = False self.ring_buffer.clear() self.voiced_frames: List[Tuple[int, float]] = list() continue def vad_segments_generator(self, segments_generator): segments = list(segments_generator) for i, segment in enumerate(segments): start = segment[0] end = segment[1] if self.start_idx == -1 and self.end_idx == -1: self.start_idx = start self.end_idx = end continue if self.end_idx - self.start_idx > self.max_speech_length: end_ = self.start_idx + self.max_speech_length vad_segment = [self.start_idx, end_] yield vad_segment self.start_idx = end_ silence_length = start - self.end_idx if silence_length < self.min_silence_length: self.end_idx = end continue if self.end_idx - self.start_idx < self.min_speech_length: self.start_idx = start self.end_idx = end continue vad_segment = [self.start_idx, self.end_idx] yield vad_segment self.start_idx = start self.end_idx = end def vad(self, probs: List[float]) -> List[list]: segments = self.segments_generator(probs) vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) return vad_segments def last_vad_segments(self) -> List[list]: # last segments if len(self.voiced_frames) == 0: segments = [] else: segment = [ self.voiced_frames[0][0], self.voiced_frames[-1][0] ] segments = [segment] # last vad segments vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) if self.start_idx > 1e-5 and self.end_idx > 1e-5: vad_segments = vad_segments + [[self.start_idx, self.end_idx]] return vad_segments def get_vad_segments(self, probs: List[float]): vad_segments = list() segments = self.vad(probs) vad_segments += segments segments = self.last_vad_segments() vad_segments += segments return vad_segments def get_vad_flags(self, probs: List[float], vad_segments: List[Tuple[int, int]]): result = [0] * len(probs) for begin, end in vad_segments: result[begin: end] = [1] * (end - begin) return result def post_process(self, probs: List[float]): vad_segments = self.get_vad_segments(probs) vad_flags = self.get_vad_flags(probs, vad_segments) return vad_flags if __name__ == "__main__": pass