Spaces:
Running
Running
File size: 3,716 Bytes
8289369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
说话人分离器基础类,包含可复用的方法
"""
import os
import logging
from abc import ABC, abstractmethod
from pydub import AudioSegment
from typing import Any, Dict, List, Union, Optional, Tuple
from ..schemas import DiarizationResult
# 配置日志
logger = logging.getLogger("diarization")
class BaseDiarizer(ABC):
"""说话人分离器基础类"""
def __init__(
self,
model_name: str,
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
):
"""
初始化说话人分离器基础参数
参数:
model_name: 模型名称
token: Hugging Face令牌,用于访问模型
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为32
"""
self.model_name = model_name
self.device = device
self.segmentation_batch_size = segmentation_batch_size
logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}")
@abstractmethod
def _load_model(self):
"""加载模型,子类需要实现"""
pass
def _prepare_audio(self, audio: AudioSegment) -> str:
"""
准备音频数据,保存为临时文件
参数:
audio: 输入的AudioSegment对象
返回:
临时音频文件的路径
"""
logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
# 确保采样率为16kHz (pyannote模型要求)
if audio.frame_rate != 16000:
logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
audio = audio.set_frame_rate(16000)
# 确保是单声道
if audio.channels > 1:
logger.debug(f"将{audio.channels}声道音频转换为单声道")
audio = audio.set_channels(1)
# 保存为临时文件
temp_audio_path = "_temp_audio_for_diarization.wav"
audio.export(temp_audio_path, format="wav")
logger.debug(f"音频处理完成,保存至: {temp_audio_path}")
return temp_audio_path
def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]:
"""
将pyannote的分段结果转换为所需格式
参数:
diarization: pyannote模型返回的分段结果
返回:
转换后的分段列表和说话人数量
"""
segments = []
speakers = set()
# 遍历说话人分离结果
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker
})
speakers.add(speaker)
# 按开始时间排序
segments.sort(key=lambda x: x["start"])
logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人")
return segments, len(speakers)
@abstractmethod
def diarize(self, audio: AudioSegment) -> DiarizationResult:
"""
对音频进行说话人分离,子类需要实现
参数:
audio: 要处理的AudioSegment对象
返回:
DiarizationResult对象,包含分段结果和说话人数量
"""
pass |