Spaces:
Running
Running
import uuid | |
import time | |
import re | |
import logging | |
from typing import Dict, Any, List, Optional | |
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from ..models.schemas import ChatCompletionRequest | |
from ..api.dependencies import verify_api_key, get_sora_client_dep | |
from ..services.image_service import process_image_task, format_think_block | |
from ..services.streaming import generate_streaming_response, generate_streaming_remix_response | |
from ..key_manager import key_manager | |
# 设置日志 | |
logger = logging.getLogger("sora-api.chat") | |
# 创建路由 | |
router = APIRouter() | |
async def chat_completions( | |
request: ChatCompletionRequest, | |
background_tasks: BackgroundTasks, | |
client_info = Depends(get_sora_client_dep()), | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
聊天完成端点 - 处理文本到图像和图像到图像的请求 | |
兼容OpenAI API格式 | |
""" | |
# 解析客户端信息 | |
sora_client, sora_auth_token = client_info | |
# 记录开始时间 | |
start_time = time.time() | |
success = False | |
try: | |
# 分析用户消息 | |
user_messages = [m for m in request.messages if m.role == "user"] | |
if not user_messages: | |
raise HTTPException(status_code=400, detail="至少需要一条用户消息") | |
last_user_message = user_messages[-1] | |
prompt = "" | |
image_data = None | |
# 提取提示词和图片数据 | |
if isinstance(last_user_message.content, str): | |
# 简单的字符串内容 | |
prompt = last_user_message.content | |
# 检查是否包含内嵌的base64图片 | |
pattern = r'data:image\/[^;]+;base64,([^"]+)' | |
match = re.search(pattern, prompt) | |
if match: | |
image_data = match.group(1) | |
# 从提示词中删除base64数据 | |
prompt = re.sub(pattern, "[已上传图片]", prompt) | |
else: | |
# 多模态内容,提取文本和图片 | |
content_items = last_user_message.content | |
text_parts = [] | |
for item in content_items: | |
if item.type == "text" and item.text: | |
text_parts.append(item.text) | |
elif item.type == "image_url" and item.image_url: | |
# 如果有图片URL包含base64数据 | |
url = item.image_url.get("url", "") | |
if url.startswith("data:image/"): | |
pattern = r'data:image\/[^;]+;base64,([^"]+)' | |
match = re.search(pattern, url) | |
if match: | |
image_data = match.group(1) | |
text_parts.append("[已上传图片]") | |
prompt = " ".join(text_parts) | |
# 检查是否为流式响应 | |
if request.stream: | |
# 流式响应处理 | |
if image_data: | |
response = StreamingResponse( | |
generate_streaming_remix_response(sora_client, prompt, image_data, request.n), | |
media_type="text/event-stream" | |
) | |
else: | |
response = StreamingResponse( | |
generate_streaming_response(sora_client, prompt, request.n), | |
media_type="text/event-stream" | |
) | |
success = True | |
# 记录请求结果 | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
return response | |
else: | |
# 非流式响应 - 返回一个即时响应,表示任务已接收 | |
request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
# 创建后台任务 | |
if image_data: | |
background_tasks.add_task( | |
process_image_task, | |
request_id, | |
sora_client, | |
"remix", | |
prompt, | |
image_data=image_data, | |
num_images=request.n | |
) | |
else: | |
background_tasks.add_task( | |
process_image_task, | |
request_id, | |
sora_client, | |
"generation", | |
prompt, | |
num_images=request.n, | |
width=720, | |
height=480 | |
) | |
# 返回正在处理的响应 | |
processing_message = "正在准备生成任务,请稍候..." | |
response = { | |
"id": request_id, | |
"object": "chat.completion", | |
"created": int(time.time()), | |
"model": "sora-1.0", | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": format_think_block(processing_message) | |
}, | |
"finish_reason": "processing" | |
} | |
], | |
"usage": { | |
"prompt_tokens": len(prompt) // 4, | |
"completion_tokens": 10, | |
"total_tokens": len(prompt) // 4 + 10 | |
} | |
} | |
success = True | |
# 记录请求结果 | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
return JSONResponse(content=response) | |
except Exception as e: | |
success = False | |
logger.error(f"处理聊天完成请求失败: {str(e)}", exc_info=True) | |
# 记录请求结果 | |
response_time = time.time() - start_time | |
key_manager.record_request_result(sora_auth_token, success, response_time) | |
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}") |