Spaces:
Running
Running
import json | |
import time | |
import asyncio | |
import logging | |
from typing import AsyncGenerator, List, Dict, Any | |
from ..sora_integration import SoraClient | |
from ..config import Config | |
from ..utils import localize_image_urls | |
from .image_service import format_think_block | |
logger = logging.getLogger("sora-api.streaming") | |
async def generate_streaming_response( | |
sora_client: SoraClient, | |
prompt: str, | |
n_images: int = 1 | |
) -> AsyncGenerator[str, None]: | |
""" | |
文本到图像的流式响应生成器 | |
Args: | |
sora_client: Sora客户端 | |
prompt: 提示词 | |
n_images: 生成图像数量 | |
Yields: | |
SSE格式的响应数据 | |
""" | |
request_id = f"chatcmpl-stream-{time.time()}-{hash(prompt) % 10000}" | |
# 发送开始事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" | |
# 发送处理中的消息(放在代码块中) | |
start_msg = "```think\n正在生成图像,请稍候...\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': start_msg}, 'finish_reason': None}]})}\n\n" | |
# 创建一个后台任务来生成图像 | |
logger.info(f"[流式响应 {request_id}] 开始生成图像, 提示词: {prompt}") | |
generation_task = asyncio.create_task(sora_client.generate_image( | |
prompt=prompt, | |
num_images=n_images, | |
width=720, | |
height=480 | |
)) | |
# 每5秒发送一条"仍在生成中"的消息,防止连接超时 | |
progress_messages = [ | |
"正在处理您的请求...", | |
"仍在生成图像中,请继续等待...", | |
"Sora正在创作您的图像作品...", | |
"图像生成需要一点时间,感谢您的耐心等待...", | |
"我们正在努力为您创作高质量图像..." | |
] | |
i = 0 | |
while not generation_task.done(): | |
# 每5秒发送一次进度消息 | |
await asyncio.sleep(5) | |
progress_msg = progress_messages[i % len(progress_messages)] | |
i += 1 | |
content = "\n" + progress_msg + "\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n" | |
try: | |
# 获取生成结果 | |
image_urls = await generation_task | |
logger.info(f"[流式响应 {request_id}] 图像生成完成,获取到 {len(image_urls) if isinstance(image_urls, list) else '非列表'} 个URL") | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION and isinstance(image_urls, list) and image_urls: | |
logger.info(f"[流式响应 {request_id}] 准备进行图片本地化处理") | |
try: | |
localized_urls = await localize_image_urls(image_urls) | |
image_urls = localized_urls | |
logger.info(f"[流式响应 {request_id}] 图片本地化处理完成") | |
except Exception as e: | |
logger.error(f"[流式响应 {request_id}] 图片本地化过程中发生错误: {str(e)}", exc_info=True) | |
logger.info(f"[流式响应 {request_id}] 由于错误,将使用原始URL") | |
elif not Config.IMAGE_LOCALIZATION: | |
logger.info(f"[流式响应 {request_id}] 图片本地化功能未启用") | |
# 结束代码块 | |
content_str = "\n```\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
# 添加生成的图片URLs | |
for i, url in enumerate(image_urls): | |
if i > 0: | |
content_str = "\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
image_markdown = f"" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n" | |
# 发送完成事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" | |
# 发送结束标志 | |
yield "data: [DONE]\n\n" | |
except Exception as e: | |
error_msg = f"图像生成失败: {str(e)}" | |
logger.error(f"[流式响应 {request_id}] 错误: {error_msg}", exc_info=True) | |
error_content = f"\n{error_msg}\n```" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n" | |
yield "data: [DONE]\n\n" | |
async def generate_streaming_remix_response( | |
sora_client: SoraClient, | |
prompt: str, | |
image_data: str, | |
n_images: int = 1 | |
) -> AsyncGenerator[str, None]: | |
""" | |
图像到图像的流式响应生成器 | |
Args: | |
sora_client: Sora客户端 | |
prompt: 提示词 | |
image_data: Base64编码的图像数据 | |
n_images: 生成图像数量 | |
Yields: | |
SSE格式的响应数据 | |
""" | |
import os | |
import tempfile | |
import uuid | |
import base64 | |
request_id = f"chatcmpl-stream-remix-{time.time()}-{hash(prompt) % 10000}" | |
# 发送开始事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" | |
try: | |
# 保存base64图片到临时文件 | |
temp_dir = tempfile.mkdtemp() | |
temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
try: | |
# 解码并保存图片 | |
with open(temp_image_path, "wb") as f: | |
f.write(base64.b64decode(image_data)) | |
# 上传图片 | |
upload_msg = "```think\n上传图片中...\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': upload_msg}, 'finish_reason': None}]})}\n\n" | |
logger.info(f"[流式响应Remix {request_id}] 上传图片中") | |
upload_result = await sora_client.upload_image(temp_image_path) | |
media_id = upload_result['id'] | |
# 发送生成中消息 | |
generate_msg = "\n基于图片生成新图像中...\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': generate_msg}, 'finish_reason': None}]})}\n\n" | |
# 创建后台任务生成图像 | |
logger.info(f"[流式响应Remix {request_id}] 开始生成图像,提示词: {prompt}") | |
generation_task = asyncio.create_task(sora_client.generate_image_remix( | |
prompt=prompt, | |
media_id=media_id, | |
num_images=n_images | |
)) | |
# 每5秒发送一条"仍在生成中"的消息 | |
progress_messages = [ | |
"正在处理您的请求...", | |
"仍在生成图像中,请继续等待...", | |
"Sora正在基于您的图片创作新作品...", | |
"图像生成需要一点时间,感谢您的耐心等待...", | |
"正在融合您的风格和提示词,打造专属图像..." | |
] | |
i = 0 | |
while not generation_task.done(): | |
await asyncio.sleep(5) | |
progress_msg = progress_messages[i % len(progress_messages)] | |
i += 1 | |
content = "\n" + progress_msg + "\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n" | |
# 获取生成结果 | |
image_urls = await generation_task | |
logger.info(f"[流式响应Remix {request_id}] 图像生成完成") | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION: | |
logger.info(f"[流式响应Remix {request_id}] 进行图片本地化处理") | |
localized_urls = await localize_image_urls(image_urls) | |
image_urls = localized_urls | |
logger.info(f"[流式响应Remix {request_id}] 图片本地化处理完成") | |
else: | |
logger.info(f"[流式响应Remix {request_id}] 图片本地化功能未启用") | |
# 结束代码块 | |
content_str = "\n```\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
# 发送图片URL作为Markdown | |
for i, url in enumerate(image_urls): | |
if i > 0: | |
newline_str = "\n\n" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': newline_str}, 'finish_reason': None}]})}\n\n" | |
image_markdown = f"" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n" | |
# 发送完成事件 | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" | |
# 发送结束标志 | |
yield "data: [DONE]\n\n" | |
finally: | |
# 清理临时文件 | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
if os.path.exists(temp_dir): | |
os.rmdir(temp_dir) | |
except Exception as e: | |
error_msg = f"图像Remix失败: {str(e)}" | |
logger.error(f"[流式响应Remix {request_id}] 错误: {error_msg}", exc_info=True) | |
error_content = f"\n{error_msg}\n```" | |
yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n" | |
# 结束流 | |
yield "data: [DONE]\n\n" |