|
|
|
|
|
import json |
|
|
import logging |
|
|
import threading |
|
|
import time |
|
|
import config |
|
|
import librosa |
|
|
import numpy as np |
|
|
import soundfile |
|
|
from pywhispercpp.model import Model |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
class ServeClientBase(object): |
|
|
RATE = 16000 |
|
|
SERVER_READY = "SERVER_READY" |
|
|
DISCONNECT = "DISCONNECT" |
|
|
|
|
|
def __init__(self, client_uid, websocket): |
|
|
self.client_uid = client_uid |
|
|
self.websocket = websocket |
|
|
self.frames = b"" |
|
|
self.timestamp_offset = 0.0 |
|
|
self.frames_np = None |
|
|
self.frames_offset = 0.0 |
|
|
self.text = [] |
|
|
self.current_out = '' |
|
|
self.prev_out = '' |
|
|
self.t_start = None |
|
|
self.exit = False |
|
|
self.same_output_count = 0 |
|
|
self.show_prev_out_thresh = 5 |
|
|
self.add_pause_thresh = 3 |
|
|
self.transcript = [] |
|
|
self.send_last_n_segments = 10 |
|
|
|
|
|
|
|
|
self.pick_previous_segments = 2 |
|
|
|
|
|
|
|
|
self.lock = threading.Lock() |
|
|
|
|
|
def speech_to_text(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def transcribe_audio(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def handle_transcription_output(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def add_frames(self, frame_np): |
|
|
""" |
|
|
Add audio frames to the ongoing audio stream buffer. |
|
|
|
|
|
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition |
|
|
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size |
|
|
to prevent excessive memory usage. |
|
|
|
|
|
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds |
|
|
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided |
|
|
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. |
|
|
|
|
|
Args: |
|
|
frame_np (numpy.ndarray): The audio frame data as a NumPy array. |
|
|
|
|
|
""" |
|
|
self.lock.acquire() |
|
|
if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE: |
|
|
self.frames_offset += 30.0 |
|
|
self.frames_np = self.frames_np[int(30 * self.RATE):] |
|
|
|
|
|
|
|
|
|
|
|
if self.timestamp_offset < self.frames_offset: |
|
|
self.timestamp_offset = self.frames_offset |
|
|
if self.frames_np is None: |
|
|
self.frames_np = frame_np.copy() |
|
|
else: |
|
|
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) |
|
|
self.lock.release() |
|
|
|
|
|
def clip_audio_if_no_valid_segment(self): |
|
|
""" |
|
|
Update the timestamp offset based on audio buffer status. |
|
|
Clip audio if the current chunk exceeds 30 seconds, this basically implies that |
|
|
no valid segment for the last 30 seconds from whisper |
|
|
""" |
|
|
with self.lock: |
|
|
if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE: |
|
|
duration = self.frames_np.shape[0] / self.RATE |
|
|
self.timestamp_offset = self.frames_offset + duration - 5 |
|
|
|
|
|
def get_audio_chunk_for_processing(self): |
|
|
""" |
|
|
Retrieves the next chunk of audio data for processing based on the current offsets. |
|
|
|
|
|
Calculates which part of the audio data should be processed next, based on |
|
|
the difference between the current timestamp offset and the frame's offset, scaled by |
|
|
the audio sample rate (RATE). It then returns this chunk of audio data along with its |
|
|
duration in seconds. |
|
|
|
|
|
Returns: |
|
|
tuple: A tuple containing: |
|
|
- input_bytes (np.ndarray): The next chunk of audio data to be processed. |
|
|
- duration (float): The duration of the audio chunk in seconds. |
|
|
""" |
|
|
with self.lock: |
|
|
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE) |
|
|
input_bytes = self.frames_np[int(samples_take):].copy() |
|
|
duration = input_bytes.shape[0] / self.RATE |
|
|
return input_bytes, duration |
|
|
|
|
|
def prepare_segments(self, last_segment=None): |
|
|
""" |
|
|
Prepares the segments of transcribed text to be sent to the client. |
|
|
|
|
|
This method compiles the recent segments of transcribed text, ensuring that only the |
|
|
specified number of the most recent segments are included. It also appends the most |
|
|
recent segment of text if provided (which is considered incomplete because of the possibility |
|
|
of the last word being truncated in the audio chunk). |
|
|
|
|
|
Args: |
|
|
last_segment (str, optional): The most recent segment of transcribed text to be added |
|
|
to the list of segments. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
list: A list of transcribed text segments to be sent to the client. |
|
|
""" |
|
|
segments = [] |
|
|
if len(self.transcript) >= self.send_last_n_segments: |
|
|
segments = self.transcript[-self.send_last_n_segments:].copy() |
|
|
else: |
|
|
segments = self.transcript.copy() |
|
|
if last_segment is not None: |
|
|
segments = segments + [last_segment] |
|
|
logging.info(f"{segments}") |
|
|
return segments |
|
|
|
|
|
def get_audio_chunk_duration(self, input_bytes): |
|
|
""" |
|
|
Calculates the duration of the provided audio chunk. |
|
|
|
|
|
Args: |
|
|
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration. |
|
|
|
|
|
Returns: |
|
|
float: The duration of the audio chunk in seconds. |
|
|
""" |
|
|
return input_bytes.shape[0] / self.RATE |
|
|
|
|
|
def send_transcription_to_client(self, segments): |
|
|
""" |
|
|
Sends the specified transcription segments to the client over the websocket connection. |
|
|
|
|
|
This method formats the transcription segments into a JSON object and attempts to send |
|
|
this object to the client. If an error occurs during the send operation, it logs the error. |
|
|
|
|
|
Returns: |
|
|
segments (list): A list of transcription segments to be sent to the client. |
|
|
""" |
|
|
try: |
|
|
self.websocket.send( |
|
|
json.dumps({ |
|
|
"uid": self.client_uid, |
|
|
"segments": segments, |
|
|
}) |
|
|
) |
|
|
except Exception as e: |
|
|
logging.error(f"[ERROR]: Sending data to client: {e}") |
|
|
|
|
|
def disconnect(self): |
|
|
""" |
|
|
Notify the client of disconnection and send a disconnect message. |
|
|
|
|
|
This method sends a disconnect message to the client via the WebSocket connection to notify them |
|
|
that the transcription service is disconnecting gracefully. |
|
|
|
|
|
""" |
|
|
self.websocket.send(json.dumps({ |
|
|
"uid": self.client_uid, |
|
|
"message": self.DISCONNECT |
|
|
})) |
|
|
|
|
|
def cleanup(self): |
|
|
""" |
|
|
Perform cleanup tasks before exiting the transcription service. |
|
|
|
|
|
This method performs necessary cleanup tasks, including stopping the transcription thread, marking |
|
|
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources |
|
|
associated with the transcription process. |
|
|
|
|
|
""" |
|
|
logging.info("Cleaning up.") |
|
|
self.exit = True |
|
|
|
|
|
|
|
|
class ServeClientWhisperCPP(ServeClientBase): |
|
|
SINGLE_MODEL = None |
|
|
SINGLE_MODEL_LOCK = threading.Lock() |
|
|
|
|
|
def __init__(self, websocket, language=None, client_uid=None, |
|
|
single_model=False): |
|
|
""" |
|
|
Initialize a ServeClient instance. |
|
|
The Whisper model is initialized based on the client's language and device availability. |
|
|
The transcription thread is started upon initialization. A "SERVER_READY" message is sent |
|
|
to the client to indicate that the server is ready. |
|
|
|
|
|
Args: |
|
|
websocket (WebSocket): The WebSocket connection for the client. |
|
|
language (str, optional): The language for transcription. Defaults to None. |
|
|
client_uid (str, optional): A unique identifier for the client. Defaults to None. |
|
|
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False. |
|
|
|
|
|
""" |
|
|
super().__init__(client_uid, websocket) |
|
|
self.language = language |
|
|
self.eos = False |
|
|
|
|
|
if single_model: |
|
|
if ServeClientWhisperCPP.SINGLE_MODEL is None: |
|
|
self.create_model() |
|
|
ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber |
|
|
else: |
|
|
self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL |
|
|
else: |
|
|
self.create_model() |
|
|
|
|
|
|
|
|
logging.info('Create a thread to process audio.') |
|
|
self.trans_thread = threading.Thread(target=self.speech_to_text) |
|
|
self.trans_thread.start() |
|
|
|
|
|
self.websocket.send(json.dumps({ |
|
|
"uid": self.client_uid, |
|
|
"message": self.SERVER_READY, |
|
|
"backend": "pywhispercpp" |
|
|
})) |
|
|
|
|
|
def create_model(self, warmup=True): |
|
|
""" |
|
|
Instantiates a new model, sets it as the transcriber and does warmup if desired. |
|
|
""" |
|
|
|
|
|
self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR) |
|
|
if warmup: |
|
|
self.warmup() |
|
|
|
|
|
def warmup(self, warmup_steps=1): |
|
|
""" |
|
|
Warmup TensorRT since first few inferences are slow. |
|
|
|
|
|
Args: |
|
|
warmup_steps (int): Number of steps to warm up the model for. |
|
|
""" |
|
|
logging.info("[INFO:] Warming up whisper.cpp engine..") |
|
|
mel, _, = soundfile.read("assets/jfk.flac") |
|
|
for i in range(warmup_steps): |
|
|
self.transcriber.transcribe(mel, print_progress=False) |
|
|
|
|
|
def set_eos(self, eos): |
|
|
""" |
|
|
Sets the End of Speech (EOS) flag. |
|
|
|
|
|
Args: |
|
|
eos (bool): The value to set for the EOS flag. |
|
|
""" |
|
|
self.lock.acquire() |
|
|
self.eos = eos |
|
|
self.lock.release() |
|
|
|
|
|
def handle_transcription_output(self, last_segment, duration): |
|
|
""" |
|
|
Handle the transcription output, updating the transcript and sending data to the client. |
|
|
|
|
|
Args: |
|
|
last_segment (str): The last segment from the whisper output which is considered to be incomplete because |
|
|
of the possibility of word being truncated. |
|
|
duration (float): Duration of the transcribed audio chunk. |
|
|
""" |
|
|
segments = self.prepare_segments({"text": last_segment}) |
|
|
self.send_transcription_to_client(segments) |
|
|
if self.eos: |
|
|
self.update_timestamp_offset(last_segment, duration) |
|
|
|
|
|
def transcribe_audio(self, input_bytes): |
|
|
""" |
|
|
Transcribe the audio chunk and send the results to the client. |
|
|
|
|
|
Args: |
|
|
input_bytes (np.array): The audio chunk to transcribe. |
|
|
""" |
|
|
if ServeClientWhisperCPP.SINGLE_MODEL: |
|
|
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire() |
|
|
logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}") |
|
|
mel = input_bytes |
|
|
duration = librosa.get_duration(y=input_bytes, sr=self.RATE) |
|
|
|
|
|
if self.language == "zh": |
|
|
prompt = '以下是简体中文普通话的句子。' |
|
|
else: |
|
|
prompt = 'The following is an English sentence.' |
|
|
|
|
|
segments = self.transcriber.transcribe( |
|
|
mel, |
|
|
language=self.language, |
|
|
initial_prompt=prompt, |
|
|
token_timestamps=True, |
|
|
|
|
|
print_progress=False |
|
|
) |
|
|
text = [] |
|
|
for segment in segments: |
|
|
content = segment.text |
|
|
text.append(content) |
|
|
last_segment = ' '.join(text) |
|
|
|
|
|
logging.info(f"[pywhispercpp:] Last segment: {last_segment}") |
|
|
|
|
|
if ServeClientWhisperCPP.SINGLE_MODEL: |
|
|
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release() |
|
|
if last_segment: |
|
|
self.handle_transcription_output(last_segment, duration) |
|
|
|
|
|
def update_timestamp_offset(self, last_segment, duration): |
|
|
""" |
|
|
Update timestamp offset and transcript. |
|
|
|
|
|
Args: |
|
|
last_segment (str): Last transcribed audio from the whisper model. |
|
|
duration (float): Duration of the last audio chunk. |
|
|
""" |
|
|
if not len(self.transcript): |
|
|
self.transcript.append({"text": last_segment + " "}) |
|
|
elif self.transcript[-1]["text"].strip() != last_segment: |
|
|
self.transcript.append({"text": last_segment + " "}) |
|
|
|
|
|
logging.info(f'Transcript list context: {self.transcript}') |
|
|
|
|
|
with self.lock: |
|
|
self.timestamp_offset += duration |
|
|
|
|
|
def speech_to_text(self): |
|
|
""" |
|
|
Process an audio stream in an infinite loop, continuously transcribing the speech. |
|
|
|
|
|
This method continuously receives audio frames, performs real-time transcription, and sends |
|
|
transcribed segments to the client via a WebSocket connection. |
|
|
|
|
|
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. |
|
|
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments |
|
|
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech |
|
|
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if |
|
|
there is no speech for a specified duration to indicate a pause. |
|
|
|
|
|
Raises: |
|
|
Exception: If there is an issue with audio processing or WebSocket communication. |
|
|
|
|
|
""" |
|
|
while True: |
|
|
if self.exit: |
|
|
logging.info("Exiting speech to text thread") |
|
|
break |
|
|
|
|
|
if self.frames_np is None: |
|
|
time.sleep(0.02) |
|
|
continue |
|
|
|
|
|
self.clip_audio_if_no_valid_segment() |
|
|
|
|
|
input_bytes, duration = self.get_audio_chunk_for_processing() |
|
|
if duration < 1: |
|
|
continue |
|
|
|
|
|
try: |
|
|
input_sample = input_bytes.copy() |
|
|
logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}") |
|
|
self.transcribe_audio(input_sample) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"[ERROR]: {e}") |
|
|
|