Spaces:
Running
Running
import os | |
import logging | |
import dotenv | |
from typing import Dict, Any, List | |
from fastapi import APIRouter, Depends, HTTPException | |
from fastapi.responses import FileResponse | |
from ..models.schemas import ApiKeyCreate, ApiKeyUpdate, ConfigUpdate, LogLevelUpdate, BatchOperation, BatchImportOperation | |
from ..api.dependencies import verify_admin, verify_admin_jwt | |
from ..config import Config | |
from ..key_manager import key_manager | |
from ..sora_integration import SoraClient | |
# 设置日志 | |
logger = logging.getLogger("sora-api.admin") | |
# 日志系统配置 | |
class LogConfig: | |
LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper() | |
FORMAT = "%(asctime)s [%(levelname)s] %(message)s" | |
# 创建路由 | |
router = APIRouter(prefix="/api") | |
# 密钥管理API | |
async def get_all_keys(admin_token = Depends(verify_admin_jwt)): | |
"""获取所有API密钥""" | |
return key_manager.get_all_keys() | |
async def get_key(key_id: str, admin_token = Depends(verify_admin_jwt)): | |
"""获取单个API密钥详情""" | |
key = key_manager.get_key_by_id(key_id) | |
if not key: | |
raise HTTPException(status_code=404, detail="密钥不存在") | |
return key | |
async def create_key(key_data: ApiKeyCreate, admin_token = Depends(verify_admin_jwt)): | |
"""创建新API密钥""" | |
try: | |
# 确保密钥值包含 Bearer 前缀 | |
key_value = key_data.key_value | |
if not key_value.startswith("Bearer "): | |
key_value = f"Bearer {key_value}" | |
new_key = key_manager.add_key( | |
key_value, | |
name=key_data.name, | |
weight=key_data.weight, | |
rate_limit=key_data.rate_limit, | |
is_enabled=key_data.is_enabled, | |
notes=key_data.notes | |
) | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return new_key | |
except Exception as e: | |
logger.error(f"创建密钥失败: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=400, detail=str(e)) | |
async def update_key(key_id: str, key_data: ApiKeyUpdate, admin_token = Depends(verify_admin_jwt)): | |
"""更新API密钥信息""" | |
try: | |
# 如果提供了新的密钥值,确保包含Bearer前缀 | |
key_value = key_data.key_value | |
if key_value and not key_value.startswith("Bearer "): | |
key_value = f"Bearer {key_value}" | |
key_data.key_value = key_value | |
updated_key = key_manager.update_key( | |
key_id, | |
key_value=key_data.key_value, | |
name=key_data.name, | |
weight=key_data.weight, | |
rate_limit=key_data.rate_limit, | |
is_enabled=key_data.is_enabled, | |
notes=key_data.notes | |
) | |
if not updated_key: | |
raise HTTPException(status_code=404, detail="密钥不存在") | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return updated_key | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"更新密钥失败: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=400, detail=str(e)) | |
async def delete_key(key_id: str, admin_token = Depends(verify_admin_jwt)): | |
"""删除API密钥""" | |
success = key_manager.delete_key(key_id) | |
if not success: | |
raise HTTPException(status_code=404, detail="密钥不存在") | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return {"status": "success", "message": "密钥已删除"} | |
async def get_usage_stats(admin_token = Depends(verify_admin_jwt)): | |
"""获取API使用统计""" | |
stats = key_manager.get_usage_stats() | |
# 处理daily_usage数据,确保前端能够正确显示 | |
daily_usage = {} | |
keys_usage = {} | |
# 从past_7_days数据转换为daily_usage格式 | |
for date, counts in stats.get("past_7_days", {}).items(): | |
daily_usage[date] = counts.get("successful", 0) + counts.get("failed", 0) | |
# 获取每个密钥的使用情况 | |
for key in key_manager.keys: | |
key_id = key.get("id") | |
key_name = key.get("name") or f"密钥_{key_id[:6]}" | |
# 获取该密钥的使用统计 | |
if key_id in key_manager.usage_stats: | |
key_stats = key_manager.usage_stats[key_id] | |
total_requests = key_stats.get("total_requests", 0) | |
if total_requests > 0: | |
keys_usage[key_name] = total_requests | |
# 添加到返回数据中 | |
stats["daily_usage"] = daily_usage | |
stats["keys_usage"] = keys_usage | |
return stats | |
async def test_key(key_data: ApiKeyCreate, admin_token = Depends(verify_admin_jwt)): | |
"""测试API密钥是否有效""" | |
try: | |
# 获取密钥值 | |
key_value = key_data.key_value.strip() | |
# 确保密钥格式正确 | |
if not key_value.startswith("Bearer "): | |
key_value = f"Bearer {key_value}" | |
# 获取代理配置 | |
proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None | |
proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None | |
proxy_user = Config.PROXY_USER if Config.PROXY_USER and Config.PROXY_USER.strip() else None | |
proxy_pass = Config.PROXY_PASS if Config.PROXY_PASS and Config.PROXY_PASS.strip() else None | |
test_client = SoraClient( | |
proxy_host=proxy_host, | |
proxy_port=proxy_port, | |
proxy_user=proxy_user, | |
proxy_pass=proxy_pass, | |
auth_token=key_value | |
) | |
# 执行简单API调用测试连接 | |
test_result = await test_client.test_connection() | |
logger.info(f"API密钥测试结果: {test_result}") | |
# 检查底层测试结果的状态 | |
if test_result.get("status") == "success": | |
# API连接测试成功 | |
return { | |
"status": "success", | |
"message": "API密钥测试成功", | |
"details": test_result, | |
"success": True | |
} | |
else: | |
# API连接测试失败 | |
return { | |
"status": "error", | |
"message": f"API密钥测试失败: {test_result.get('message', '连接失败')}", | |
"details": test_result, | |
"success": False | |
} | |
except Exception as e: | |
logger.error(f"测试密钥失败: {str(e)}", exc_info=True) | |
return { | |
"status": "error", | |
"message": f"API密钥测试失败: {str(e)}", | |
"success": False | |
} | |
async def batch_operation(operation: Dict[str, Any], admin_token = Depends(verify_admin_jwt)): | |
"""批量操作API密钥""" | |
try: | |
action = operation.get("action") | |
logger.info(f"接收到批量操作请求: {action}") | |
if not action: | |
logger.warning("批量操作缺少action参数") | |
raise HTTPException(status_code=400, detail="缺少必要参数: action") | |
logger.info(f"批量操作类型: {action}") | |
if action == "import": | |
# 批量导入API密钥 | |
keys_data = operation.get("keys", []) | |
if not keys_data: | |
logger.warning("批量导入缺少keys数据") | |
raise HTTPException(status_code=400, detail="未提供密钥数据") | |
logger.info(f"准备导入 {len(keys_data)} 个密钥") | |
# 对每个密钥处理Bearer前缀 | |
for key_data in keys_data: | |
if isinstance(key_data, dict): | |
key_value = key_data.get("key", "").strip() | |
if key_value and not key_value.startswith("Bearer "): | |
key_data["key"] = f"Bearer {key_value}" | |
# 执行批量导入 | |
try: | |
result = key_manager.batch_import_keys(keys_data) | |
logger.info(f"导入结果: 成功={result['imported']}, 跳过={result['skipped']}") | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return { | |
"success": True, | |
"message": f"成功导入 {result['imported']} 个密钥,跳过 {result['skipped']} 个重复密钥", | |
"imported": result["imported"], | |
"skipped": result["skipped"] | |
} | |
except Exception as e: | |
logger.error(f"批量导入密钥错误: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"导入密钥失败: {str(e)}") | |
elif action not in ["enable", "disable", "delete"]: | |
logger.warning(f"不支持的批量操作: {action}") | |
raise HTTPException(status_code=400, detail=f"不支持的操作: {action}") | |
# 对于非导入操作,需要提供key_ids | |
key_ids = operation.get("key_ids", []) | |
if not key_ids: | |
logger.warning(f"{action}操作缺少key_ids参数") | |
raise HTTPException(status_code=400, detail="缺少必要参数: key_ids") | |
# 确保key_ids是一个列表 | |
if isinstance(key_ids, str): | |
key_ids = [key_ids] | |
logger.info(f"批量{action}操作 {len(key_ids)} 个密钥") | |
if action == "enable": | |
# 批量启用 | |
success_count = 0 | |
for key_id in key_ids: | |
updated = key_manager.update_key(key_id, is_enabled=True) | |
if updated: | |
success_count += 1 | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
logger.info(f"成功启用 {success_count} 个密钥") | |
return { | |
"success": True, | |
"message": f"已成功启用 {success_count} 个密钥", | |
"affected": success_count | |
} | |
elif action == "disable": | |
# 批量禁用 | |
success_count = 0 | |
for key_id in key_ids: | |
updated = key_manager.update_key(key_id, is_enabled=False) | |
if updated: | |
success_count += 1 | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
logger.info(f"成功禁用 {success_count} 个密钥") | |
return { | |
"success": True, | |
"message": f"已成功禁用 {success_count} 个密钥", | |
"affected": success_count | |
} | |
elif action == "delete": | |
# 批量删除 | |
success_count = 0 | |
for key_id in key_ids: | |
if key_manager.delete_key(key_id): | |
success_count += 1 | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
logger.info(f"成功删除 {success_count} 个密钥") | |
return { | |
"success": True, | |
"message": f"已成功删除 {success_count} 个密钥", | |
"affected": success_count | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"批量操作失败: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
# 配置管理API | |
async def get_config(admin_token = Depends(verify_admin_jwt)): | |
"""获取当前系统配置""" | |
return { | |
"HOST": Config.HOST, | |
"PORT": Config.PORT, | |
"BASE_URL": Config.BASE_URL, | |
"PROXY_HOST": Config.PROXY_HOST, | |
"PROXY_PORT": Config.PROXY_PORT, | |
"PROXY_USER": Config.PROXY_USER, | |
"PROXY_PASS": "******" if Config.PROXY_PASS else "", | |
"IMAGE_LOCALIZATION": Config.IMAGE_LOCALIZATION, | |
"IMAGE_SAVE_DIR": Config.IMAGE_SAVE_DIR, | |
"API_AUTH_TOKEN": bool(Config.API_AUTH_TOKEN) # 只返回是否设置,不返回实际值 | |
} | |
async def update_config(config_data: ConfigUpdate, admin_token = Depends(verify_admin_jwt)): | |
"""更新系统配置""" | |
try: | |
changes = [] | |
# 更新代理设置 | |
if config_data.PROXY_HOST is not None: | |
Config.PROXY_HOST = config_data.PROXY_HOST | |
changes.append("PROXY_HOST") | |
# 更新环境变量 | |
os.environ["PROXY_HOST"] = config_data.PROXY_HOST | |
if config_data.PROXY_PORT is not None: | |
Config.PROXY_PORT = config_data.PROXY_PORT | |
changes.append("PROXY_PORT") | |
# 更新环境变量 | |
os.environ["PROXY_PORT"] = config_data.PROXY_PORT | |
# 更新代理认证设置 | |
if config_data.PROXY_USER is not None: | |
Config.PROXY_USER = config_data.PROXY_USER | |
changes.append("PROXY_USER") | |
# 更新环境变量 | |
os.environ["PROXY_USER"] = config_data.PROXY_USER | |
if config_data.PROXY_PASS is not None: | |
Config.PROXY_PASS = config_data.PROXY_PASS | |
changes.append("PROXY_PASS") | |
# 更新环境变量 | |
os.environ["PROXY_PASS"] = config_data.PROXY_PASS | |
# 更新基础URL设置 | |
if config_data.BASE_URL is not None: | |
Config.BASE_URL = config_data.BASE_URL | |
changes.append("BASE_URL") | |
# 更新环境变量 | |
os.environ["BASE_URL"] = config_data.BASE_URL | |
# 更新图片本地化设置 | |
if config_data.IMAGE_LOCALIZATION is not None: | |
Config.IMAGE_LOCALIZATION = config_data.IMAGE_LOCALIZATION | |
changes.append("IMAGE_LOCALIZATION") | |
# 更新环境变量 | |
os.environ["IMAGE_LOCALIZATION"] = str(config_data.IMAGE_LOCALIZATION) | |
if config_data.IMAGE_SAVE_DIR is not None: | |
Config.IMAGE_SAVE_DIR = config_data.IMAGE_SAVE_DIR | |
changes.append("IMAGE_SAVE_DIR") | |
# 更新环境变量 | |
os.environ["IMAGE_SAVE_DIR"] = config_data.IMAGE_SAVE_DIR | |
# 确保目录存在 | |
os.makedirs(Config.IMAGE_SAVE_DIR, exist_ok=True) | |
# 保存到.env文件 | |
if config_data.save_to_env and changes: | |
env_file = os.path.join(Config.BASE_DIR, '.env') | |
env_data = {} | |
# 先读取现有的.env文件 | |
if os.path.exists(env_file): | |
env_data = dotenv.dotenv_values(env_file) | |
# 更新环境变量 | |
for field in changes: | |
value = getattr(Config, field) | |
env_data[field] = str(value) | |
# 写入.env文件 | |
with open(env_file, 'w') as f: | |
for key, value in env_data.items(): | |
f.write(f"{key}={value}\n") | |
logger.info(f"已将配置保存到.env文件: {changes}") | |
return { | |
"status": "success", | |
"message": f"配置已更新: {', '.join(changes) if changes else '无变更'}" | |
} | |
except Exception as e: | |
logger.error(f"更新配置失败: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=400, detail=f"更新配置失败: {str(e)}") | |
async def update_log_level(data: LogLevelUpdate, admin_token = Depends(verify_admin_jwt)): | |
"""更新日志级别""" | |
try: | |
# 验证日志级别 | |
level = data.level.upper() | |
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
if level not in valid_levels: | |
raise HTTPException(status_code=400, detail=f"无效的日志级别: {level}") | |
# 更新根日志记录器的级别 | |
root_logger = logging.getLogger() | |
root_logger.setLevel(getattr(logging, level)) | |
# 同时更新sora-api模块的日志级别 | |
sora_logger = logging.getLogger("sora-api") | |
sora_logger.setLevel(getattr(logging, level)) | |
# 记录日志级别变更 | |
logger.info(f"日志级别已更新为: {level}") | |
# 如果需要,保存到环境变量 | |
if data.save_to_env: | |
env_file = os.path.join(Config.BASE_DIR, '.env') | |
env_data = {} | |
# 先读取现有的.env文件 | |
if os.path.exists(env_file): | |
env_data = dotenv.dotenv_values(env_file) | |
# 更新LOG_LEVEL环境变量 | |
env_data["LOG_LEVEL"] = level | |
# 写入.env文件 | |
with open(env_file, 'w') as f: | |
for key, value in env_data.items(): | |
f.write(f"{key}={value}\n") | |
# 记录配置保存 | |
logger.info(f"已将日志级别保存到.env文件: LOG_LEVEL={level}") | |
return {"status": "success", "message": f"日志级别已更新为: {level}"} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"更新日志级别失败: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=400, detail=f"更新日志级别失败: {str(e)}") | |
# 管理员密钥API - 已移至app.py中 |