soraapi / src /api /chat.py
anycallzhf's picture
Initial commit for Hugging Face Space deployment
b064311
import uuid
import time
import re
import logging
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from ..models.schemas import ChatCompletionRequest
from ..api.dependencies import verify_api_key, get_sora_client_dep
from ..services.image_service import process_image_task, format_think_block
from ..services.streaming import generate_streaming_response, generate_streaming_remix_response
from ..key_manager import key_manager
# 设置日志
logger = logging.getLogger("sora-api.chat")
# 创建路由
router = APIRouter()
@router.post("/chat/completions")
async def chat_completions(
request: ChatCompletionRequest,
background_tasks: BackgroundTasks,
client_info = Depends(get_sora_client_dep()),
api_key: str = Depends(verify_api_key)
):
"""
聊天完成端点 - 处理文本到图像和图像到图像的请求
兼容OpenAI API格式
"""
# 解析客户端信息
sora_client, sora_auth_token = client_info
# 记录开始时间
start_time = time.time()
success = False
try:
# 分析用户消息
user_messages = [m for m in request.messages if m.role == "user"]
if not user_messages:
raise HTTPException(status_code=400, detail="至少需要一条用户消息")
last_user_message = user_messages[-1]
prompt = ""
image_data = None
# 提取提示词和图片数据
if isinstance(last_user_message.content, str):
# 简单的字符串内容
prompt = last_user_message.content
# 检查是否包含内嵌的base64图片
pattern = r'data:image\/[^;]+;base64,([^"]+)'
match = re.search(pattern, prompt)
if match:
image_data = match.group(1)
# 从提示词中删除base64数据
prompt = re.sub(pattern, "[已上传图片]", prompt)
else:
# 多模态内容,提取文本和图片
content_items = last_user_message.content
text_parts = []
for item in content_items:
if item.type == "text" and item.text:
text_parts.append(item.text)
elif item.type == "image_url" and item.image_url:
# 如果有图片URL包含base64数据
url = item.image_url.get("url", "")
if url.startswith("data:image/"):
pattern = r'data:image\/[^;]+;base64,([^"]+)'
match = re.search(pattern, url)
if match:
image_data = match.group(1)
text_parts.append("[已上传图片]")
prompt = " ".join(text_parts)
# 检查是否为流式响应
if request.stream:
# 流式响应处理
if image_data:
response = StreamingResponse(
generate_streaming_remix_response(sora_client, prompt, image_data, request.n),
media_type="text/event-stream"
)
else:
response = StreamingResponse(
generate_streaming_response(sora_client, prompt, request.n),
media_type="text/event-stream"
)
success = True
# 记录请求结果
response_time = time.time() - start_time
key_manager.record_request_result(sora_auth_token, success, response_time)
return response
else:
# 非流式响应 - 返回一个即时响应,表示任务已接收
request_id = f"chatcmpl-{uuid.uuid4().hex}"
# 创建后台任务
if image_data:
background_tasks.add_task(
process_image_task,
request_id,
sora_client,
"remix",
prompt,
image_data=image_data,
num_images=request.n
)
else:
background_tasks.add_task(
process_image_task,
request_id,
sora_client,
"generation",
prompt,
num_images=request.n,
width=720,
height=480
)
# 返回正在处理的响应
processing_message = "正在准备生成任务,请稍候..."
response = {
"id": request_id,
"object": "chat.completion",
"created": int(time.time()),
"model": "sora-1.0",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": format_think_block(processing_message)
},
"finish_reason": "processing"
}
],
"usage": {
"prompt_tokens": len(prompt) // 4,
"completion_tokens": 10,
"total_tokens": len(prompt) // 4 + 10
}
}
success = True
# 记录请求结果
response_time = time.time() - start_time
key_manager.record_request_result(sora_auth_token, success, response_time)
return JSONResponse(content=response)
except Exception as e:
success = False
logger.error(f"处理聊天完成请求失败: {str(e)}", exc_info=True)
# 记录请求结果
response_time = time.time() - start_time
key_manager.record_request_result(sora_auth_token, success, response_time)
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")