|
from io import IOBase |
|
from pathlib import Path |
|
from typing import Mapping, Text |
|
from typing import Optional |
|
from typing import Union |
|
|
|
import torch |
|
|
|
from ..diarize import Segment as SegmentX |
|
from .vad import Vad |
|
|
|
AudioFile = Union[Text, Path, IOBase, Mapping] |
|
|
|
|
|
class Silero(Vad): |
|
|
|
def __init__(self, **kwargs): |
|
print(">>Performing voice activity detection using Silero...") |
|
super().__init__(kwargs['vad_onset']) |
|
|
|
self.vad_onset = kwargs['vad_onset'] |
|
self.chunk_size = kwargs['chunk_size'] |
|
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', |
|
model='silero_vad', |
|
force_reload=False, |
|
onnx=False, |
|
trust_repo=True) |
|
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils |
|
|
|
def __call__(self, audio: AudioFile, **kwargs): |
|
"""use silero to get segments of speech""" |
|
|
|
|
|
|
|
sample_rate = audio["sample_rate"] |
|
if sample_rate != 16000: |
|
raise ValueError("Only 16000Hz sample rate is allowed") |
|
|
|
timestamps = self.get_speech_timestamps(audio["waveform"], |
|
model=self.vad_pipeline, |
|
sampling_rate=sample_rate, |
|
max_speech_duration_s=self.chunk_size, |
|
threshold=self.vad_onset |
|
|
|
|
|
|
|
|
|
) |
|
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps] |
|
|
|
@staticmethod |
|
def preprocess_audio(audio): |
|
return audio |
|
|
|
@staticmethod |
|
def merge_chunks(segments_list, |
|
chunk_size, |
|
onset: float = 0.5, |
|
offset: Optional[float] = None, |
|
): |
|
assert chunk_size > 0 |
|
if len(segments_list) == 0: |
|
print("No active speech found in audio") |
|
return [] |
|
assert segments_list, "segments_list is empty." |
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset) |
|
|