NikolaSelic's picture
Create app.py
f7e8228
import logging
import os
import sys
import traceback
from contextlib import contextmanager
import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
import whisper_timestamped as whisper
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.sources import MicrophoneAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
def concat(chunks, collar=0.05):
"""
Concatenate predictions and audio
given a list of `(diarization, waveform)` pairs
and merge contiguous single-speaker regions
with pauses shorter than `collar` seconds.
"""
first_annotation = chunks[0][0]
first_waveform = chunks[0][1]
annotation = Annotation(uri=first_annotation.uri)
data = []
for ann, wav in chunks:
annotation.update(ann)
data.append(wav.data)
annotation = annotation.support(collar)
window = SlidingWindow(
first_waveform.sliding_window.duration,
first_waveform.sliding_window.step,
first_waveform.sliding_window.start,
)
data = np.concatenate(data, axis=0)
return annotation, SlidingWindowFeature(data, window)
def colorize_transcription(transcription):
colors = 2 * [
"bright_red",
"bright_blue",
"bright_green",
"orange3",
"deep_pink1",
"yellow2",
"magenta",
"cyan",
"bright_magenta",
"dodger_blue2",
]
result = []
for speaker, text in transcription:
if speaker == -1:
# No speakerfound for this text, use default terminal color
result.append(text)
else:
result.append(f"[{colors[speaker]}]{text}")
return "\n".join(result)
# @contextmanager
# def suppress_stdout():
# with open(os.devnull, "w") as devnull:
# old_stdout = sys.stdout
# sys.stdout = devnull
# try:
# yield
# finally:
# sys.stdout = old_stdout
class WhisperTranscriber:
def __init__(self, model="small", device=None):
self.model = whisper.load_model(model, device=device)
self._buffer = ""
def transcribe(self, waveform):
"""Transcribe audio using Whisper"""
# Pad/trim audio to fit 30 seconds as required by Whisper
audio = waveform.data.astype("float32").reshape(-1)
audio = whisper.pad_or_trim(audio)
# Transcribe the given audio while suppressing logs
transcription = whisper.transcribe(
self.model,
audio,
# We use past transcriptions to condition the model
initial_prompt=self._buffer,
verbose=True, # to avoid progress bar
)
return transcription
def identify_speakers(self, transcription, diarization, time_shift):
"""Iterate over transcription segments to assign speakers"""
speaker_captions = []
for segment in transcription["segments"]:
# Crop diarization to the segment timestamps
start = time_shift + segment["words"][0]["start"]
end = time_shift + segment["words"][-1]["end"]
dia = diarization.crop(Segment(start, end))
# Assign a speaker to the segment based on diarization
speakers = dia.labels()
num_speakers = len(speakers)
if num_speakers == 0:
# No speakers were detected
caption = (-1, segment["text"])
elif num_speakers == 1:
# Only one speaker is active in this segment
spk_id = int(speakers[0].split("speaker")[1])
caption = (spk_id, segment["text"])
else:
# Multiple speakers, select the one that speaks the most
max_speaker = int(
np.argmax([dia.label_duration(spk) for spk in speakers])
)
caption = (max_speaker, segment["text"])
speaker_captions.append(caption)
return speaker_captions
def __call__(self, diarization, waveform):
# Step 1: Transcribe
transcription = self.transcribe(waveform)
# Update transcription buffer
self._buffer += transcription["text"]
# The audio may not be the beginning of the conversation
time_shift = waveform.sliding_window.start
# Step 2: Assign speakers
speaker_transcriptions = self.identify_speakers(
transcription, diarization, time_shift
)
return speaker_transcriptions
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
config = PipelineConfig(
duration=5, step=0.5, latency="min", tau_active=0.5, rho_update=0.1, delta_new=0.57
)
dia = OnlineSpeakerDiarization(config)
source = MicrophoneAudioSource(config.sample_rate)
asr = WhisperTranscriber(model="small")
transcription_duration = 2
batch_size = int(transcription_duration // config.step)
source.stream.pipe(
# Format audio stream to sliding windows of 5s with a step of 500ms
dops.rearrange_audio_stream(config.duration, config.step, config.sample_rate),
# Wait until a batch is full
# The output is a list of audio chunks
ops.buffer_with_count(count=batch_size),
# Obtain diarization prediction
# The output is a list of pairs `(diarization, audio chunk)`
ops.map(dia),
# Concatenate 500ms predictions/chunks to form a single 2s chunk
ops.map(concat),
# Ignore this chunk if it does not contain speech
ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
# Obtain speaker-aware transcriptions
# The output is a list of pairs `(speaker: int, caption: str)`
ops.starmap(asr),
ops.map(colorize_transcription),
).subscribe(
on_next=rich.print, # print colored text
on_error=lambda _: traceback.print_exc(), # print stacktrace if error
)
print("Listening...")
source.read()