Spaces:
Running
Running
import asyncio | |
import json | |
import time | |
import uuid | |
import base64 | |
import os | |
import tempfile | |
import threading | |
import dotenv | |
import logging | |
from typing import List, Dict, Any, Optional, Union | |
from fastapi import FastAPI, HTTPException, Depends, Request, BackgroundTasks, File, UploadFile, Form | |
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel, Field | |
import uvicorn | |
import re | |
from .key_manager import KeyManager | |
from .sora_integration import SoraClient | |
from .config import Config | |
from .utils import localize_image_urls # 导入新增的图片本地化功能 | |
# 日志系统配置 | |
class LogConfig: | |
LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper() | |
FORMAT = "%(asctime)s [%(levelname)s] %(message)s" | |
# 初始化日志 | |
logging.basicConfig( | |
level=getattr(logging, LogConfig.LEVEL), | |
format=LogConfig.FORMAT, | |
datefmt="%Y-%m-%d %H:%M:%S" | |
) | |
logger = logging.getLogger("sora-api") | |
# 打印日志级别信息 | |
logger.info(f"日志级别设置为: {LogConfig.LEVEL}") | |
logger.info(f"要调整日志级别,请设置环境变量 LOG_LEVEL=DEBUG|INFO|WARNING|ERROR") | |
# 创建FastAPI应用 | |
app = FastAPI(title="OpenAI Compatible Sora API") | |
# 添加CORS支持 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 确保静态文件目录存在 | |
os.makedirs(os.path.join(Config.STATIC_DIR, "admin"), exist_ok=True) | |
os.makedirs(os.path.join(Config.STATIC_DIR, "admin/js"), exist_ok=True) | |
os.makedirs(os.path.join(Config.STATIC_DIR, "admin/css"), exist_ok=True) | |
os.makedirs(os.path.join(Config.STATIC_DIR, "images"), exist_ok=True) # 确保图片目录存在 | |
# 打印配置信息 | |
Config.print_config() | |
# 挂载静态文件目录 | |
app.mount("/static", StaticFiles(directory=Config.STATIC_DIR), name="static") | |
# 初始化Key管理器 | |
key_manager = KeyManager(storage_file=Config.KEYS_STORAGE_FILE) | |
# 初始化时保存管理员密钥 | |
Config.save_admin_key() | |
# 创建Sora客户端池 | |
sora_clients = {} | |
# 存储生成结果的全局字典 | |
generation_results = {} | |
# 请求模型 | |
class ContentItem(BaseModel): | |
type: str | |
text: Optional[str] = None | |
image_url: Optional[Dict[str, str]] = None | |
class ChatMessage(BaseModel): | |
role: str | |
content: Union[str, List[ContentItem]] | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 1.0 | |
top_p: Optional[float] = 1.0 | |
n: Optional[int] = 1 | |
stream: Optional[bool] = False | |
max_tokens: Optional[int] = None | |
presence_penalty: Optional[float] = 0 | |
frequency_penalty: Optional[float] = 0 | |
# API密钥管理模型 | |
class ApiKeyCreate(BaseModel): | |
name: str = Field(..., description="密钥名称") | |
key_value: str = Field(..., description="Bearer Token值") | |
weight: int = Field(default=1, ge=1, le=10, description="权重值") | |
rate_limit: int = Field(default=60, description="每分钟最大请求数") | |
is_enabled: bool = Field(default=True, description="是否启用") | |
notes: Optional[str] = Field(default=None, description="备注信息") | |
class ApiKeyUpdate(BaseModel): | |
name: Optional[str] = None | |
key_value: Optional[str] = None | |
weight: Optional[int] = None | |
rate_limit: Optional[int] = None | |
is_enabled: Optional[bool] = None | |
notes: Optional[str] = None | |
# 获取Sora客户端 | |
def get_sora_client(auth_token: str) -> SoraClient: | |
if auth_token not in sora_clients: | |
proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None | |
proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None | |
sora_clients[auth_token] = SoraClient( | |
proxy_host=proxy_host, | |
proxy_port=proxy_port, | |
auth_token=auth_token | |
) | |
return sora_clients[auth_token] | |
# 验证API key | |
async def verify_api_key(request: Request): | |
auth_header = request.headers.get("Authorization") | |
if not auth_header or not auth_header.startswith("Bearer "): | |
raise HTTPException(status_code=401, detail="缺少或无效的API key") | |
api_key = auth_header.replace("Bearer ", "") | |
# 在实际应用中,这里应该验证key的有效性 | |
# 这里简化处理,假设所有key都有效 | |
return api_key | |
# 验证管理员权限 | |
async def verify_admin(request: Request): | |
auth_header = request.headers.get("Authorization") | |
if not auth_header or not auth_header.startswith("Bearer "): | |
raise HTTPException(status_code=401, detail="未授权") | |
admin_key = auth_header.replace("Bearer ", "") | |
# 这里应该检查是否为管理员密钥 | |
# 简化处理,假设admin_key是预设的管理员密钥 | |
if admin_key != Config.ADMIN_KEY: | |
raise HTTPException(status_code=403, detail="没有管理员权限") | |
return admin_key | |
# 将处理中状态消息格式化为think代码块 | |
def format_think_block(message): | |
"""将消息放入```think代码块中""" | |
return f"```think\n{message}\n```" | |
# 后台任务处理函数 - 文本生成图像 | |
async def process_image_generation( | |
request_id: str, | |
sora_client: SoraClient, | |
prompt: str, | |
num_images: int = 1, | |
width: int = 720, | |
height: int = 480 | |
): | |
try: | |
# 更新状态为生成中 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在生成图像中,请耐心等待..."), | |
"timestamp": int(time.time()) | |
} | |
# 生成图像 | |
logger.info(f"[{request_id}] 开始生成图像, 提示词: {prompt}") | |
image_urls = await sora_client.generate_image( | |
prompt=prompt, | |
num_images=num_images, | |
width=width, | |
height=height | |
) | |
# 验证生成结果 | |
if isinstance(image_urls, str): | |
logger.warning(f"[{request_id}] 图像生成失败或返回了错误信息: {image_urls}") | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": image_urls, | |
"message": format_think_block(f"图像生成失败: {image_urls}"), | |
"timestamp": int(time.time()) | |
} | |
return | |
if not image_urls: | |
logger.warning(f"[{request_id}] 图像生成返回了空列表") | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": "图像生成返回了空结果", | |
"message": format_think_block("图像生成失败: 服务器返回了空结果"), | |
"timestamp": int(time.time()) | |
} | |
return | |
logger.info(f"[{request_id}] 成功生成 {len(image_urls)} 张图片") | |
if logger.isEnabledFor(logging.DEBUG): | |
for i, url in enumerate(image_urls): | |
logger.debug(f"[{request_id}] 图片 {i+1}: {url}") | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION: | |
logger.info(f"[{request_id}] 准备进行图片本地化处理") | |
logger.debug(f"[{request_id}] 图片本地化配置: 启用={Config.IMAGE_LOCALIZATION}, 保存目录={Config.IMAGE_SAVE_DIR}") | |
try: | |
localized_urls = await localize_image_urls(image_urls) | |
logger.info(f"[{request_id}] 图片本地化处理完成") | |
# 检查本地化结果 | |
if not localized_urls: | |
logger.warning(f"[{request_id}] 本地化处理返回了空列表,将使用原始URL") | |
localized_urls = image_urls | |
# 检查是否所有URL都被正确本地化 | |
local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url) | |
logger.info(f"[{request_id}] 本地化结果: 总计 {len(localized_urls)} 张图片,成功本地化 {local_count} 张") | |
if local_count == 0: | |
logger.warning(f"[{request_id}] 警告:没有一个URL被成功本地化,将使用原始URL") | |
localized_urls = image_urls | |
# 打印结果对比 | |
if logger.isEnabledFor(logging.DEBUG): | |
for i, (orig, local) in enumerate(zip(image_urls, localized_urls)): | |
logger.debug(f"[{request_id}] 图片 {i+1} 本地化结果: {orig} -> {local}") | |
image_urls = localized_urls | |
except Exception as e: | |
logger.error(f"[{request_id}] 图片本地化过程中发生错误: {str(e)}") | |
if logger.isEnabledFor(logging.DEBUG): | |
import traceback | |
logger.debug(traceback.format_exc()) | |
logger.info(f"[{request_id}] 由于错误,将使用原始URL") | |
else: | |
logger.info(f"[{request_id}] 图片本地化功能未启用,使用原始URL") | |
# 存储结果 | |
generation_results[request_id] = { | |
"status": "completed", | |
"image_urls": image_urls, | |
"timestamp": int(time.time()) | |
} | |
# 30分钟后自动清理结果 | |
threading.Timer(1800, lambda: generation_results.pop(request_id, None)).start() | |
except Exception as e: | |
error_message = f"图像生成失败: {str(e)}" | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": error_message, | |
"message": format_think_block(error_message), | |
"timestamp": int(time.time()) | |
} | |
logger.error(f"图像生成失败 (ID: {request_id}): {str(e)}") | |
if logger.isEnabledFor(logging.DEBUG): | |
import traceback | |
logger.debug(traceback.format_exc()) | |
# 后台任务处理函数 - 带图片的remix | |
async def process_image_remix( | |
request_id: str, | |
sora_client: SoraClient, | |
prompt: str, | |
image_data: str, | |
num_images: int = 1 | |
): | |
try: | |
# 更新状态为处理中 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在处理上传的图片..."), | |
"timestamp": int(time.time()) | |
} | |
# 保存base64图片到临时文件 | |
temp_dir = tempfile.mkdtemp() | |
temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
try: | |
# 解码并保存图片 | |
with open(temp_image_path, "wb") as f: | |
f.write(base64.b64decode(image_data)) | |
# 更新状态为上传中 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在上传图片到Sora服务..."), | |
"timestamp": int(time.time()) | |
} | |
# 上传图片 | |
upload_result = await sora_client.upload_image(temp_image_path) | |
media_id = upload_result['id'] | |
# 更新状态为生成中 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在基于图片生成新图像..."), | |
"timestamp": int(time.time()) | |
} | |
# 执行remix生成 | |
logger.info(f"[{request_id}] 开始生成Remix图像, 提示词: {prompt}") | |
image_urls = await sora_client.generate_image_remix( | |
prompt=prompt, | |
media_id=media_id, | |
num_images=num_images | |
) | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION: | |
logger.info(f"[{request_id}] 准备进行图片本地化处理") | |
localized_urls = await localize_image_urls(image_urls) | |
image_urls = localized_urls | |
logger.info(f"[{request_id}] Remix图片本地化处理完成") | |
# 存储结果 | |
generation_results[request_id] = { | |
"status": "completed", | |
"image_urls": image_urls, | |
"timestamp": int(time.time()) | |
} | |
# 30分钟后自动清理结果 | |
threading.Timer(1800, lambda: generation_results.pop(request_id, None)).start() | |
finally: | |
# 清理临时文件 | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
if os.path.exists(temp_dir): | |
os.rmdir(temp_dir) | |
except Exception as e: | |
error_message = f"图像Remix失败: {str(e)}" | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": error_message, | |
"message": format_think_block(error_message), | |
"timestamp": int(time.time()) | |
} | |
logger.error(f"图像Remix失败 (ID: {request_id}): {str(e)}") | |
# 添加一个新端点用于检查生成状态 | |
@app.get("/v1/generation/{request_id}") | |
async def check_generation_status(request_id: str, api_key: str = Depends(verify_api_key)): | |
""" | |
检查图像生成任务的状态 | |
""" | |
# 获取一个可用的key并记录开始时间 | |
sora_auth_token = key_manager.get_key() | |
if not sora_auth_token: | |
raise HTTPException(status_code=429, detail="所有API key都已达到速率限制") | |
start_time = time.time() | |
success = False | |
try: | |
if request_id not in generation_results: | |
raise HTTPException(status_code=404, detail=f"找不到生成任务: {request_id}") | |
result = generation_results[request_id] | |
if result["status"] == "completed": | |
image_urls = result["image_urls"] | |
# 构建OpenAI兼容的响应 | |
response = { | |
"id": request_id, | |
"object": "chat.completion", | |
"created": result["timestamp"], | |
"model": "sora-1.0", | |
"choices": [ | |
{ | |
"index": i, | |
"message": { | |
"role": "assistant", | |
"content": f"" | |
}, | |
"finish_reason": "stop" | |
} | |
for i, url in enumerate(image_urls) | |
], | |
"usage": { | |
"prompt_tokens": 0, # 简化的令牌计算 | |
"completion_tokens": 20, | |
"total_tokens": 20 | |
} | |
} | |
success = True | |
return JSONResponse(content=response) | |
elif result["status"] == "failed": | |
if "message" in result: | |
# 返回带有格式化错误消息的响应 | |
response = { | |
"id": request_id, | |
"object": "chat.completion", | |
"created": result["timestamp"], | |
"model": "sora-1.0", | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": result["message"] | |
}, | |
"finish_reason": "error" | |
} | |
], | |
"usage": { | |
"prompt_tokens": 0, | |
"completion_tokens": 10, | |
"total_tokens": 10 | |
} | |
} | |
success = False | |
return JSONResponse(content=response) | |
else: | |
# 向后兼容,使用老的方式 | |
raise HTTPException(status_code=500, detail=f"生成失败: {result['error']}") | |
else: # 处理中 | |
message = result.get("message", "```think\n正在生成图像,请稍候...\n```") | |
response = { | |
"id": request_id, | |
"object": "chat.completion", | |
"created": result["timestamp"], | |
"model": "sora-1.0", | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": message | |
}, | |
"finish_reason": "processing" | |
} | |
], | |
"usage": { | |
"prompt_tokens": 0, | |
"completion_tokens": 10, | |
"total_tokens": 10 | |
} | |
} | |
success = True | |
return JSONResponse(content=response) | |
except Exception as e: | |
success = False | |
raise HTTPException(status_code=500, detail=f"检查任务状态失败: {str(e)}") | |
finally: | |
# 记录请求结果 | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
# 聊天完成端点 | |
@app.post("/v1/chat/completions") | |
async def chat_completions( | |
request: ChatCompletionRequest, | |
api_key: str = Depends(verify_api_key), | |
background_tasks: BackgroundTasks = None | |
): | |
# 获取一个可用的key | |
sora_auth_token = key_manager.get_key() | |
if not sora_auth_token: | |
raise HTTPException(status_code=429, detail="所有API key都已达到速率限制") | |
# 获取Sora客户端 | |
sora_client = get_sora_client(sora_auth_token) | |
# 分析最后一条用户消息以提取内容 | |
user_messages = [m for m in request.messages if m.role == "user"] | |
if not user_messages: | |
raise HTTPException(status_code=400, detail="至少需要一条用户消息") | |
last_user_message = user_messages[-1] | |
prompt = "" | |
image_data = None | |
# 提取提示词和图片数据 | |
if isinstance(last_user_message.content, str): | |
# 简单的字符串内容 | |
prompt = last_user_message.content | |
# 检查是否包含内嵌的base64图片 | |
pattern = r'data:image\/[^;]+;base64,([^"]+)' | |
match = re.search(pattern, prompt) | |
if match: | |
image_data = match.group(1) | |
# 从提示词中删除base64数据,以保持提示词的可读性 | |
prompt = re.sub(pattern, "[已上传图片]", prompt) | |
else: | |
# 多模态内容,提取文本和图片 | |
content_items = last_user_message.content | |
text_parts = [] | |
for item in content_items: | |
if item.type == "text" and item.text: | |
text_parts.append(item.text) | |
elif item.type == "image_url" and item.image_url: | |
# 如果有图片URL包含base64数据 | |
url = item.image_url.get("url", "") | |
if url.startswith("data:image/"): | |
pattern = r'data:image\/[^;]+;base64,([^"]+)' | |
match = re.search(pattern, url) | |
if match: | |
image_data = match.group(1) | |
text_parts.append("[已上传图片]") | |
prompt = " ".join(text_parts) | |
# 记录开始时间 | |
start_time = time.time() | |
success = False | |
# 处理图片生成 | |
try: | |
# 检查是否为流式响应 | |
if request.stream: | |
# 流式响应特殊处理文本+图片的情况 | |
if image_data: | |
response = StreamingResponse( | |
generate_streaming_remix_response(sora_client, prompt, image_data, request.n), | |
media_type="text/event-stream" | |
) | |
else: | |
response = StreamingResponse( | |
generate_streaming_response(sora_client, prompt, request.n), | |
media_type="text/event-stream" | |
) | |
success = True | |
# 记录请求结果(流式响应立即记录) | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
return response | |
else: | |
# 对于非流式响应,返回一个即时响应,表示任务已接收 | |
# 创建一个唯一ID | |
request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
# 在结果字典中创建初始状态 | |
processing_message = "正在准备生成任务,请稍候..." | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block(processing_message), | |
"timestamp": int(time.time()) | |
} | |
# 添加后台任务 | |
if image_data: | |
background_tasks.add_task( | |
process_image_remix, | |
request_id, | |
sora_client, | |
prompt, | |
image_data, | |
request.n | |
) | |
else: | |
background_tasks.add_task( | |
process_image_generation, | |
request_id, | |
sora_client, | |
prompt, | |
request.n, | |
720, # width | |
480 # height | |
) | |
# 立即返回一个"正在处理中"的响应 | |
response = { | |
"id": request_id, | |
"object": "chat.completion", | |
"created": int(time.time()), | |
"model": "sora-1.0", | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": format_think_block(processing_message) | |
}, | |
"finish_reason": "processing" | |
} | |
], | |
"usage": { | |
"prompt_tokens": len(prompt) // 4, | |
"completion_tokens": 10, | |
"total_tokens": len(prompt) // 4 + 10 | |
} | |
} | |
success = True | |
# 记录请求结果(非流式响应立即记录) | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
return JSONResponse(content=response) | |
except Exception as e: | |
success = False | |
# 记录请求结果(异常情况也记录) | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}") | |
# 流式响应生成器 - 普通文本到图像 | |
async def generate_streaming_response( | |
sora_client: SoraClient, | |
prompt: str, | |
n_images: int = 1 | |
): | |
request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
# 发送开始事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" | |
# 发送处理中的消息(放在代码块中) | |
start_msg = "```think\n正在生成图像,请稍候...\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': start_msg}, 'finish_reason': None}]})}\n\n" | |
# 创建一个后台任务来生成图像 | |
logger.info(f"[流式响应 {request_id}] 开始生成图像, 提示词: {prompt}") | |
generation_task = asyncio.create_task(sora_client.generate_image( | |
prompt=prompt, | |
num_images=n_images, | |
width=720, | |
height=480 | |
)) | |
# 每5秒发送一条"仍在生成中"的消息,防止连接超时 | |
progress_messages = [ | |
"正在处理您的请求...", | |
"仍在生成图像中,请继续等待...", | |
"Sora正在创作您的图像作品...", | |
"图像生成需要一点时间,感谢您的耐心等待...", | |
"我们正在努力为您创作高质量图像..." | |
] | |
i = 0 | |
while not generation_task.done(): | |
# 每5秒发送一次进度消息 | |
await asyncio.sleep(5) | |
progress_msg = progress_messages[i % len(progress_messages)] | |
i += 1 | |
content = "\n" + progress_msg + "\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n" | |
try: | |
# 获取生成结果 | |
image_urls = await generation_task | |
logger.info(f"[流式响应 {request_id}] 图像生成完成,获取到 {len(image_urls) if isinstance(image_urls, list) else '非列表'} 个URL") | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION and isinstance(image_urls, list) and image_urls: | |
logger.info(f"[流式响应 {request_id}] 准备进行图片本地化处理") | |
try: | |
localized_urls = await localize_image_urls(image_urls) | |
logger.info(f"[流式响应 {request_id}] 图片本地化处理完成") | |
# 检查本地化结果 | |
if not localized_urls: | |
logger.warning(f"[流式响应 {request_id}] 本地化处理返回了空列表,将使用原始URL") | |
localized_urls = image_urls | |
# 检查是否所有URL都被正确本地化 | |
local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url) | |
if local_count == 0: | |
logger.warning(f"[流式响应 {request_id}] 警告:没有一个URL被成功本地化,将使用原始URL") | |
localized_urls = image_urls | |
else: | |
logger.info(f"[流式响应 {request_id}] 成功本地化 {local_count}/{len(localized_urls)} 张图片") | |
# 打印本地化对比结果 | |
if logger.isEnabledFor(logging.DEBUG): | |
for i, (orig, local) in enumerate(zip(image_urls, localized_urls)): | |
logger.debug(f"[流式响应 {request_id}] 图片 {i+1}: {orig} -> {local}") | |
image_urls = localized_urls | |
except Exception as e: | |
logger.error(f"[流式响应 {request_id}] 图片本地化过程中发生错误: {str(e)}") | |
if logger.isEnabledFor(logging.DEBUG): | |
import traceback | |
logger.debug(traceback.format_exc()) | |
logger.info(f"[流式响应 {request_id}] 由于错误,将使用原始URL") | |
elif not Config.IMAGE_LOCALIZATION: | |
logger.info(f"[流式响应 {request_id}] 图片本地化功能未启用") | |
elif not isinstance(image_urls, list) or not image_urls: | |
logger.warning(f"[流式响应 {request_id}] 无法进行本地化: 图像结果不是有效的URL列表") | |
# 结束代码块 | |
content_str = "\n```\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
# 添加生成的图片URLs | |
for i, url in enumerate(image_urls): | |
if i > 0: | |
content_str = "\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
image_markdown = f"" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n" | |
# 发送完成事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" | |
# 发送结束标志 | |
yield "data: [DONE]\n\n" | |
except Exception as e: | |
error_msg = f"图像生成失败: {str(e)}" | |
logger.error(f"[流式响应 {request_id}] 错误: {error_msg}") | |
if logger.isEnabledFor(logging.DEBUG): | |
import traceback | |
logger.debug(traceback.format_exc()) | |
error_content = f"\n{error_msg}\n```" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n" | |
yield "data: [DONE]\n\n" | |
# 流式响应生成器 - 带图片的remix | |
async def generate_streaming_remix_response( | |
sora_client: SoraClient, | |
prompt: str, | |
image_data: str, | |
n_images: int = 1 | |
): | |
request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
# 发送开始事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" | |
try: | |
# 保存base64图片到临时文件 | |
temp_dir = tempfile.mkdtemp() | |
temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
try: | |
# 解码并保存图片 | |
with open(temp_image_path, "wb") as f: | |
f.write(base64.b64decode(image_data)) | |
# 上传图片 | |
upload_msg = "```think\n上传图片中...\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': upload_msg}, 'finish_reason': None}]})}\n\n" | |
logger.info(f"[流式响应Remix {request_id}] 上传图片中") | |
upload_result = await sora_client.upload_image(temp_image_path) | |
media_id = upload_result['id'] | |
# 发送生成中消息 | |
generate_msg = "\n基于图片生成新图像中...\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': generate_msg}, 'finish_reason': None}]})}\n\n" | |
# 创建一个后台任务来生成图像 | |
logger.info(f"[流式响应Remix {request_id}] 开始生成图像,提示词: {prompt}") | |
generation_task = asyncio.create_task(sora_client.generate_image_remix( | |
prompt=prompt, | |
media_id=media_id, | |
num_images=n_images | |
)) | |
# 每5秒发送一条"仍在生成中"的消息,防止连接超时 | |
progress_messages = [ | |
"正在处理您的请求...", | |
"仍在生成图像中,请继续等待...", | |
"Sora正在基于您的图片创作新作品...", | |
"图像生成需要一点时间,感谢您的耐心等待...", | |
"正在努力融合您的风格和提示词,打造专属图像..." | |
] | |
i = 0 | |
while not generation_task.done(): | |
# 每5秒发送一次进度消息 | |
await asyncio.sleep(5) | |
progress_msg = progress_messages[i % len(progress_messages)] | |
i += 1 | |
content = "\n" + progress_msg + "\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n" | |
# 获取生成结果 | |
image_urls = await generation_task | |
logger.info(f"[流式响应Remix {request_id}] 图像生成完成") | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION: | |
logger.info(f"[流式响应Remix {request_id}] 进行图片本地化处理") | |
localized_urls = await localize_image_urls(image_urls) | |
image_urls = localized_urls | |
logger.info(f"[流式响应Remix {request_id}] 图片本地化处理完成") | |
else: | |
logger.info(f"[流式响应Remix {request_id}] 图片本地化功能未启用") | |
# 结束代码块 | |
content_str = "\n```\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
# 发送图片URL作为Markdown | |
for i, url in enumerate(image_urls): | |
if i > 0: | |
newline_str = "\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': newline_str}, 'finish_reason': None}]})}\n\n" | |
image_markdown = f"" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n" | |
# 发送完成事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" | |
# 发送结束标志 | |
yield "data: [DONE]\n\n" | |
finally: | |
# 清理临时文件 | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
if os.path.exists(temp_dir): | |
os.rmdir(temp_dir) | |
except Exception as e: | |
error_msg = f"图像Remix失败: {str(e)}" | |
logger.error(f"[流式响应Remix {request_id}] 错误: {error_msg}") | |
if logger.isEnabledFor(logging.DEBUG): | |
import traceback | |
logger.debug(traceback.format_exc()) | |
error_content = f"\n{error_msg}\n```" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n" | |
# 结束流 | |
yield "data: [DONE]\n\n" | |
# API密钥管理端点 | |
@app.get("/api/keys") | |
async def get_all_keys(admin_key: str = Depends(verify_admin)): | |
"""获取所有API密钥""" | |
return key_manager.get_all_keys() | |
@app.get("/api/keys/{key_id}") | |
async def get_key(key_id: str, admin_key: str = Depends(verify_admin)): | |
"""获取单个API密钥详情""" | |
key = key_manager.get_key_by_id(key_id) | |
if not key: | |
raise HTTPException(status_code=404, detail="密钥不存在") | |
return key | |
@app.post("/api/keys") | |
async def create_key(key_data: ApiKeyCreate, admin_key: str = Depends(verify_admin)): | |
"""创建新API密钥""" | |
try: | |
# 确保密钥值包含 Bearer 前缀 | |
key_value = key_data.key_value | |
if not key_value.startswith("Bearer "): | |
key_value = f"Bearer {key_value}" | |
new_key = key_manager.add_key( | |
key_value, | |
name=key_data.name, | |
weight=key_data.weight, | |
rate_limit=key_data.rate_limit, | |
is_enabled=key_data.is_enabled, | |
notes=key_data.notes | |
) | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return new_key | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
@app.put("/api/keys/{key_id}") | |
async def update_key(key_id: str, key_data: ApiKeyUpdate, admin_key: str = Depends(verify_admin)): | |
"""更新API密钥信息""" | |
try: | |
# 如果提供了新的密钥值,确保包含Bearer前缀 | |
key_value = key_data.key_value | |
if key_value and not key_value.startswith("Bearer "): | |
key_value = f"Bearer {key_value}" | |
key_data.key_value = key_value | |
updated_key = key_manager.update_key( | |
key_id, | |
key_value=key_data.key_value, | |
name=key_data.name, | |
weight=key_data.weight, | |
rate_limit=key_data.rate_limit, | |
is_enabled=key_data.is_enabled, | |
notes=key_data.notes | |
) | |
if not updated_key: | |
raise HTTPException(status_code=404, detail="密钥不存在") | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return updated_key | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
@app.delete("/api/keys/{key_id}") | |
async def delete_key(key_id: str, admin_key: str = Depends(verify_admin)): | |
"""删除API密钥""" | |
success = key_manager.delete_key(key_id) | |
if not success: | |
raise HTTPException(status_code=404, detail="密钥不存在") | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return {"status": "success", "message": "密钥已删除"} | |
@app.get("/api/stats") | |
async def get_usage_stats(admin_key: str = Depends(verify_admin)): | |
"""获取API使用统计""" | |
return key_manager.get_usage_stats() | |
@app.post("/api/keys/test") | |
async def test_key(key_data: ApiKeyCreate, admin_key: str = Depends(verify_admin)): | |
"""测试API密钥是否有效""" | |
try: | |
# 确保密钥值包含 Bearer 前缀 | |
key_value = key_data.key_value | |
if not key_value.startswith("Bearer "): | |
key_value = f"Bearer {key_value}" | |
# 获取代理配置 | |
proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None | |
proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None | |
# 创建临时客户端测试连接 | |
test_client = SoraClient( | |
proxy_host=proxy_host, | |
proxy_port=proxy_port, | |
auth_token=key_value | |
) | |
# 执行简单API调用测试连接 | |
test_result = await test_client.test_connection() | |
return {"status": "success", "message": "API密钥测试成功", "details": test_result} | |
except Exception as e: | |
return {"status": "error", "message": f"API密钥测试失败: {str(e)}"} | |
@app.post("/api/keys/batch") | |
async def batch_operation(operation: Dict[str, Any], admin_key: str = Depends(verify_admin)): | |
"""批量操作API密钥""" | |
action = operation.get("action") | |
key_ids = operation.get("key_ids", []) | |
if not action or not key_ids: | |
raise HTTPException(status_code=400, detail="无效的请求参数") | |
# 确保key_ids是一个列表 | |
if isinstance(key_ids, str): | |
key_ids = [key_ids] | |
results = {} | |
if action == "enable": | |
for key_id in key_ids: | |
success = key_manager.update_key(key_id, is_enabled=True) | |
results[key_id] = "success" if success else "failed" | |
elif action == "disable": | |
for key_id in key_ids: | |
success = key_manager.update_key(key_id, is_enabled=False) | |
results[key_id] = "success" if success else "failed" | |
elif action == "delete": | |
for key_id in key_ids: | |
success = key_manager.delete_key(key_id) | |
results[key_id] = "success" if success else "failed" | |
else: | |
raise HTTPException(status_code=400, detail="不支持的操作类型") | |
# 通过Config永久保存所有密钥 | |
Config.save_api_keys(key_manager.keys) | |
return {"status": "success", "results": results} | |
# 健康检查端点 | |
@app.get("/health") | |
async def health_check(): | |
return {"status": "ok", "timestamp": time.time()} | |
# 管理界面路由 | |
@app.get("/admin") | |
async def admin_panel(): | |
return FileResponse(os.path.join(Config.STATIC_DIR, "admin/index.html")) | |
# 管理员密钥API | |
@app.get("/admin/key") | |
async def admin_key(): | |
return {"admin_key": Config.ADMIN_KEY} | |
# 挂载静态文件 | |
app.mount("/admin", StaticFiles(directory=os.path.join(Config.STATIC_DIR, "admin"), html=True), name="admin") | |
# 配置管理模型 | |
class ConfigUpdate(BaseModel): | |
IMAGE_LOCALIZATION: Optional[bool] = None | |
IMAGE_SAVE_DIR: Optional[str] = None | |
LOG_LEVEL: Optional[str] = None | |
# 配置管理页面 | |
@app.get("/admin/config") | |
async def config_panel(): | |
return FileResponse("src/static/admin/config.html") | |
# 获取当前配置 | |
@app.get("/api/config") | |
async def get_config(admin_key: str = Depends(verify_admin)): | |
"""获取当前系统配置""" | |
return { | |
"IMAGE_LOCALIZATION": Config.IMAGE_LOCALIZATION, | |
"IMAGE_SAVE_DIR": Config.IMAGE_SAVE_DIR, | |
"LOG_LEVEL": LogConfig.LEVEL | |
} | |
# 更新配置 | |
@app.post("/api/config") | |
async def update_config(config_data: ConfigUpdate, admin_key: str = Depends(verify_admin)): | |
"""更新系统配置""" | |
changes = {} | |
if config_data.IMAGE_LOCALIZATION is not None: | |
old_value = Config.IMAGE_LOCALIZATION | |
Config.IMAGE_LOCALIZATION = config_data.IMAGE_LOCALIZATION | |
os.environ["IMAGE_LOCALIZATION"] = str(config_data.IMAGE_LOCALIZATION) | |
changes["IMAGE_LOCALIZATION"] = { | |
"old": old_value, | |
"new": Config.IMAGE_LOCALIZATION | |
} | |
if config_data.IMAGE_SAVE_DIR is not None: | |
old_value = Config.IMAGE_SAVE_DIR | |
Config.IMAGE_SAVE_DIR = config_data.IMAGE_SAVE_DIR | |
os.environ["IMAGE_SAVE_DIR"] = config_data.IMAGE_SAVE_DIR | |
# 确保目录存在 | |
os.makedirs(Config.IMAGE_SAVE_DIR, exist_ok=True) | |
changes["IMAGE_SAVE_DIR"] = { | |
"old": old_value, | |
"new": Config.IMAGE_SAVE_DIR | |
} | |
if config_data.LOG_LEVEL is not None: | |
old_value = LogConfig.LEVEL | |
level = config_data.LOG_LEVEL.upper() | |
# 验证日志级别是否有效 | |
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
if level not in valid_levels: | |
raise HTTPException(status_code=400, detail=f"无效的日志级别,有效值:{', '.join(valid_levels)}") | |
# 更新日志级别 | |
LogConfig.LEVEL = level | |
logging.getLogger("sora-api").setLevel(getattr(logging, level)) | |
os.environ["LOG_LEVEL"] = level | |
changes["LOG_LEVEL"] = { | |
"old": old_value, | |
"new": level | |
} | |
logger.info(f"日志级别已更改为: {level}") | |
# 保存到.env文件以持久化配置 | |
try: | |
dotenv_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env") | |
# 读取现有.env文件 | |
env_vars = {} | |
if os.path.exists(dotenv_file): | |
with open(dotenv_file, "r") as f: | |
for line in f: | |
if line.strip() and not line.startswith("#"): | |
key, value = line.strip().split("=", 1) | |
env_vars[key] = value | |
# 更新值 | |
if config_data.IMAGE_LOCALIZATION is not None: | |
env_vars["IMAGE_LOCALIZATION"] = str(config_data.IMAGE_LOCALIZATION) | |
if config_data.IMAGE_SAVE_DIR is not None: | |
env_vars["IMAGE_SAVE_DIR"] = config_data.IMAGE_SAVE_DIR | |
if config_data.LOG_LEVEL is not None: | |
env_vars["LOG_LEVEL"] = config_data.LOG_LEVEL.upper() | |
# 写回文件 | |
with open(dotenv_file, "w") as f: | |
for key, value in env_vars.items(): | |
f.write(f"{key}={value}\n") | |
except Exception as e: | |
logger.error(f"保存配置到.env文件失败: {str(e)}") | |
return { | |
"success": True, | |
"message": "配置已更新", | |
"changes": changes | |
} | |
# 日志级别控制 | |
class LogLevelUpdate(BaseModel): | |
level: str = Field(..., description="日志级别") | |
@app.post("/api/logs/level") | |
async def update_log_level(data: LogLevelUpdate, admin_key: str = Depends(verify_admin)): | |
"""更新系统日志级别""" | |
level = data.level.upper() | |
# 验证日志级别是否有效 | |
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
if level not in valid_levels: | |
raise HTTPException(status_code=400, detail=f"无效的日志级别,有效值:{', '.join(valid_levels)}") | |
# 更新日志级别 | |
old_level = LogConfig.LEVEL | |
LogConfig.LEVEL = level | |
logging.getLogger("sora-api").setLevel(getattr(logging, level)) | |
os.environ["LOG_LEVEL"] = level | |
# 记录变更 | |
logger.info(f"日志级别已更改为: {level}") | |
# 更新.env文件 | |
try: | |
dotenv_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env") | |
# 读取现有.env文件 | |
env_vars = {} | |
if os.path.exists(dotenv_file): | |
with open(dotenv_file, "r") as f: | |
for line in f: | |
if line.strip() and not line.startswith("#"): | |
key, value = line.strip().split("=", 1) | |
env_vars[key] = value | |
# 更新日志级别 | |
env_vars["LOG_LEVEL"] = level | |
# 写回文件 | |
with open(dotenv_file, "w") as f: | |
for key, value in env_vars.items(): | |
f.write(f"{key}={value}\n") | |
except Exception as e: | |
logger.warning(f"保存日志级别到.env文件失败: {str(e)}") | |
return { | |
"success": True, | |
"message": f"日志级别已更改: {old_level} -> {level}" | |
} |