File size: 5,973 Bytes
2bad007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b7f2d0
 
 
 
 
 
 
2bad007
6b7f2d0
 
 
 
bae320c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b7f2d0
 
 
2bad007
 
 
 
 
 
 
6b7f2d0
 
 
 
2bad007
 
63f8f02
 
2bad007
 
 
63f8f02
 
 
 
 
2bad007
 
63f8f02
 
 
2bad007
 
 
 
6b7f2d0
 
 
 
2bad007
 
 
 
 
 
 
 
 
 
6b7f2d0
 
 
 
2bad007
 
 
 
 
 
 
 
 
 
 
 
6f044bd
 
 
 
 
 
 
 
 
 
 
 
 
 
bae320c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bad007
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# oss_utils.py
# OSS相关工具函数
import os
import oss2
from typing import List
import shutil

# OSS配置
OSS_CONFIG = {
    "access_key_id": os.getenv("OSS_ACCESS_KEY_ID"),
    "access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET"),
    "endpoint": os.getenv("OSS_ENDPOINT"),
    "bucket_name": os.getenv("OSS_BUCKET_NAME")
}

# 调试OSS配置信息
print(f"🔍 OSS CONFIG DEBUG:")
print(f"  - access_key_id: {'✅' if OSS_CONFIG['access_key_id'] else '❌'} ({'***' + OSS_CONFIG['access_key_id'][-4:] if OSS_CONFIG['access_key_id'] else 'None'})")
print(f"  - access_key_secret: {'✅' if OSS_CONFIG['access_key_secret'] else '❌'} ({'***' + OSS_CONFIG['access_key_secret'][-4:] if OSS_CONFIG['access_key_secret'] else 'None'})")
print(f"  - endpoint: {OSS_CONFIG['endpoint'] or '❌ None'}")
print(f"  - bucket_name: {OSS_CONFIG['bucket_name'] or '❌ None'}")

# 初始化OSS客户端
try:
    auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"])
    bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"])
    print(f"✅ OSS client initialized successfully")
    
    # 测试OSS连接
    try:
        # 尝试列出bucket根目录的文件来测试连接
        test_files = []
        for i, obj in enumerate(oss2.ObjectIterator(bucket, max_keys=5)):
            test_files.append(obj.key)
            if i >= 4:  # 只获取前5个
                break
        print(f"✅ OSS connection test successful, found {len(test_files)} test files")
        if test_files:
            print(f"   Sample files: {test_files[:3]}")
    except Exception as test_e:
        print(f"⚠️  OSS connection test failed: {test_e}")
        
except Exception as e:
    print(f"❌ OSS client initialization failed: {e}")
    bucket = None

# 临时文件根目录
TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_ROOT, exist_ok=True)

def list_oss_files(folder_path: str) -> List[str]:
    """列出OSS文件夹中的所有文件"""
    if bucket is None:
        print(f"❌ OSS DEBUG: Bucket not initialized, cannot list files")
        return []
        
    files = []
    try:
        print(f"🔍 OSS DEBUG: Listing files with prefix: '{folder_path}'")
        file_count = 0
        for obj in oss2.ObjectIterator(bucket, prefix=folder_path):
            if not obj.key.endswith('/'):  # 排除目录本身
                files.append(obj.key)
                file_count += 1
                if file_count <= 5:  # 只输出前5个文件用于调试
                    print(f"🔍 OSS DEBUG: Found file: {obj.key}")
        
        print(f"🔍 OSS DEBUG: Total files found: {len(files)}")
        return sorted(files, key=lambda x: os.path.splitext(x)[0]) 
    except Exception as e:
        print(f"❌ OSS DEBUG: Error listing OSS files: {str(e)}")
        print(f"❌ OSS DEBUG: Exception type: {type(e).__name__}")
        print(f"❌ OSS DEBUG: folder_path was: '{folder_path}'")
        return []

def download_oss_file(oss_path: str, local_path: str):
    """从OSS下载文件到本地"""
    if bucket is None:
        print(f"❌ OSS DEBUG: Bucket not initialized, cannot download file")
        raise Exception("OSS bucket not initialized")
        
    try:
        # 确保本地目录存在
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        bucket.get_object_to_file(oss_path, local_path)
    except Exception as e:
        print(f"Error downloading file {oss_path}: {str(e)}")
        raise

def oss_file_exists(oss_path: str) -> bool:
    """检查OSS文件是否存在"""
    if bucket is None:
        print(f"❌ OSS DEBUG: Bucket not initialized, cannot check file existence")
        return False
        
    try:
        return bucket.object_exists(oss_path)
    except Exception as e:
        print(f"Error checking if file exists in OSS: {str(e)}")
        return False

def get_user_tmp_dir(session_hash: str) -> str:
    """获取用户临时目录"""
    user_dir = os.path.join(TMP_ROOT, str(session_hash))
    os.makedirs(user_dir, exist_ok=True)
    return user_dir

def clean_oss_result_path(result_folder: str, task_id: str) -> str:
    """统一的OSS结果路径清理函数"""
    cleaned_result_folder = result_folder.strip('/')
    if cleaned_result_folder.startswith('oss-waic/'):
        cleaned_result_folder = cleaned_result_folder[9:]  # 移除 'oss-waic/' 前缀
    elif cleaned_result_folder.startswith('/oss-waic/'):
        cleaned_result_folder = cleaned_result_folder[10:]  # 移除 '/oss-waic/' 前缀
    
    # 确保路径格式正确
    if not cleaned_result_folder.startswith('gradio_demo/'):
        cleaned_result_folder = f"gradio_demo/tasks/{task_id}"
    
    return cleaned_result_folder

def test_oss_access(task_id: str = None):
    """测试OSS访问并查找特定任务的文件"""
    if bucket is None:
        print("❌ Cannot test OSS access - bucket not initialized")
        return
    
    test_paths = [
        "gradio_demo/",
        "gradio_demo/tasks/",
    ]
    
    if task_id:
        test_paths.extend([
            f"gradio_demo/tasks/{task_id}/",
            f"gradio_demo/tasks/{task_id}/images/",
            f"gradio_demo/tasks/{task_id}/image/",
        ])
    
    for path in test_paths:
        try:
            print(f"🔍 Testing path: {path}")
            files = []
            for i, obj in enumerate(oss2.ObjectIterator(bucket, prefix=path, max_keys=10)):
                files.append(obj.key)
                if i >= 9:
                    break
            print(f"   Found {len(files)} files")
            if files:
                print(f"   Sample: {files[:3]}")
        except Exception as e:
            print(f"   ❌ Error: {e}")

def cleanup_user_tmp_dir(session_hash: str):
    """清理用户临时目录"""
    user_dir = os.path.join(TMP_ROOT, str(session_hash))
    if os.path.exists(user_dir):
        shutil.rmtree(user_dir)