File size: 3,716 Bytes
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
说话人分离器基础类,包含可复用的方法
"""

import os
import logging
from abc import ABC, abstractmethod
from pydub import AudioSegment
from typing import Any, Dict, List, Union, Optional, Tuple

from ..schemas import DiarizationResult

# 配置日志
logger = logging.getLogger("diarization")


class BaseDiarizer(ABC):
    """说话人分离器基础类"""
    
    def __init__(
        self, 
        model_name: str,
        token: Optional[str] = None,
        device: str = "cpu",
        segmentation_batch_size: int = 32,
    ):
        """
        初始化说话人分离器基础参数
        
        参数:
            model_name: 模型名称
            token: Hugging Face令牌,用于访问模型
            device: 推理设备,'cpu'或'cuda'
            segmentation_batch_size: 分割批处理大小,默认为32
        """
        self.model_name = model_name
        self.device = device
        self.segmentation_batch_size = segmentation_batch_size
        
        logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}")
    
    @abstractmethod
    def _load_model(self):
        """加载模型,子类需要实现"""
        pass
    
    def _prepare_audio(self, audio: AudioSegment) -> str:
        """
        准备音频数据,保存为临时文件
        
        参数:
            audio: 输入的AudioSegment对象
            
        返回:
            临时音频文件的路径
        """
        logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
        
        # 确保采样率为16kHz (pyannote模型要求)
        if audio.frame_rate != 16000:
            logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
            audio = audio.set_frame_rate(16000)
            
        # 确保是单声道
        if audio.channels > 1:
            logger.debug(f"将{audio.channels}声道音频转换为单声道")
            audio = audio.set_channels(1)
            
        # 保存为临时文件
        temp_audio_path = "_temp_audio_for_diarization.wav"
        audio.export(temp_audio_path, format="wav")
        
        logger.debug(f"音频处理完成,保存至: {temp_audio_path}")
        
        return temp_audio_path
    
    def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]:
        """
        将pyannote的分段结果转换为所需格式
        
        参数:
            diarization: pyannote模型返回的分段结果
            
        返回:
            转换后的分段列表和说话人数量
        """
        segments = []
        speakers = set()
        
        # 遍历说话人分离结果
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            segments.append({
                "start": turn.start,
                "end": turn.end,
                "speaker": speaker
            })
            speakers.add(speaker)
        
        # 按开始时间排序
        segments.sort(key=lambda x: x["start"])
        
        logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人")
        
        return segments, len(speakers)
    
    @abstractmethod
    def diarize(self, audio: AudioSegment) -> DiarizationResult:
        """
        对音频进行说话人分离,子类需要实现
        
        参数:
            audio: 要处理的AudioSegment对象
            
        返回:
            DiarizationResult对象,包含分段结果和说话人数量
        """
        pass