import collections import logging from difflib import SequenceMatcher from itertools import chain from dataclasses import dataclass, field from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE from enum import Enum import wordninja import config import re logger = logging.getLogger("TranscriptionStrategy") class SplitMode(Enum): PUNCTUATION = "punctuation" PAUSE = "pause" END = "end" @dataclass class TranscriptResult: seg_id: int = 0 cut_index: int = 0 is_end_sentence: bool = False context: str = "" def partial(self): return not self.is_end_sentence @dataclass class TranscriptToken: """表示一个转录片段,包含文本和时间信息""" text: str # 转录的文本内容 t0: int # 开始时间(百分之一秒) t1: int # 结束时间(百分之一秒) def is_punctuation(self): """检查文本是否包含标点符号""" return REGEX_MARKERS.search(self.text.strip()) is not None def is_end(self): """检查文本是否为句子结束标记""" return SENTENCE_END_PATTERN.search(self.text.strip()) is not None def is_pause(self): """检查文本是否为暂停标记""" return PAUSEE_END_PATTERN.search(self.text.strip()) is not None def buffer_index(self) -> int: return int(self.t1 / 100 * SAMPLE_RATE) @dataclass class TranscriptChunk: """表示一组转录片段,支持分割和比较操作""" separator: str = "" # 用于连接片段的分隔符 items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表 @staticmethod def _calculate_similarity(text1: str, text2: str) -> float: """计算两段文本的相似度""" return SequenceMatcher(None, text1, text2).ratio() def split_by(self, mode: SplitMode) -> list['TranscriptChunk']: """根据文本中的标点符号分割片段列表""" if mode == SplitMode.PUNCTUATION: indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()] elif mode == SplitMode.PAUSE: indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()] elif mode == SplitMode.END: indexes = [i for i, seg in enumerate(self.items) if seg.is_end()] else: raise ValueError(f"Unsupported mode: {mode}") # 每个切分点向后移一个索引,表示“分隔符归前段” cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)] chunks = [ TranscriptChunk(items=self.items[start:end], separator=self.separator) for start, end in zip(cut_points, cut_points[1:]) ] return [ ck for ck in chunks if not ck.only_punctuation() ] def get_split_first_rest(self, mode: SplitMode): chunks = self.split_by(mode) fisrt_chunk = chunks[0] if chunks else self rest_chunks = chunks[1:] if chunks else None return fisrt_chunk, rest_chunks def puncation_numbers(self) -> int: """计算片段中标点符号的数量""" return sum(1 for seg in self.items if seg.is_punctuation()) def length(self) -> int: """返回片段列表的长度""" return len(self.items) def join(self) -> str: """将片段连接为一个字符串""" return self.separator.join(seg.text for seg in self.items) def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float: """比较当前片段与另一个片段的相似度""" if not chunk: return 0 score = self._calculate_similarity(self.join(), chunk.join()) # logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}") return score def only_punctuation(self)->bool: return all(seg.is_punctuation() for seg in self.items) def has_punctuation(self) -> bool: return any(seg.is_punctuation() for seg in self.items) def get_buffer_index(self) -> int: return self.items[-1].buffer_index() def is_end_sentence(self) ->bool: return self.items[-1].is_end() class TranscriptHistory: """管理转录片段的历史记录""" def __init__(self) -> None: self.history = collections.deque(maxlen=2) # 存储最近的两个片段 def add(self, chunk: TranscriptChunk): """添加新的片段到历史记录""" self.history.appendleft(chunk) def previous_chunk(self) -> Optional[TranscriptChunk]: """获取上一个片段(如果存在)""" return self.history[1] if len(self.history) == 2 else None def lastest_chunk(self): """获取最后一个片段""" return self.history[-1] def clear(self): self.history.clear() class TranscriptBuffer: """ 管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落 |-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --| 管理 pending -> line -> paragraph 的缓冲逻辑 """ def __init__(self, source_lang:str, separator:str): self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落 self._sentences: List[str] = collections.deque() # 当前段落中的短句 self._buffer: str = "" # 当前缓冲中的文本 self._current_seg_id: int = 0 self.source_language = source_lang self._separator = separator def get_seg_id(self) -> int: return self._current_seg_id @property def current_sentences_length(self) -> int: count = 0 for item in self._sentences: if self._separator: count += len(item.split(self._separator)) else: count += len(item) return count def update_pending_text(self, text: str) -> None: """更新临时缓冲字符串""" self._buffer = text def commit_line(self,) -> None: """将缓冲字符串提交为短句""" if self._buffer: self._sentences.append(self._buffer) self._buffer = "" def commit_paragraph(self) -> None: """ 提交当前短句为完整段落(如句子结束) Args: end_of_sentence: 是否为句子结尾(如检测到句号) """ count = 0 current_sentences = [] while len(self._sentences): # and count < 20: item = self._sentences.popleft() current_sentences.append(item) if self._separator: count += len(item.split(self._separator)) else: count += len(item) if current_sentences: self._segments.append("".join(current_sentences)) logger.debug(f"=== count to paragraph ===") logger.debug(f"push: {current_sentences}") logger.debug(f"rest: {self._sentences}") # if self._sentences: # self._segments.append("".join(self._sentences)) # self._sentences.clear() def rebuild(self, text): output = self.split_and_join( text.replace( self._separator, "")) logger.debug("==== rebuild string ====") logger.debug(text) logger.debug(output) return output @staticmethod def split_and_join(text): tokens = [] word_buf = '' for char in text: if char in ALL_MARKERS: if word_buf: tokens.extend(wordninja.split(word_buf)) word_buf = '' tokens.append(char) else: word_buf += char if word_buf: tokens.extend(wordninja.split(word_buf)) output = '' for i, token in enumerate(tokens): if i == 0: output += token elif token in ALL_MARKERS: output += (token + " ") else: output += ' ' + token return output def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False): if self.source_language == "en": stable_strings = [self.rebuild(i) for i in stable_strings] remaining_strings =[self.rebuild(i) for i in remaining_strings] remaining_string = "".join(remaining_strings) logger.debug(f"{self.__dict__}") if is_end_sentence: for stable_str in stable_strings: self.update_pending_text(stable_str) self.commit_line() current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text) # current_text_len = len(self.current_not_commit_text.split(self._separator)) self.update_pending_text(remaining_string) if current_text_len >= config.TEXT_THREHOLD: self.commit_paragraph() self._current_seg_id += 1 return True else: for stable_str in stable_strings: self.update_pending_text(stable_str) self.commit_line() self.update_pending_text(remaining_string) return False @property def un_commit_paragraph(self) -> str: """当前短句组合""" return "".join([i for i in self._sentences]) @property def pending_text(self) -> str: """当前缓冲内容""" return self._buffer @property def latest_paragraph(self) -> str: """最新确认的段落""" return self._segments[-1] if self._segments else "" @property def current_not_commit_text(self) -> str: return self.un_commit_paragraph + self.pending_text class TranscriptStabilityAnalyzer: def __init__(self, source_lang, separator) -> None: self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator) self._transcript_history = TranscriptHistory() self._separator = separator logger.debug(f"Current separator: {self._separator}") def merge_chunks(self, chunks: List[TranscriptChunk])->str: if not chunks: return [""] output = list(r.join() for r in chunks if r) return output def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]: current = TranscriptChunk(items=current, separator=self._separator) self._transcript_history.add(current) prev = self._transcript_history.previous_chunk() self._transcript_buffer.update_pending_text(current.join()) if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行 yield TranscriptResult( context=self._transcript_buffer.current_not_commit_text, seg_id=self._transcript_buffer.get_seg_id() ) return # yield from self._handle_short_buffer(current, prev) if buffer_duration <= 4: yield from self._handle_short_buffer(current, prev) else: yield from self._handle_long_buffer(current) def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]: curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION) prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION) # logger.debug("==== Current cut item ====") # logger.debug(f"{curr.join()} ") # logger.debug(f"{prev.join()}") # logger.debug("==========================") if curr_first and prev_first: core = curr_first.compare(prev_first) has_punctuation = curr_first.has_punctuation() if core >= 0.8 and has_punctuation: yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence()) return yield TranscriptResult( seg_id=self._transcript_buffer.get_seg_id(), context=self._transcript_buffer.current_not_commit_text ) def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]: chunks = curr.split_by(SplitMode.PUNCTUATION) if len(chunks) > 1: stable, remaining = chunks[:-1], chunks[-1:] # stable_str = self.merge_chunks(stable) # remaining_str = self.merge_chunks(remaining) yield from self._yield_commit_results( stable, remaining, is_end_sentence=True # 暂时硬编码为True ) else: yield TranscriptResult( seg_id=self._transcript_buffer.get_seg_id(), context=self._transcript_buffer.current_not_commit_text ) def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]: stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk) remaining_str_list = self.merge_chunks(remaining_chunks) frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index() prev_seg_id = self._transcript_buffer.get_seg_id() commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence) logger.debug(f"current buffer: {self._transcript_buffer.__dict__}") if commit_paragraph: # 表示生成了一个新段落 换行 yield TranscriptResult( seg_id=prev_seg_id, cut_index=frame_cut_index, context=self._transcript_buffer.latest_paragraph, is_end_sentence=True ) if (context := self._transcript_buffer.current_not_commit_text.strip()): yield TranscriptResult( seg_id=self._transcript_buffer.get_seg_id(), context=context, ) else: yield TranscriptResult( seg_id=self._transcript_buffer.get_seg_id(), cut_index=frame_cut_index, context=self._transcript_buffer.current_not_commit_text, )