import logging from typing import List, Dict, Optional import json import re from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode from ..llm import llm_router # 配置日志 logger = logging.getLogger("speaker_identify") class SpeakerIdentifier: """ 说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称 """ def __init__(self, llm_model_name: str, llm_provider: str, device: Optional[str] = None): """ 初始化说话人识别器 参数: llm_model_name: LLM模型名称,如果为None则使用默认模型 llm_provider: LLM提供者,默认为"gemma-mlx" device: 计算设备,例如 "cpu", "cuda", "mps" """ self.llm_model_name = llm_model_name self.llm_provider = llm_provider self.device = device def _clean_html(self, html_string: Optional[str]) -> str: """ 简单地从字符串中移除HTML标签并清理多余空白。 """ if not html_string: return "" # 移除HTML标签 text = re.sub(r'<[^>]+>', ' ', html_string) # 替换HTML实体(简单版本,只处理常见几个) text = text.replace(' ', ' ').replace('&', '&').replace('<', '<').replace('>', '>') # 移除多余的空白符 text = re.sub(r'\\s+', ' ', text).strip() return text def _get_dialogue_samples( self, segments: List[EnhancedSegment], max_samples_per_speaker: int = 3, # 增加样本数量 max_length_per_sample: int = 200 # 增加样本长度 ) -> Dict[str, List[str]]: """ 为每个说话人提取对话样本。 """ speaker_dialogues: Dict[str, List[str]] = {} for segment in segments: speaker = segment.speaker if speaker == "UNKNOWN" or not segment.text.strip(): # 跳过未知说话人或空文本 continue if speaker not in speaker_dialogues: speaker_dialogues[speaker] = [] if len(speaker_dialogues[speaker]) < max_samples_per_speaker: text_sample = segment.text.strip()[:max_length_per_sample] if len(segment.text.strip()) > max_length_per_sample: text_sample += "..." speaker_dialogues[speaker].append(text_sample) return speaker_dialogues def recognize_speaker_names( self, segments: List[EnhancedSegment], podcast_info: Optional[PodcastChannel], episode_info: Optional[PodcastEpisode], max_shownotes_length: int = 1500, max_desc_length: int = 500 ) -> Dict[str, str]: """ 使用LLM根据转录分段和播客/剧集元数据识别说话人的真实姓名或昵称。 参数: segments: 转录后的 EnhancedSegment 列表。 podcast_info: 包含播客元数据的 PodcastChannel 对象。 episode_info: 包含单集播客元数据的 PodcastEpisode 对象。 max_shownotes_length: 用于Prompt的 Shownotes 最大字符数。 max_desc_length: 用于Prompt的播客描述最大字符数。 返回: 一个字典,键是原始的 "SPEAKER_XX",值是识别出的说话人名称。 """ unique_speaker_ids = sorted(list(set(seg.speaker for seg in segments if seg.speaker != "UNKNOWN" and seg.text.strip()))) if not unique_speaker_ids: print("未能从 segments 中提取到有效的 speaker_ids。") return {} dialogue_samples = self._get_dialogue_samples(segments) # 增加每个说话人的话语分析信息,包括话语频率和长度 speaker_stats = {} for segment in segments: speaker = segment.speaker if speaker == "UNKNOWN" or not segment.text.strip(): continue if speaker not in speaker_stats: speaker_stats[speaker] = { "total_segments": 0, "total_chars": 0, "avg_segment_length": 0, "intro_likely": False # 是否有介绍性质的话语 } speaker_stats[speaker]["total_segments"] += 1 speaker_stats[speaker]["total_chars"] += len(segment.text) # 检测可能的自我介绍或他人介绍 lower_text = segment.text.lower() intro_patterns = [ r'欢迎来到', r'欢迎收听', r'我是', r'我叫', r'大家好', r'今天的嘉宾是', r'我们请到了', r'welcome to', r'i\'m your host', r'this is', r'today we have', r'joining us', r'our guest', r'my name is' ] if any(re.search(pattern, lower_text) for pattern in intro_patterns): speaker_stats[speaker]["intro_likely"] = True # 计算平均话语长度 for speaker, stats in speaker_stats.items(): if stats["total_segments"] > 0: stats["avg_segment_length"] = stats["total_chars"] / stats["total_segments"] # 创建增强的说话人信息,包含统计数据 speaker_info_for_prompt = [] for speaker_id in unique_speaker_ids: samples = dialogue_samples.get(speaker_id, ["(No dialogue samples available)"]) stats = speaker_stats.get(speaker_id, {"total_segments": 0, "avg_segment_length": 0, "intro_likely": False}) speaker_info_for_prompt.append({ "speaker_id": speaker_id, "dialogue_samples": samples, "speech_stats": { "total_segments": stats["total_segments"], "avg_segment_length": round(stats["avg_segment_length"], 2), "has_intro_pattern": stats["intro_likely"] } }) # 安全地访问属性,提供默认值 podcast_title = podcast_info.title if podcast_info and podcast_info.title else "Unknown Podcast" podcast_author = podcast_info.author if podcast_info and podcast_info.author else "Unknown" raw_podcast_desc = podcast_info.description if podcast_info and podcast_info.description else "" cleaned_podcast_desc = self._clean_html(raw_podcast_desc) podcast_desc_for_prompt = cleaned_podcast_desc[:max_desc_length] if len(cleaned_podcast_desc) > max_desc_length: podcast_desc_for_prompt += "..." episode_title = episode_info.title if episode_info and episode_info.title else "Unknown Episode" raw_episode_summary = episode_info.summary if episode_info and episode_info.summary else "" cleaned_episode_summary = self._clean_html(raw_episode_summary) episode_summary_for_prompt = cleaned_episode_summary[:max_desc_length] # 使用与描述相同的长度限制 if len(cleaned_episode_summary) > max_desc_length: episode_summary_for_prompt += "..." raw_episode_shownotes = episode_info.shownotes if episode_info and episode_info.shownotes else "" cleaned_episode_shownotes = self._clean_html(raw_episode_shownotes) episode_shownotes_for_prompt = cleaned_episode_shownotes[:max_shownotes_length] if len(cleaned_episode_shownotes) > max_shownotes_length: episode_shownotes_for_prompt += "..." system_prompt = """You are a speaker identification expert. Return only a JSON object mapping speaker IDs to names. Start directly with { and end with }. No markdown, no explanations.""" # 进一步简化,只保留最关键的信息 key_info = [] for speaker_id in unique_speaker_ids: samples = dialogue_samples.get(speaker_id, []) stats = speaker_stats.get(speaker_id, {"total_segments": 0, "intro_likely": False}) # 构建简短描述 desc_parts = [] if stats["intro_likely"]: desc_parts.append("intro") if stats["total_segments"] > 0: desc_parts.append(f"{stats['total_segments']}segs") if samples: # 只取第一个样本的前50个字符 sample_text = samples[0][:50].replace('\n', ' ').strip() if sample_text: desc_parts.append(f'"{sample_text}"') key_info.append(f"{speaker_id}: {', '.join(desc_parts)}") user_prompt_template = f"""Podcast: {podcast_title} Host: {podcast_author} Episode: {episode_title} Notes: {episode_shownotes_for_prompt[:300]} Speakers: {chr(10).join(key_info)} Return JSON like: {{"SPEAKER_00": "Name1", "SPEAKER_01": "Name2"}}""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_template} ] # 预设默认映射,使用更智能的启发式方法而不是简单依赖顺序 final_map = {} # 尝试使用说话模式启发式方法来初步识别角色 # 1. 说话次数最多的可能是主持人 # 2. 有介绍性话语的可能是主持人 # 3. 其他角色先标记为嘉宾 host_candidates = [] for speaker_id, stats in speaker_stats.items(): if stats["intro_likely"]: host_candidates.append((speaker_id, 2)) # 优先级2:有介绍性话语 else: # 按说话次数排序 host_candidates.append((speaker_id, stats["total_segments"])) # 按可能性排序(介绍性话语 > 说话次数) host_candidates.sort(key=lambda x: (-1 if x[1] == 2 else 0, x[1]), reverse=True) if host_candidates: # 最可能的主持人 host_id = host_candidates[0][0] final_map[host_id] = "Podcast Host" # 其他人先标为嘉宾 guest_counter = 1 for speaker_id in unique_speaker_ids: if speaker_id != host_id: final_map[speaker_id] = f"Guest {guest_counter}" guest_counter += 1 else: # 如果没有明显线索,使用传统的顺序方法作为备选 is_host_assigned = False guest_counter = 1 for speaker_id in unique_speaker_ids: if not is_host_assigned: final_map[speaker_id] = "Podcast Host" is_host_assigned = True else: final_map[speaker_id] = f"Guest {guest_counter}" guest_counter += 1 try: response = llm_router.chat_completion( messages=messages, provider=self.llm_provider, model=self.llm_model_name, temperature=0.2, # 稍微提高温度 max_tokens=300, # 进一步增加token数 top_p=0.5, # 适度提高top_p device=self.device, repetition_penalty=1.0, # 保持不使用重复惩罚 do_sample=True # 允许少量采样,不使用stop tokens ) logger.info(f"LLM调用日志,请求参数:【{messages}】, 响应: 【{response}】") assistant_response_content = response["choices"][0]["message"]["content"] # 更严格的JSON提取逻辑 parsed_llm_output = None # 首先尝试直接解析整个响应(如果它就是JSON) try: parsed_llm_output = json.loads(assistant_response_content.strip()) if isinstance(parsed_llm_output, dict): print("直接解析响应为JSON成功") else: parsed_llm_output = None except json.JSONDecodeError: pass # 如果直接解析失败,尝试提取JSON部分 if parsed_llm_output is None: # 尝试从Markdown代码块中提取JSON json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', assistant_response_content, re.DOTALL) if json_match: json_str = json_match.group(1) print("从markdown代码块中提取JSON") else: # 如果没有markdown块,尝试找到第一个 '{' 到最后一个 '}' first_brace = assistant_response_content.find('{') last_brace = assistant_response_content.rfind('}') if first_brace != -1 and last_brace != -1 and last_brace > first_brace: json_str = assistant_response_content[first_brace : last_brace+1] print("通过大括号位置提取JSON") else: print("无法找到有效的JSON结构,使用默认映射") return final_map try: # 清理JSON字符串 json_str = json_str.strip() # 移除可能的换行符和多余空格 json_str = re.sub(r'\s+', ' ', json_str) parsed_llm_output = json.loads(json_str) if not isinstance(parsed_llm_output, dict): print(f"LLM返回的JSON不是一个字典: {parsed_llm_output}") parsed_llm_output = None else: print("JSON解析成功") except json.JSONDecodeError as e: print(f"LLM返回的JSON解析失败: {e}") print(f"用于解析的字符串: '{json_str[:200]}...'") parsed_llm_output = None if parsed_llm_output: # 直接使用LLM的有效输出,不再依赖预设的角色分配逻辑 final_map = {} unknown_counter = 1 # 先处理LLM识别出的角色 for spk_id in unique_speaker_ids: if spk_id in parsed_llm_output and isinstance(parsed_llm_output[spk_id], str) and parsed_llm_output[spk_id].strip(): final_map[spk_id] = parsed_llm_output[spk_id].strip() else: # 如果LLM没有给出特定ID的结果,使用"Unknown Speaker" final_map[spk_id] = f"Unknown Speaker {unknown_counter}" unknown_counter += 1 # 检查是否有"Host"或"主持人"标识 has_host = any("主持人" in name or "Host" in name for name in final_map.values()) # 如果没有任何主持人标识,且存在"Unknown Speaker",可以考虑将最活跃的未知说话人设为主持人 if not has_host and any("Unknown Speaker" in name for name in final_map.values()): # 找出最活跃的未知说话人 most_active_unknown = None max_segments = 0 for spk_id, name in final_map.items(): if "Unknown Speaker" in name and spk_id in speaker_stats: if speaker_stats[spk_id]["total_segments"] > max_segments: max_segments = speaker_stats[spk_id]["total_segments"] most_active_unknown = spk_id if most_active_unknown: final_map[most_active_unknown] = "Podcast Host" print(f"LLM识别结果: {final_map}") return final_map except Exception as e: import traceback print(f"调用LLM或处理响应时发生严重错误: {e}") print(traceback.format_exc()) # 发生任何严重错误,返回初始的启发式映射 return final_map