import time import random import uuid import json import os import logging import threading from typing import Dict, List, Optional, Any, Union, Tuple, Callable # 初始化日志 logger = logging.getLogger("sora-api.key_manager") class KeyManager: def __init__(self, storage_file: str = "api_keys.json"): """ 初始化密钥管理器 Args: storage_file: 密钥存储文件路径 """ self.keys = [] # 密钥列表 self.storage_file = storage_file self.usage_stats = {} # 使用统计 self._lock = threading.RLock() # 添加可重入锁以支持并发访问 self._working_keys = {} # 新增:记录正在工作中的密钥 {key_value: task_id} self._load_keys() def _load_keys(self) -> None: """从环境变量或文件加载密钥""" keys_loaded = False # 先尝试从环境变量加载 api_keys_str = os.getenv("API_KEYS", "") if api_keys_str: try: env_data = json.loads(api_keys_str) self._process_keys_data(env_data) if len(self.keys) > 0: logger.info(f"已从环境变量加载 {len(self.keys)} 个密钥") keys_loaded = True else: logger.warning("环境变量API_KEYS存在但未包含有效密钥") except json.JSONDecodeError as e: logger.error(f"解析环境变量API keys失败: {str(e)}") # 如果环境变量未设置、解析失败或未加载到密钥,从文件加载 if not keys_loaded: try: if os.path.exists(self.storage_file): logger.info(f"尝试从文件加载密钥: {self.storage_file}") with open(self.storage_file, 'r', encoding='utf-8') as f: data = json.load(f) keys_before = len(self.keys) self._process_keys_data(data) keys_loaded = len(self.keys) > keys_before logger.info(f"已从文件加载 {len(self.keys) - keys_before} 个密钥") else: logger.warning(f"密钥文件不存在: {self.storage_file}") except Exception as e: logger.error(f"加载密钥失败: {str(e)}") if len(self.keys) == 0: logger.warning("未能从环境变量或文件加载任何密钥") def _process_keys_data(self, data): """处理不同格式的密钥数据""" # 处理不同的数据格式 if isinstance(data, list): # 旧版格式:直接是密钥列表 raw_keys = data self.keys = [] self.usage_stats = {} # 为每个密钥创建完整的记录 for key_info in raw_keys: if isinstance(key_info, dict): key_value = key_info.get("key") if not key_value: logger.warning(f"忽略无效密钥配置: {key_info}") continue # 确保有ID key_id = key_info.get("id") or str(uuid.uuid4()) # 构建完整的密钥记录 key_record = { "id": key_id, "name": key_info.get("name", ""), "key": key_value, "weight": key_info.get("weight", 1), "max_rpm": key_info.get("max_rpm", 60), "requests": 0, "last_reset": time.time(), "available": key_info.get("is_enabled", True), "is_enabled": key_info.get("is_enabled", True), "created_at": key_info.get("created_at", time.time()), "last_used": key_info.get("last_used"), "notes": key_info.get("notes") } self.keys.append(key_record) # 初始化使用统计 self.usage_stats[key_id] = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "daily_usage": {}, "average_response_time": 0 } elif isinstance(key_info, str): # 如果是字符串,直接作为密钥值 key_id = str(uuid.uuid4()) self.keys.append({ "id": key_id, "name": "", "key": key_info, "weight": 1, "max_rpm": 60, "requests": 0, "last_reset": time.time(), "available": True, "is_enabled": True, "created_at": time.time(), "last_used": None, "notes": None }) # 初始化使用统计 self.usage_stats[key_id] = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "daily_usage": {}, "average_response_time": 0 } else: # 新版格式:包含keys和usage_stats的字典 self.keys = data.get('keys', []) self.usage_stats = data.get('usage_stats', {}) def _save_keys(self) -> None: """保存密钥到文件""" try: with open(self.storage_file, 'w', encoding='utf-8') as f: json.dump({ 'keys': self.keys, 'usage_stats': self.usage_stats }, f, ensure_ascii=False, indent=2) # 同时更新Config中的API_KEYS try: from .config import Config Config.API_KEYS = self.keys except (ImportError, AttributeError): logger.debug("无法更新Config中的API_KEYS") except Exception as e: logger.error(f"保存密钥失败: {str(e)}") def add_key(self, key_value: str, name: str = "", weight: int = 1, rate_limit: int = 60, is_enabled: bool = True, notes: str = None) -> Dict[str, Any]: """ 添加密钥 Args: key_value: 密钥值 name: 密钥名称 weight: 权重 rate_limit: 速率限制(每分钟请求数) is_enabled: 是否启用 notes: 备注 Returns: 添加的密钥信息 """ with self._lock: # 使用锁保护添加过程 # 检查密钥是否已存在 for key in self.keys: if key.get("key") == key_value: return key key_id = str(uuid.uuid4()) new_key = { "id": key_id, "name": name, "key": key_value, "weight": weight, "max_rpm": rate_limit, "requests": 0, "last_reset": time.time(), "available": is_enabled, "is_enabled": is_enabled, "created_at": time.time(), "last_used": None, "notes": notes } self.keys.append(new_key) # 初始化使用统计 self.usage_stats[key_id] = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "daily_usage": {}, "average_response_time": 0 } self._save_keys() logger.info(f"已添加密钥: {name or key_id}") return new_key def get_all_keys(self) -> List[Dict[str, Any]]: """获取所有密钥信息(已隐藏完整密钥值)""" with self._lock: # 使用锁保护读取过程 result = [] for key in self.keys: key_copy = key.copy() if "key" in key_copy: # 只显示密钥前6位和后4位 full_key = key_copy["key"] if len(full_key) > 10: key_copy["key"] = full_key[:6] + "..." + full_key[-4:] # 增加临时禁用信息的处理 if key_copy.get("temp_disabled_until"): temp_disabled_until = key_copy["temp_disabled_until"] # 确保temp_disabled_until是时间戳格式 if isinstance(temp_disabled_until, (int, float)): # 转换为可读格式,但保留原始时间戳,让前端可以自行处理 disabled_until_date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(temp_disabled_until)) key_copy["temp_disabled_until_formatted"] = disabled_until_date key_copy["temp_disabled_remaining"] = int(temp_disabled_until - time.time()) result.append(key_copy) return result def get_key_by_id(self, key_id: str) -> Optional[Dict[str, Any]]: """根据ID获取密钥信息""" with self._lock: # 使用锁保护读取过程 for key in self.keys: if key.get("id") == key_id: return key return None def update_key(self, key_id: str, **kwargs) -> Optional[Dict[str, Any]]: """ 更新密钥信息 Args: key_id: 密钥ID **kwargs: 要更新的字段 Returns: 更新后的密钥信息,未找到则返回None """ with self._lock: # 使用锁保护更新过程 for key in self.keys: if key.get("id") == key_id: # 更新提供的字段 for field, value in kwargs.items(): if value is not None: if field == "is_enabled": key["available"] = value # 同步更新available字段 key[field] = value self._save_keys() logger.info(f"已更新密钥: {key.get('name') or key_id}") return key logger.warning(f"未找到密钥: {key_id}") return None def delete_key(self, key_id: str) -> bool: """ 删除密钥 Args: key_id: 密钥ID Returns: 是否成功删除 """ with self._lock: # 使用锁保护删除过程 original_length = len(self.keys) self.keys = [key for key in self.keys if key.get("id") != key_id] # 如果成功删除,保存密钥 if len(self.keys) < original_length: self._save_keys() return True return False def batch_import_keys(self, keys_data: List[Dict[str, Any]]) -> Dict[str, int]: """ 批量导入密钥 Args: keys_data: 密钥数据列表,每个元素为包含密钥信息的字典 Returns: 导入结果统计 """ with self._lock: # 使用锁保护导入过程 imported_count = 0 skipped_count = 0 # 获取现有密钥值 existing_keys = {key.get("key") for key in self.keys} for key_data in keys_data: key_value = key_data.get("key") if not key_value: continue # 检查密钥是否已存在 if key_value in existing_keys: skipped_count += 1 continue # 添加新密钥 key_id = str(uuid.uuid4()) new_key = { "id": key_id, "name": key_data.get("name", ""), "key": key_value, "weight": key_data.get("weight", 1), "max_rpm": key_data.get("rate_limit", 60), "requests": 0, "last_reset": time.time(), "available": key_data.get("enabled", True), "is_enabled": key_data.get("enabled", True), "created_at": time.time(), "last_used": None, "notes": key_data.get("notes") } self.keys.append(new_key) existing_keys.add(key_value) # 添加到已存在集合中 # 初始化使用统计 self.usage_stats[key_id] = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "daily_usage": {}, "average_response_time": 0 } imported_count += 1 # 保存密钥 if imported_count > 0: self._save_keys() return { "imported": imported_count, "skipped": skipped_count } def get_key(self) -> Optional[str]: """获取下一个可用的密钥""" with self._lock: # 使用锁保护整个获取密钥过程 if not self.keys: logger.warning("没有可用的密钥") return None # 重置计数器(如果需要) current_time = time.time() temporary_disabled_updated = False for key in self.keys: # 检查是否有被临时禁用的密钥需要重新启用 if key.get("temp_disabled_until") and current_time > key.get("temp_disabled_until"): key["is_enabled"] = True key["available"] = True key["temp_disabled_until"] = None temporary_disabled_updated = True logger.info(f"密钥 {key.get('name') or key.get('id')} 的临时禁用已解除") if current_time - key["last_reset"] >= 60: key["requests"] = 0 key["last_reset"] = current_time if not key.get("temp_disabled_until"): # 只有未被临时禁用的密钥才会被重新激活 key["available"] = key.get("is_enabled", True) # 如果有任何临时禁用的密钥被更新,保存变更 if temporary_disabled_updated: self._save_keys() # 筛选可用的密钥,排除工作中的密钥 available_keys = [] for k in self.keys: key_value = k.get("key", "") clean_key = key_value.replace("Bearer ", "") if key_value.startswith("Bearer ") else key_value # 检查此密钥是否在工作中 is_working = clean_key in self._working_keys if k.get("available", False) and not is_working: available_keys.append(k) if not available_keys: logger.warning("没有可用的密钥(所有密钥都达到速率限制、被禁用或正在工作中)") return None # 根据权重选择密钥 weights = [k.get("weight", 1) for k in available_keys] selected_idx = random.choices(range(len(available_keys)), weights=weights, k=1)[0] selected_key = available_keys[selected_idx] # 更新使用统计 selected_key["requests"] += 1 selected_key["last_used"] = current_time # 检查是否达到速率限制 if selected_key["requests"] >= selected_key.get("max_rpm", 60): selected_key["available"] = False # 保存数据 - 并发环境下调整为每次都保存,避免状态不一致 # 原来是随机保存(10%的概率) self._save_keys() # 确保返回的密钥包含"Bearer "前缀 key_value = selected_key["key"] if not key_value.startswith("Bearer "): key_value = f"Bearer {key_value}" return key_value def record_request_result(self, key: str, success: bool, response_time: float = 0) -> None: """ 记录请求结果 Args: key: 密钥值 success: 请求是否成功 response_time: 响应时间(秒) """ if not key: logger.warning("记录请求结果失败:密钥为空") return with self._lock: # 使用锁保护记录过程 # 去掉可能的Bearer前缀 key_for_search = key.replace("Bearer ", "") if key.startswith("Bearer ") else key # 查找对应的密钥ID key_id = None key_info = None for k in self.keys: stored_key = k.get("key", "").replace("Bearer ", "") if k.get("key", "").startswith("Bearer ") else k.get("key", "") if stored_key == key_for_search: key_id = k.get("id") key_info = k break if not key_id: logger.warning(f"记录请求结果失败:未找到密钥 {key_for_search[:6]}...") return # 初始化usage_stats如果该密钥还没有统计数据 if key_id not in self.usage_stats: self.usage_stats[key_id] = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "daily_usage": {}, "average_response_time": 0 } # 记录请求结果 stats = self.usage_stats[key_id] stats["total_requests"] += 1 if success: stats["successful_requests"] += 1 else: stats["failed_requests"] += 1 # 记录响应时间 if response_time > 0: if stats["average_response_time"] == 0: stats["average_response_time"] = response_time else: # 使用加权平均 old_avg = stats["average_response_time"] total = stats["total_requests"] # 避免 total 为0或1时产生问题,尽管前面 total_requests 已经增加了 if total > 0: stats["average_response_time"] = ((old_avg * (total - 1)) + response_time) / total else: # 理论上不应该发生,因为 total_requests 已经增加了 stats["average_response_time"] = response_time # 记录每日使用情况 today = time.strftime("%Y-%m-%d") if today not in stats["daily_usage"]: stats["daily_usage"][today] = {"successful": 0, "failed": 0} # 初始化每日统计 # 根据成功与否更新每日统计 if success: stats["daily_usage"][today]["successful"] += 1 else: stats["daily_usage"][today]["failed"] += 1 # 保留最近30天的数据 if len(stats["daily_usage"]) > 30: # 获取所有日期并排序,然后删除最早的 sorted_dates = sorted(stats["daily_usage"].keys()) if sorted_dates: # 确保列表不为空 oldest_date = sorted_dates[0] del stats["daily_usage"][oldest_date] # 更新密钥的最后使用时间 if key_info and "last_used" in key_info: key_info["last_used"] = time.time() # 并发环境下每次都保存,确保统计准确性 self._save_keys() def get_usage_stats(self) -> Dict[str, Any]: """获取使用统计信息""" with self._lock: # 使用锁保护读取过程 total_keys = len(self.keys) active_keys = sum(1 for k in self.keys if k.get("is_enabled", False)) available_keys = sum(1 for k in self.keys if k.get("available", False)) total_requests = sum(stats.get("total_requests", 0) for stats in self.usage_stats.values()) successful_requests = sum(stats.get("successful_requests", 0) for stats in self.usage_stats.values()) # 计算成功率 success_rate = (successful_requests / total_requests * 100) if total_requests > 0 else 0 # 计算每个密钥的平均响应时间 avg_response_times = [stats.get("average_response_time", 0) for stats in self.usage_stats.values() if stats.get("average_response_time", 0) > 0] overall_avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0 # 获取过去7天的使用情况 past_7_days = {} for key_id, stats in self.usage_stats.items(): daily_usage = stats.get("daily_usage", {}) for date, count_data in daily_usage.items(): if date not in past_7_days: past_7_days[date] = {"successful": 0, "failed": 0} # 正确处理字典类型的count_data past_7_days[date]["successful"] += count_data.get("successful", 0) past_7_days[date]["failed"] += count_data.get("failed", 0) # 只保留最近7天 dates = sorted(past_7_days.keys(), reverse=True)[:7] past_7_days = {date: past_7_days[date] for date in dates} return { "total_keys": total_keys, "active_keys": active_keys, "available_keys": available_keys, "total_requests": total_requests, "successful_requests": successful_requests, "failed_requests": total_requests - successful_requests, "success_rate": success_rate, "average_response_time": overall_avg_response_time, "past_7_days": past_7_days } def mark_key_as_working(self, key: str, task_id: str) -> None: """ 将密钥标记为工作中状态 Args: key: API密钥值(可能包含Bearer前缀) task_id: 关联的任务ID """ with self._lock: clean_key = key.replace("Bearer ", "") if key.startswith("Bearer ") else key self._working_keys[clean_key] = task_id logger.debug(f"密钥已标记为工作中,关联任务ID: {task_id}") def release_key(self, key: str) -> None: """ 释放工作中的密钥 Args: key: API密钥值(可能包含Bearer前缀) """ with self._lock: clean_key = key.replace("Bearer ", "") if key.startswith("Bearer ") else key if clean_key in self._working_keys: del self._working_keys[clean_key] logger.debug(f"密钥已释放") def is_key_working(self, key: str) -> bool: """ 检查密钥是否正在工作中 Args: key: API密钥值(可能包含Bearer前缀) Returns: bool: 是否在工作中 """ with self._lock: clean_key = key.replace("Bearer ", "") if key.startswith("Bearer ") else key return clean_key in self._working_keys def mark_key_invalid(self, key: str) -> Optional[str]: """ 将指定的密钥标记为无效(临时禁用而不是永久禁用),并返回一个新的可用密钥 Args: key: API密钥值(可能包含Bearer前缀) Returns: Optional[str]: 新的可用密钥,如果没有可用密钥则返回None """ # 调用临时禁用方法,设置24小时禁用时间 return self.mark_key_temp_disabled(key, hours=24.0) def mark_key_temp_disabled(self, key: str, hours: float = 12.0) -> Optional[str]: """ 将指定的密钥临时禁用指定小时数,并返回一个新的可用密钥 Args: key: API密钥值(可能包含Bearer前缀) hours: 禁用小时数 Returns: Optional[str]: 新的可用密钥,如果没有可用密钥则返回None """ with self._lock: # 使用锁保护临时禁用过程 # 去掉可能的Bearer前缀 key_for_search = key.replace("Bearer ", "") if key.startswith("Bearer ") else key # 检查是否是因为密钥在工作中导致的错误 if key_for_search in self._working_keys: logger.warning(f"尝试禁用正在工作中的密钥(任务ID: {self._working_keys[key_for_search]}),跳过禁用操作") # 获取一个新密钥返回,但不禁用当前密钥 new_key = self.get_key() if new_key: logger.info(f"已返回新密钥,但未禁用工作中的密钥") return new_key else: logger.warning("没有可用的备用密钥") return None # 查找对应的密钥 key_found = False disabled_key_id = None for key_info in self.keys: stored_key = key_info.get("key", "").replace("Bearer ", "") if key_info.get("key", "").startswith("Bearer ") else key_info.get("key", "") if stored_key == key_for_search: # 标记密钥为临时禁用 disabled_until = time.time() + (hours * 3600) # 当前时间加上禁用小时数 key_info["available"] = False key_info["temp_disabled_until"] = disabled_until key_info["notes"] = (key_info.get("notes") or "") + f"\n[自动] 在 {time.strftime('%Y-%m-%d %H:%M:%S')} 被临时禁用{hours}小时" key_found = True disabled_key_id = key_info.get("id") logger.warning(f"密钥 {key_info.get('name') or key_info.get('id')} 被临时禁用{hours}小时") break if key_found: # 保存更改 self._save_keys() # 获取新的密钥,排除已禁用的 new_key = self.get_key() if new_key: logger.info(f"已自动切换到新的密钥") return new_key else: logger.warning("没有可用的备用密钥") return None else: logger.warning(f"未找到要临时禁用的密钥") return None def retry_request(self, original_key: str, request_func: Callable, max_retries: int = 1, max_key_switches: int = 3) -> Tuple[bool, Any, str]: """ 出错时自动重试请求,并在需要时切换密钥 Args: original_key: 原始API密钥(可能包含Bearer前缀) request_func: 执行请求的函数,接受一个参数(密钥)并返回(成功标志, 结果) max_retries: 使用同一密钥的最大重试次数 max_key_switches: 最大密钥切换次数 Returns: Tuple[bool, Any, str]: (是否成功, 请求结果, 使用的密钥) """ current_key = original_key current_key_switches = 0 # 首先用原始密钥尝试 for attempt in range(max_retries + 1): # +1是因为第一次不算重试 try: success, result = request_func(current_key) # 成功的请求不应该导致密钥被禁用 if success: # 记录请求成功,避免不必要的密钥禁用 with self._lock: self.record_request_result(current_key, True) return True, result, current_key logger.warning(f"请求失败(尝试 {attempt+1}/{max_retries+1}): {result}") except Exception as e: logger.error(f"请求异常(尝试 {attempt+1}/{max_retries+1}): {str(e)}") # 如果这不是最后一次尝试,等待一秒后重试 if attempt < max_retries: time.sleep(1) # 如果原始密钥的所有重试都失败,尝试切换密钥 tried_keys = set([current_key.replace("Bearer ", "") if current_key.startswith("Bearer ") else current_key]) while current_key_switches < max_key_switches: # 获取新的密钥 with self._lock: new_key = self.get_key() if not new_key: logger.warning("没有更多可用的密钥") break # 确保不使用已经尝试过的密钥 clean_new_key = new_key.replace("Bearer ", "") if new_key.startswith("Bearer ") else new_key if clean_new_key in tried_keys: continue tried_keys.add(clean_new_key) current_key = new_key current_key_switches += 1 logger.info(f"切换到新密钥 (切换 {current_key_switches}/{max_key_switches})") # 用新密钥尝试 for attempt in range(max_retries + 1): try: success, result = request_func(current_key) if success: # 记录请求成功 with self._lock: self.record_request_result(current_key, True) return True, result, current_key logger.warning(f"使用新密钥请求失败(尝试 {attempt+1}/{max_retries+1}): {result}") except Exception as e: logger.error(f"使用新密钥请求异常(尝试 {attempt+1}/{max_retries+1}): {str(e)}") # 如果这不是最后一次尝试,等待一秒后重试 if attempt < max_retries: time.sleep(1) # 所有尝试都失败,临时禁用原始密钥 # 但是在并发环境下,这可能是因为网络或服务问题,而非密钥问题 # 增加额外检查以减少不必要的密钥禁用 should_disable = True # 在临时禁用前,确认是否是密钥问题而非服务或网络问题 # 此处可以添加额外逻辑来判断是否应该禁用密钥 if should_disable: logger.error(f"所有重试和密钥切换尝试都失败,临时禁用原始密钥") with self._lock: self.mark_key_temp_disabled(original_key, hours=6.0) # 减少禁用时间,避免资源浪费 else: logger.warning(f"所有重试和密钥切换尝试都失败,但可能是服务问题而非密钥问题,不禁用密钥") # 返回最后一次尝试的结果 return False, result, current_key # 创建全局密钥管理器实例 storage_file = os.getenv("KEYS_STORAGE_FILE", "api_keys.json") # 如果提供了绝对路径则直接使用,否则使用相对路径 if not os.path.isabs(storage_file): base_dir = os.getenv("BASE_DIR", os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) storage_file = os.path.join(base_dir, storage_file) key_manager = KeyManager(storage_file=storage_file) logger.info(f"初始化全局密钥管理器,存储文件: {storage_file}")