cc_vad / toolbox /vad /utils.py
HoneyTian's picture
update
bebc2b8
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import collections
from typing import List, Tuple
class PostProcess(object):
def __init__(self,
start_ring_rate: float = 0.5,
end_ring_rate: float = 0.5,
ring_max_length: int = 10,
min_silence_length: int = 6,
max_speech_length: float = 100000,
min_speech_length: float = 15,
):
self.start_ring_rate = start_ring_rate
self.end_ring_rate = end_ring_rate
self.ring_max_length = ring_max_length
self.max_speech_length = max_speech_length
self.min_speech_length = min_speech_length
self.min_silence_length = min_silence_length
# segments
self.ring_buffer = collections.deque(maxlen=self.ring_max_length)
self.triggered = False
# vad segments
self.is_first_segment = True
self.start_idx: int = -1
self.end_idx: int = -1
# speech probs
self.voiced_frames: List[Tuple[int, float]] = list()
def segments_generator(self, probs: List[float]):
for idx, prob in enumerate(probs):
if not self.triggered:
self.ring_buffer.append((idx, prob))
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
self.triggered = True
for idx_prob_t in self.ring_buffer:
self.voiced_frames.append(idx_prob_t)
continue
idx_prob_t = (idx, prob)
self.voiced_frames.append(idx_prob_t)
self.ring_buffer.append(idx_prob_t)
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
segment = [
self.voiced_frames[0][0],
self.voiced_frames[-1][0],
]
yield segment
self.triggered = False
self.ring_buffer.clear()
self.voiced_frames: List[Tuple[int, float]] = list()
continue
def vad_segments_generator(self, segments_generator):
segments = list(segments_generator)
for i, segment in enumerate(segments):
start = segment[0]
end = segment[1]
if self.start_idx == -1 and self.end_idx == -1:
self.start_idx = start
self.end_idx = end
continue
if self.end_idx - self.start_idx > self.max_speech_length:
end_ = self.start_idx + self.max_speech_length
vad_segment = [self.start_idx, end_]
yield vad_segment
self.start_idx = end_
silence_length = start - self.end_idx
if silence_length < self.min_silence_length:
self.end_idx = end
continue
if self.end_idx - self.start_idx < self.min_speech_length:
self.start_idx = start
self.end_idx = end
continue
vad_segment = [self.start_idx, self.end_idx]
yield vad_segment
self.start_idx = start
self.end_idx = end
def vad(self, probs: List[float]) -> List[list]:
segments = self.segments_generator(probs)
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
return vad_segments
def last_vad_segments(self) -> List[list]:
# last segments
if len(self.voiced_frames) == 0:
segments = []
else:
segment = [
self.voiced_frames[0][0],
self.voiced_frames[-1][0]
]
segments = [segment]
# last vad segments
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
if self.start_idx > 1e-5 and self.end_idx > 1e-5:
vad_segments = vad_segments + [[self.start_idx, self.end_idx]]
return vad_segments
def get_vad_segments(self, probs: List[float]):
vad_segments = list()
segments = self.vad(probs)
vad_segments += segments
segments = self.last_vad_segments()
vad_segments += segments
return vad_segments
def get_vad_flags(self, probs: List[float], vad_segments: List[Tuple[int, int]]):
result = [0] * len(probs)
for begin, end in vad_segments:
result[begin: end] = [1] * (end - begin)
return result
def post_process(self, probs: List[float]):
vad_segments = self.get_vad_segments(probs)
vad_flags = self.get_vad_flags(probs, vad_segments)
return vad_flags
if __name__ == "__main__":
pass