soraapi / src /services /image_service.py
anycallzhf's picture
Initial commit for Hugging Face Space deployment
b064311
import asyncio
import base64
import os
import tempfile
import time
import uuid
import logging
import threading
from typing import List, Dict, Any, Optional, Union, Tuple
from ..sora_integration import SoraClient
from ..config import Config
from ..utils import localize_image_urls
logger = logging.getLogger("sora-api.image_service")
# 存储生成结果的全局字典
generation_results = {}
# 存储任务与API密钥的映射关系
task_to_api_key = {}
# 将处理中状态消息格式化为think代码块
def format_think_block(message: str) -> str:
"""将消息放入```think代码块中"""
return f"```think\n{message}\n```"
async def process_image_task(
request_id: str,
sora_client: SoraClient,
task_type: str,
prompt: str,
**kwargs
) -> None:
"""
统一的图像处理任务函数
Args:
request_id: 请求ID
sora_client: Sora客户端实例
task_type: 任务类型 ("generation" 或 "remix")
prompt: 提示词
**kwargs: 其他参数,取决于任务类型
"""
try:
# 保存当前任务使用的API密钥,以便后续使用同一密钥进行操作
current_api_key = sora_client.auth_token
task_to_api_key[request_id] = current_api_key
# 更新状态为处理中
generation_results[request_id] = {
"status": "processing",
"message": format_think_block("正在准备生成任务,请稍候..."),
"timestamp": int(time.time()),
"api_key": current_api_key # 记录使用的API密钥
}
# 根据任务类型执行不同操作
if task_type == "generation":
# 文本到图像生成
num_images = kwargs.get("num_images", 1)
width = kwargs.get("width", 720)
height = kwargs.get("height", 480)
# 更新状态
generation_results[request_id] = {
"status": "processing",
"message": format_think_block("正在生成图像,请耐心等待..."),
"timestamp": int(time.time()),
"api_key": current_api_key
}
# 生成图像
logger.info(f"[{request_id}] 开始生成图像, 提示词: {prompt}")
image_urls = await sora_client.generate_image(
prompt=prompt,
num_images=num_images,
width=width,
height=height
)
elif task_type == "remix":
# 图像到图像生成(Remix)
image_data = kwargs.get("image_data")
num_images = kwargs.get("num_images", 1)
if not image_data:
raise ValueError("缺少图像数据")
# 更新状态
generation_results[request_id] = {
"status": "processing",
"message": format_think_block("正在处理上传的图片..."),
"timestamp": int(time.time()),
"api_key": current_api_key
}
# 保存base64图片到临时文件
temp_dir = tempfile.mkdtemp()
temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png")
try:
# 解码并保存图片
with open(temp_image_path, "wb") as f:
f.write(base64.b64decode(image_data))
# 更新状态
generation_results[request_id] = {
"status": "processing",
"message": format_think_block("正在上传图片到Sora服务..."),
"timestamp": int(time.time()),
"api_key": current_api_key
}
# 上传图片 - 确保使用与初始请求相同的API密钥
upload_result = await sora_client.upload_image(temp_image_path)
media_id = upload_result['id']
# 更新状态
generation_results[request_id] = {
"status": "processing",
"message": format_think_block("正在基于图片生成新图像..."),
"timestamp": int(time.time()),
"api_key": current_api_key
}
# 执行remix生成
logger.info(f"[{request_id}] 开始生成Remix图像, 提示词: {prompt}")
image_urls = await sora_client.generate_image_remix(
prompt=prompt,
media_id=media_id,
num_images=num_images
)
finally:
# 清理临时文件
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
if os.path.exists(temp_dir):
os.rmdir(temp_dir)
else:
raise ValueError(f"未知的任务类型: {task_type}")
# 验证生成结果
if isinstance(image_urls, str):
logger.warning(f"[{request_id}] 图像生成失败或返回了错误信息: {image_urls}")
generation_results[request_id] = {
"status": "failed",
"error": image_urls,
"message": format_think_block(f"图像生成失败: {image_urls}"),
"timestamp": int(time.time()),
"api_key": current_api_key
}
return
if not image_urls:
logger.warning(f"[{request_id}] 图像生成返回了空列表")
generation_results[request_id] = {
"status": "failed",
"error": "图像生成返回了空结果",
"message": format_think_block("图像生成失败: 服务器返回了空结果"),
"timestamp": int(time.time()),
"api_key": current_api_key
}
return
logger.info(f"[{request_id}] 成功生成 {len(image_urls)} 张图片")
# 本地化图片URL
if Config.IMAGE_LOCALIZATION:
logger.info(f"[{request_id}] 准备进行图片本地化处理")
try:
localized_urls = await localize_image_urls(image_urls)
logger.info(f"[{request_id}] 图片本地化处理完成")
# 检查本地化结果
if not localized_urls:
logger.warning(f"[{request_id}] 本地化处理返回了空列表,将使用原始URL")
localized_urls = image_urls
# 检查是否所有URL都被正确本地化
local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url)
logger.info(f"[{request_id}] 本地化结果: 总计 {len(localized_urls)} 张图片,成功本地化 {local_count} 张")
if local_count == 0:
logger.warning(f"[{request_id}] 警告:没有一个URL被成功本地化,将使用原始URL")
localized_urls = image_urls
image_urls = localized_urls
except Exception as e:
logger.error(f"[{request_id}] 图片本地化过程中发生错误: {str(e)}", exc_info=True)
logger.info(f"[{request_id}] 由于错误,将使用原始URL")
else:
logger.info(f"[{request_id}] 图片本地化功能未启用,使用原始URL")
# 存储结果
generation_results[request_id] = {
"status": "completed",
"image_urls": image_urls,
"timestamp": int(time.time()),
"api_key": current_api_key
}
# 30分钟后自动清理结果
def cleanup_task():
generation_results.pop(request_id, None)
task_to_api_key.pop(request_id, None)
threading.Timer(1800, cleanup_task).start()
except Exception as e:
error_message = f"图像生成失败: {str(e)}"
generation_results[request_id] = {
"status": "failed",
"error": error_message,
"message": format_think_block(error_message),
"timestamp": int(time.time()),
"api_key": sora_client.auth_token # 记录当前API密钥
}
logger.error(f"图像生成失败 (ID: {request_id}): {str(e)}", exc_info=True)
def get_generation_result(request_id: str) -> Dict[str, Any]:
"""获取生成结果"""
if request_id not in generation_results:
return {
"status": "not_found",
"error": f"找不到生成任务: {request_id}",
"timestamp": int(time.time())
}
return generation_results[request_id]
def get_task_api_key(request_id: str) -> Optional[str]:
"""获取任务对应的API密钥"""
return task_to_api_key.get(request_id)