soraapi / src /sora_integration.py
anycallzhf's picture
Initial commit for Hugging Face Space deployment
b064311
import asyncio
import concurrent.futures
from typing import List, Dict, Any, Optional, Union
import json
import os
import base64
import tempfile
import uuid
# 导入原有的SoraImageGenerator类
from .sora_generator import SoraImageGenerator
class SoraClient:
def __init__(self, proxy_host=None, proxy_port=None, proxy_user=None, proxy_pass=None, auth_token=None):
"""初始化Sora客户端,使用cloudscraper绕过CF验证"""
self.generator = SoraImageGenerator(
proxy_host=proxy_host,
proxy_port=proxy_port,
proxy_user=proxy_user,
proxy_pass=proxy_pass,
auth_token=auth_token
)
# 保存原始的auth_token,用于检测是否已更新
self.auth_token = auth_token
# 创建线程池执行器
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
async def generate_image(self, prompt: str, num_images: int = 1,
width: int = 720, height: int = 480) -> List[str]:
"""异步包装SoraImageGenerator.generate_image方法"""
loop = asyncio.get_running_loop()
# 使用线程池执行同步方法(因为cloudscraper不是异步的)
result = await loop.run_in_executor(
self.executor,
lambda: self.generator.generate_image(prompt, num_images, width, height)
)
# 检查generator中的auth_token是否已经被更新(由自动切换密钥机制)
if self.generator.auth_token != self.auth_token:
self.auth_token = self.generator.auth_token
if isinstance(result, list):
return result
else:
raise Exception(f"图像生成失败: {result}")
async def upload_image(self, image_path: str) -> Dict:
"""异步包装上传图片方法"""
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
self.executor,
lambda: self.generator.upload_image(image_path)
)
# 检查generator中的auth_token是否已经被更新
if self.generator.auth_token != self.auth_token:
self.auth_token = self.generator.auth_token
if isinstance(result, dict) and 'id' in result:
return result
else:
raise Exception(f"图片上传失败: {result}")
async def generate_image_remix(self, prompt: str, media_id: str,
num_images: int = 1) -> List[str]:
"""异步包装remix方法"""
loop = asyncio.get_running_loop()
# 处理可能包含API密钥信息的media_id对象
if isinstance(media_id, dict) and 'id' in media_id:
# 如果上传时使用的密钥与当前不同,则先切换密钥
if 'used_auth_token' in media_id and media_id['used_auth_token'] != self.auth_token:
self.auth_token = media_id['used_auth_token']
# 同步更新generator的auth_token
self.generator.auth_token = self.auth_token
# 提取实际的media_id
media_id = media_id['id']
result = await loop.run_in_executor(
self.executor,
lambda: self.generator.generate_image_remix(prompt, media_id, num_images)
)
# 检查generator中的auth_token是否已经被更新
if self.generator.auth_token != self.auth_token:
self.auth_token = self.generator.auth_token
if isinstance(result, list):
return result
else:
raise Exception(f"Remix生成失败: {result}")
async def test_connection(self) -> Dict:
"""测试API连接是否有效"""
try:
# 简单测试上传功能,这个方法会调用API但不会真正上传文件
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
self.executor,
lambda: self.generator.test_connection()
)
# 检查generator中的auth_token是否已经被更新
if self.generator.auth_token != self.auth_token:
self.auth_token = self.generator.auth_token
# 直接返回generator.test_connection的结果,保留所有信息
return result
except Exception as e:
return {"status": "error", "message": f"API连接测试失败: {str(e)}"}
def close(self):
"""关闭线程池"""
self.executor.shutdown(wait=False)