Spaces:
Running
Running
File size: 11,591 Bytes
b064311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import json
import time
import asyncio
import logging
from typing import AsyncGenerator, List, Dict, Any
from ..sora_integration import SoraClient
from ..config import Config
from ..utils import localize_image_urls
from .image_service import format_think_block
logger = logging.getLogger("sora-api.streaming")
async def generate_streaming_response(
sora_client: SoraClient,
prompt: str,
n_images: int = 1
) -> AsyncGenerator[str, None]:
"""
文本到图像的流式响应生成器
Args:
sora_client: Sora客户端
prompt: 提示词
n_images: 生成图像数量
Yields:
SSE格式的响应数据
"""
request_id = f"chatcmpl-stream-{time.time()}-{hash(prompt) % 10000}"
# 发送开始事件
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
# 发送处理中的消息(放在代码块中)
start_msg = "```think\n正在生成图像,请稍候...\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': start_msg}, 'finish_reason': None}]})}\n\n"
# 创建一个后台任务来生成图像
logger.info(f"[流式响应 {request_id}] 开始生成图像, 提示词: {prompt}")
generation_task = asyncio.create_task(sora_client.generate_image(
prompt=prompt,
num_images=n_images,
width=720,
height=480
))
# 每5秒发送一条"仍在生成中"的消息,防止连接超时
progress_messages = [
"正在处理您的请求...",
"仍在生成图像中,请继续等待...",
"Sora正在创作您的图像作品...",
"图像生成需要一点时间,感谢您的耐心等待...",
"我们正在努力为您创作高质量图像..."
]
i = 0
while not generation_task.done():
# 每5秒发送一次进度消息
await asyncio.sleep(5)
progress_msg = progress_messages[i % len(progress_messages)]
i += 1
content = "\n" + progress_msg + "\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n"
try:
# 获取生成结果
image_urls = await generation_task
logger.info(f"[流式响应 {request_id}] 图像生成完成,获取到 {len(image_urls) if isinstance(image_urls, list) else '非列表'} 个URL")
# 本地化图片URL
if Config.IMAGE_LOCALIZATION and isinstance(image_urls, list) and image_urls:
logger.info(f"[流式响应 {request_id}] 准备进行图片本地化处理")
try:
localized_urls = await localize_image_urls(image_urls)
image_urls = localized_urls
logger.info(f"[流式响应 {request_id}] 图片本地化处理完成")
except Exception as e:
logger.error(f"[流式响应 {request_id}] 图片本地化过程中发生错误: {str(e)}", exc_info=True)
logger.info(f"[流式响应 {request_id}] 由于错误,将使用原始URL")
elif not Config.IMAGE_LOCALIZATION:
logger.info(f"[流式响应 {request_id}] 图片本地化功能未启用")
# 结束代码块
content_str = "\n```\n\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n"
# 添加生成的图片URLs
for i, url in enumerate(image_urls):
if i > 0:
content_str = "\n\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n"
image_markdown = f""
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n"
# 发送完成事件
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
# 发送结束标志
yield "data: [DONE]\n\n"
except Exception as e:
error_msg = f"图像生成失败: {str(e)}"
logger.error(f"[流式响应 {request_id}] 错误: {error_msg}", exc_info=True)
error_content = f"\n{error_msg}\n```"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n"
yield "data: [DONE]\n\n"
async def generate_streaming_remix_response(
sora_client: SoraClient,
prompt: str,
image_data: str,
n_images: int = 1
) -> AsyncGenerator[str, None]:
"""
图像到图像的流式响应生成器
Args:
sora_client: Sora客户端
prompt: 提示词
image_data: Base64编码的图像数据
n_images: 生成图像数量
Yields:
SSE格式的响应数据
"""
import os
import tempfile
import uuid
import base64
request_id = f"chatcmpl-stream-remix-{time.time()}-{hash(prompt) % 10000}"
# 发送开始事件
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
try:
# 保存base64图片到临时文件
temp_dir = tempfile.mkdtemp()
temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png")
try:
# 解码并保存图片
with open(temp_image_path, "wb") as f:
f.write(base64.b64decode(image_data))
# 上传图片
upload_msg = "```think\n上传图片中...\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': upload_msg}, 'finish_reason': None}]})}\n\n"
logger.info(f"[流式响应Remix {request_id}] 上传图片中")
upload_result = await sora_client.upload_image(temp_image_path)
media_id = upload_result['id']
# 发送生成中消息
generate_msg = "\n基于图片生成新图像中...\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': generate_msg}, 'finish_reason': None}]})}\n\n"
# 创建后台任务生成图像
logger.info(f"[流式响应Remix {request_id}] 开始生成图像,提示词: {prompt}")
generation_task = asyncio.create_task(sora_client.generate_image_remix(
prompt=prompt,
media_id=media_id,
num_images=n_images
))
# 每5秒发送一条"仍在生成中"的消息
progress_messages = [
"正在处理您的请求...",
"仍在生成图像中,请继续等待...",
"Sora正在基于您的图片创作新作品...",
"图像生成需要一点时间,感谢您的耐心等待...",
"正在融合您的风格和提示词,打造专属图像..."
]
i = 0
while not generation_task.done():
await asyncio.sleep(5)
progress_msg = progress_messages[i % len(progress_messages)]
i += 1
content = "\n" + progress_msg + "\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n"
# 获取生成结果
image_urls = await generation_task
logger.info(f"[流式响应Remix {request_id}] 图像生成完成")
# 本地化图片URL
if Config.IMAGE_LOCALIZATION:
logger.info(f"[流式响应Remix {request_id}] 进行图片本地化处理")
localized_urls = await localize_image_urls(image_urls)
image_urls = localized_urls
logger.info(f"[流式响应Remix {request_id}] 图片本地化处理完成")
else:
logger.info(f"[流式响应Remix {request_id}] 图片本地化功能未启用")
# 结束代码块
content_str = "\n```\n\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n"
# 发送图片URL作为Markdown
for i, url in enumerate(image_urls):
if i > 0:
newline_str = "\n\n"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': newline_str}, 'finish_reason': None}]})}\n\n"
image_markdown = f""
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n"
# 发送完成事件
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
# 发送结束标志
yield "data: [DONE]\n\n"
finally:
# 清理临时文件
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
if os.path.exists(temp_dir):
os.rmdir(temp_dir)
except Exception as e:
error_msg = f"图像Remix失败: {str(e)}"
logger.error(f"[流式响应Remix {request_id}] 错误: {error_msg}", exc_info=True)
error_content = f"\n{error_msg}\n```"
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n"
# 结束流
yield "data: [DONE]\n\n" |