daihui.zhang
commited on
Commit
·
98c9c23
1
Parent(s):
9775978
remove unused codes
Browse files- config.py +3 -2
- main.py +1 -5
- transcribe/client.py +0 -677
- transcribe/helpers/vadprocessor.py +1 -1
- transcribe/pipelines/pipe_vad.py +8 -15
- transcribe/server.py +0 -382
- transcribe/strategy.py +0 -405
- transcribe/transcription.py +0 -334
- transcribe/translatepipes.py +0 -1
- transcribe/whisper_llm_serve.py +21 -107
config.py
CHANGED
|
@@ -3,10 +3,11 @@ import re
|
|
| 3 |
import logging
|
| 4 |
|
| 5 |
DEBUG = True
|
|
|
|
| 6 |
|
| 7 |
logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
|
| 8 |
logging.basicConfig(
|
| 9 |
-
level=
|
| 10 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 11 |
filename='translator.log',
|
| 12 |
datefmt="%H:%M:%S"
|
|
@@ -15,7 +16,7 @@ logging.basicConfig(
|
|
| 15 |
SAVE_DATA_SAVE = False
|
| 16 |
# Add terminal log
|
| 17 |
console_handler = logging.StreamHandler()
|
| 18 |
-
console_handler.setLevel(
|
| 19 |
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 20 |
console_handler.setFormatter(console_formatter)
|
| 21 |
logging.getLogger().addHandler(console_handler)
|
|
|
|
| 3 |
import logging
|
| 4 |
|
| 5 |
DEBUG = True
|
| 6 |
+
LOG_LEVEL = logging.WARNING if DEBUG else logging.INFO
|
| 7 |
|
| 8 |
logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
|
| 9 |
logging.basicConfig(
|
| 10 |
+
level=LOG_LEVEL,
|
| 11 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 12 |
filename='translator.log',
|
| 13 |
datefmt="%H:%M:%S"
|
|
|
|
| 16 |
SAVE_DATA_SAVE = False
|
| 17 |
# Add terminal log
|
| 18 |
console_handler = logging.StreamHandler()
|
| 19 |
+
console_handler.setLevel(LOG_LEVEL)
|
| 20 |
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 21 |
console_handler.setFormatter(console_formatter)
|
| 22 |
logging.getLogger().addHandler(console_handler)
|
main.py
CHANGED
|
@@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles
|
|
| 11 |
from fastapi.responses import RedirectResponse
|
| 12 |
import os
|
| 13 |
from transcribe.utils import pcm_bytes_to_np_array
|
|
|
|
| 14 |
logger = getLogger(__name__)
|
| 15 |
|
| 16 |
|
|
@@ -39,9 +40,6 @@ async def lifespan(app:FastAPI):
|
|
| 39 |
yield
|
| 40 |
|
| 41 |
|
| 42 |
-
# 获取当前文件所在目录的绝对路径
|
| 43 |
-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 44 |
-
# 构建frontend目录的绝对路径
|
| 45 |
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
| 46 |
|
| 47 |
|
|
@@ -66,9 +64,7 @@ async def translate(websocket: WebSocket):
|
|
| 66 |
client_uid=f"{uuid1()}",
|
| 67 |
)
|
| 68 |
|
| 69 |
-
|
| 70 |
if from_lang and to_lang and client:
|
| 71 |
-
client.set_language(from_lang, to_lang)
|
| 72 |
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 73 |
await websocket.accept()
|
| 74 |
try:
|
|
|
|
| 11 |
from fastapi.responses import RedirectResponse
|
| 12 |
import os
|
| 13 |
from transcribe.utils import pcm_bytes_to_np_array
|
| 14 |
+
from config import BASE_DIR
|
| 15 |
logger = getLogger(__name__)
|
| 16 |
|
| 17 |
|
|
|
|
| 40 |
yield
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
| 44 |
|
| 45 |
|
|
|
|
| 64 |
client_uid=f"{uuid1()}",
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
if from_lang and to_lang and client:
|
|
|
|
| 68 |
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 69 |
await websocket.accept()
|
| 70 |
try:
|
transcribe/client.py
DELETED
|
@@ -1,677 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import threading
|
| 5 |
-
import time
|
| 6 |
-
import uuid
|
| 7 |
-
import wave
|
| 8 |
-
|
| 9 |
-
import av
|
| 10 |
-
import numpy as np
|
| 11 |
-
import pyaudio
|
| 12 |
-
import websocket
|
| 13 |
-
|
| 14 |
-
import transcribe.utils as utils
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Client:
|
| 18 |
-
"""
|
| 19 |
-
Handles communication with a server using WebSocket.
|
| 20 |
-
"""
|
| 21 |
-
INSTANCES = {}
|
| 22 |
-
END_OF_AUDIO = "END_OF_AUDIO"
|
| 23 |
-
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
host=None,
|
| 27 |
-
port=None,
|
| 28 |
-
lang=None,
|
| 29 |
-
log_transcription=True,
|
| 30 |
-
max_clients=4,
|
| 31 |
-
max_connection_time=600,
|
| 32 |
-
dst_lang='zh',
|
| 33 |
-
):
|
| 34 |
-
"""
|
| 35 |
-
Initializes a Client instance for audio recording and streaming to a server.
|
| 36 |
-
|
| 37 |
-
If host and port are not provided, the WebSocket connection will not be established.
|
| 38 |
-
the audio recording starts immediately upon initialization.
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
host (str): The hostname or IP address of the server.
|
| 42 |
-
port (int): The port number for the WebSocket server.
|
| 43 |
-
lang (str, optional): The selected language for transcription. Default is None.
|
| 44 |
-
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
| 45 |
-
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
| 46 |
-
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
| 47 |
-
"""
|
| 48 |
-
self.recording = False
|
| 49 |
-
self.uid = str(uuid.uuid4())
|
| 50 |
-
self.waiting = False
|
| 51 |
-
self.last_response_received = None
|
| 52 |
-
self.disconnect_if_no_response_for = 15
|
| 53 |
-
self.language = lang
|
| 54 |
-
self.server_error = False
|
| 55 |
-
self.last_segment = None
|
| 56 |
-
self.last_received_segment = None
|
| 57 |
-
self.log_transcription = log_transcription
|
| 58 |
-
self.max_clients = max_clients
|
| 59 |
-
self.max_connection_time = max_connection_time
|
| 60 |
-
self.dst_lang = dst_lang
|
| 61 |
-
|
| 62 |
-
self.audio_bytes = None
|
| 63 |
-
|
| 64 |
-
if host is not None and port is not None:
|
| 65 |
-
socket_url = f"ws://{host}:{port}?from={self.language}&to={self.dst_lang}"
|
| 66 |
-
self.client_socket = websocket.WebSocketApp(
|
| 67 |
-
socket_url,
|
| 68 |
-
on_open=lambda ws: self.on_open(ws),
|
| 69 |
-
on_message=lambda ws, message: self.on_message(ws, message),
|
| 70 |
-
on_error=lambda ws, error: self.on_error(ws, error),
|
| 71 |
-
on_close=lambda ws, close_status_code, close_msg: self.on_close(
|
| 72 |
-
ws, close_status_code, close_msg
|
| 73 |
-
),
|
| 74 |
-
)
|
| 75 |
-
else:
|
| 76 |
-
print("[ERROR]: No host or port specified.")
|
| 77 |
-
return
|
| 78 |
-
|
| 79 |
-
Client.INSTANCES[self.uid] = self
|
| 80 |
-
|
| 81 |
-
# start websocket client in a thread
|
| 82 |
-
self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
|
| 83 |
-
self.ws_thread.daemon = True
|
| 84 |
-
self.ws_thread.start()
|
| 85 |
-
|
| 86 |
-
self.transcript = []
|
| 87 |
-
print("[INFO]: * recording")
|
| 88 |
-
|
| 89 |
-
def handle_status_messages(self, message_data):
|
| 90 |
-
"""Handles server status messages."""
|
| 91 |
-
status = message_data["status"]
|
| 92 |
-
if status == "WAIT":
|
| 93 |
-
self.waiting = True
|
| 94 |
-
print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
|
| 95 |
-
elif status == "ERROR":
|
| 96 |
-
print(f"Message from Server: {message_data['message']}")
|
| 97 |
-
self.server_error = True
|
| 98 |
-
elif status == "WARNING":
|
| 99 |
-
print(f"Message from Server: {message_data['message']}")
|
| 100 |
-
|
| 101 |
-
def process_segments(self, segments):
|
| 102 |
-
"""Processes transcript segments."""
|
| 103 |
-
text = []
|
| 104 |
-
for i, seg in enumerate(segments):
|
| 105 |
-
if not text or text[-1] != seg["text"]:
|
| 106 |
-
text.append(seg["text"])
|
| 107 |
-
if i == len(segments) - 1 and not seg.get("completed", False):
|
| 108 |
-
self.last_segment = seg
|
| 109 |
-
|
| 110 |
-
# update last received segment and last valid response time
|
| 111 |
-
if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
|
| 112 |
-
self.last_response_received = time.time()
|
| 113 |
-
self.last_received_segment = segments[-1]["text"]
|
| 114 |
-
|
| 115 |
-
if self.log_transcription:
|
| 116 |
-
# Truncate to last 3 entries for brevity.
|
| 117 |
-
text = text[-3:]
|
| 118 |
-
utils.clear_screen()
|
| 119 |
-
utils.print_transcript(text)
|
| 120 |
-
|
| 121 |
-
def on_message(self, ws, message):
|
| 122 |
-
"""
|
| 123 |
-
Callback function called when a message is received from the server.
|
| 124 |
-
|
| 125 |
-
It updates various attributes of the client based on the received message, including
|
| 126 |
-
recording status, language detection, and server messages. If a disconnect message
|
| 127 |
-
is received, it sets the recording status to False.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
ws (websocket.WebSocketApp): The WebSocket client instance.
|
| 131 |
-
message (str): The received message from the server.
|
| 132 |
-
|
| 133 |
-
"""
|
| 134 |
-
message = json.loads(message)
|
| 135 |
-
|
| 136 |
-
# if self.uid != message.get("uid"):
|
| 137 |
-
# print("[ERROR]: invalid client uid")
|
| 138 |
-
# return
|
| 139 |
-
|
| 140 |
-
if "status" in message.keys():
|
| 141 |
-
self.handle_status_messages(message)
|
| 142 |
-
return
|
| 143 |
-
|
| 144 |
-
if "message" in message.keys() and message["message"] == "DISCONNECT":
|
| 145 |
-
print("[INFO]: Server disconnected due to overtime.")
|
| 146 |
-
self.recording = False
|
| 147 |
-
|
| 148 |
-
if "message" in message.keys() and message["message"] == "SERVER_READY":
|
| 149 |
-
self.last_response_received = time.time()
|
| 150 |
-
self.recording = True
|
| 151 |
-
self.server_backend = message["backend"]
|
| 152 |
-
print(f"[INFO]: Server Running with backend {self.server_backend}")
|
| 153 |
-
return
|
| 154 |
-
|
| 155 |
-
if "language" in message.keys():
|
| 156 |
-
self.language = message.get("language")
|
| 157 |
-
lang_prob = message.get("language_prob")
|
| 158 |
-
print(
|
| 159 |
-
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
|
| 160 |
-
)
|
| 161 |
-
return
|
| 162 |
-
|
| 163 |
-
if "segments" in message.keys():
|
| 164 |
-
self.process_segments(message["segments"])
|
| 165 |
-
|
| 166 |
-
def on_error(self, ws, error):
|
| 167 |
-
print(f"[ERROR] WebSocket Error: {error}")
|
| 168 |
-
self.server_error = True
|
| 169 |
-
self.error_message = error
|
| 170 |
-
|
| 171 |
-
def on_close(self, ws, close_status_code, close_msg):
|
| 172 |
-
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
|
| 173 |
-
self.recording = False
|
| 174 |
-
self.waiting = False
|
| 175 |
-
|
| 176 |
-
def on_open(self, ws):
|
| 177 |
-
"""
|
| 178 |
-
Callback function called when the WebSocket connection is successfully opened.
|
| 179 |
-
|
| 180 |
-
Sends an initial configuration message to the server, including client UID,
|
| 181 |
-
language selection, and task type.
|
| 182 |
-
|
| 183 |
-
Args:
|
| 184 |
-
ws (websocket.WebSocketApp): The WebSocket client instance.
|
| 185 |
-
|
| 186 |
-
"""
|
| 187 |
-
print("[INFO]: Opened connection")
|
| 188 |
-
ws.send(
|
| 189 |
-
json.dumps(
|
| 190 |
-
{
|
| 191 |
-
"uid": self.uid,
|
| 192 |
-
"language": self.language,
|
| 193 |
-
"max_clients": self.max_clients,
|
| 194 |
-
"max_connection_time": self.max_connection_time,
|
| 195 |
-
}
|
| 196 |
-
)
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def send_packet_to_server(self, message):
|
| 200 |
-
"""
|
| 201 |
-
Send an audio packet to the server using WebSocket.
|
| 202 |
-
|
| 203 |
-
Args:
|
| 204 |
-
message (bytes): The audio data packet in bytes to be sent to the server.
|
| 205 |
-
|
| 206 |
-
"""
|
| 207 |
-
try:
|
| 208 |
-
self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
|
| 209 |
-
except Exception as e:
|
| 210 |
-
print(e)
|
| 211 |
-
|
| 212 |
-
def close_websocket(self):
|
| 213 |
-
"""
|
| 214 |
-
Close the WebSocket connection and join the WebSocket thread.
|
| 215 |
-
|
| 216 |
-
First attempts to close the WebSocket connection using `self.client_socket.close()`. After
|
| 217 |
-
closing the connection, it joins the WebSocket thread to ensure proper termination.
|
| 218 |
-
|
| 219 |
-
"""
|
| 220 |
-
try:
|
| 221 |
-
self.client_socket.close()
|
| 222 |
-
except Exception as e:
|
| 223 |
-
print("[ERROR]: Error closing WebSocket:", e)
|
| 224 |
-
|
| 225 |
-
try:
|
| 226 |
-
self.ws_thread.join()
|
| 227 |
-
except Exception as e:
|
| 228 |
-
print("[ERROR:] Error joining WebSocket thread:", e)
|
| 229 |
-
|
| 230 |
-
def get_client_socket(self):
|
| 231 |
-
"""
|
| 232 |
-
Get the WebSocket client socket instance.
|
| 233 |
-
|
| 234 |
-
Returns:
|
| 235 |
-
WebSocketApp: The WebSocket client socket instance currently in use by the client.
|
| 236 |
-
"""
|
| 237 |
-
return self.client_socket
|
| 238 |
-
|
| 239 |
-
def wait_before_disconnect(self):
|
| 240 |
-
"""Waits a bit before disconnecting in order to process pending responses."""
|
| 241 |
-
assert self.last_response_received
|
| 242 |
-
while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
|
| 243 |
-
continue
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
class TranscriptionTeeClient:
|
| 247 |
-
"""
|
| 248 |
-
Client for handling audio recording, streaming, and transcription tasks via one or more
|
| 249 |
-
WebSocket connections.
|
| 250 |
-
|
| 251 |
-
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
| 252 |
-
to send audio data for transcription to one or more servers, and receive transcribed text segments.
|
| 253 |
-
Args:
|
| 254 |
-
clients (list): one or more previously initialized Client instances
|
| 255 |
-
|
| 256 |
-
Attributes:
|
| 257 |
-
clients (list): the underlying Client instances responsible for handling WebSocket connections.
|
| 258 |
-
"""
|
| 259 |
-
|
| 260 |
-
def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav",
|
| 261 |
-
mute_audio_playback=False):
|
| 262 |
-
self.clients = clients
|
| 263 |
-
if not self.clients:
|
| 264 |
-
raise Exception("At least one client is required.")
|
| 265 |
-
self.chunk = 4096
|
| 266 |
-
self.format = pyaudio.paInt16
|
| 267 |
-
self.channels = 1
|
| 268 |
-
self.rate = 16000
|
| 269 |
-
self.record_seconds = 60000
|
| 270 |
-
self.save_output_recording = save_output_recording
|
| 271 |
-
self.output_recording_filename = output_recording_filename
|
| 272 |
-
self.mute_audio_playback = mute_audio_playback
|
| 273 |
-
self.frames = b""
|
| 274 |
-
self.p = pyaudio.PyAudio()
|
| 275 |
-
try:
|
| 276 |
-
self.stream = self.p.open(
|
| 277 |
-
format=self.format,
|
| 278 |
-
channels=self.channels,
|
| 279 |
-
rate=self.rate,
|
| 280 |
-
input=True,
|
| 281 |
-
frames_per_buffer=self.chunk,
|
| 282 |
-
)
|
| 283 |
-
except OSError as error:
|
| 284 |
-
print(f"[WARN]: Unable to access microphone. {error}")
|
| 285 |
-
self.stream = None
|
| 286 |
-
|
| 287 |
-
def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
|
| 288 |
-
"""
|
| 289 |
-
Start the transcription process.
|
| 290 |
-
|
| 291 |
-
Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
|
| 292 |
-
to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
|
| 293 |
-
will be played and streamed to the server; otherwise, it will perform live recording.
|
| 294 |
-
|
| 295 |
-
Args:
|
| 296 |
-
audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
|
| 297 |
-
|
| 298 |
-
"""
|
| 299 |
-
assert sum(
|
| 300 |
-
source is not None for source in [audio, rtsp_url, hls_url]
|
| 301 |
-
) <= 1, 'You must provide only one selected source'
|
| 302 |
-
|
| 303 |
-
print("[INFO]: Waiting for server ready ...")
|
| 304 |
-
for client in self.clients:
|
| 305 |
-
while not client.recording:
|
| 306 |
-
if client.waiting or client.server_error:
|
| 307 |
-
self.close_all_clients()
|
| 308 |
-
return
|
| 309 |
-
|
| 310 |
-
print("[INFO]: Server Ready!")
|
| 311 |
-
if hls_url is not None:
|
| 312 |
-
self.process_hls_stream(hls_url, save_file)
|
| 313 |
-
elif audio is not None:
|
| 314 |
-
resampled_file = utils.resample(audio)
|
| 315 |
-
self.play_file(resampled_file)
|
| 316 |
-
elif rtsp_url is not None:
|
| 317 |
-
self.process_rtsp_stream(rtsp_url)
|
| 318 |
-
else:
|
| 319 |
-
self.record()
|
| 320 |
-
|
| 321 |
-
def close_all_clients(self):
|
| 322 |
-
"""Closes all client websockets."""
|
| 323 |
-
for client in self.clients:
|
| 324 |
-
client.close_websocket()
|
| 325 |
-
|
| 326 |
-
def multicast_packet(self, packet, unconditional=False):
|
| 327 |
-
"""
|
| 328 |
-
Sends an identical packet via all clients.
|
| 329 |
-
|
| 330 |
-
Args:
|
| 331 |
-
packet (bytes): The audio data packet in bytes to be sent.
|
| 332 |
-
unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
|
| 333 |
-
"""
|
| 334 |
-
for client in self.clients:
|
| 335 |
-
if (unconditional or client.recording):
|
| 336 |
-
client.send_packet_to_server(packet)
|
| 337 |
-
|
| 338 |
-
def play_file(self, filename):
|
| 339 |
-
"""
|
| 340 |
-
Play an audio file and send it to the server for processing.
|
| 341 |
-
|
| 342 |
-
Reads an audio file, plays it through the audio output, and simultaneously sends
|
| 343 |
-
the audio data to the server for processing. It uses PyAudio to create an audio
|
| 344 |
-
stream for playback. The audio data is read from the file in chunks, converted to
|
| 345 |
-
floating-point format, and sent to the server using WebSocket communication.
|
| 346 |
-
This method is typically used when you want to process pre-recorded audio and send it
|
| 347 |
-
to the server in real-time.
|
| 348 |
-
|
| 349 |
-
Args:
|
| 350 |
-
filename (str): The path to the audio file to be played and sent to the server.
|
| 351 |
-
"""
|
| 352 |
-
|
| 353 |
-
# read audio and create pyaudio stream
|
| 354 |
-
with wave.open(filename, "rb") as wavfile:
|
| 355 |
-
self.stream = self.p.open(
|
| 356 |
-
format=self.p.get_format_from_width(wavfile.getsampwidth()),
|
| 357 |
-
channels=wavfile.getnchannels(),
|
| 358 |
-
rate=wavfile.getframerate(),
|
| 359 |
-
input=True,
|
| 360 |
-
output=True,
|
| 361 |
-
frames_per_buffer=self.chunk,
|
| 362 |
-
)
|
| 363 |
-
chunk_duration = self.chunk / float(wavfile.getframerate())
|
| 364 |
-
try:
|
| 365 |
-
while any(client.recording for client in self.clients):
|
| 366 |
-
data = wavfile.readframes(self.chunk)
|
| 367 |
-
if data == b"":
|
| 368 |
-
break
|
| 369 |
-
|
| 370 |
-
audio_array = self.bytes_to_float_array(data)
|
| 371 |
-
self.multicast_packet(audio_array.tobytes())
|
| 372 |
-
if self.mute_audio_playback:
|
| 373 |
-
time.sleep(chunk_duration)
|
| 374 |
-
else:
|
| 375 |
-
self.stream.write(data)
|
| 376 |
-
|
| 377 |
-
wavfile.close()
|
| 378 |
-
|
| 379 |
-
for client in self.clients:
|
| 380 |
-
client.wait_before_disconnect()
|
| 381 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 382 |
-
self.stream.close()
|
| 383 |
-
self.close_all_clients()
|
| 384 |
-
|
| 385 |
-
except KeyboardInterrupt:
|
| 386 |
-
wavfile.close()
|
| 387 |
-
self.stream.stop_stream()
|
| 388 |
-
self.stream.close()
|
| 389 |
-
self.p.terminate()
|
| 390 |
-
self.close_all_clients()
|
| 391 |
-
print("[INFO]: Keyboard interrupt.")
|
| 392 |
-
|
| 393 |
-
def process_rtsp_stream(self, rtsp_url):
|
| 394 |
-
"""
|
| 395 |
-
Connect to an RTSP source, process the audio stream, and send it for transcription.
|
| 396 |
-
|
| 397 |
-
Args:
|
| 398 |
-
rtsp_url (str): The URL of the RTSP stream source.
|
| 399 |
-
"""
|
| 400 |
-
print("[INFO]: Connecting to RTSP stream...")
|
| 401 |
-
try:
|
| 402 |
-
container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
|
| 403 |
-
self.process_av_stream(container, stream_type="RTSP")
|
| 404 |
-
except Exception as e:
|
| 405 |
-
print(f"[ERROR]: Failed to process RTSP stream: {e}")
|
| 406 |
-
finally:
|
| 407 |
-
for client in self.clients:
|
| 408 |
-
client.wait_before_disconnect()
|
| 409 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 410 |
-
self.close_all_clients()
|
| 411 |
-
print("[INFO]: RTSP stream processing finished.")
|
| 412 |
-
|
| 413 |
-
def process_hls_stream(self, hls_url, save_file=None):
|
| 414 |
-
"""
|
| 415 |
-
Connect to an HLS source, process the audio stream, and send it for transcription.
|
| 416 |
-
|
| 417 |
-
Args:
|
| 418 |
-
hls_url (str): The URL of the HLS stream source.
|
| 419 |
-
save_file (str, optional): Local path to save the network stream.
|
| 420 |
-
"""
|
| 421 |
-
print("[INFO]: Connecting to HLS stream...")
|
| 422 |
-
try:
|
| 423 |
-
container = av.open(hls_url, format="hls")
|
| 424 |
-
self.process_av_stream(container, stream_type="HLS", save_file=save_file)
|
| 425 |
-
except Exception as e:
|
| 426 |
-
print(f"[ERROR]: Failed to process HLS stream: {e}")
|
| 427 |
-
finally:
|
| 428 |
-
for client in self.clients:
|
| 429 |
-
client.wait_before_disconnect()
|
| 430 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 431 |
-
self.close_all_clients()
|
| 432 |
-
print("[INFO]: HLS stream processing finished.")
|
| 433 |
-
|
| 434 |
-
def process_av_stream(self, container, stream_type, save_file=None):
|
| 435 |
-
"""
|
| 436 |
-
Process an AV container stream and send audio packets to the server.
|
| 437 |
-
|
| 438 |
-
Args:
|
| 439 |
-
container (av.container.InputContainer): The input container to process.
|
| 440 |
-
stream_type (str): The type of stream being processed ("RTSP" or "HLS").
|
| 441 |
-
save_file (str, optional): Local path to save the stream. Default is None.
|
| 442 |
-
"""
|
| 443 |
-
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
| 444 |
-
if not audio_stream:
|
| 445 |
-
print(f"[ERROR]: No audio stream found in {stream_type} source.")
|
| 446 |
-
return
|
| 447 |
-
|
| 448 |
-
output_container = None
|
| 449 |
-
if save_file:
|
| 450 |
-
output_container = av.open(save_file, mode="w")
|
| 451 |
-
output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
|
| 452 |
-
|
| 453 |
-
try:
|
| 454 |
-
for packet in container.demux(audio_stream):
|
| 455 |
-
for frame in packet.decode():
|
| 456 |
-
audio_data = frame.to_ndarray().tobytes()
|
| 457 |
-
self.multicast_packet(audio_data)
|
| 458 |
-
|
| 459 |
-
if save_file:
|
| 460 |
-
output_container.mux(frame)
|
| 461 |
-
except Exception as e:
|
| 462 |
-
print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
|
| 463 |
-
finally:
|
| 464 |
-
# Wait for server to send any leftover transcription.
|
| 465 |
-
time.sleep(5)
|
| 466 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 467 |
-
if output_container:
|
| 468 |
-
output_container.close()
|
| 469 |
-
container.close()
|
| 470 |
-
|
| 471 |
-
def save_chunk(self, n_audio_file):
|
| 472 |
-
"""
|
| 473 |
-
Saves the current audio frames to a WAV file in a separate thread.
|
| 474 |
-
|
| 475 |
-
Args:
|
| 476 |
-
n_audio_file (int): The index of the audio file which determines the filename.
|
| 477 |
-
This helps in maintaining the order and uniqueness of each chunk.
|
| 478 |
-
"""
|
| 479 |
-
t = threading.Thread(
|
| 480 |
-
target=self.write_audio_frames_to_file,
|
| 481 |
-
args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
|
| 482 |
-
)
|
| 483 |
-
t.start()
|
| 484 |
-
|
| 485 |
-
def finalize_recording(self, n_audio_file):
|
| 486 |
-
"""
|
| 487 |
-
Finalizes the recording process by saving any remaining audio frames,
|
| 488 |
-
closing the audio stream, and terminating the process.
|
| 489 |
-
|
| 490 |
-
Args:
|
| 491 |
-
n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
|
| 492 |
-
This index is incremented before use if the last chunk is saved.
|
| 493 |
-
"""
|
| 494 |
-
if self.save_output_recording and len(self.frames):
|
| 495 |
-
self.write_audio_frames_to_file(
|
| 496 |
-
self.frames[:], f"chunks/{n_audio_file}.wav"
|
| 497 |
-
)
|
| 498 |
-
n_audio_file += 1
|
| 499 |
-
self.stream.stop_stream()
|
| 500 |
-
self.stream.close()
|
| 501 |
-
self.p.terminate()
|
| 502 |
-
self.close_all_clients()
|
| 503 |
-
if self.save_output_recording:
|
| 504 |
-
self.write_output_recording(n_audio_file)
|
| 505 |
-
|
| 506 |
-
def record(self):
|
| 507 |
-
"""
|
| 508 |
-
Record audio data from the input stream and save it to a WAV file.
|
| 509 |
-
|
| 510 |
-
Continuously records audio data from the input stream, sends it to the server via a WebSocket
|
| 511 |
-
connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
|
| 512 |
-
the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
|
| 513 |
-
|
| 514 |
-
Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
|
| 515 |
-
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
|
| 516 |
-
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
|
| 517 |
-
the method combines all the saved audio chunks into the specified `out_file`.
|
| 518 |
-
"""
|
| 519 |
-
n_audio_file = 0
|
| 520 |
-
if self.save_output_recording:
|
| 521 |
-
if os.path.exists("chunks"):
|
| 522 |
-
shutil.rmtree("chunks")
|
| 523 |
-
os.makedirs("chunks")
|
| 524 |
-
try:
|
| 525 |
-
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
|
| 526 |
-
if not any(client.recording for client in self.clients):
|
| 527 |
-
break
|
| 528 |
-
data = self.stream.read(self.chunk, exception_on_overflow=False)
|
| 529 |
-
self.frames += data
|
| 530 |
-
|
| 531 |
-
audio_array = self.bytes_to_float_array(data)
|
| 532 |
-
|
| 533 |
-
self.multicast_packet(audio_array.tobytes())
|
| 534 |
-
|
| 535 |
-
# save frames if more than a minute
|
| 536 |
-
if len(self.frames) > 60 * self.rate:
|
| 537 |
-
if self.save_output_recording:
|
| 538 |
-
self.save_chunk(n_audio_file)
|
| 539 |
-
n_audio_file += 1
|
| 540 |
-
self.frames = b""
|
| 541 |
-
|
| 542 |
-
except KeyboardInterrupt:
|
| 543 |
-
self.finalize_recording(n_audio_file)
|
| 544 |
-
|
| 545 |
-
def write_audio_frames_to_file(self, frames, file_name):
|
| 546 |
-
"""
|
| 547 |
-
Write audio frames to a WAV file.
|
| 548 |
-
|
| 549 |
-
The WAV file is created or overwritten with the specified name. The audio frames should be
|
| 550 |
-
in the correct format and match the specified channel, sample width, and sample rate.
|
| 551 |
-
|
| 552 |
-
Args:
|
| 553 |
-
frames (bytes): The audio frames to be written to the file.
|
| 554 |
-
file_name (str): The name of the WAV file to which the frames will be written.
|
| 555 |
-
|
| 556 |
-
"""
|
| 557 |
-
with wave.open(file_name, "wb") as wavfile:
|
| 558 |
-
wavfile: wave.Wave_write
|
| 559 |
-
wavfile.setnchannels(self.channels)
|
| 560 |
-
wavfile.setsampwidth(2)
|
| 561 |
-
wavfile.setframerate(self.rate)
|
| 562 |
-
wavfile.writeframes(frames)
|
| 563 |
-
|
| 564 |
-
def write_output_recording(self, n_audio_file):
|
| 565 |
-
"""
|
| 566 |
-
Combine and save recorded audio chunks into a single WAV file.
|
| 567 |
-
|
| 568 |
-
The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
|
| 569 |
-
file, appends its audio data to the final recording, and then deletes the chunk file. After combining
|
| 570 |
-
and saving, the final recording is stored in the specified `out_file`.
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
Args:
|
| 574 |
-
n_audio_file (int): The number of audio chunk files to combine.
|
| 575 |
-
out_file (str): The name of the output WAV file to save the final recording.
|
| 576 |
-
|
| 577 |
-
"""
|
| 578 |
-
input_files = [
|
| 579 |
-
f"chunks/{i}.wav"
|
| 580 |
-
for i in range(n_audio_file)
|
| 581 |
-
if os.path.exists(f"chunks/{i}.wav")
|
| 582 |
-
]
|
| 583 |
-
with wave.open(self.output_recording_filename, "wb") as wavfile:
|
| 584 |
-
wavfile: wave.Wave_write
|
| 585 |
-
wavfile.setnchannels(self.channels)
|
| 586 |
-
wavfile.setsampwidth(2)
|
| 587 |
-
wavfile.setframerate(self.rate)
|
| 588 |
-
for in_file in input_files:
|
| 589 |
-
with wave.open(in_file, "rb") as wav_in:
|
| 590 |
-
while True:
|
| 591 |
-
data = wav_in.readframes(self.chunk)
|
| 592 |
-
if data == b"":
|
| 593 |
-
break
|
| 594 |
-
wavfile.writeframes(data)
|
| 595 |
-
# remove this file
|
| 596 |
-
os.remove(in_file)
|
| 597 |
-
wavfile.close()
|
| 598 |
-
# clean up temporary directory to store chunks
|
| 599 |
-
if os.path.exists("chunks"):
|
| 600 |
-
shutil.rmtree("chunks")
|
| 601 |
-
|
| 602 |
-
@staticmethod
|
| 603 |
-
def bytes_to_float_array(audio_bytes):
|
| 604 |
-
"""
|
| 605 |
-
Convert audio data from bytes to a NumPy float array.
|
| 606 |
-
|
| 607 |
-
It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
|
| 608 |
-
have values between -1 and 1.
|
| 609 |
-
|
| 610 |
-
Args:
|
| 611 |
-
audio_bytes (bytes): Audio data in bytes.
|
| 612 |
-
|
| 613 |
-
Returns:
|
| 614 |
-
np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
|
| 615 |
-
"""
|
| 616 |
-
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
|
| 617 |
-
return raw_data.astype(np.float32) / 32768.0
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
class TranscriptionClient(TranscriptionTeeClient):
|
| 621 |
-
"""
|
| 622 |
-
Client for handling audio transcription tasks via a single WebSocket connection.
|
| 623 |
-
|
| 624 |
-
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
| 625 |
-
to send audio data for transcription to a server and receive transcribed text segments.
|
| 626 |
-
|
| 627 |
-
Args:
|
| 628 |
-
host (str): The hostname or IP address of the server.
|
| 629 |
-
port (int): The port number to connect to on the server.
|
| 630 |
-
lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
|
| 631 |
-
save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
|
| 632 |
-
output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
|
| 633 |
-
output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
|
| 634 |
-
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
| 635 |
-
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
| 636 |
-
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
| 637 |
-
mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
|
| 638 |
-
|
| 639 |
-
Attributes:
|
| 640 |
-
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
|
| 641 |
-
|
| 642 |
-
Example:
|
| 643 |
-
To create a TranscriptionClient and start transcription on microphone audio:
|
| 644 |
-
```python
|
| 645 |
-
transcription_client = TranscriptionClient(host="localhost", port=9090)
|
| 646 |
-
transcription_client()
|
| 647 |
-
```
|
| 648 |
-
"""
|
| 649 |
-
|
| 650 |
-
def __init__(
|
| 651 |
-
self,
|
| 652 |
-
host,
|
| 653 |
-
port,
|
| 654 |
-
lang=None,
|
| 655 |
-
save_output_recording=False,
|
| 656 |
-
output_recording_filename="./output_recording.wav",
|
| 657 |
-
log_transcription=True,
|
| 658 |
-
max_clients=4,
|
| 659 |
-
max_connection_time=600,
|
| 660 |
-
mute_audio_playback=False,
|
| 661 |
-
dst_lang='en',
|
| 662 |
-
):
|
| 663 |
-
self.client = Client(
|
| 664 |
-
host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
|
| 665 |
-
max_connection_time=max_connection_time, dst_lang=dst_lang
|
| 666 |
-
)
|
| 667 |
-
|
| 668 |
-
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
| 669 |
-
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
|
| 670 |
-
|
| 671 |
-
TranscriptionTeeClient.__init__(
|
| 672 |
-
self,
|
| 673 |
-
[self.client],
|
| 674 |
-
save_output_recording=save_output_recording,
|
| 675 |
-
output_recording_filename=output_recording_filename,
|
| 676 |
-
mute_audio_playback=mute_audio_playback,
|
| 677 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/helpers/vadprocessor.py
CHANGED
|
@@ -36,7 +36,7 @@ class AdaptiveSilenceController:
|
|
| 36 |
speed_factor = 0.5
|
| 37 |
elif avg_speech < 600:
|
| 38 |
speed_factor = 0.8
|
| 39 |
-
|
| 40 |
# 3. silence 的变化趋势也考虑进去
|
| 41 |
adaptive = self.base * speed_factor + 0.3 * avg_silence
|
| 42 |
|
|
|
|
| 36 |
speed_factor = 0.5
|
| 37 |
elif avg_speech < 600:
|
| 38 |
speed_factor = 0.8
|
| 39 |
+
logging.warning(f"Avg speech :{avg_speech}, Avg silence: {avg_silence}")
|
| 40 |
# 3. silence 的变化趋势也考虑进去
|
| 41 |
adaptive = self.base * speed_factor + 0.3 * avg_silence
|
| 42 |
|
transcribe/pipelines/pipe_vad.py
CHANGED
|
@@ -3,10 +3,8 @@ from .base import MetaItem, BasePipe
|
|
| 3 |
from ..helpers.vadprocessor import FixedVADIterator, AdaptiveSilenceController
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
-
from silero_vad import get_speech_timestamps
|
| 7 |
-
from typing import List
|
| 8 |
import logging
|
| 9 |
-
|
| 10 |
# import noisereduce as nr
|
| 11 |
|
| 12 |
|
|
@@ -60,27 +58,22 @@ class VadPipe(BasePipe):
|
|
| 60 |
|
| 61 |
def update_silence_ms(self):
|
| 62 |
min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
|
| 63 |
-
<<<<<<< HEAD
|
| 64 |
min_silence_samples = self.sample_rate * min_silence / 1000
|
| 65 |
-
self.vac.min_silence_samples
|
| 66 |
-
logging.warning(f"🫠 update_silence_ms :{
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
logging.warning(f"🫠 update_silence_ms :{min_silence} => current: {self.vac.min_silence_duration_ms} ")
|
| 70 |
-
self.vac.min_silence_duration_ms = min_silence
|
| 71 |
-
|
| 72 |
-
>>>>>>> efad27a (add log to debug silence ms)
|
| 73 |
def process(self, in_data: MetaItem) -> MetaItem:
|
| 74 |
if self._offset == 0:
|
| 75 |
self.vac.reset_states()
|
| 76 |
-
|
| 77 |
# silence_audio_100ms = np.zeros(int(0.1*self.sample_rate))
|
| 78 |
source_audio = np.frombuffer(in_data.source_audio, dtype=np.float32)
|
| 79 |
speech_data = self._process_speech_chunk(source_audio)
|
| 80 |
|
| 81 |
if speech_data: # 表示有音频的变化点出现
|
| 82 |
-
|
| 83 |
-
rel_start_frame, rel_end_frame = speech_data
|
| 84 |
if rel_start_frame is not None and rel_end_frame is None:
|
| 85 |
self._status = "START" # 语音开始
|
| 86 |
target_audio = source_audio[rel_start_frame:]
|
|
|
|
| 3 |
from ..helpers.vadprocessor import FixedVADIterator, AdaptiveSilenceController
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
import logging
|
| 7 |
+
|
| 8 |
# import noisereduce as nr
|
| 9 |
|
| 10 |
|
|
|
|
| 58 |
|
| 59 |
def update_silence_ms(self):
|
| 60 |
min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
|
|
|
|
| 61 |
min_silence_samples = self.sample_rate * min_silence / 1000
|
| 62 |
+
old_silence_samples = self.vac.min_silence_samples
|
| 63 |
+
logging.warning(f"🫠 update_silence_ms :{old_silence_samples * 1000 / self.sample_rate :.2f}ms => current: {min_silence}ms ")
|
| 64 |
+
# self.vac.min_silence_samples = min_silence_samples
|
| 65 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def process(self, in_data: MetaItem) -> MetaItem:
|
| 67 |
if self._offset == 0:
|
| 68 |
self.vac.reset_states()
|
| 69 |
+
|
| 70 |
# silence_audio_100ms = np.zeros(int(0.1*self.sample_rate))
|
| 71 |
source_audio = np.frombuffer(in_data.source_audio, dtype=np.float32)
|
| 72 |
speech_data = self._process_speech_chunk(source_audio)
|
| 73 |
|
| 74 |
if speech_data: # 表示有音频的变化点出现
|
| 75 |
+
|
| 76 |
+
rel_start_frame, rel_end_frame = speech_data
|
| 77 |
if rel_start_frame is not None and rel_end_frame is None:
|
| 78 |
self._status = "START" # 语音开始
|
| 79 |
target_audio = source_audio[rel_start_frame:]
|
transcribe/server.py
DELETED
|
@@ -1,382 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
import threading
|
| 5 |
-
import time
|
| 6 |
-
import config
|
| 7 |
-
import librosa
|
| 8 |
-
import numpy as np
|
| 9 |
-
import soundfile
|
| 10 |
-
from pywhispercpp.model import Model
|
| 11 |
-
|
| 12 |
-
logging.basicConfig(level=logging.INFO)
|
| 13 |
-
|
| 14 |
-
class ServeClientBase(object):
|
| 15 |
-
RATE = 16000
|
| 16 |
-
SERVER_READY = "SERVER_READY"
|
| 17 |
-
DISCONNECT = "DISCONNECT"
|
| 18 |
-
|
| 19 |
-
def __init__(self, client_uid, websocket):
|
| 20 |
-
self.client_uid = client_uid
|
| 21 |
-
self.websocket = websocket
|
| 22 |
-
self.frames = b""
|
| 23 |
-
self.timestamp_offset = 0.0
|
| 24 |
-
self.frames_np = None
|
| 25 |
-
self.frames_offset = 0.0
|
| 26 |
-
self.text = []
|
| 27 |
-
self.current_out = ''
|
| 28 |
-
self.prev_out = ''
|
| 29 |
-
self.t_start = None
|
| 30 |
-
self.exit = False
|
| 31 |
-
self.same_output_count = 0
|
| 32 |
-
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
| 33 |
-
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
| 34 |
-
self.transcript = []
|
| 35 |
-
self.send_last_n_segments = 10
|
| 36 |
-
|
| 37 |
-
# text formatting
|
| 38 |
-
self.pick_previous_segments = 2
|
| 39 |
-
|
| 40 |
-
# threading
|
| 41 |
-
self.lock = threading.Lock()
|
| 42 |
-
|
| 43 |
-
def speech_to_text(self):
|
| 44 |
-
raise NotImplementedError
|
| 45 |
-
|
| 46 |
-
def transcribe_audio(self):
|
| 47 |
-
raise NotImplementedError
|
| 48 |
-
|
| 49 |
-
def handle_transcription_output(self):
|
| 50 |
-
raise NotImplementedError
|
| 51 |
-
|
| 52 |
-
def add_frames(self, frame_np):
|
| 53 |
-
"""
|
| 54 |
-
Add audio frames to the ongoing audio stream buffer.
|
| 55 |
-
|
| 56 |
-
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
| 57 |
-
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
| 58 |
-
to prevent excessive memory usage.
|
| 59 |
-
|
| 60 |
-
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
| 61 |
-
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
| 62 |
-
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
| 66 |
-
|
| 67 |
-
"""
|
| 68 |
-
self.lock.acquire()
|
| 69 |
-
if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
|
| 70 |
-
self.frames_offset += 30.0
|
| 71 |
-
self.frames_np = self.frames_np[int(30 * self.RATE):]
|
| 72 |
-
# check timestamp offset(should be >= self.frame_offset)
|
| 73 |
-
# this basically means that there is no speech as timestamp offset hasnt updated
|
| 74 |
-
# and is less than frame_offset
|
| 75 |
-
if self.timestamp_offset < self.frames_offset:
|
| 76 |
-
self.timestamp_offset = self.frames_offset
|
| 77 |
-
if self.frames_np is None:
|
| 78 |
-
self.frames_np = frame_np.copy()
|
| 79 |
-
else:
|
| 80 |
-
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
| 81 |
-
self.lock.release()
|
| 82 |
-
|
| 83 |
-
def clip_audio_if_no_valid_segment(self):
|
| 84 |
-
"""
|
| 85 |
-
Update the timestamp offset based on audio buffer status.
|
| 86 |
-
Clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
| 87 |
-
no valid segment for the last 30 seconds from whisper
|
| 88 |
-
"""
|
| 89 |
-
with self.lock:
|
| 90 |
-
if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
|
| 91 |
-
duration = self.frames_np.shape[0] / self.RATE
|
| 92 |
-
self.timestamp_offset = self.frames_offset + duration - 5
|
| 93 |
-
|
| 94 |
-
def get_audio_chunk_for_processing(self):
|
| 95 |
-
"""
|
| 96 |
-
Retrieves the next chunk of audio data for processing based on the current offsets.
|
| 97 |
-
|
| 98 |
-
Calculates which part of the audio data should be processed next, based on
|
| 99 |
-
the difference between the current timestamp offset and the frame's offset, scaled by
|
| 100 |
-
the audio sample rate (RATE). It then returns this chunk of audio data along with its
|
| 101 |
-
duration in seconds.
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
tuple: A tuple containing:
|
| 105 |
-
- input_bytes (np.ndarray): The next chunk of audio data to be processed.
|
| 106 |
-
- duration (float): The duration of the audio chunk in seconds.
|
| 107 |
-
"""
|
| 108 |
-
with self.lock:
|
| 109 |
-
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
|
| 110 |
-
input_bytes = self.frames_np[int(samples_take):].copy()
|
| 111 |
-
duration = input_bytes.shape[0] / self.RATE
|
| 112 |
-
return input_bytes, duration
|
| 113 |
-
|
| 114 |
-
def prepare_segments(self, last_segment=None):
|
| 115 |
-
"""
|
| 116 |
-
Prepares the segments of transcribed text to be sent to the client.
|
| 117 |
-
|
| 118 |
-
This method compiles the recent segments of transcribed text, ensuring that only the
|
| 119 |
-
specified number of the most recent segments are included. It also appends the most
|
| 120 |
-
recent segment of text if provided (which is considered incomplete because of the possibility
|
| 121 |
-
of the last word being truncated in the audio chunk).
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
last_segment (str, optional): The most recent segment of transcribed text to be added
|
| 125 |
-
to the list of segments. Defaults to None.
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
list: A list of transcribed text segments to be sent to the client.
|
| 129 |
-
"""
|
| 130 |
-
segments = []
|
| 131 |
-
if len(self.transcript) >= self.send_last_n_segments:
|
| 132 |
-
segments = self.transcript[-self.send_last_n_segments:].copy()
|
| 133 |
-
else:
|
| 134 |
-
segments = self.transcript.copy()
|
| 135 |
-
if last_segment is not None:
|
| 136 |
-
segments = segments + [last_segment]
|
| 137 |
-
logging.info(f"{segments}")
|
| 138 |
-
return segments
|
| 139 |
-
|
| 140 |
-
def get_audio_chunk_duration(self, input_bytes):
|
| 141 |
-
"""
|
| 142 |
-
Calculates the duration of the provided audio chunk.
|
| 143 |
-
|
| 144 |
-
Args:
|
| 145 |
-
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
|
| 146 |
-
|
| 147 |
-
Returns:
|
| 148 |
-
float: The duration of the audio chunk in seconds.
|
| 149 |
-
"""
|
| 150 |
-
return input_bytes.shape[0] / self.RATE
|
| 151 |
-
|
| 152 |
-
def send_transcription_to_client(self, segments):
|
| 153 |
-
"""
|
| 154 |
-
Sends the specified transcription segments to the client over the websocket connection.
|
| 155 |
-
|
| 156 |
-
This method formats the transcription segments into a JSON object and attempts to send
|
| 157 |
-
this object to the client. If an error occurs during the send operation, it logs the error.
|
| 158 |
-
|
| 159 |
-
Returns:
|
| 160 |
-
segments (list): A list of transcription segments to be sent to the client.
|
| 161 |
-
"""
|
| 162 |
-
try:
|
| 163 |
-
self.websocket.send(
|
| 164 |
-
json.dumps({
|
| 165 |
-
"uid": self.client_uid,
|
| 166 |
-
"segments": segments,
|
| 167 |
-
})
|
| 168 |
-
)
|
| 169 |
-
except Exception as e:
|
| 170 |
-
logging.error(f"[ERROR]: Sending data to client: {e}")
|
| 171 |
-
|
| 172 |
-
def disconnect(self):
|
| 173 |
-
"""
|
| 174 |
-
Notify the client of disconnection and send a disconnect message.
|
| 175 |
-
|
| 176 |
-
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
| 177 |
-
that the transcription service is disconnecting gracefully.
|
| 178 |
-
|
| 179 |
-
"""
|
| 180 |
-
self.websocket.send(json.dumps({
|
| 181 |
-
"uid": self.client_uid,
|
| 182 |
-
"message": self.DISCONNECT
|
| 183 |
-
}))
|
| 184 |
-
|
| 185 |
-
def cleanup(self):
|
| 186 |
-
"""
|
| 187 |
-
Perform cleanup tasks before exiting the transcription service.
|
| 188 |
-
|
| 189 |
-
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
| 190 |
-
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
| 191 |
-
associated with the transcription process.
|
| 192 |
-
|
| 193 |
-
"""
|
| 194 |
-
logging.info("Cleaning up.")
|
| 195 |
-
self.exit = True
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
class ServeClientWhisperCPP(ServeClientBase):
|
| 199 |
-
SINGLE_MODEL = None
|
| 200 |
-
SINGLE_MODEL_LOCK = threading.Lock()
|
| 201 |
-
|
| 202 |
-
def __init__(self, websocket, language=None, client_uid=None,
|
| 203 |
-
single_model=False):
|
| 204 |
-
"""
|
| 205 |
-
Initialize a ServeClient instance.
|
| 206 |
-
The Whisper model is initialized based on the client's language and device availability.
|
| 207 |
-
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
| 208 |
-
to the client to indicate that the server is ready.
|
| 209 |
-
|
| 210 |
-
Args:
|
| 211 |
-
websocket (WebSocket): The WebSocket connection for the client.
|
| 212 |
-
language (str, optional): The language for transcription. Defaults to None.
|
| 213 |
-
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
| 214 |
-
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
| 215 |
-
|
| 216 |
-
"""
|
| 217 |
-
super().__init__(client_uid, websocket)
|
| 218 |
-
self.language = language
|
| 219 |
-
self.eos = False
|
| 220 |
-
|
| 221 |
-
if single_model:
|
| 222 |
-
if ServeClientWhisperCPP.SINGLE_MODEL is None:
|
| 223 |
-
self.create_model()
|
| 224 |
-
ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
|
| 225 |
-
else:
|
| 226 |
-
self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
|
| 227 |
-
else:
|
| 228 |
-
self.create_model()
|
| 229 |
-
|
| 230 |
-
# threading
|
| 231 |
-
logging.info('Create a thread to process audio.')
|
| 232 |
-
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
| 233 |
-
self.trans_thread.start()
|
| 234 |
-
|
| 235 |
-
self.websocket.send(json.dumps({
|
| 236 |
-
"uid": self.client_uid,
|
| 237 |
-
"message": self.SERVER_READY,
|
| 238 |
-
"backend": "pywhispercpp"
|
| 239 |
-
}))
|
| 240 |
-
|
| 241 |
-
def create_model(self, warmup=True):
|
| 242 |
-
"""
|
| 243 |
-
Instantiates a new model, sets it as the transcriber and does warmup if desired.
|
| 244 |
-
"""
|
| 245 |
-
|
| 246 |
-
self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR)
|
| 247 |
-
if warmup:
|
| 248 |
-
self.warmup()
|
| 249 |
-
|
| 250 |
-
def warmup(self, warmup_steps=1):
|
| 251 |
-
"""
|
| 252 |
-
Warmup TensorRT since first few inferences are slow.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
warmup_steps (int): Number of steps to warm up the model for.
|
| 256 |
-
"""
|
| 257 |
-
logging.info("[INFO:] Warming up whisper.cpp engine..")
|
| 258 |
-
mel, _, = soundfile.read("assets/jfk.flac")
|
| 259 |
-
for i in range(warmup_steps):
|
| 260 |
-
self.transcriber.transcribe(mel, print_progress=False)
|
| 261 |
-
|
| 262 |
-
def set_eos(self, eos):
|
| 263 |
-
"""
|
| 264 |
-
Sets the End of Speech (EOS) flag.
|
| 265 |
-
|
| 266 |
-
Args:
|
| 267 |
-
eos (bool): The value to set for the EOS flag.
|
| 268 |
-
"""
|
| 269 |
-
self.lock.acquire()
|
| 270 |
-
self.eos = eos
|
| 271 |
-
self.lock.release()
|
| 272 |
-
|
| 273 |
-
def handle_transcription_output(self, last_segment, duration):
|
| 274 |
-
"""
|
| 275 |
-
Handle the transcription output, updating the transcript and sending data to the client.
|
| 276 |
-
|
| 277 |
-
Args:
|
| 278 |
-
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
|
| 279 |
-
of the possibility of word being truncated.
|
| 280 |
-
duration (float): Duration of the transcribed audio chunk.
|
| 281 |
-
"""
|
| 282 |
-
segments = self.prepare_segments({"text": last_segment})
|
| 283 |
-
self.send_transcription_to_client(segments)
|
| 284 |
-
if self.eos:
|
| 285 |
-
self.update_timestamp_offset(last_segment, duration)
|
| 286 |
-
|
| 287 |
-
def transcribe_audio(self, input_bytes):
|
| 288 |
-
"""
|
| 289 |
-
Transcribe the audio chunk and send the results to the client.
|
| 290 |
-
|
| 291 |
-
Args:
|
| 292 |
-
input_bytes (np.array): The audio chunk to transcribe.
|
| 293 |
-
"""
|
| 294 |
-
if ServeClientWhisperCPP.SINGLE_MODEL:
|
| 295 |
-
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
|
| 296 |
-
logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
|
| 297 |
-
mel = input_bytes
|
| 298 |
-
duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
|
| 299 |
-
|
| 300 |
-
if self.language == "zh":
|
| 301 |
-
prompt = '以下是简体中文普通话的句子。'
|
| 302 |
-
else:
|
| 303 |
-
prompt = 'The following is an English sentence.'
|
| 304 |
-
|
| 305 |
-
segments = self.transcriber.transcribe(
|
| 306 |
-
mel,
|
| 307 |
-
language=self.language,
|
| 308 |
-
initial_prompt=prompt,
|
| 309 |
-
token_timestamps=True,
|
| 310 |
-
# max_len=max_len,
|
| 311 |
-
print_progress=False
|
| 312 |
-
)
|
| 313 |
-
text = []
|
| 314 |
-
for segment in segments:
|
| 315 |
-
content = segment.text
|
| 316 |
-
text.append(content)
|
| 317 |
-
last_segment = ' '.join(text)
|
| 318 |
-
|
| 319 |
-
logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
|
| 320 |
-
|
| 321 |
-
if ServeClientWhisperCPP.SINGLE_MODEL:
|
| 322 |
-
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
|
| 323 |
-
if last_segment:
|
| 324 |
-
self.handle_transcription_output(last_segment, duration)
|
| 325 |
-
|
| 326 |
-
def update_timestamp_offset(self, last_segment, duration):
|
| 327 |
-
"""
|
| 328 |
-
Update timestamp offset and transcript.
|
| 329 |
-
|
| 330 |
-
Args:
|
| 331 |
-
last_segment (str): Last transcribed audio from the whisper model.
|
| 332 |
-
duration (float): Duration of the last audio chunk.
|
| 333 |
-
"""
|
| 334 |
-
if not len(self.transcript):
|
| 335 |
-
self.transcript.append({"text": last_segment + " "})
|
| 336 |
-
elif self.transcript[-1]["text"].strip() != last_segment:
|
| 337 |
-
self.transcript.append({"text": last_segment + " "})
|
| 338 |
-
|
| 339 |
-
logging.info(f'Transcript list context: {self.transcript}')
|
| 340 |
-
|
| 341 |
-
with self.lock:
|
| 342 |
-
self.timestamp_offset += duration
|
| 343 |
-
|
| 344 |
-
def speech_to_text(self):
|
| 345 |
-
"""
|
| 346 |
-
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
| 347 |
-
|
| 348 |
-
This method continuously receives audio frames, performs real-time transcription, and sends
|
| 349 |
-
transcribed segments to the client via a WebSocket connection.
|
| 350 |
-
|
| 351 |
-
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
| 352 |
-
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
| 353 |
-
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
| 354 |
-
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
| 355 |
-
there is no speech for a specified duration to indicate a pause.
|
| 356 |
-
|
| 357 |
-
Raises:
|
| 358 |
-
Exception: If there is an issue with audio processing or WebSocket communication.
|
| 359 |
-
|
| 360 |
-
"""
|
| 361 |
-
while True:
|
| 362 |
-
if self.exit:
|
| 363 |
-
logging.info("Exiting speech to text thread")
|
| 364 |
-
break
|
| 365 |
-
|
| 366 |
-
if self.frames_np is None:
|
| 367 |
-
time.sleep(0.02) # wait for any audio to arrive
|
| 368 |
-
continue
|
| 369 |
-
|
| 370 |
-
self.clip_audio_if_no_valid_segment()
|
| 371 |
-
|
| 372 |
-
input_bytes, duration = self.get_audio_chunk_for_processing()
|
| 373 |
-
if duration < 1:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
try:
|
| 377 |
-
input_sample = input_bytes.copy()
|
| 378 |
-
logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
|
| 379 |
-
self.transcribe_audio(input_sample)
|
| 380 |
-
|
| 381 |
-
except Exception as e:
|
| 382 |
-
logging.error(f"[ERROR]: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/strategy.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import collections
|
| 3 |
-
import logging
|
| 4 |
-
from difflib import SequenceMatcher
|
| 5 |
-
from itertools import chain
|
| 6 |
-
from dataclasses import dataclass, field
|
| 7 |
-
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
|
| 8 |
-
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
|
| 9 |
-
from enum import Enum
|
| 10 |
-
import wordninja
|
| 11 |
-
import config
|
| 12 |
-
import re
|
| 13 |
-
logger = logging.getLogger("TranscriptionStrategy")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class SplitMode(Enum):
|
| 17 |
-
PUNCTUATION = "punctuation"
|
| 18 |
-
PAUSE = "pause"
|
| 19 |
-
END = "end"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@dataclass
|
| 24 |
-
class TranscriptResult:
|
| 25 |
-
seg_id: int = 0
|
| 26 |
-
cut_index: int = 0
|
| 27 |
-
is_end_sentence: bool = False
|
| 28 |
-
context: str = ""
|
| 29 |
-
|
| 30 |
-
def partial(self):
|
| 31 |
-
return not self.is_end_sentence
|
| 32 |
-
|
| 33 |
-
@dataclass
|
| 34 |
-
class TranscriptToken:
|
| 35 |
-
"""表示一个转录片段,包含文本和时间信息"""
|
| 36 |
-
text: str # 转录的文本内容
|
| 37 |
-
t0: int # 开始时间(百分之一秒)
|
| 38 |
-
t1: int # 结束时间(百分之一秒)
|
| 39 |
-
|
| 40 |
-
def is_punctuation(self):
|
| 41 |
-
"""检查文本是否包含标点符号"""
|
| 42 |
-
return REGEX_MARKERS.search(self.text.strip()) is not None
|
| 43 |
-
|
| 44 |
-
def is_end(self):
|
| 45 |
-
"""检查文本是否为句子结束标记"""
|
| 46 |
-
return SENTENCE_END_PATTERN.search(self.text.strip()) is not None
|
| 47 |
-
|
| 48 |
-
def is_pause(self):
|
| 49 |
-
"""检查文本是否为暂停标记"""
|
| 50 |
-
return PAUSEE_END_PATTERN.search(self.text.strip()) is not None
|
| 51 |
-
|
| 52 |
-
def buffer_index(self) -> int:
|
| 53 |
-
return int(self.t1 / 100 * SAMPLE_RATE)
|
| 54 |
-
|
| 55 |
-
@dataclass
|
| 56 |
-
class TranscriptChunk:
|
| 57 |
-
"""表示一组转录片段,支持分割和比较操作"""
|
| 58 |
-
separator: str = "" # 用于连接片段的分隔符
|
| 59 |
-
items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
|
| 60 |
-
|
| 61 |
-
@staticmethod
|
| 62 |
-
def _calculate_similarity(text1: str, text2: str) -> float:
|
| 63 |
-
"""计算两段文本的相似度"""
|
| 64 |
-
return SequenceMatcher(None, text1, text2).ratio()
|
| 65 |
-
|
| 66 |
-
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
|
| 67 |
-
"""根据文本中的标点符号分割片段列表"""
|
| 68 |
-
if mode == SplitMode.PUNCTUATION:
|
| 69 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
|
| 70 |
-
elif mode == SplitMode.PAUSE:
|
| 71 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
|
| 72 |
-
elif mode == SplitMode.END:
|
| 73 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
|
| 74 |
-
else:
|
| 75 |
-
raise ValueError(f"Unsupported mode: {mode}")
|
| 76 |
-
|
| 77 |
-
# 每个切分点向后移一个索引,表示“分隔符归前段”
|
| 78 |
-
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
|
| 79 |
-
chunks = [
|
| 80 |
-
TranscriptChunk(items=self.items[start:end], separator=self.separator)
|
| 81 |
-
for start, end in zip(cut_points, cut_points[1:])
|
| 82 |
-
]
|
| 83 |
-
return [
|
| 84 |
-
ck
|
| 85 |
-
for ck in chunks
|
| 86 |
-
if not ck.only_punctuation()
|
| 87 |
-
]
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def get_split_first_rest(self, mode: SplitMode):
|
| 91 |
-
chunks = self.split_by(mode)
|
| 92 |
-
fisrt_chunk = chunks[0] if chunks else self
|
| 93 |
-
rest_chunks = chunks[1:] if chunks else None
|
| 94 |
-
return fisrt_chunk, rest_chunks
|
| 95 |
-
|
| 96 |
-
def puncation_numbers(self) -> int:
|
| 97 |
-
"""计算片段中标点符号的数量"""
|
| 98 |
-
return sum(1 for seg in self.items if seg.is_punctuation())
|
| 99 |
-
|
| 100 |
-
def length(self) -> int:
|
| 101 |
-
"""返回片段列表的长度"""
|
| 102 |
-
return len(self.items)
|
| 103 |
-
|
| 104 |
-
def join(self) -> str:
|
| 105 |
-
"""将片段连接为一个字符串"""
|
| 106 |
-
return self.separator.join(seg.text for seg in self.items)
|
| 107 |
-
|
| 108 |
-
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
|
| 109 |
-
"""比较当前片段与另一个片段的相似度"""
|
| 110 |
-
if not chunk:
|
| 111 |
-
return 0
|
| 112 |
-
|
| 113 |
-
score = self._calculate_similarity(self.join(), chunk.join())
|
| 114 |
-
# logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
|
| 115 |
-
return score
|
| 116 |
-
|
| 117 |
-
def only_punctuation(self)->bool:
|
| 118 |
-
return all(seg.is_punctuation() for seg in self.items)
|
| 119 |
-
|
| 120 |
-
def has_punctuation(self) -> bool:
|
| 121 |
-
return any(seg.is_punctuation() for seg in self.items)
|
| 122 |
-
|
| 123 |
-
def get_buffer_index(self) -> int:
|
| 124 |
-
return self.items[-1].buffer_index()
|
| 125 |
-
|
| 126 |
-
def is_end_sentence(self) ->bool:
|
| 127 |
-
return self.items[-1].is_end()
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
class TranscriptHistory:
|
| 131 |
-
"""管理转录片段的历史记录"""
|
| 132 |
-
|
| 133 |
-
def __init__(self) -> None:
|
| 134 |
-
self.history = collections.deque(maxlen=2) # 存储最近的两个片段
|
| 135 |
-
|
| 136 |
-
def add(self, chunk: TranscriptChunk):
|
| 137 |
-
"""添加新的片段到历史记录"""
|
| 138 |
-
self.history.appendleft(chunk)
|
| 139 |
-
|
| 140 |
-
def previous_chunk(self) -> Optional[TranscriptChunk]:
|
| 141 |
-
"""获取上一个片段(如果存在)"""
|
| 142 |
-
return self.history[1] if len(self.history) == 2 else None
|
| 143 |
-
|
| 144 |
-
def lastest_chunk(self):
|
| 145 |
-
"""获取最后一个片段"""
|
| 146 |
-
return self.history[-1]
|
| 147 |
-
|
| 148 |
-
def clear(self):
|
| 149 |
-
self.history.clear()
|
| 150 |
-
|
| 151 |
-
class TranscriptBuffer:
|
| 152 |
-
"""
|
| 153 |
-
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
|
| 154 |
-
|
| 155 |
-
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
|
| 156 |
-
|
| 157 |
-
管理 pending -> line -> paragraph 的缓冲逻辑
|
| 158 |
-
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
def __init__(self, source_lang:str, separator:str):
|
| 162 |
-
self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
|
| 163 |
-
self._sentences: List[str] = collections.deque() # 当前段落中的短句
|
| 164 |
-
self._buffer: str = "" # 当前缓冲中的文本
|
| 165 |
-
self._current_seg_id: int = 0
|
| 166 |
-
self.source_language = source_lang
|
| 167 |
-
self._separator = separator
|
| 168 |
-
|
| 169 |
-
def get_seg_id(self) -> int:
|
| 170 |
-
return self._current_seg_id
|
| 171 |
-
|
| 172 |
-
@property
|
| 173 |
-
def current_sentences_length(self) -> int:
|
| 174 |
-
count = 0
|
| 175 |
-
for item in self._sentences:
|
| 176 |
-
if self._separator:
|
| 177 |
-
count += len(item.split(self._separator))
|
| 178 |
-
else:
|
| 179 |
-
count += len(item)
|
| 180 |
-
return count
|
| 181 |
-
|
| 182 |
-
def update_pending_text(self, text: str) -> None:
|
| 183 |
-
"""更新临时缓冲字符串"""
|
| 184 |
-
self._buffer = text
|
| 185 |
-
|
| 186 |
-
def commit_line(self,) -> None:
|
| 187 |
-
"""将缓冲字符串提交为短句"""
|
| 188 |
-
if self._buffer:
|
| 189 |
-
self._sentences.append(self._buffer)
|
| 190 |
-
self._buffer = ""
|
| 191 |
-
|
| 192 |
-
def commit_paragraph(self) -> None:
|
| 193 |
-
"""
|
| 194 |
-
提交当前短句为完整段落(如句子结束)
|
| 195 |
-
|
| 196 |
-
Args:
|
| 197 |
-
end_of_sentence: 是否为句子结尾(如检测到句号)
|
| 198 |
-
"""
|
| 199 |
-
|
| 200 |
-
count = 0
|
| 201 |
-
current_sentences = []
|
| 202 |
-
while len(self._sentences): # and count < 20:
|
| 203 |
-
item = self._sentences.popleft()
|
| 204 |
-
current_sentences.append(item)
|
| 205 |
-
if self._separator:
|
| 206 |
-
count += len(item.split(self._separator))
|
| 207 |
-
else:
|
| 208 |
-
count += len(item)
|
| 209 |
-
if current_sentences:
|
| 210 |
-
self._segments.append("".join(current_sentences))
|
| 211 |
-
logger.debug(f"=== count to paragraph ===")
|
| 212 |
-
logger.debug(f"push: {current_sentences}")
|
| 213 |
-
logger.debug(f"rest: {self._sentences}")
|
| 214 |
-
# if self._sentences:
|
| 215 |
-
# self._segments.append("".join(self._sentences))
|
| 216 |
-
# self._sentences.clear()
|
| 217 |
-
|
| 218 |
-
def rebuild(self, text):
|
| 219 |
-
output = self.split_and_join(
|
| 220 |
-
text.replace(
|
| 221 |
-
self._separator, ""))
|
| 222 |
-
|
| 223 |
-
logger.debug("==== rebuild string ====")
|
| 224 |
-
logger.debug(text)
|
| 225 |
-
logger.debug(output)
|
| 226 |
-
|
| 227 |
-
return output
|
| 228 |
-
|
| 229 |
-
@staticmethod
|
| 230 |
-
def split_and_join(text):
|
| 231 |
-
tokens = []
|
| 232 |
-
word_buf = ''
|
| 233 |
-
|
| 234 |
-
for char in text:
|
| 235 |
-
if char in ALL_MARKERS:
|
| 236 |
-
if word_buf:
|
| 237 |
-
tokens.extend(wordninja.split(word_buf))
|
| 238 |
-
word_buf = ''
|
| 239 |
-
tokens.append(char)
|
| 240 |
-
else:
|
| 241 |
-
word_buf += char
|
| 242 |
-
if word_buf:
|
| 243 |
-
tokens.extend(wordninja.split(word_buf))
|
| 244 |
-
|
| 245 |
-
output = ''
|
| 246 |
-
for i, token in enumerate(tokens):
|
| 247 |
-
if i == 0:
|
| 248 |
-
output += token
|
| 249 |
-
elif token in ALL_MARKERS:
|
| 250 |
-
output += (token + " ")
|
| 251 |
-
else:
|
| 252 |
-
output += ' ' + token
|
| 253 |
-
return output
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
|
| 257 |
-
if self.source_language == "en":
|
| 258 |
-
stable_strings = [self.rebuild(i) for i in stable_strings]
|
| 259 |
-
remaining_strings =[self.rebuild(i) for i in remaining_strings]
|
| 260 |
-
remaining_string = "".join(remaining_strings)
|
| 261 |
-
|
| 262 |
-
logger.debug(f"{self.__dict__}")
|
| 263 |
-
if is_end_sentence:
|
| 264 |
-
for stable_str in stable_strings:
|
| 265 |
-
self.update_pending_text(stable_str)
|
| 266 |
-
self.commit_line()
|
| 267 |
-
|
| 268 |
-
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
|
| 269 |
-
# current_text_len = len(self.current_not_commit_text.split(self._separator))
|
| 270 |
-
self.update_pending_text(remaining_string)
|
| 271 |
-
if current_text_len >= config.TEXT_THREHOLD:
|
| 272 |
-
self.commit_paragraph()
|
| 273 |
-
self._current_seg_id += 1
|
| 274 |
-
return True
|
| 275 |
-
else:
|
| 276 |
-
for stable_str in stable_strings:
|
| 277 |
-
self.update_pending_text(stable_str)
|
| 278 |
-
self.commit_line()
|
| 279 |
-
self.update_pending_text(remaining_string)
|
| 280 |
-
return False
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
@property
|
| 284 |
-
def un_commit_paragraph(self) -> str:
|
| 285 |
-
"""当前短句组合"""
|
| 286 |
-
return "".join([i for i in self._sentences])
|
| 287 |
-
|
| 288 |
-
@property
|
| 289 |
-
def pending_text(self) -> str:
|
| 290 |
-
"""当前缓冲内容"""
|
| 291 |
-
return self._buffer
|
| 292 |
-
|
| 293 |
-
@property
|
| 294 |
-
def latest_paragraph(self) -> str:
|
| 295 |
-
"""最新确认的段落"""
|
| 296 |
-
return self._segments[-1] if self._segments else ""
|
| 297 |
-
|
| 298 |
-
@property
|
| 299 |
-
def current_not_commit_text(self) -> str:
|
| 300 |
-
return self.un_commit_paragraph + self.pending_text
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
class TranscriptStabilityAnalyzer:
|
| 305 |
-
def __init__(self, source_lang, separator) -> None:
|
| 306 |
-
self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
|
| 307 |
-
self._transcript_history = TranscriptHistory()
|
| 308 |
-
self._separator = separator
|
| 309 |
-
logger.debug(f"Current separator: {self._separator}")
|
| 310 |
-
|
| 311 |
-
def merge_chunks(self, chunks: List[TranscriptChunk])->str:
|
| 312 |
-
if not chunks:
|
| 313 |
-
return [""]
|
| 314 |
-
output = list(r.join() for r in chunks if r)
|
| 315 |
-
return output
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
|
| 319 |
-
current = TranscriptChunk(items=current, separator=self._separator)
|
| 320 |
-
self._transcript_history.add(current)
|
| 321 |
-
|
| 322 |
-
prev = self._transcript_history.previous_chunk()
|
| 323 |
-
self._transcript_buffer.update_pending_text(current.join())
|
| 324 |
-
if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
|
| 325 |
-
yield TranscriptResult(
|
| 326 |
-
context=self._transcript_buffer.current_not_commit_text,
|
| 327 |
-
seg_id=self._transcript_buffer.get_seg_id()
|
| 328 |
-
)
|
| 329 |
-
return
|
| 330 |
-
|
| 331 |
-
# yield from self._handle_short_buffer(current, prev)
|
| 332 |
-
if buffer_duration <= 4:
|
| 333 |
-
yield from self._handle_short_buffer(current, prev)
|
| 334 |
-
else:
|
| 335 |
-
yield from self._handle_long_buffer(current)
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
|
| 339 |
-
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
|
| 340 |
-
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
|
| 341 |
-
|
| 342 |
-
# logger.debug("==== Current cut item ====")
|
| 343 |
-
# logger.debug(f"{curr.join()} ")
|
| 344 |
-
# logger.debug(f"{prev.join()}")
|
| 345 |
-
# logger.debug("==========================")
|
| 346 |
-
|
| 347 |
-
if curr_first and prev_first:
|
| 348 |
-
|
| 349 |
-
core = curr_first.compare(prev_first)
|
| 350 |
-
has_punctuation = curr_first.has_punctuation()
|
| 351 |
-
if core >= 0.8 and has_punctuation:
|
| 352 |
-
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
|
| 353 |
-
return
|
| 354 |
-
|
| 355 |
-
yield TranscriptResult(
|
| 356 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 357 |
-
context=self._transcript_buffer.current_not_commit_text
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
|
| 362 |
-
chunks = curr.split_by(SplitMode.PUNCTUATION)
|
| 363 |
-
if len(chunks) > 1:
|
| 364 |
-
stable, remaining = chunks[:-1], chunks[-1:]
|
| 365 |
-
# stable_str = self.merge_chunks(stable)
|
| 366 |
-
# remaining_str = self.merge_chunks(remaining)
|
| 367 |
-
yield from self._yield_commit_results(
|
| 368 |
-
stable, remaining, is_end_sentence=True # 暂时硬编码为True
|
| 369 |
-
)
|
| 370 |
-
else:
|
| 371 |
-
yield TranscriptResult(
|
| 372 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 373 |
-
context=self._transcript_buffer.current_not_commit_text
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
|
| 378 |
-
stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
|
| 379 |
-
remaining_str_list = self.merge_chunks(remaining_chunks)
|
| 380 |
-
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
|
| 381 |
-
|
| 382 |
-
prev_seg_id = self._transcript_buffer.get_seg_id()
|
| 383 |
-
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
|
| 384 |
-
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
|
| 385 |
-
|
| 386 |
-
if commit_paragraph:
|
| 387 |
-
# 表示生成了一个新段落 换行
|
| 388 |
-
yield TranscriptResult(
|
| 389 |
-
seg_id=prev_seg_id,
|
| 390 |
-
cut_index=frame_cut_index,
|
| 391 |
-
context=self._transcript_buffer.latest_paragraph,
|
| 392 |
-
is_end_sentence=True
|
| 393 |
-
)
|
| 394 |
-
if (context := self._transcript_buffer.current_not_commit_text.strip()):
|
| 395 |
-
yield TranscriptResult(
|
| 396 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 397 |
-
context=context,
|
| 398 |
-
)
|
| 399 |
-
else:
|
| 400 |
-
yield TranscriptResult(
|
| 401 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 402 |
-
cut_index=frame_cut_index,
|
| 403 |
-
context=self._transcript_buffer.current_not_commit_text,
|
| 404 |
-
)
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/transcription.py
DELETED
|
@@ -1,334 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import time
|
| 3 |
-
import functools
|
| 4 |
-
import json
|
| 5 |
-
import logging
|
| 6 |
-
import time
|
| 7 |
-
from enum import Enum
|
| 8 |
-
from typing import List, Optional
|
| 9 |
-
import numpy as np
|
| 10 |
-
from .server import ServeClientBase
|
| 11 |
-
from .whisper_llm_serve import PyWhiperCppServe
|
| 12 |
-
from .vad import VoiceActivityDetector
|
| 13 |
-
from urllib.parse import urlparse, parse_qsl
|
| 14 |
-
from websockets.exceptions import ConnectionClosed
|
| 15 |
-
from websockets.sync.server import serve
|
| 16 |
-
from uuid import uuid1
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
logging.basicConfig(level=logging.INFO)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class ClientManager:
|
| 23 |
-
def __init__(self, max_clients=4, max_connection_time=600):
|
| 24 |
-
"""
|
| 25 |
-
Initializes the ClientManager with specified limits on client connections and connection durations.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
|
| 29 |
-
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
|
| 30 |
-
to 600 seconds (10 minutes).
|
| 31 |
-
"""
|
| 32 |
-
self.clients = {}
|
| 33 |
-
self.start_times = {}
|
| 34 |
-
self.max_clients = max_clients
|
| 35 |
-
self.max_connection_time = max_connection_time
|
| 36 |
-
|
| 37 |
-
def add_client(self, websocket, client):
|
| 38 |
-
"""
|
| 39 |
-
Adds a client and their connection start time to the tracking dictionaries.
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
websocket: The websocket associated with the client to add.
|
| 43 |
-
client: The client object to be added and tracked.
|
| 44 |
-
"""
|
| 45 |
-
self.clients[websocket] = client
|
| 46 |
-
self.start_times[websocket] = time.time()
|
| 47 |
-
|
| 48 |
-
def get_client(self, websocket):
|
| 49 |
-
"""
|
| 50 |
-
Retrieves a client associated with the given websocket.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
websocket: The websocket associated with the client to retrieve.
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
The client object if found, False otherwise.
|
| 57 |
-
"""
|
| 58 |
-
if websocket in self.clients:
|
| 59 |
-
return self.clients[websocket]
|
| 60 |
-
return False
|
| 61 |
-
|
| 62 |
-
def remove_client(self, websocket):
|
| 63 |
-
"""
|
| 64 |
-
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
|
| 65 |
-
client if necessary.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
websocket: The websocket associated with the client to be removed.
|
| 69 |
-
"""
|
| 70 |
-
client = self.clients.pop(websocket, None)
|
| 71 |
-
if client:
|
| 72 |
-
client.cleanup()
|
| 73 |
-
self.start_times.pop(websocket, None)
|
| 74 |
-
|
| 75 |
-
def get_wait_time(self):
|
| 76 |
-
"""
|
| 77 |
-
Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
|
| 81 |
-
"""
|
| 82 |
-
wait_time = None
|
| 83 |
-
for start_time in self.start_times.values():
|
| 84 |
-
current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
|
| 85 |
-
if wait_time is None or current_client_time_remaining < wait_time:
|
| 86 |
-
wait_time = current_client_time_remaining
|
| 87 |
-
return wait_time / 60 if wait_time is not None else 0
|
| 88 |
-
|
| 89 |
-
def is_server_full(self, websocket, options):
|
| 90 |
-
"""
|
| 91 |
-
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
|
| 92 |
-
|
| 93 |
-
Args:
|
| 94 |
-
websocket: The websocket of the client attempting to connect.
|
| 95 |
-
options: A dictionary of options that may include the client's unique identifier.
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
True if the server is full, False otherwise.
|
| 99 |
-
"""
|
| 100 |
-
if len(self.clients) >= self.max_clients:
|
| 101 |
-
wait_time = self.get_wait_time()
|
| 102 |
-
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
|
| 103 |
-
websocket.send(json.dumps(response))
|
| 104 |
-
return True
|
| 105 |
-
return False
|
| 106 |
-
|
| 107 |
-
def is_client_timeout(self, websocket):
|
| 108 |
-
"""
|
| 109 |
-
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
|
| 110 |
-
|
| 111 |
-
Args:
|
| 112 |
-
websocket: The websocket associated with the client to check.
|
| 113 |
-
|
| 114 |
-
Returns:
|
| 115 |
-
True if the client's connection time has exceeded the maximum limit, False otherwise.
|
| 116 |
-
"""
|
| 117 |
-
elapsed_time = time.time() - self.start_times[websocket]
|
| 118 |
-
if elapsed_time >= self.max_connection_time:
|
| 119 |
-
self.clients[websocket].disconnect()
|
| 120 |
-
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
|
| 121 |
-
return True
|
| 122 |
-
return False
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class BackendType(Enum):
|
| 126 |
-
PYWHISPERCPP = "pywhispercpp"
|
| 127 |
-
|
| 128 |
-
@staticmethod
|
| 129 |
-
def valid_types() -> List[str]:
|
| 130 |
-
return [backend_type.value for backend_type in BackendType]
|
| 131 |
-
|
| 132 |
-
@staticmethod
|
| 133 |
-
def is_valid(backend: str) -> bool:
|
| 134 |
-
return backend in BackendType.valid_types()
|
| 135 |
-
|
| 136 |
-
def is_pywhispercpp(self) -> bool:
|
| 137 |
-
return self == BackendType.PYWHISPERCPP
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class TranscriptionServer:
|
| 141 |
-
RATE = 16000
|
| 142 |
-
|
| 143 |
-
def __init__(self):
|
| 144 |
-
self.client_manager = None
|
| 145 |
-
self.no_voice_activity_chunks = 0
|
| 146 |
-
self.single_model = False
|
| 147 |
-
|
| 148 |
-
def initialize_client(
|
| 149 |
-
self, websocket, options
|
| 150 |
-
):
|
| 151 |
-
client: Optional[ServeClientBase] = None
|
| 152 |
-
|
| 153 |
-
if self.backend.is_pywhispercpp():
|
| 154 |
-
client = PyWhiperCppServe(
|
| 155 |
-
websocket,
|
| 156 |
-
language=options["language"],
|
| 157 |
-
client_uid=options["uid"],
|
| 158 |
-
)
|
| 159 |
-
logging.info("Running pywhispercpp backend.")
|
| 160 |
-
|
| 161 |
-
if client is None:
|
| 162 |
-
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
|
| 163 |
-
|
| 164 |
-
self.client_manager.add_client(websocket, client)
|
| 165 |
-
|
| 166 |
-
def get_audio_from_websocket(self, websocket):
|
| 167 |
-
"""
|
| 168 |
-
Receives audio buffer from websocket and creates a numpy array out of it.
|
| 169 |
-
|
| 170 |
-
Args:
|
| 171 |
-
websocket: The websocket to receive audio from.
|
| 172 |
-
|
| 173 |
-
Returns:
|
| 174 |
-
A numpy array containing the audio.
|
| 175 |
-
"""
|
| 176 |
-
frame_data = websocket.recv()
|
| 177 |
-
if frame_data == b"END_OF_AUDIO":
|
| 178 |
-
return False
|
| 179 |
-
return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0
|
| 180 |
-
# return np.frombuffer(frame_data, dtype=np.float32)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
def handle_new_connection(self, websocket):
|
| 184 |
-
query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query))
|
| 185 |
-
from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
|
| 186 |
-
|
| 187 |
-
try:
|
| 188 |
-
logging.info("New client connected")
|
| 189 |
-
options = websocket.recv()
|
| 190 |
-
try:
|
| 191 |
-
options = json.loads(options)
|
| 192 |
-
except Exception as e:
|
| 193 |
-
options = {"language": from_lang, "uid": str(uuid1())}
|
| 194 |
-
if self.client_manager is None:
|
| 195 |
-
max_clients = options.get('max_clients', 4)
|
| 196 |
-
max_connection_time = options.get('max_connection_time', 600)
|
| 197 |
-
self.client_manager = ClientManager(max_clients, max_connection_time)
|
| 198 |
-
|
| 199 |
-
if self.client_manager.is_server_full(websocket, options):
|
| 200 |
-
websocket.close()
|
| 201 |
-
return False # Indicates that the connection should not continue
|
| 202 |
-
|
| 203 |
-
if self.backend.is_pywhispercpp():
|
| 204 |
-
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
|
| 205 |
-
|
| 206 |
-
self.initialize_client(websocket, options)
|
| 207 |
-
if from_lang and to_lang:
|
| 208 |
-
self.set_lang(websocket, from_lang, to_lang)
|
| 209 |
-
logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 210 |
-
return True
|
| 211 |
-
except json.JSONDecodeError:
|
| 212 |
-
logging.error("Failed to decode JSON from client")
|
| 213 |
-
return False
|
| 214 |
-
except ConnectionClosed:
|
| 215 |
-
logging.info("Connection closed by client")
|
| 216 |
-
return False
|
| 217 |
-
except Exception as e:
|
| 218 |
-
logging.error(f"Error during new connection initialization: {str(e)}")
|
| 219 |
-
return False
|
| 220 |
-
|
| 221 |
-
def process_audio_frames(self, websocket):
|
| 222 |
-
frame_np = self.get_audio_from_websocket(websocket)
|
| 223 |
-
client = self.client_manager.get_client(websocket)
|
| 224 |
-
|
| 225 |
-
# TODO Vad has some problem, it will be blocking process loop
|
| 226 |
-
# if frame_np is False:
|
| 227 |
-
# if self.backend.is_pywhispercpp():
|
| 228 |
-
# client.set_eos(True)
|
| 229 |
-
# return False
|
| 230 |
-
|
| 231 |
-
# if self.backend.is_pywhispercpp():
|
| 232 |
-
# voice_active = self.voice_activity(websocket, frame_np)
|
| 233 |
-
# if voice_active:
|
| 234 |
-
# self.no_voice_activity_chunks = 0
|
| 235 |
-
# client.set_eos(False)
|
| 236 |
-
# if self.use_vad and not voice_active:
|
| 237 |
-
# return True
|
| 238 |
-
|
| 239 |
-
client.add_frames(frame_np)
|
| 240 |
-
return True
|
| 241 |
-
|
| 242 |
-
def set_lang(self, websocket, src_lang, dst_lang):
|
| 243 |
-
client = self.client_manager.get_client(websocket)
|
| 244 |
-
if isinstance(client, PyWhiperCppServe):
|
| 245 |
-
client.set_lang(src_lang, dst_lang)
|
| 246 |
-
|
| 247 |
-
def recv_audio(self,
|
| 248 |
-
websocket,
|
| 249 |
-
backend: BackendType = BackendType.PYWHISPERCPP):
|
| 250 |
-
|
| 251 |
-
self.backend = backend
|
| 252 |
-
if not self.handle_new_connection(websocket):
|
| 253 |
-
return
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
try:
|
| 257 |
-
while not self.client_manager.is_client_timeout(websocket):
|
| 258 |
-
if not self.process_audio_frames(websocket):
|
| 259 |
-
break
|
| 260 |
-
except ConnectionClosed:
|
| 261 |
-
logging.info("Connection closed by client")
|
| 262 |
-
except Exception as e:
|
| 263 |
-
logging.error(f"Unexpected error: {str(e)}")
|
| 264 |
-
finally:
|
| 265 |
-
if self.client_manager.get_client(websocket):
|
| 266 |
-
self.cleanup(websocket)
|
| 267 |
-
websocket.close()
|
| 268 |
-
del websocket
|
| 269 |
-
|
| 270 |
-
def run(self,
|
| 271 |
-
host,
|
| 272 |
-
port=9090,
|
| 273 |
-
backend="pywhispercpp"):
|
| 274 |
-
"""
|
| 275 |
-
Run the transcription server.
|
| 276 |
-
|
| 277 |
-
Args:
|
| 278 |
-
host (str): The host address to bind the server.
|
| 279 |
-
port (int): The port number to bind the server.
|
| 280 |
-
"""
|
| 281 |
-
|
| 282 |
-
if not BackendType.is_valid(backend):
|
| 283 |
-
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
|
| 284 |
-
|
| 285 |
-
with serve(
|
| 286 |
-
functools.partial(
|
| 287 |
-
self.recv_audio,
|
| 288 |
-
backend=BackendType(backend),
|
| 289 |
-
),
|
| 290 |
-
host,
|
| 291 |
-
port
|
| 292 |
-
) as server:
|
| 293 |
-
server.serve_forever()
|
| 294 |
-
|
| 295 |
-
def voice_activity(self, websocket, frame_np):
|
| 296 |
-
"""
|
| 297 |
-
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
|
| 298 |
-
|
| 299 |
-
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
|
| 300 |
-
contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
|
| 301 |
-
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
|
| 302 |
-
speech detection to improve subsequent processing steps.
|
| 303 |
-
|
| 304 |
-
Args:
|
| 305 |
-
websocket: The websocket associated with the current client. Used to retrieve the client object
|
| 306 |
-
from the client manager for state management.
|
| 307 |
-
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
|
| 308 |
-
the audio data for the current frame.
|
| 309 |
-
|
| 310 |
-
Returns:
|
| 311 |
-
bool: True if voice activity is detected in the current frame, False otherwise. When returning False
|
| 312 |
-
after detecting no voice activity for more than three consecutive frames, it also triggers the
|
| 313 |
-
end-of-speech (EOS) flag for the client.
|
| 314 |
-
"""
|
| 315 |
-
if not self.vad_detector(frame_np):
|
| 316 |
-
self.no_voice_activity_chunks += 1
|
| 317 |
-
if self.no_voice_activity_chunks > 3:
|
| 318 |
-
client = self.client_manager.get_client(websocket)
|
| 319 |
-
if not client.eos:
|
| 320 |
-
client.set_eos(True)
|
| 321 |
-
time.sleep(0.1) # Sleep 100m; wait some voice activity.
|
| 322 |
-
return False
|
| 323 |
-
return True
|
| 324 |
-
|
| 325 |
-
def cleanup(self, websocket):
|
| 326 |
-
"""
|
| 327 |
-
Cleans up resources associated with a given client's websocket.
|
| 328 |
-
|
| 329 |
-
Args:
|
| 330 |
-
websocket: The websocket associated with the client to be cleaned up.
|
| 331 |
-
"""
|
| 332 |
-
if self.client_manager.get_client(websocket):
|
| 333 |
-
self.client_manager.remove_client(websocket)
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/translatepipes.py
CHANGED
|
@@ -14,7 +14,6 @@ class TranslatePipes:
|
|
| 14 |
|
| 15 |
# llm 翻译
|
| 16 |
# self._translate_pipe = self._launch_process(TranslatePipe())
|
| 17 |
-
|
| 18 |
self._translate_7b_pipe = self._launch_process(Translate7BPipe())
|
| 19 |
# vad
|
| 20 |
self._vad_pipe = self._launch_process(VadPipe())
|
|
|
|
| 14 |
|
| 15 |
# llm 翻译
|
| 16 |
# self._translate_pipe = self._launch_process(TranslatePipe())
|
|
|
|
| 17 |
self._translate_7b_pipe = self._launch_process(Translate7BPipe())
|
| 18 |
# vad
|
| 19 |
self._vad_pipe = self._launch_process(VadPipe())
|
transcribe/whisper_llm_serve.py
CHANGED
|
@@ -4,7 +4,6 @@ import queue
|
|
| 4 |
import threading
|
| 5 |
import time
|
| 6 |
from logging import getLogger
|
| 7 |
-
from typing import List, Optional, Iterator, Tuple, Any
|
| 8 |
import asyncio
|
| 9 |
import numpy as np
|
| 10 |
import config
|
|
@@ -45,8 +44,6 @@ class WhisperTranscriptionService:
|
|
| 45 |
self.sample_rate = 16000
|
| 46 |
|
| 47 |
self.lock = threading.Lock()
|
| 48 |
-
|
| 49 |
-
|
| 50 |
# 文本分隔符,根据语言设置
|
| 51 |
self.text_separator = self._get_text_separator(language)
|
| 52 |
self.loop = asyncio.get_event_loop()
|
|
@@ -54,7 +51,7 @@ class WhisperTranscriptionService:
|
|
| 54 |
# 原始音频队列
|
| 55 |
self._frame_queue = queue.Queue()
|
| 56 |
# 音频队列缓冲区
|
| 57 |
-
self.frames_np =
|
| 58 |
# 完整音频队列
|
| 59 |
self.segments_queue = collections.deque()
|
| 60 |
self._temp_string = ""
|
|
@@ -100,21 +97,6 @@ class WhisperTranscriptionService:
|
|
| 100 |
"""根据语言返回适当的文本分隔符"""
|
| 101 |
return "" if language == "zh" else " "
|
| 102 |
|
| 103 |
-
async def send_ready_state(self) -> None:
|
| 104 |
-
"""发送服务就绪状态消息"""
|
| 105 |
-
await self.websocket.send(json.dumps({
|
| 106 |
-
"uid": self.client_uid,
|
| 107 |
-
"message": self.SERVER_READY,
|
| 108 |
-
"backend": "whisper_transcription"
|
| 109 |
-
}))
|
| 110 |
-
|
| 111 |
-
def set_language(self, source_lang: str, target_lang: str) -> None:
|
| 112 |
-
"""设置源语言和目标语言"""
|
| 113 |
-
self.source_language = source_lang
|
| 114 |
-
self.target_language = target_lang
|
| 115 |
-
self.text_separator = self._get_text_separator(source_lang)
|
| 116 |
-
# self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator)
|
| 117 |
-
|
| 118 |
def add_frames(self, frame_np: np.ndarray) -> None:
|
| 119 |
"""添加音频帧到处理队列"""
|
| 120 |
self._frame_queue.put(frame_np)
|
|
@@ -135,60 +117,35 @@ class WhisperTranscriptionService:
|
|
| 135 |
if frame_np is None or len(frame_np) == 0:
|
| 136 |
continue
|
| 137 |
with self.lock:
|
| 138 |
-
|
| 139 |
-
self.frames_np = frame_np.copy()
|
| 140 |
-
else:
|
| 141 |
-
self.frames_np = np.append(self.frames_np, frame_np)
|
| 142 |
if speech_status == "END" and len(self.frames_np) > 0:
|
| 143 |
self.segments_queue.appendleft(self.frames_np.copy())
|
| 144 |
self.frames_np = np.array([], dtype=np.float32)
|
| 145 |
except queue.Empty:
|
| 146 |
pass
|
| 147 |
|
| 148 |
-
def _process_transcription_results_2(self, seg_text:str,partial):
|
| 149 |
-
|
| 150 |
-
item = TransResult(
|
| 151 |
-
seg_id=self.row_number,
|
| 152 |
-
context=seg_text,
|
| 153 |
-
from_=self.source_language,
|
| 154 |
-
to=self.target_language,
|
| 155 |
-
tran_content=self._translate_text_large(seg_text),
|
| 156 |
-
partial=partial
|
| 157 |
-
)
|
| 158 |
-
if partial == False:
|
| 159 |
-
self.row_number += 1
|
| 160 |
-
return item
|
| 161 |
-
|
| 162 |
def _transcription_processing_loop(self) -> None:
|
| 163 |
"""主转录处理循环"""
|
| 164 |
frame_epoch = 1
|
| 165 |
while not self._translate_thread_stop.is_set():
|
| 166 |
-
|
| 167 |
-
if self.frames_np is None:
|
| 168 |
-
time.sleep(0.01)
|
| 169 |
-
continue
|
| 170 |
-
|
| 171 |
|
| 172 |
-
if len(self.
|
| 173 |
-
audio_buffer = self.segments_queue.pop()
|
| 174 |
-
partial = False
|
| 175 |
-
else:
|
| 176 |
-
with self.lock:
|
| 177 |
-
audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
|
| 178 |
-
partial = True
|
| 179 |
-
|
| 180 |
-
if len(audio_buffer) ==0:
|
| 181 |
time.sleep(0.01)
|
| 182 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
if len(audio_buffer) < int(self.sample_rate):
|
| 185 |
silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
|
| 186 |
silence_audio[-len(audio_buffer):] = audio_buffer
|
| 187 |
audio_buffer = silence_audio
|
| 188 |
|
| 189 |
-
|
| 190 |
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
|
| 191 |
-
# try:
|
| 192 |
meta_item = self._transcribe_audio(audio_buffer)
|
| 193 |
segments = meta_item.segments
|
| 194 |
logger.debug(f"Segments: {segments}")
|
|
@@ -205,22 +162,24 @@ class WhisperTranscriptionService:
|
|
| 205 |
else:
|
| 206 |
self._temp_string = ""
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
result = self._process_transcription_results_2(seg_text, partial)
|
| 210 |
self._send_result_to_client(result)
|
| 211 |
-
time.sleep(0.1)
|
| 212 |
|
| 213 |
if partial == False:
|
| 214 |
frame_epoch = 1
|
| 215 |
else:
|
| 216 |
frame_epoch += 1
|
| 217 |
-
|
| 218 |
-
# for result in self._process_transcription_results(segments, audio_buffer):
|
| 219 |
-
# self._send_result_to_client(result)
|
| 220 |
-
|
| 221 |
-
# except Exception as e:
|
| 222 |
-
# logger.error(f"Error processing audio: {e}")
|
| 223 |
-
|
| 224 |
|
| 225 |
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
|
| 226 |
"""转录音频并返回转录片段"""
|
|
@@ -270,51 +229,6 @@ class WhisperTranscriptionService:
|
|
| 270 |
return translated_text
|
| 271 |
|
| 272 |
|
| 273 |
-
|
| 274 |
-
def _process_transcription_results(self, segments: List[TranscriptToken], audio_buffer: np.ndarray) -> Iterator[TransResult]:
|
| 275 |
-
"""
|
| 276 |
-
处理转录结果,生成翻译结果
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
TransResult对象的迭代器
|
| 280 |
-
"""
|
| 281 |
-
|
| 282 |
-
if not segments:
|
| 283 |
-
return
|
| 284 |
-
start_time = time.perf_counter()
|
| 285 |
-
for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate):
|
| 286 |
-
if (cut_index :=ana_result.cut_index)>0:
|
| 287 |
-
# 更新音频缓冲区,移除已处理部分
|
| 288 |
-
self._update_audio_buffer(cut_index)
|
| 289 |
-
if ana_result.partial():
|
| 290 |
-
translated_context = self._translate_text(ana_result.context)
|
| 291 |
-
else:
|
| 292 |
-
translated_context = self._translate_text_large(ana_result.context)
|
| 293 |
-
|
| 294 |
-
yield TransResult(
|
| 295 |
-
seg_id=ana_result.seg_id,
|
| 296 |
-
context=ana_result.context,
|
| 297 |
-
from_=self.source_language,
|
| 298 |
-
to=self.target_language,
|
| 299 |
-
tran_content=translated_context,
|
| 300 |
-
partial=ana_result.partial()
|
| 301 |
-
)
|
| 302 |
-
current_time = time.perf_counter()
|
| 303 |
-
time_diff = current_time - start_time
|
| 304 |
-
if config.SAVE_DATA_SAVE:
|
| 305 |
-
self._save_queue.put(DebugResult(
|
| 306 |
-
seg_id=ana_result.seg_id,
|
| 307 |
-
transcrible_time=self._transcrible_time_cost,
|
| 308 |
-
translate_time=self._translate_time_cost,
|
| 309 |
-
context=ana_result.context,
|
| 310 |
-
from_=self.source_language,
|
| 311 |
-
to=self.target_language,
|
| 312 |
-
tran_content=translated_context,
|
| 313 |
-
partial=ana_result.partial()
|
| 314 |
-
))
|
| 315 |
-
log_block("🚦 Traffic times diff", round(time_diff, 2), 's')
|
| 316 |
-
|
| 317 |
-
|
| 318 |
def _send_result_to_client(self, result: TransResult) -> None:
|
| 319 |
"""发送翻译结果到客户端"""
|
| 320 |
try:
|
|
|
|
| 4 |
import threading
|
| 5 |
import time
|
| 6 |
from logging import getLogger
|
|
|
|
| 7 |
import asyncio
|
| 8 |
import numpy as np
|
| 9 |
import config
|
|
|
|
| 44 |
self.sample_rate = 16000
|
| 45 |
|
| 46 |
self.lock = threading.Lock()
|
|
|
|
|
|
|
| 47 |
# 文本分隔符,根据语言设置
|
| 48 |
self.text_separator = self._get_text_separator(language)
|
| 49 |
self.loop = asyncio.get_event_loop()
|
|
|
|
| 51 |
# 原始音频队列
|
| 52 |
self._frame_queue = queue.Queue()
|
| 53 |
# 音频队列缓冲区
|
| 54 |
+
self.frames_np = np.array([], dtype=np.float32)
|
| 55 |
# 完整音频队列
|
| 56 |
self.segments_queue = collections.deque()
|
| 57 |
self._temp_string = ""
|
|
|
|
| 97 |
"""根据语言返回适当的文本分隔符"""
|
| 98 |
return "" if language == "zh" else " "
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def add_frames(self, frame_np: np.ndarray) -> None:
|
| 101 |
"""添加音频帧到处理队列"""
|
| 102 |
self._frame_queue.put(frame_np)
|
|
|
|
| 117 |
if frame_np is None or len(frame_np) == 0:
|
| 118 |
continue
|
| 119 |
with self.lock:
|
| 120 |
+
self.frames_np = np.append(self.frames_np, frame_np)
|
|
|
|
|
|
|
|
|
|
| 121 |
if speech_status == "END" and len(self.frames_np) > 0:
|
| 122 |
self.segments_queue.appendleft(self.frames_np.copy())
|
| 123 |
self.frames_np = np.array([], dtype=np.float32)
|
| 124 |
except queue.Empty:
|
| 125 |
pass
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
def _transcription_processing_loop(self) -> None:
|
| 128 |
"""主转录处理循环"""
|
| 129 |
frame_epoch = 1
|
| 130 |
while not self._translate_thread_stop.is_set():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
if len(self.frames_np) ==0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
time.sleep(0.01)
|
| 134 |
continue
|
| 135 |
+
with self.lock:
|
| 136 |
+
if len(self.segments_queue) >0:
|
| 137 |
+
audio_buffer = self.segments_queue.pop()
|
| 138 |
+
partial = False
|
| 139 |
+
else:
|
| 140 |
+
audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
|
| 141 |
+
partial = True
|
| 142 |
|
| 143 |
if len(audio_buffer) < int(self.sample_rate):
|
| 144 |
silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
|
| 145 |
silence_audio[-len(audio_buffer):] = audio_buffer
|
| 146 |
audio_buffer = silence_audio
|
| 147 |
|
|
|
|
| 148 |
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
|
|
|
|
| 149 |
meta_item = self._transcribe_audio(audio_buffer)
|
| 150 |
segments = meta_item.segments
|
| 151 |
logger.debug(f"Segments: {segments}")
|
|
|
|
| 162 |
else:
|
| 163 |
self._temp_string = ""
|
| 164 |
|
| 165 |
+
result = TransResult(
|
| 166 |
+
seg_id=self.row_number,
|
| 167 |
+
context=seg_text,
|
| 168 |
+
from_=self.source_language,
|
| 169 |
+
to=self.target_language,
|
| 170 |
+
tran_content=self._translate_text_large(seg_text),
|
| 171 |
+
partial=partial
|
| 172 |
+
)
|
| 173 |
+
if partial == False:
|
| 174 |
+
self.row_number += 1
|
| 175 |
|
|
|
|
| 176 |
self._send_result_to_client(result)
|
|
|
|
| 177 |
|
| 178 |
if partial == False:
|
| 179 |
frame_epoch = 1
|
| 180 |
else:
|
| 181 |
frame_epoch += 1
|
| 182 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
|
| 185 |
"""转录音频并返回转录片段"""
|
|
|
|
| 229 |
return translated_text
|
| 230 |
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
def _send_result_to_client(self, result: TransResult) -> None:
|
| 233 |
"""发送翻译结果到客户端"""
|
| 234 |
try:
|