File size: 6,230 Bytes
b064311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import uuid
import time
import re
import logging
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse

from ..models.schemas import ChatCompletionRequest
from ..api.dependencies import verify_api_key, get_sora_client_dep
from ..services.image_service import process_image_task, format_think_block
from ..services.streaming import generate_streaming_response, generate_streaming_remix_response
from ..key_manager import key_manager

# 设置日志
logger = logging.getLogger("sora-api.chat")

# 创建路由
router = APIRouter()

@router.post("/chat/completions")
async def chat_completions(
    request: ChatCompletionRequest,
    background_tasks: BackgroundTasks,
    client_info = Depends(get_sora_client_dep()),
    api_key: str = Depends(verify_api_key)
):
    """
    聊天完成端点 - 处理文本到图像和图像到图像的请求
    兼容OpenAI API格式
    """
    # 解析客户端信息
    sora_client, sora_auth_token = client_info
    
    # 记录开始时间
    start_time = time.time()
    success = False
    
    try:
        # 分析用户消息
        user_messages = [m for m in request.messages if m.role == "user"]
        if not user_messages:
            raise HTTPException(status_code=400, detail="至少需要一条用户消息")
        
        last_user_message = user_messages[-1]
        prompt = ""
        image_data = None
        
        # 提取提示词和图片数据
        if isinstance(last_user_message.content, str):
            # 简单的字符串内容
            prompt = last_user_message.content
            
            # 检查是否包含内嵌的base64图片
            pattern = r'data:image\/[^;]+;base64,([^"]+)'
            match = re.search(pattern, prompt)
            if match:
                image_data = match.group(1)
                # 从提示词中删除base64数据
                prompt = re.sub(pattern, "[已上传图片]", prompt)
        else:
            # 多模态内容,提取文本和图片
            content_items = last_user_message.content
            text_parts = []
            
            for item in content_items:
                if item.type == "text" and item.text:
                    text_parts.append(item.text)
                elif item.type == "image_url" and item.image_url:
                    # 如果有图片URL包含base64数据
                    url = item.image_url.get("url", "")
                    if url.startswith("data:image/"):
                        pattern = r'data:image\/[^;]+;base64,([^"]+)'
                        match = re.search(pattern, url)
                        if match:
                            image_data = match.group(1)
                            text_parts.append("[已上传图片]")
            
            prompt = " ".join(text_parts)
        
        # 检查是否为流式响应
        if request.stream:
            # 流式响应处理
            if image_data:
                response = StreamingResponse(
                    generate_streaming_remix_response(sora_client, prompt, image_data, request.n),
                    media_type="text/event-stream"
                )
            else:
                response = StreamingResponse(
                    generate_streaming_response(sora_client, prompt, request.n),
                    media_type="text/event-stream"
                )
            success = True
            
            # 记录请求结果
            response_time = time.time() - start_time
            key_manager.record_request_result(sora_auth_token, success, response_time)
            
            return response
        else:
            # 非流式响应 - 返回一个即时响应,表示任务已接收
            request_id = f"chatcmpl-{uuid.uuid4().hex}"
            
            # 创建后台任务
            if image_data:
                background_tasks.add_task(
                    process_image_task,
                    request_id,
                    sora_client,
                    "remix",
                    prompt,
                    image_data=image_data,
                    num_images=request.n
                )
            else:
                background_tasks.add_task(
                    process_image_task,
                    request_id,
                    sora_client,
                    "generation",
                    prompt,
                    num_images=request.n,
                    width=720,
                    height=480
                )
                
            # 返回正在处理的响应
            processing_message = "正在准备生成任务,请稍候..."
            response = {
                "id": request_id,
                "object": "chat.completion",
                "created": int(time.time()),
                "model": "sora-1.0",
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": format_think_block(processing_message)
                        },
                        "finish_reason": "processing"
                    }
                ],
                "usage": {
                    "prompt_tokens": len(prompt) // 4,
                    "completion_tokens": 10,
                    "total_tokens": len(prompt) // 4 + 10
                }
            }
            
            success = True
            
            # 记录请求结果
            response_time = time.time() - start_time
            key_manager.record_request_result(sora_auth_token, success, response_time)
            
            return JSONResponse(content=response)
            
    except Exception as e:
        success = False
        logger.error(f"处理聊天完成请求失败: {str(e)}", exc_info=True)
        
        # 记录请求结果
        response_time = time.time() - start_time
        key_manager.record_request_result(sora_auth_token, success, response_time)
        
        raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")