|
|
|
|
|
import collections |
|
|
|
from typing import List, Tuple |
|
|
|
|
|
class PostProcess(object): |
|
def __init__(self, |
|
start_ring_rate: float = 0.5, |
|
end_ring_rate: float = 0.5, |
|
ring_max_length: int = 10, |
|
min_silence_length: int = 6, |
|
max_speech_length: float = 100000, |
|
min_speech_length: float = 15, |
|
): |
|
self.start_ring_rate = start_ring_rate |
|
self.end_ring_rate = end_ring_rate |
|
self.ring_max_length = ring_max_length |
|
self.max_speech_length = max_speech_length |
|
self.min_speech_length = min_speech_length |
|
self.min_silence_length = min_silence_length |
|
|
|
|
|
self.ring_buffer = collections.deque(maxlen=self.ring_max_length) |
|
self.triggered = False |
|
|
|
|
|
self.is_first_segment = True |
|
self.start_idx: int = -1 |
|
self.end_idx: int = -1 |
|
|
|
|
|
self.voiced_frames: List[Tuple[int, float]] = list() |
|
|
|
def segments_generator(self, probs: List[float]): |
|
for idx, prob in enumerate(probs): |
|
if not self.triggered: |
|
self.ring_buffer.append((idx, prob)) |
|
num_voiced = sum([p for _, p in self.ring_buffer]) |
|
|
|
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen: |
|
self.triggered = True |
|
for idx_prob_t in self.ring_buffer: |
|
self.voiced_frames.append(idx_prob_t) |
|
continue |
|
|
|
idx_prob_t = (idx, prob) |
|
self.voiced_frames.append(idx_prob_t) |
|
self.ring_buffer.append(idx_prob_t) |
|
num_voiced = sum([p for _, p in self.ring_buffer]) |
|
|
|
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen: |
|
segment = [ |
|
self.voiced_frames[0][0], |
|
self.voiced_frames[-1][0], |
|
] |
|
yield segment |
|
self.triggered = False |
|
self.ring_buffer.clear() |
|
self.voiced_frames: List[Tuple[int, float]] = list() |
|
continue |
|
|
|
def vad_segments_generator(self, segments_generator): |
|
segments = list(segments_generator) |
|
|
|
for i, segment in enumerate(segments): |
|
start = segment[0] |
|
end = segment[1] |
|
|
|
if self.start_idx == -1 and self.end_idx == -1: |
|
self.start_idx = start |
|
self.end_idx = end |
|
continue |
|
|
|
if self.end_idx - self.start_idx > self.max_speech_length: |
|
end_ = self.start_idx + self.max_speech_length |
|
vad_segment = [self.start_idx, end_] |
|
yield vad_segment |
|
self.start_idx = end_ |
|
|
|
silence_length = start - self.end_idx |
|
if silence_length < self.min_silence_length: |
|
self.end_idx = end |
|
continue |
|
|
|
if self.end_idx - self.start_idx < self.min_speech_length: |
|
self.start_idx = start |
|
self.end_idx = end |
|
continue |
|
|
|
vad_segment = [self.start_idx, self.end_idx] |
|
yield vad_segment |
|
self.start_idx = start |
|
self.end_idx = end |
|
|
|
def vad(self, probs: List[float]) -> List[list]: |
|
segments = self.segments_generator(probs) |
|
vad_segments = self.vad_segments_generator(segments) |
|
vad_segments = list(vad_segments) |
|
return vad_segments |
|
|
|
def last_vad_segments(self) -> List[list]: |
|
|
|
if len(self.voiced_frames) == 0: |
|
segments = [] |
|
else: |
|
segment = [ |
|
self.voiced_frames[0][0], |
|
self.voiced_frames[-1][0] |
|
] |
|
segments = [segment] |
|
|
|
|
|
vad_segments = self.vad_segments_generator(segments) |
|
vad_segments = list(vad_segments) |
|
|
|
if self.start_idx > 1e-5 and self.end_idx > 1e-5: |
|
vad_segments = vad_segments + [[self.start_idx, self.end_idx]] |
|
return vad_segments |
|
|
|
def get_vad_segments(self, probs: List[float]): |
|
vad_segments = list() |
|
segments = self.vad(probs) |
|
vad_segments += segments |
|
segments = self.last_vad_segments() |
|
vad_segments += segments |
|
|
|
return vad_segments |
|
|
|
def get_vad_flags(self, probs: List[float], vad_segments: List[Tuple[int, int]]): |
|
result = [0] * len(probs) |
|
for begin, end in vad_segments: |
|
result[begin: end] = [1] * (end - begin) |
|
|
|
return result |
|
|
|
def post_process(self, probs: List[float]): |
|
vad_segments = self.get_vad_segments(probs) |
|
vad_flags = self.get_vad_flags(probs, vad_segments) |
|
return vad_flags |
|
|
|
|
|
if __name__ == "__main__": |
|
pass |
|
|