Spaces:
Build error
Build error
File size: 5,949 Bytes
f7e8228 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import logging
import os
import sys
import traceback
from contextlib import contextmanager
import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
import whisper_timestamped as whisper
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.sources import MicrophoneAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
def concat(chunks, collar=0.05):
"""
Concatenate predictions and audio
given a list of `(diarization, waveform)` pairs
and merge contiguous single-speaker regions
with pauses shorter than `collar` seconds.
"""
first_annotation = chunks[0][0]
first_waveform = chunks[0][1]
annotation = Annotation(uri=first_annotation.uri)
data = []
for ann, wav in chunks:
annotation.update(ann)
data.append(wav.data)
annotation = annotation.support(collar)
window = SlidingWindow(
first_waveform.sliding_window.duration,
first_waveform.sliding_window.step,
first_waveform.sliding_window.start,
)
data = np.concatenate(data, axis=0)
return annotation, SlidingWindowFeature(data, window)
def colorize_transcription(transcription):
colors = 2 * [
"bright_red",
"bright_blue",
"bright_green",
"orange3",
"deep_pink1",
"yellow2",
"magenta",
"cyan",
"bright_magenta",
"dodger_blue2",
]
result = []
for speaker, text in transcription:
if speaker == -1:
# No speakerfound for this text, use default terminal color
result.append(text)
else:
result.append(f"[{colors[speaker]}]{text}")
return "\n".join(result)
# @contextmanager
# def suppress_stdout():
# with open(os.devnull, "w") as devnull:
# old_stdout = sys.stdout
# sys.stdout = devnull
# try:
# yield
# finally:
# sys.stdout = old_stdout
class WhisperTranscriber:
def __init__(self, model="small", device=None):
self.model = whisper.load_model(model, device=device)
self._buffer = ""
def transcribe(self, waveform):
"""Transcribe audio using Whisper"""
# Pad/trim audio to fit 30 seconds as required by Whisper
audio = waveform.data.astype("float32").reshape(-1)
audio = whisper.pad_or_trim(audio)
# Transcribe the given audio while suppressing logs
transcription = whisper.transcribe(
self.model,
audio,
# We use past transcriptions to condition the model
initial_prompt=self._buffer,
verbose=True, # to avoid progress bar
)
return transcription
def identify_speakers(self, transcription, diarization, time_shift):
"""Iterate over transcription segments to assign speakers"""
speaker_captions = []
for segment in transcription["segments"]:
# Crop diarization to the segment timestamps
start = time_shift + segment["words"][0]["start"]
end = time_shift + segment["words"][-1]["end"]
dia = diarization.crop(Segment(start, end))
# Assign a speaker to the segment based on diarization
speakers = dia.labels()
num_speakers = len(speakers)
if num_speakers == 0:
# No speakers were detected
caption = (-1, segment["text"])
elif num_speakers == 1:
# Only one speaker is active in this segment
spk_id = int(speakers[0].split("speaker")[1])
caption = (spk_id, segment["text"])
else:
# Multiple speakers, select the one that speaks the most
max_speaker = int(
np.argmax([dia.label_duration(spk) for spk in speakers])
)
caption = (max_speaker, segment["text"])
speaker_captions.append(caption)
return speaker_captions
def __call__(self, diarization, waveform):
# Step 1: Transcribe
transcription = self.transcribe(waveform)
# Update transcription buffer
self._buffer += transcription["text"]
# The audio may not be the beginning of the conversation
time_shift = waveform.sliding_window.start
# Step 2: Assign speakers
speaker_transcriptions = self.identify_speakers(
transcription, diarization, time_shift
)
return speaker_transcriptions
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
config = PipelineConfig(
duration=5, step=0.5, latency="min", tau_active=0.5, rho_update=0.1, delta_new=0.57
)
dia = OnlineSpeakerDiarization(config)
source = MicrophoneAudioSource(config.sample_rate)
asr = WhisperTranscriber(model="small")
transcription_duration = 2
batch_size = int(transcription_duration // config.step)
source.stream.pipe(
# Format audio stream to sliding windows of 5s with a step of 500ms
dops.rearrange_audio_stream(config.duration, config.step, config.sample_rate),
# Wait until a batch is full
# The output is a list of audio chunks
ops.buffer_with_count(count=batch_size),
# Obtain diarization prediction
# The output is a list of pairs `(diarization, audio chunk)`
ops.map(dia),
# Concatenate 500ms predictions/chunks to form a single 2s chunk
ops.map(concat),
# Ignore this chunk if it does not contain speech
ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
# Obtain speaker-aware transcriptions
# The output is a list of pairs `(speaker: int, caption: str)`
ops.starmap(asr),
ops.map(colorize_transcription),
).subscribe(
on_next=rich.print, # print colored text
on_error=lambda _: traceback.print_exc(), # print stacktrace if error
)
print("Listening...")
source.read()
|