Spaces:
Running
Running
import asyncio | |
import base64 | |
import os | |
import tempfile | |
import time | |
import uuid | |
import logging | |
import threading | |
from typing import List, Dict, Any, Optional, Union, Tuple | |
from ..sora_integration import SoraClient | |
from ..config import Config | |
from ..utils import localize_image_urls | |
logger = logging.getLogger("sora-api.image_service") | |
# 存储生成结果的全局字典 | |
generation_results = {} | |
# 存储任务与API密钥的映射关系 | |
task_to_api_key = {} | |
# 将处理中状态消息格式化为think代码块 | |
def format_think_block(message: str) -> str: | |
"""将消息放入```think代码块中""" | |
return f"```think\n{message}\n```" | |
async def process_image_task( | |
request_id: str, | |
sora_client: SoraClient, | |
task_type: str, | |
prompt: str, | |
**kwargs | |
) -> None: | |
""" | |
统一的图像处理任务函数 | |
Args: | |
request_id: 请求ID | |
sora_client: Sora客户端实例 | |
task_type: 任务类型 ("generation" 或 "remix") | |
prompt: 提示词 | |
**kwargs: 其他参数,取决于任务类型 | |
""" | |
try: | |
# 保存当前任务使用的API密钥,以便后续使用同一密钥进行操作 | |
current_api_key = sora_client.auth_token | |
task_to_api_key[request_id] = current_api_key | |
# 更新状态为处理中 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在准备生成任务,请稍候..."), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key # 记录使用的API密钥 | |
} | |
# 根据任务类型执行不同操作 | |
if task_type == "generation": | |
# 文本到图像生成 | |
num_images = kwargs.get("num_images", 1) | |
width = kwargs.get("width", 720) | |
height = kwargs.get("height", 480) | |
# 更新状态 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在生成图像,请耐心等待..."), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
# 生成图像 | |
logger.info(f"[{request_id}] 开始生成图像, 提示词: {prompt}") | |
image_urls = await sora_client.generate_image( | |
prompt=prompt, | |
num_images=num_images, | |
width=width, | |
height=height | |
) | |
elif task_type == "remix": | |
# 图像到图像生成(Remix) | |
image_data = kwargs.get("image_data") | |
num_images = kwargs.get("num_images", 1) | |
if not image_data: | |
raise ValueError("缺少图像数据") | |
# 更新状态 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在处理上传的图片..."), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
# 保存base64图片到临时文件 | |
temp_dir = tempfile.mkdtemp() | |
temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
try: | |
# 解码并保存图片 | |
with open(temp_image_path, "wb") as f: | |
f.write(base64.b64decode(image_data)) | |
# 更新状态 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在上传图片到Sora服务..."), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
# 上传图片 - 确保使用与初始请求相同的API密钥 | |
upload_result = await sora_client.upload_image(temp_image_path) | |
media_id = upload_result['id'] | |
# 更新状态 | |
generation_results[request_id] = { | |
"status": "processing", | |
"message": format_think_block("正在基于图片生成新图像..."), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
# 执行remix生成 | |
logger.info(f"[{request_id}] 开始生成Remix图像, 提示词: {prompt}") | |
image_urls = await sora_client.generate_image_remix( | |
prompt=prompt, | |
media_id=media_id, | |
num_images=num_images | |
) | |
finally: | |
# 清理临时文件 | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
if os.path.exists(temp_dir): | |
os.rmdir(temp_dir) | |
else: | |
raise ValueError(f"未知的任务类型: {task_type}") | |
# 验证生成结果 | |
if isinstance(image_urls, str): | |
logger.warning(f"[{request_id}] 图像生成失败或返回了错误信息: {image_urls}") | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": image_urls, | |
"message": format_think_block(f"图像生成失败: {image_urls}"), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
return | |
if not image_urls: | |
logger.warning(f"[{request_id}] 图像生成返回了空列表") | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": "图像生成返回了空结果", | |
"message": format_think_block("图像生成失败: 服务器返回了空结果"), | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
return | |
logger.info(f"[{request_id}] 成功生成 {len(image_urls)} 张图片") | |
# 本地化图片URL | |
if Config.IMAGE_LOCALIZATION: | |
logger.info(f"[{request_id}] 准备进行图片本地化处理") | |
try: | |
localized_urls = await localize_image_urls(image_urls) | |
logger.info(f"[{request_id}] 图片本地化处理完成") | |
# 检查本地化结果 | |
if not localized_urls: | |
logger.warning(f"[{request_id}] 本地化处理返回了空列表,将使用原始URL") | |
localized_urls = image_urls | |
# 检查是否所有URL都被正确本地化 | |
local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url) | |
logger.info(f"[{request_id}] 本地化结果: 总计 {len(localized_urls)} 张图片,成功本地化 {local_count} 张") | |
if local_count == 0: | |
logger.warning(f"[{request_id}] 警告:没有一个URL被成功本地化,将使用原始URL") | |
localized_urls = image_urls | |
image_urls = localized_urls | |
except Exception as e: | |
logger.error(f"[{request_id}] 图片本地化过程中发生错误: {str(e)}", exc_info=True) | |
logger.info(f"[{request_id}] 由于错误,将使用原始URL") | |
else: | |
logger.info(f"[{request_id}] 图片本地化功能未启用,使用原始URL") | |
# 存储结果 | |
generation_results[request_id] = { | |
"status": "completed", | |
"image_urls": image_urls, | |
"timestamp": int(time.time()), | |
"api_key": current_api_key | |
} | |
# 30分钟后自动清理结果 | |
def cleanup_task(): | |
generation_results.pop(request_id, None) | |
task_to_api_key.pop(request_id, None) | |
threading.Timer(1800, cleanup_task).start() | |
except Exception as e: | |
error_message = f"图像生成失败: {str(e)}" | |
generation_results[request_id] = { | |
"status": "failed", | |
"error": error_message, | |
"message": format_think_block(error_message), | |
"timestamp": int(time.time()), | |
"api_key": sora_client.auth_token # 记录当前API密钥 | |
} | |
logger.error(f"图像生成失败 (ID: {request_id}): {str(e)}", exc_info=True) | |
def get_generation_result(request_id: str) -> Dict[str, Any]: | |
"""获取生成结果""" | |
if request_id not in generation_results: | |
return { | |
"status": "not_found", | |
"error": f"找不到生成任务: {request_id}", | |
"timestamp": int(time.time()) | |
} | |
return generation_results[request_id] | |
def get_task_api_key(request_id: str) -> Optional[str]: | |
"""获取任务对应的API密钥""" | |
return task_to_api_key.get(request_id) |