|
|
import logging |
|
|
import time |
|
|
import functools |
|
|
import json |
|
|
import logging |
|
|
import time |
|
|
from enum import Enum |
|
|
from typing import List, Optional |
|
|
import numpy as np |
|
|
from .server import ServeClientBase |
|
|
from .whisper_llm_serve import PyWhiperCppServe |
|
|
from .vad import VoiceActivityDetector |
|
|
from urllib.parse import urlparse, parse_qsl |
|
|
from websockets.exceptions import ConnectionClosed |
|
|
from websockets.sync.server import serve |
|
|
from uuid import uuid1 |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
class ClientManager: |
|
|
def __init__(self, max_clients=4, max_connection_time=600): |
|
|
""" |
|
|
Initializes the ClientManager with specified limits on client connections and connection durations. |
|
|
|
|
|
Args: |
|
|
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4. |
|
|
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults |
|
|
to 600 seconds (10 minutes). |
|
|
""" |
|
|
self.clients = {} |
|
|
self.start_times = {} |
|
|
self.max_clients = max_clients |
|
|
self.max_connection_time = max_connection_time |
|
|
|
|
|
def add_client(self, websocket, client): |
|
|
""" |
|
|
Adds a client and their connection start time to the tracking dictionaries. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket associated with the client to add. |
|
|
client: The client object to be added and tracked. |
|
|
""" |
|
|
self.clients[websocket] = client |
|
|
self.start_times[websocket] = time.time() |
|
|
|
|
|
def get_client(self, websocket): |
|
|
""" |
|
|
Retrieves a client associated with the given websocket. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket associated with the client to retrieve. |
|
|
|
|
|
Returns: |
|
|
The client object if found, False otherwise. |
|
|
""" |
|
|
if websocket in self.clients: |
|
|
return self.clients[websocket] |
|
|
return False |
|
|
|
|
|
def remove_client(self, websocket): |
|
|
""" |
|
|
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the |
|
|
client if necessary. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket associated with the client to be removed. |
|
|
""" |
|
|
client = self.clients.pop(websocket, None) |
|
|
if client: |
|
|
client.cleanup() |
|
|
self.start_times.pop(websocket, None) |
|
|
|
|
|
def get_wait_time(self): |
|
|
""" |
|
|
Calculates the estimated wait time for new clients based on the remaining connection times of current clients. |
|
|
|
|
|
Returns: |
|
|
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots. |
|
|
""" |
|
|
wait_time = None |
|
|
for start_time in self.start_times.values(): |
|
|
current_client_time_remaining = self.max_connection_time - (time.time() - start_time) |
|
|
if wait_time is None or current_client_time_remaining < wait_time: |
|
|
wait_time = current_client_time_remaining |
|
|
return wait_time / 60 if wait_time is not None else 0 |
|
|
|
|
|
def is_server_full(self, websocket, options): |
|
|
""" |
|
|
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket of the client attempting to connect. |
|
|
options: A dictionary of options that may include the client's unique identifier. |
|
|
|
|
|
Returns: |
|
|
True if the server is full, False otherwise. |
|
|
""" |
|
|
if len(self.clients) >= self.max_clients: |
|
|
wait_time = self.get_wait_time() |
|
|
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time} |
|
|
websocket.send(json.dumps(response)) |
|
|
return True |
|
|
return False |
|
|
|
|
|
def is_client_timeout(self, websocket): |
|
|
""" |
|
|
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket associated with the client to check. |
|
|
|
|
|
Returns: |
|
|
True if the client's connection time has exceeded the maximum limit, False otherwise. |
|
|
""" |
|
|
elapsed_time = time.time() - self.start_times[websocket] |
|
|
if elapsed_time >= self.max_connection_time: |
|
|
self.clients[websocket].disconnect() |
|
|
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.") |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
class BackendType(Enum): |
|
|
PYWHISPERCPP = "pywhispercpp" |
|
|
|
|
|
@staticmethod |
|
|
def valid_types() -> List[str]: |
|
|
return [backend_type.value for backend_type in BackendType] |
|
|
|
|
|
@staticmethod |
|
|
def is_valid(backend: str) -> bool: |
|
|
return backend in BackendType.valid_types() |
|
|
|
|
|
def is_pywhispercpp(self) -> bool: |
|
|
return self == BackendType.PYWHISPERCPP |
|
|
|
|
|
|
|
|
class TranscriptionServer: |
|
|
RATE = 16000 |
|
|
|
|
|
def __init__(self): |
|
|
self.client_manager = None |
|
|
self.no_voice_activity_chunks = 0 |
|
|
self.single_model = False |
|
|
|
|
|
def initialize_client( |
|
|
self, websocket, options |
|
|
): |
|
|
client: Optional[ServeClientBase] = None |
|
|
|
|
|
if self.backend.is_pywhispercpp(): |
|
|
client = PyWhiperCppServe( |
|
|
websocket, |
|
|
language=options["language"], |
|
|
client_uid=options["uid"], |
|
|
) |
|
|
logging.info("Running pywhispercpp backend.") |
|
|
|
|
|
if client is None: |
|
|
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.") |
|
|
|
|
|
self.client_manager.add_client(websocket, client) |
|
|
|
|
|
def get_audio_from_websocket(self, websocket): |
|
|
""" |
|
|
Receives audio buffer from websocket and creates a numpy array out of it. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket to receive audio from. |
|
|
|
|
|
Returns: |
|
|
A numpy array containing the audio. |
|
|
""" |
|
|
frame_data = websocket.recv() |
|
|
if frame_data == b"END_OF_AUDIO": |
|
|
return False |
|
|
return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0 |
|
|
|
|
|
|
|
|
|
|
|
def handle_new_connection(self, websocket): |
|
|
query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query)) |
|
|
from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to') |
|
|
|
|
|
try: |
|
|
logging.info("New client connected") |
|
|
options = websocket.recv() |
|
|
try: |
|
|
options = json.loads(options) |
|
|
except Exception as e: |
|
|
options = {"language": from_lang, "uid": str(uuid1())} |
|
|
if self.client_manager is None: |
|
|
max_clients = options.get('max_clients', 4) |
|
|
max_connection_time = options.get('max_connection_time', 600) |
|
|
self.client_manager = ClientManager(max_clients, max_connection_time) |
|
|
|
|
|
if self.client_manager.is_server_full(websocket, options): |
|
|
websocket.close() |
|
|
return False |
|
|
|
|
|
if self.backend.is_pywhispercpp(): |
|
|
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE) |
|
|
|
|
|
self.initialize_client(websocket, options) |
|
|
if from_lang and to_lang: |
|
|
self.set_lang(websocket, from_lang, to_lang) |
|
|
logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}") |
|
|
return True |
|
|
except json.JSONDecodeError: |
|
|
logging.error("Failed to decode JSON from client") |
|
|
return False |
|
|
except ConnectionClosed: |
|
|
logging.info("Connection closed by client") |
|
|
return False |
|
|
except Exception as e: |
|
|
logging.error(f"Error during new connection initialization: {str(e)}") |
|
|
return False |
|
|
|
|
|
def process_audio_frames(self, websocket): |
|
|
frame_np = self.get_audio_from_websocket(websocket) |
|
|
client = self.client_manager.get_client(websocket) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client.add_frames(frame_np) |
|
|
return True |
|
|
|
|
|
def set_lang(self, websocket, src_lang, dst_lang): |
|
|
client = self.client_manager.get_client(websocket) |
|
|
if isinstance(client, PyWhiperCppServe): |
|
|
client.set_lang(src_lang, dst_lang) |
|
|
|
|
|
def recv_audio(self, |
|
|
websocket, |
|
|
backend: BackendType = BackendType.PYWHISPERCPP): |
|
|
|
|
|
self.backend = backend |
|
|
if not self.handle_new_connection(websocket): |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
while not self.client_manager.is_client_timeout(websocket): |
|
|
if not self.process_audio_frames(websocket): |
|
|
break |
|
|
except ConnectionClosed: |
|
|
logging.info("Connection closed by client") |
|
|
except Exception as e: |
|
|
logging.error(f"Unexpected error: {str(e)}") |
|
|
finally: |
|
|
if self.client_manager.get_client(websocket): |
|
|
self.cleanup(websocket) |
|
|
websocket.close() |
|
|
del websocket |
|
|
|
|
|
def run(self, |
|
|
host, |
|
|
port=9090, |
|
|
backend="pywhispercpp"): |
|
|
""" |
|
|
Run the transcription server. |
|
|
|
|
|
Args: |
|
|
host (str): The host address to bind the server. |
|
|
port (int): The port number to bind the server. |
|
|
""" |
|
|
|
|
|
if not BackendType.is_valid(backend): |
|
|
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}") |
|
|
|
|
|
with serve( |
|
|
functools.partial( |
|
|
self.recv_audio, |
|
|
backend=BackendType(backend), |
|
|
), |
|
|
host, |
|
|
port |
|
|
) as server: |
|
|
server.serve_forever() |
|
|
|
|
|
def voice_activity(self, websocket, frame_np): |
|
|
""" |
|
|
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection. |
|
|
|
|
|
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame |
|
|
contains speech. If the VAD model detects no voice activity for more than three consecutive frames, |
|
|
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage |
|
|
speech detection to improve subsequent processing steps. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket associated with the current client. Used to retrieve the client object |
|
|
from the client manager for state management. |
|
|
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing |
|
|
the audio data for the current frame. |
|
|
|
|
|
Returns: |
|
|
bool: True if voice activity is detected in the current frame, False otherwise. When returning False |
|
|
after detecting no voice activity for more than three consecutive frames, it also triggers the |
|
|
end-of-speech (EOS) flag for the client. |
|
|
""" |
|
|
if not self.vad_detector(frame_np): |
|
|
self.no_voice_activity_chunks += 1 |
|
|
if self.no_voice_activity_chunks > 3: |
|
|
client = self.client_manager.get_client(websocket) |
|
|
if not client.eos: |
|
|
client.set_eos(True) |
|
|
time.sleep(0.1) |
|
|
return False |
|
|
return True |
|
|
|
|
|
def cleanup(self, websocket): |
|
|
""" |
|
|
Cleans up resources associated with a given client's websocket. |
|
|
|
|
|
Args: |
|
|
websocket: The websocket associated with the client to be cleaned up. |
|
|
""" |
|
|
if self.client_manager.get_client(websocket): |
|
|
self.client_manager.remove_client(websocket) |
|
|
|
|
|
|