Spaces:
Running
Running
File size: 25,204 Bytes
8289369 924aa01 8289369 924aa01 8289369 924aa01 8289369 924aa01 8289369 924aa01 8289369 924aa01 8289369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 |
"""
整合ASR和说话人分离的转录器模块,支持流式处理长语音对话
"""
import os
from pydub import AudioSegment
from typing import Dict, List, Union, Optional, Any
import logging
from concurrent.futures import ThreadPoolExecutor
import re
from .summary.speaker_identify import SpeakerIdentifier # 新增导入
# 导入ASR和说话人分离模块,使用相对导入
from .asr import asr_router
from .asr.asr_base import TranscriptionResult
from .diarization import diarizer_router
from .schemas import EnhancedSegment, CombinedTranscriptionResult, PodcastChannel, PodcastEpisode, DiarizationResult
# 配置日志
logger = logging.getLogger("podcast_transcribe")
class CombinedTranscriber:
"""整合ASR和说话人分离的转录器"""
def __init__(
self,
asr_model_name: str,
asr_provider: str,
diarization_provider: str,
diarization_model_name: str,
llm_model_name: str,
llm_provider: str,
device: Optional[str] = None,
segmentation_batch_size: int = 64,
parallel: bool = False,
):
"""
初始化转录器
参数:
asr_model_name: ASR模型名称
asr_provider: ASR提供者名称
diarization_provider: 说话人分离提供者名称
diarization_model_name: 说话人分离模型名称
llm_model_name: LLM模型名称
llm_provider: LLM提供者名称
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为64
parallel: 是否并行执行ASR和说话人分离,默认为False
"""
if not device:
import torch
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
self.asr_model_name = asr_model_name
self.asr_provider = asr_provider
self.diarization_provider = diarization_provider
self.diarization_model_name = diarization_model_name
self.device = device
self.segmentation_batch_size = segmentation_batch_size
self.parallel = parallel
self.speaker_identifier = SpeakerIdentifier(
llm_model_name=llm_model_name,
llm_provider=llm_provider,
device=device
)
logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
def _merge_adjacent_text_segments(self, segments: List[EnhancedSegment]) -> List[EnhancedSegment]:
"""
合并相邻的、可能属于同一句子的 EnhancedSegment。
合并条件:同一说话人,时间基本连续,文本内容可拼接。
"""
if not segments:
return []
merged_segments: List[EnhancedSegment] = []
if not segments: # 重复检查,可移除
return merged_segments
current_merged_segment = segments[0]
for i in range(1, len(segments)):
next_segment = segments[i]
time_gap_seconds = next_segment.start - current_merged_segment.end
can_merge_text = False
if current_merged_segment.text and next_segment.text:
current_text_stripped = current_merged_segment.text.strip()
if current_text_stripped and not current_text_stripped[-1] in ".。?!?!":
can_merge_text = True
if (current_merged_segment.speaker == next_segment.speaker and
0 <= time_gap_seconds < 0.75 and
can_merge_text):
current_merged_segment = EnhancedSegment(
start=current_merged_segment.start,
end=next_segment.end,
text=(current_merged_segment.text.strip() + " " + next_segment.text.strip()).strip(),
speaker=current_merged_segment.speaker,
language=current_merged_segment.language
)
else:
merged_segments.append(current_merged_segment)
current_merged_segment = next_segment
merged_segments.append(current_merged_segment)
return merged_segments
def _run_asr(self, audio: AudioSegment) -> TranscriptionResult:
"""执行ASR处理"""
logger.debug("执行ASR...")
return asr_router.transcribe_audio(
audio,
provider=self.asr_provider,
model_name=self.asr_model_name,
device=self.device
)
def _run_diarization(self, audio: AudioSegment) -> DiarizationResult:
"""执行说话人分离处理"""
logger.debug("执行说话人分离...")
return diarizer_router.diarize_audio(
audio,
provider=self.diarization_provider,
model_name=self.diarization_model_name,
device=self.device,
segmentation_batch_size=self.segmentation_batch_size
)
def transcribe(self, audio: AudioSegment) -> CombinedTranscriptionResult:
"""
转录整个音频 (新的非流式逻辑将在这里实现)
参数:
audio: 要转录的AudioSegment对象
返回:
包含完整转录和说话人信息的结果
"""
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频 (非流式)")
if self.parallel:
# 并行执行ASR和说话人分离
logger.info("并行执行ASR和说话人分离")
with ThreadPoolExecutor(max_workers=2) as executor:
asr_future = executor.submit(self._run_asr, audio)
diarization_future = executor.submit(self._run_diarization, audio)
asr_result: TranscriptionResult = asr_future.result()
diarization_result: DiarizationResult = diarization_future.result()
logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
else:
# 顺序执行ASR和说话人分离
# 步骤1: 对整个音频执行ASR
logger.debug("执行ASR...")
asr_result: TranscriptionResult = asr_router.transcribe_audio(
audio,
provider=self.asr_provider,
model_name=self.asr_model_name,
device=self.device
)
logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
# 步骤2: 对整个音频执行说话人分离
logger.debug("执行说话人分离...")
diarization_result: DiarizationResult = diarizer_router.diarize_audio(
audio,
provider=self.diarization_provider,
model_name=self.diarization_model_name,
device=self.device,
segmentation_batch_size=self.segmentation_batch_size
)
logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
# 步骤3: 创建增强分段
all_enhanced_segments: List[EnhancedSegment] = self._create_enhanced_segments_with_splitting(
asr_result.segments,
diarization_result.segments,
asr_result.language
)
# 步骤4: (可选)合并相邻的文本分段
if all_enhanced_segments:
logger.debug(f"合并前有 {len(all_enhanced_segments)} 个增强分段,尝试合并相邻分段...")
final_segments = self._merge_adjacent_text_segments(all_enhanced_segments)
logger.debug(f"合并后有 {len(final_segments)} 个增强分段")
else:
final_segments = []
logger.debug("没有增强分段可供合并。")
# 整理合并的文本
full_text = " ".join([segment.text for segment in final_segments]).strip()
# 计算最终说话人数
num_speakers_set = set(s.speaker for s in final_segments if s.speaker != "UNKNOWN")
return CombinedTranscriptionResult(
segments=final_segments,
text=full_text,
language=asr_result.language or "unknown",
num_speakers=len(num_speakers_set) if num_speakers_set else diarization_result.num_speakers
)
# 新方法:根据标点分割ASR文本片段
def _split_asr_segment_by_punctuation(
self,
asr_seg_text: str,
asr_seg_start: float,
asr_seg_end: float
) -> List[Dict[str, Any]]:
"""
根据标点符号分割ASR文本片段,并按字符比例估算子片段的时间戳。
返回: 字典列表,每个字典包含 'text', 'start', 'end'。
"""
sentence_terminators = ".。?!?!;;"
# 正则表达式:匹配句子内容以及紧随其后的标点(如果存在)
# 使用 re.split 保留分隔符,然后重组
parts = re.split(f'([{sentence_terminators}])', asr_seg_text)
sub_texts_final = []
current_s = ""
for s_part in parts:
if not s_part:
continue
current_s += s_part
if s_part in sentence_terminators:
if current_s.strip():
sub_texts_final.append(current_s.strip())
current_s = ""
if current_s.strip():
sub_texts_final.append(current_s.strip())
if not sub_texts_final or (len(sub_texts_final) == 1 and sub_texts_final[0] == asr_seg_text.strip()):
# 没有有效分割或分割后只有一个句子(等于原始文本)
return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
output_sub_segments = []
total_text_len = len(asr_seg_text) # 使用原始文本长度进行比例计算
if total_text_len == 0:
return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
current_time = asr_seg_start
original_duration = asr_seg_end - asr_seg_start
for i, sub_text in enumerate(sub_texts_final):
sub_len = len(sub_text)
sub_duration = (sub_len / total_text_len) * original_duration
sub_start_time = current_time
sub_end_time = current_time + sub_duration
# 对于最后一个分片,确保其结束时间与原始分段的结束时间一致,以避免累积误差
if i == len(sub_texts_final) - 1:
sub_end_time = asr_seg_end
# 确保结束时间不超过原始结束时间,并且开始时间不晚于结束时间
sub_end_time = min(sub_end_time, asr_seg_end)
if sub_start_time >= sub_end_time and sub_start_time == asr_seg_end : # 如果开始等于原始结束,允许微小片段
if sub_text: # 仅当有文本时
output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
elif sub_start_time < sub_end_time :
output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
current_time = sub_end_time
if current_time >= asr_seg_end and i < len(sub_texts_final) -1: # 如果时间已用完,但还有句子
# 将剩余句子附加到最后一个有效的时间段,或创建零长度的段
logger.warning(f"时间已在分割过程中用尽,但仍有文本未分配时间。原始段: [{asr_seg_start}-{asr_seg_end}], 当前子句: '{sub_text}'")
# 为后续未分配时间的文本创建零时长或极短时长的片段,附着在末尾
for k in range(i + 1, len(sub_texts_final)):
remaining_text = sub_texts_final[k]
if remaining_text:
output_sub_segments.append({"text": remaining_text, "start": asr_seg_end, "end": asr_seg_end})
break
# 如果处理后没有任何子分段(例如原始文本为空,或分割逻辑问题),返回原始信息作为一个分段
if not output_sub_segments and asr_seg_text.strip():
return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
elif not output_sub_segments and not asr_seg_text.strip():
return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
return output_sub_segments
# 新的核心方法:创建增强分段,包含说话人分配和按需分裂逻辑
def _create_enhanced_segments_with_splitting(
self,
asr_segments: List[Dict[str, Union[float, str]]],
diarization_segments: List[Dict[str, Union[float, str, int]]],
language: str
) -> List[EnhancedSegment]:
"""
为ASR分段分配说话人,如果ASR分段跨越多个说话人,则尝试按标点分裂。
"""
final_enhanced_segments: List[EnhancedSegment] = []
if not asr_segments:
return []
# 为了快速查找,可以预处理 diarization_segments,但对于数量不多的情况,直接遍历也可
# diarization_segments.sort(key=lambda x: x['start']) # 确保有序
for asr_seg in asr_segments:
asr_start = float(asr_seg["start"])
asr_end = float(asr_seg["end"])
asr_text = str(asr_seg["text"]).strip()
if not asr_text or asr_start >= asr_end: # 跳过无效的ASR分段
continue
# 找出与当前ASR分段在时间上重叠的所有说话人分段
overlapping_diar_segs = []
for diar_seg in diarization_segments:
diar_start = float(diar_seg["start"])
diar_end = float(diar_seg["end"])
overlap_start = max(asr_start, diar_start)
overlap_end = min(asr_end, diar_end)
if overlap_end > overlap_start: # 有重叠
overlapping_diar_segs.append({
"speaker": str(diar_seg["speaker"]),
"start": diar_start,
"end": diar_end,
"overlap_duration": overlap_end - overlap_start
})
distinct_speakers_in_overlap = set(d['speaker'] for d in overlapping_diar_segs)
segments_to_process_further: List[Dict[str, Any]] = []
if len(distinct_speakers_in_overlap) > 1:
logger.debug(f"ASR段 [{asr_start:.2f}-{asr_end:.2f}] \"{asr_text[:50]}...\" 跨越 {len(distinct_speakers_in_overlap)} 个说话人。尝试按标点分裂。")
# 跨多个说话人,尝试按标点分裂ASR segment
sub_asr_segments_data = self._split_asr_segment_by_punctuation(
asr_text,
asr_start,
asr_end
)
if len(sub_asr_segments_data) > 1:
logger.debug(f"成功将ASR段分裂成 {len(sub_asr_segments_data)} 个子句。")
segments_to_process_further.extend(sub_asr_segments_data)
else:
# 单一说话人或无说话人重叠(也视为单一处理单位)
segments_to_process_further.append({"text": asr_text, "start": asr_start, "end": asr_end})
# 为每个原始或分裂后的ASR(子)分段分配说话人
for current_proc_seg_data in segments_to_process_further:
proc_text = current_proc_seg_data["text"].strip()
proc_start = current_proc_seg_data["start"]
proc_end = current_proc_seg_data["end"]
if not proc_text or proc_start >= proc_end: # 跳过无效的子分段
continue
# 为当前处理的(可能是子)分段确定最佳说话人
speaker_overlaps_for_proc_seg = {}
for diar_seg_info in overlapping_diar_segs: # 使用之前计算的、与原始ASR段重叠的diar_segs
# 现在需要计算这个 diar_seg_info 与 proc_seg 的重叠
overlap_start = max(proc_start, diar_seg_info["start"])
overlap_end = min(proc_end, diar_seg_info["end"])
if overlap_end > overlap_start:
overlap_duration = overlap_end - overlap_start
speaker = diar_seg_info["speaker"]
speaker_overlaps_for_proc_seg[speaker] = \
speaker_overlaps_for_proc_seg.get(speaker, 0) + overlap_duration
best_speaker = "UNKNOWN"
if speaker_overlaps_for_proc_seg:
best_speaker = max(speaker_overlaps_for_proc_seg.items(), key=lambda x: x[1])[0]
elif overlapping_diar_segs: # 如果子分段本身没有重叠,但原始ASR段有
# 可以选择原始ASR段中占比最大的,或者最近的
# 为简化,如果子分段无直接重叠,也可能标记为UNKNOWN,或尝试找最近的
# 这里采用:如果子分段无直接重叠,但在原始ASR段中有说话人,则使用原始ASR段中重叠最长的
# (此逻辑分支效果待观察,更简单的是直接UNKNOWN)
# 此处简化:若子分段无重叠,则为UNKNOWN
pass # best_speaker 默认为 UNKNOWN
# 如果 best_speaker 仍为 UNKNOWN,但原始ASR段只有一个说话者,则使用该说话者
if best_speaker == "UNKNOWN" and len(distinct_speakers_in_overlap) == 1:
best_speaker = list(distinct_speakers_in_overlap)[0]
elif best_speaker == "UNKNOWN" and not overlapping_diar_segs:
# 如果整个ASR段都没有任何说话人信息,则确实是UNKNOWN
pass
final_enhanced_segments.append(
EnhancedSegment(
start=proc_start,
end=proc_end,
text=proc_text,
speaker=best_speaker,
language=language # 所有子分段继承原始ASR段的语言
)
)
# 对最终结果按开始时间排序
final_enhanced_segments.sort(key=lambda seg: seg.start)
return final_enhanced_segments
def transcribe_podcast(
self,
audio: AudioSegment,
podcast_info: PodcastChannel,
episode_info: PodcastEpisode,
) -> CombinedTranscriptionResult:
"""
专门针对播客剧集的音频转录方法
参数:
audio: 要转录的AudioSegment对象
podcast_info: 播客频道信息
episode_info: 播客剧集信息
返回:
包含完整转录和识别后说话人名称的结果
"""
logger.info(f"开始转录播客剧集 {len(audio)/1000:.2f} 秒的音频")
# 1. 先执行基础转录流程
transcription_result = self.transcribe(audio)
# 3. 识别说话人名称
logger.info("识别说话人名称...")
speaker_name_map = self.speaker_identifier.recognize_speaker_names(
transcription_result.segments,
podcast_info,
episode_info
)
# 4. 将识别的说话人名称添加到转录结果中
enhanced_segments_with_names = []
for segment in transcription_result.segments:
# 复制原始段落并添加说话人名称
speaker_id = segment.speaker
speaker_name = speaker_name_map.get(speaker_id, None)
# 创建新的段落对象,包含说话人名称
new_segment = EnhancedSegment(
start=segment.start,
end=segment.end,
text=segment.text,
speaker=speaker_id,
language=segment.language,
speaker_name=speaker_name
)
enhanced_segments_with_names.append(new_segment)
# 5. 创建并返回新的转录结果
return CombinedTranscriptionResult(
segments=enhanced_segments_with_names,
text=transcription_result.text,
language=transcription_result.language,
num_speakers=transcription_result.num_speakers
)
def transcribe_audio(
audio_segment: AudioSegment,
asr_model_name: str = "distil-whisper/distil-large-v3.5",
asr_provider: str = "distil_whisper_transformers",
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
diarization_provider: str = "pyannote_transformers",
device: Optional[str] = None,
segmentation_batch_size: int = 64,
parallel: bool = False,
) -> CombinedTranscriptionResult: # 返回类型固定为 CombinedTranscriptionResult
"""
整合ASR和说话人分离的音频转录函数 (仅支持非流式)
参数:
audio_segment: 输入的AudioSegment对象
asr_model_name: ASR模型名称
asr_provider: ASR提供者名称
diarization_model_name: 说话人分离模型名称
diarization_provider: 说话人分离提供者名称
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为64
parallel: 是否并行执行ASR和说话人分离,默认为False
返回:
完整转录结果
"""
logger.info(f"调用transcribe_audio函数 (非流式),音频长度: {len(audio_segment)/1000:.2f}秒")
transcriber = CombinedTranscriber(
asr_model_name=asr_model_name,
asr_provider=asr_provider,
diarization_model_name=diarization_model_name,
diarization_provider=diarization_provider,
llm_model_name="",
llm_provider="",
device=device,
segmentation_batch_size=segmentation_batch_size,
parallel=parallel
)
# 直接调用 transcribe 方法
return transcriber.transcribe(audio_segment)
def transcribe_podcast_audio(
audio_segment: AudioSegment,
podcast_info: PodcastChannel,
episode_info: PodcastEpisode,
asr_model_name: str = "distil-whisper/distil-large-v3.5",
asr_provider: str = "distil_whisper_transformers",
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
diarization_provider: str = "pyannote_transformers",
llm_model_name: str = "google/gemma-3-4b-it",
llm_provider: str = "gemma-transformers",
device: Optional[str] = None,
segmentation_batch_size: int = 64,
parallel: bool = False,
) -> CombinedTranscriptionResult:
"""
针对播客剧集的音频转录函数,包含说话人名称识别
参数:
audio_segment: 输入的AudioSegment对象
podcast_info: 播客频道信息
episode_info: 播客剧集信息
asr_model_name: ASR模型名称
asr_provider: ASR提供者名称
diarization_provider: 说话人分离提供者名称
diarization_model_name: 说话人分离模型名称
llm_model_name: LLM模型名称
llm_provider: LLM提供者名称
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为64
parallel: 是否并行执行ASR和说话人分离,默认为False
返回:
包含说话人名称的完整转录结果
"""
logger.info(f"调用transcribe_podcast_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
transcriber = CombinedTranscriber(
asr_model_name=asr_model_name,
asr_provider=asr_provider,
diarization_provider=diarization_provider,
diarization_model_name=diarization_model_name,
llm_model_name=llm_model_name,
llm_provider=llm_provider,
device=device,
segmentation_batch_size=segmentation_batch_size,
parallel=parallel
)
# 调用播客专用转录方法
return transcriber.transcribe_podcast(
audio=audio_segment,
podcast_info=podcast_info,
episode_info=episode_info,
)
|