|
|
import asyncio |
|
|
import json |
|
|
import queue |
|
|
import threading |
|
|
import time |
|
|
from logging import getLogger |
|
|
from typing import List, Optional, Iterator, Tuple, Any |
|
|
import asyncio |
|
|
import numpy as np |
|
|
import config |
|
|
|
|
|
from api_model import TransResult, Message, DebugResult |
|
|
|
|
|
from .utils import log_block, save_to_wave, TestDataWriter |
|
|
from .translatepipes import TranslatePipes |
|
|
from .strategy import ( |
|
|
TranscriptStabilityAnalyzer, TranscriptToken) |
|
|
from transcribe.helpers.vadprocessor import VadProcessor |
|
|
from transcribe.pipelines import MetaItem |
|
|
|
|
|
logger = getLogger("TranscriptionService") |
|
|
|
|
|
|
|
|
class WhisperTranscriptionService: |
|
|
""" |
|
|
Whisper语音转录服务类,处理音频流转录和翻译 |
|
|
""" |
|
|
|
|
|
SERVER_READY = "SERVER_READY" |
|
|
DISCONNECT = "DISCONNECT" |
|
|
|
|
|
def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None): |
|
|
print('>>>>>>>>>>>>>>>> init service >>>>>>>>>>>>>>>>>>>>>>') |
|
|
print('src_lang:', language) |
|
|
self.source_language = language |
|
|
self.target_language = dst_lang |
|
|
self.client_uid = client_uid |
|
|
|
|
|
self.websocket = websocket |
|
|
self._translate_pipe = pipe |
|
|
|
|
|
|
|
|
self.sample_rate = 16000 |
|
|
|
|
|
self.lock = threading.Lock() |
|
|
self._frame_queue = queue.Queue() |
|
|
self._vad_frame_queue = queue.Queue() |
|
|
|
|
|
|
|
|
self.text_separator = self._get_text_separator(language) |
|
|
self.loop = asyncio.get_event_loop() |
|
|
|
|
|
|
|
|
self._transcrible_analysis = None |
|
|
|
|
|
self._translate_thread_stop = threading.Event() |
|
|
self._frame_processing_thread_stop = threading.Event() |
|
|
|
|
|
self.translate_thread = self._start_thread(self._transcription_processing_loop) |
|
|
self.frame_processing_thread = self._start_thread(self._frame_processing_loop) |
|
|
if language == "zh": |
|
|
self._vad = VadProcessor(prob_threshold=0.8, silence_s=0.2, cache_s=0.15) |
|
|
else: |
|
|
self._vad = VadProcessor(prob_threshold=0.7, silence_s=0.2, cache_s=0.15) |
|
|
self.row_number = 0 |
|
|
|
|
|
self._transcrible_time_cost = 0. |
|
|
self._translate_time_cost = 0. |
|
|
if config.TEST: |
|
|
self._test_task_stop = threading.Event() |
|
|
self._test_queue = queue.Queue() |
|
|
self._test_thread = self._start_thread(self.test_data_loop) |
|
|
|
|
|
|
|
|
|
|
|
def test_data_loop(self): |
|
|
writer = TestDataWriter() |
|
|
while not self._test_task_stop.is_set(): |
|
|
test_data = self._test_queue.get() |
|
|
writer.write(test_data) |
|
|
|
|
|
|
|
|
def _start_thread(self, target_function) -> threading.Thread: |
|
|
"""启动守护线程执行指定函数""" |
|
|
thread = threading.Thread(target=target_function) |
|
|
thread.daemon = True |
|
|
thread.start() |
|
|
return thread |
|
|
|
|
|
def _get_text_separator(self, language: str) -> str: |
|
|
"""根据语言返回适当的文本分隔符""" |
|
|
return "" if language == "zh" else " " |
|
|
|
|
|
async def send_ready_state(self) -> None: |
|
|
"""发送服务就绪状态消息""" |
|
|
await self.websocket.send(json.dumps({ |
|
|
"uid": self.client_uid, |
|
|
"message": self.SERVER_READY, |
|
|
"backend": "whisper_transcription" |
|
|
})) |
|
|
|
|
|
def set_language(self, source_lang: str, target_lang: str) -> None: |
|
|
"""设置源语言和目标语言""" |
|
|
self.source_language = source_lang |
|
|
self.target_language = target_lang |
|
|
self.text_separator = self._get_text_separator(source_lang) |
|
|
|
|
|
|
|
|
def add_frames(self, frame_np: np.ndarray) -> None: |
|
|
"""添加音频帧到处理队列""" |
|
|
self._frame_queue.put(frame_np) |
|
|
|
|
|
def _frame_processing_loop(self) -> None: |
|
|
"""从队列获取音频帧并合并到缓冲区""" |
|
|
while not self._frame_processing_thread_stop.is_set(): |
|
|
try: |
|
|
audio = self._frame_queue.get(timeout=0.1) |
|
|
|
|
|
processed_audio = self._vad.process_audio(audio) |
|
|
if processed_audio.shape[0] > 0: |
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Vad frame: {processed_audio.shape[0]/self.sample_rate:.2f}") |
|
|
|
|
|
self._vad_frame_queue.put(processed_audio) |
|
|
except queue.Empty: |
|
|
pass |
|
|
|
|
|
|
|
|
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem: |
|
|
"""转录音频并返回转录片段""" |
|
|
log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s") |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
result = self._translate_pipe.transcrible(audio_buffer.tobytes(), self.source_language) |
|
|
segments = result.segments |
|
|
time_diff = (time.perf_counter() - start_time) |
|
|
logger.debug(f"📝 Transcrible Segments: {segments} ") |
|
|
|
|
|
log_block("📝 Transcrible output", f"{self.text_separator.join(seg.text for seg in segments)}", "") |
|
|
log_block("📝 Transcrible time", f"{time_diff:.3f}", "s") |
|
|
self._transcrible_time_cost = round(time_diff, 3) |
|
|
return result |
|
|
|
|
|
def _translate_text(self, text: str) -> str: |
|
|
"""将文本翻译为目标语言""" |
|
|
if not text.strip(): |
|
|
return "" |
|
|
|
|
|
log_block("🐧 Translation input ", f"{text}") |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
result = self._translate_pipe.translate(text, self.source_language, self.target_language) |
|
|
translated_text = result.translate_content |
|
|
time_diff = (time.perf_counter() - start_time) |
|
|
log_block("🐧 Translation time ", f"{time_diff:.3f}", "s") |
|
|
log_block("🐧 Translation out ", f"{translated_text}") |
|
|
self._translate_time_cost = round(time_diff, 3) |
|
|
return translated_text |
|
|
|
|
|
def _translate_text_large(self, text: str) -> str: |
|
|
"""将文本翻译为目标语言""" |
|
|
if not text.strip(): |
|
|
return "" |
|
|
|
|
|
log_block("Translation input", f"{text}") |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
result = self._translate_pipe.translate_large(text, self.source_language, self.target_language) |
|
|
translated_text = result.translate_content |
|
|
time_diff = (time.perf_counter() - start_time) |
|
|
log_block("Translation large model time ", f"{time_diff:.3f}", "s") |
|
|
log_block("Translation large model output", f"{translated_text}") |
|
|
self._translate_time_cost = round(time_diff, 3) |
|
|
return translated_text |
|
|
|
|
|
def _transcription_processing_loop(self) -> None: |
|
|
"""主转录处理循环""" |
|
|
|
|
|
while not self._translate_thread_stop.is_set(): |
|
|
audio_buffer = self._vad_frame_queue.get() |
|
|
if audio_buffer is None: |
|
|
time.sleep(0.2) |
|
|
continue |
|
|
if len(audio_buffer) < int(self.sample_rate): |
|
|
silence_audio = np.zeros(self.sample_rate, dtype=np.float32) |
|
|
silence_audio[-len(audio_buffer):] = audio_buffer |
|
|
audio_buffer = silence_audio |
|
|
|
|
|
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s") |
|
|
|
|
|
meta_item = self._transcribe_audio(audio_buffer) |
|
|
segments = meta_item.segments |
|
|
logger.debug(f"Segments: {segments}") |
|
|
if len(segments): |
|
|
result = self._process_transcription_results_2(segments) |
|
|
self._send_result_to_client(result) |
|
|
time.sleep(0.1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_transcription_results_2(self, segments: List[TranscriptToken],): |
|
|
seg_text = self.text_separator.join(seg.text for seg in segments) |
|
|
item = TransResult( |
|
|
seg_id=self.row_number, |
|
|
context=seg_text, |
|
|
from_=self.source_language, |
|
|
to=self.target_language, |
|
|
tran_content=self._translate_text_large(seg_text), |
|
|
partial=False |
|
|
) |
|
|
self.row_number += 1 |
|
|
return item |
|
|
|
|
|
def _process_transcription_results(self, segments: List[TranscriptToken], audio_buffer: np.ndarray) -> Iterator[TransResult]: |
|
|
""" |
|
|
处理转录结果,生成翻译结果 |
|
|
|
|
|
Returns: |
|
|
TransResult对象的迭代器 |
|
|
""" |
|
|
|
|
|
if not segments: |
|
|
return |
|
|
start_time = time.perf_counter() |
|
|
for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate): |
|
|
if (cut_index :=ana_result.cut_index)>0: |
|
|
|
|
|
self._update_audio_buffer(cut_index) |
|
|
if ana_result.partial(): |
|
|
translated_context = self._translate_text(ana_result.context) |
|
|
else: |
|
|
translated_context = self._translate_text_large(ana_result.context) |
|
|
|
|
|
yield TransResult( |
|
|
seg_id=ana_result.seg_id, |
|
|
context=ana_result.context, |
|
|
from_=self.source_language, |
|
|
to=self.target_language, |
|
|
tran_content=translated_context, |
|
|
partial=ana_result.partial() |
|
|
) |
|
|
current_time = time.perf_counter() |
|
|
time_diff = current_time - start_time |
|
|
if config.TEST: |
|
|
self._test_queue.put(DebugResult( |
|
|
seg_id=ana_result.seg_id, |
|
|
transcrible_time=self._transcrible_time_cost, |
|
|
translate_time=self._translate_time_cost, |
|
|
context=ana_result.context, |
|
|
from_=self.source_language, |
|
|
to=self.target_language, |
|
|
tran_content=translated_context, |
|
|
partial=ana_result.partial() |
|
|
)) |
|
|
log_block("🚦 Traffic times diff", round(time_diff, 2), 's') |
|
|
|
|
|
|
|
|
def _send_result_to_client(self, result: TransResult) -> None: |
|
|
"""发送翻译结果到客户端""" |
|
|
try: |
|
|
message = Message(result=result, request_id=self.client_uid).model_dump_json(by_alias=True) |
|
|
coro = self.websocket.send_text(message) |
|
|
future = asyncio.run_coroutine_threadsafe(coro, self.loop) |
|
|
future.add_done_callback(lambda fut: fut.exception() and self.stop()) |
|
|
except RuntimeError: |
|
|
self.stop() |
|
|
except Exception as e: |
|
|
logger.error(f"Error sending result to client: {e}") |
|
|
|
|
|
def stop(self) -> None: |
|
|
"""停止所有处理线程并清理资源""" |
|
|
self._translate_thread_stop.set() |
|
|
self._frame_processing_thread_stop.set() |
|
|
if config.TEST: |
|
|
self._test_task_stop.set() |
|
|
logger.info(f"Stopping transcription service for client: {self.client_uid}") |
|
|
|