File size: 4,890 Bytes
00e4381 bebc2b8 00e4381 bebc2b8 00e4381 bebc2b8 00e4381 bebc2b8 00e4381 bebc2b8 00e4381 bebc2b8 00e4381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import collections
from typing import List, Tuple
class PostProcess(object):
def __init__(self,
start_ring_rate: float = 0.5,
end_ring_rate: float = 0.5,
ring_max_length: int = 10,
min_silence_length: int = 6,
max_speech_length: float = 100000,
min_speech_length: float = 15,
):
self.start_ring_rate = start_ring_rate
self.end_ring_rate = end_ring_rate
self.ring_max_length = ring_max_length
self.max_speech_length = max_speech_length
self.min_speech_length = min_speech_length
self.min_silence_length = min_silence_length
# segments
self.ring_buffer = collections.deque(maxlen=self.ring_max_length)
self.triggered = False
# vad segments
self.is_first_segment = True
self.start_idx: int = -1
self.end_idx: int = -1
# speech probs
self.voiced_frames: List[Tuple[int, float]] = list()
def segments_generator(self, probs: List[float]):
for idx, prob in enumerate(probs):
if not self.triggered:
self.ring_buffer.append((idx, prob))
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
self.triggered = True
for idx_prob_t in self.ring_buffer:
self.voiced_frames.append(idx_prob_t)
continue
idx_prob_t = (idx, prob)
self.voiced_frames.append(idx_prob_t)
self.ring_buffer.append(idx_prob_t)
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
segment = [
self.voiced_frames[0][0],
self.voiced_frames[-1][0],
]
yield segment
self.triggered = False
self.ring_buffer.clear()
self.voiced_frames: List[Tuple[int, float]] = list()
continue
def vad_segments_generator(self, segments_generator):
segments = list(segments_generator)
for i, segment in enumerate(segments):
start = segment[0]
end = segment[1]
if self.start_idx == -1 and self.end_idx == -1:
self.start_idx = start
self.end_idx = end
continue
if self.end_idx - self.start_idx > self.max_speech_length:
end_ = self.start_idx + self.max_speech_length
vad_segment = [self.start_idx, end_]
yield vad_segment
self.start_idx = end_
silence_length = start - self.end_idx
if silence_length < self.min_silence_length:
self.end_idx = end
continue
if self.end_idx - self.start_idx < self.min_speech_length:
self.start_idx = start
self.end_idx = end
continue
vad_segment = [self.start_idx, self.end_idx]
yield vad_segment
self.start_idx = start
self.end_idx = end
def vad(self, probs: List[float]) -> List[list]:
segments = self.segments_generator(probs)
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
return vad_segments
def last_vad_segments(self) -> List[list]:
# last segments
if len(self.voiced_frames) == 0:
segments = []
else:
segment = [
self.voiced_frames[0][0],
self.voiced_frames[-1][0]
]
segments = [segment]
# last vad segments
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
if self.start_idx > 1e-5 and self.end_idx > 1e-5:
vad_segments = vad_segments + [[self.start_idx, self.end_idx]]
return vad_segments
def get_vad_segments(self, probs: List[float]):
vad_segments = list()
segments = self.vad(probs)
vad_segments += segments
segments = self.last_vad_segments()
vad_segments += segments
return vad_segments
def get_vad_flags(self, probs: List[float], vad_segments: List[Tuple[int, int]]):
result = [0] * len(probs)
for begin, end in vad_segments:
result[begin: end] = [1] * (end - begin)
return result
def post_process(self, probs: List[float]):
vad_segments = self.get_vad_segments(probs)
vad_flags = self.get_vad_flags(probs, vad_segments)
return vad_flags
if __name__ == "__main__":
pass
|