File size: 4,890 Bytes
00e4381
 
 
 
 
 
 
 
 
 
 
bebc2b8
 
 
 
00e4381
 
 
bebc2b8
00e4381
 
 
 
 
bebc2b8
00e4381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bebc2b8
00e4381
 
 
 
 
 
bebc2b8
 
 
00e4381
 
 
 
 
 
bebc2b8
 
 
 
 
00e4381
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import collections

from typing import List, Tuple


class PostProcess(object):
    def __init__(self,
                 start_ring_rate: float = 0.5,
                 end_ring_rate: float = 0.5,
                 ring_max_length: int = 10,
                 min_silence_length: int = 6,
                 max_speech_length: float = 100000,
                 min_speech_length: float = 15,
                 ):
        self.start_ring_rate = start_ring_rate
        self.end_ring_rate = end_ring_rate
        self.ring_max_length = ring_max_length
        self.max_speech_length = max_speech_length
        self.min_speech_length = min_speech_length
        self.min_silence_length = min_silence_length

        # segments
        self.ring_buffer = collections.deque(maxlen=self.ring_max_length)
        self.triggered = False

        # vad segments
        self.is_first_segment = True
        self.start_idx: int = -1
        self.end_idx: int = -1

        # speech probs
        self.voiced_frames: List[Tuple[int, float]] = list()

    def segments_generator(self, probs: List[float]):
        for idx, prob in enumerate(probs):
            if not self.triggered:
                self.ring_buffer.append((idx, prob))
                num_voiced = sum([p for _, p in self.ring_buffer])

                if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
                    self.triggered = True
                    for idx_prob_t in self.ring_buffer:
                        self.voiced_frames.append(idx_prob_t)
                continue

            idx_prob_t = (idx, prob)
            self.voiced_frames.append(idx_prob_t)
            self.ring_buffer.append(idx_prob_t)
            num_voiced = sum([p for _, p in self.ring_buffer])

            if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
                segment = [
                    self.voiced_frames[0][0],
                    self.voiced_frames[-1][0],
                ]
                yield segment
                self.triggered = False
                self.ring_buffer.clear()
                self.voiced_frames: List[Tuple[int, float]] = list()
                continue

    def vad_segments_generator(self, segments_generator):
        segments = list(segments_generator)

        for i, segment in enumerate(segments):
            start = segment[0]
            end = segment[1]

            if self.start_idx == -1 and self.end_idx == -1:
                self.start_idx = start
                self.end_idx = end
                continue

            if self.end_idx - self.start_idx > self.max_speech_length:
                end_ = self.start_idx + self.max_speech_length
                vad_segment = [self.start_idx, end_]
                yield vad_segment
                self.start_idx = end_

            silence_length = start - self.end_idx
            if silence_length < self.min_silence_length:
                self.end_idx = end
                continue

            if self.end_idx - self.start_idx < self.min_speech_length:
                self.start_idx = start
                self.end_idx = end
                continue

            vad_segment = [self.start_idx, self.end_idx]
            yield vad_segment
            self.start_idx = start
            self.end_idx = end

    def vad(self, probs: List[float]) -> List[list]:
        segments = self.segments_generator(probs)
        vad_segments = self.vad_segments_generator(segments)
        vad_segments = list(vad_segments)
        return vad_segments

    def last_vad_segments(self) -> List[list]:
        # last segments
        if len(self.voiced_frames) == 0:
            segments = []
        else:
            segment = [
                self.voiced_frames[0][0],
                self.voiced_frames[-1][0]
            ]
            segments = [segment]

        # last vad segments
        vad_segments = self.vad_segments_generator(segments)
        vad_segments = list(vad_segments)

        if self.start_idx > 1e-5 and self.end_idx > 1e-5:
            vad_segments = vad_segments + [[self.start_idx, self.end_idx]]
        return vad_segments

    def get_vad_segments(self, probs: List[float]):
        vad_segments = list()
        segments = self.vad(probs)
        vad_segments += segments
        segments = self.last_vad_segments()
        vad_segments += segments

        return vad_segments

    def get_vad_flags(self, probs: List[float], vad_segments: List[Tuple[int, int]]):
        result = [0] * len(probs)
        for begin, end in vad_segments:
            result[begin: end] = [1] * (end - begin)

        return result

    def post_process(self, probs: List[float]):
        vad_segments = self.get_vad_segments(probs)
        vad_flags = self.get_vad_flags(probs, vad_segments)
        return vad_flags


if __name__ == "__main__":
    pass