Spaces:
Running
Running
import asyncio | |
import concurrent.futures | |
from typing import List, Dict, Any, Optional, Union | |
import json | |
import os | |
import base64 | |
import tempfile | |
import uuid | |
# 导入原有的SoraImageGenerator类 | |
from .sora_generator import SoraImageGenerator | |
class SoraClient: | |
def __init__(self, proxy_host=None, proxy_port=None, proxy_user=None, proxy_pass=None, auth_token=None): | |
"""初始化Sora客户端,使用cloudscraper绕过CF验证""" | |
self.generator = SoraImageGenerator( | |
proxy_host=proxy_host, | |
proxy_port=proxy_port, | |
proxy_user=proxy_user, | |
proxy_pass=proxy_pass, | |
auth_token=auth_token | |
) | |
# 保存原始的auth_token,用于检测是否已更新 | |
self.auth_token = auth_token | |
# 创建线程池执行器 | |
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) | |
async def generate_image(self, prompt: str, num_images: int = 1, | |
width: int = 720, height: int = 480) -> List[str]: | |
"""异步包装SoraImageGenerator.generate_image方法""" | |
loop = asyncio.get_running_loop() | |
# 使用线程池执行同步方法(因为cloudscraper不是异步的) | |
result = await loop.run_in_executor( | |
self.executor, | |
lambda: self.generator.generate_image(prompt, num_images, width, height) | |
) | |
# 检查generator中的auth_token是否已经被更新(由自动切换密钥机制) | |
if self.generator.auth_token != self.auth_token: | |
self.auth_token = self.generator.auth_token | |
if isinstance(result, list): | |
return result | |
else: | |
raise Exception(f"图像生成失败: {result}") | |
async def upload_image(self, image_path: str) -> Dict: | |
"""异步包装上传图片方法""" | |
loop = asyncio.get_running_loop() | |
result = await loop.run_in_executor( | |
self.executor, | |
lambda: self.generator.upload_image(image_path) | |
) | |
# 检查generator中的auth_token是否已经被更新 | |
if self.generator.auth_token != self.auth_token: | |
self.auth_token = self.generator.auth_token | |
if isinstance(result, dict) and 'id' in result: | |
return result | |
else: | |
raise Exception(f"图片上传失败: {result}") | |
async def generate_image_remix(self, prompt: str, media_id: str, | |
num_images: int = 1) -> List[str]: | |
"""异步包装remix方法""" | |
loop = asyncio.get_running_loop() | |
# 处理可能包含API密钥信息的media_id对象 | |
if isinstance(media_id, dict) and 'id' in media_id: | |
# 如果上传时使用的密钥与当前不同,则先切换密钥 | |
if 'used_auth_token' in media_id and media_id['used_auth_token'] != self.auth_token: | |
self.auth_token = media_id['used_auth_token'] | |
# 同步更新generator的auth_token | |
self.generator.auth_token = self.auth_token | |
# 提取实际的media_id | |
media_id = media_id['id'] | |
result = await loop.run_in_executor( | |
self.executor, | |
lambda: self.generator.generate_image_remix(prompt, media_id, num_images) | |
) | |
# 检查generator中的auth_token是否已经被更新 | |
if self.generator.auth_token != self.auth_token: | |
self.auth_token = self.generator.auth_token | |
if isinstance(result, list): | |
return result | |
else: | |
raise Exception(f"Remix生成失败: {result}") | |
async def test_connection(self) -> Dict: | |
"""测试API连接是否有效""" | |
try: | |
# 简单测试上传功能,这个方法会调用API但不会真正上传文件 | |
loop = asyncio.get_running_loop() | |
result = await loop.run_in_executor( | |
self.executor, | |
lambda: self.generator.test_connection() | |
) | |
# 检查generator中的auth_token是否已经被更新 | |
if self.generator.auth_token != self.auth_token: | |
self.auth_token = self.generator.auth_token | |
# 直接返回generator.test_connection的结果,保留所有信息 | |
return result | |
except Exception as e: | |
return {"status": "error", "message": f"API连接测试失败: {str(e)}"} | |
def close(self): | |
"""关闭线程池""" | |
self.executor.shutdown(wait=False) |