File size: 4,666 Bytes
b064311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import asyncio
import concurrent.futures
from typing import List, Dict, Any, Optional, Union
import json
import os
import base64
import tempfile
import uuid

# 导入原有的SoraImageGenerator类
from .sora_generator import SoraImageGenerator

class SoraClient:
    def __init__(self, proxy_host=None, proxy_port=None, proxy_user=None, proxy_pass=None, auth_token=None):
        """初始化Sora客户端,使用cloudscraper绕过CF验证"""
        self.generator = SoraImageGenerator(
            proxy_host=proxy_host, 
            proxy_port=proxy_port,
            proxy_user=proxy_user,
            proxy_pass=proxy_pass,
            auth_token=auth_token
        )
        # 保存原始的auth_token,用于检测是否已更新
        self.auth_token = auth_token
        # 创建线程池执行器
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
        
    async def generate_image(self, prompt: str, num_images: int = 1, 
                           width: int = 720, height: int = 480) -> List[str]:
        """异步包装SoraImageGenerator.generate_image方法"""
        loop = asyncio.get_running_loop()
        # 使用线程池执行同步方法(因为cloudscraper不是异步的)
        result = await loop.run_in_executor(
            self.executor, 
            lambda: self.generator.generate_image(prompt, num_images, width, height)
        )
        
        # 检查generator中的auth_token是否已经被更新(由自动切换密钥机制)
        if self.generator.auth_token != self.auth_token:
            self.auth_token = self.generator.auth_token
        
        if isinstance(result, list):
            return result
        else:
            raise Exception(f"图像生成失败: {result}")
    
    async def upload_image(self, image_path: str) -> Dict:
        """异步包装上传图片方法"""
        loop = asyncio.get_running_loop()
        result = await loop.run_in_executor(
            self.executor,
            lambda: self.generator.upload_image(image_path)
        )
        
        # 检查generator中的auth_token是否已经被更新
        if self.generator.auth_token != self.auth_token:
            self.auth_token = self.generator.auth_token
        
        if isinstance(result, dict) and 'id' in result:
            return result
        else:
            raise Exception(f"图片上传失败: {result}")
            
    async def generate_image_remix(self, prompt: str, media_id: str, 
                                 num_images: int = 1) -> List[str]:
        """异步包装remix方法"""
        loop = asyncio.get_running_loop()
        
        # 处理可能包含API密钥信息的media_id对象
        if isinstance(media_id, dict) and 'id' in media_id:
            # 如果上传时使用的密钥与当前不同,则先切换密钥
            if 'used_auth_token' in media_id and media_id['used_auth_token'] != self.auth_token:
                self.auth_token = media_id['used_auth_token']
                # 同步更新generator的auth_token
                self.generator.auth_token = self.auth_token
            # 提取实际的media_id
            media_id = media_id['id']
            
        result = await loop.run_in_executor(
            self.executor,
            lambda: self.generator.generate_image_remix(prompt, media_id, num_images)
        )
        
        # 检查generator中的auth_token是否已经被更新
        if self.generator.auth_token != self.auth_token:
            self.auth_token = self.generator.auth_token
        
        if isinstance(result, list):
            return result
        else:
            raise Exception(f"Remix生成失败: {result}")
            
    async def test_connection(self) -> Dict:
        """测试API连接是否有效"""
        try:
            # 简单测试上传功能,这个方法会调用API但不会真正上传文件
            loop = asyncio.get_running_loop()
            result = await loop.run_in_executor(
                self.executor,
                lambda: self.generator.test_connection()
            )
            
            # 检查generator中的auth_token是否已经被更新
            if self.generator.auth_token != self.auth_token:
                self.auth_token = self.generator.auth_token
            
            # 直接返回generator.test_connection的结果,保留所有信息
            return result
        except Exception as e:
            return {"status": "error", "message": f"API连接测试失败: {str(e)}"}
            
    def close(self):
        """关闭线程池"""
        self.executor.shutdown(wait=False)