File size: 5,973 Bytes
2bad007 6b7f2d0 2bad007 6b7f2d0 bae320c 6b7f2d0 2bad007 6b7f2d0 2bad007 63f8f02 2bad007 63f8f02 2bad007 63f8f02 2bad007 6b7f2d0 2bad007 6b7f2d0 2bad007 6f044bd bae320c 2bad007 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# oss_utils.py
# OSS相关工具函数
import os
import oss2
from typing import List
import shutil
# OSS配置
OSS_CONFIG = {
"access_key_id": os.getenv("OSS_ACCESS_KEY_ID"),
"access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET"),
"endpoint": os.getenv("OSS_ENDPOINT"),
"bucket_name": os.getenv("OSS_BUCKET_NAME")
}
# 调试OSS配置信息
print(f"🔍 OSS CONFIG DEBUG:")
print(f" - access_key_id: {'✅' if OSS_CONFIG['access_key_id'] else '❌'} ({'***' + OSS_CONFIG['access_key_id'][-4:] if OSS_CONFIG['access_key_id'] else 'None'})")
print(f" - access_key_secret: {'✅' if OSS_CONFIG['access_key_secret'] else '❌'} ({'***' + OSS_CONFIG['access_key_secret'][-4:] if OSS_CONFIG['access_key_secret'] else 'None'})")
print(f" - endpoint: {OSS_CONFIG['endpoint'] or '❌ None'}")
print(f" - bucket_name: {OSS_CONFIG['bucket_name'] or '❌ None'}")
# 初始化OSS客户端
try:
auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"])
bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"])
print(f"✅ OSS client initialized successfully")
# 测试OSS连接
try:
# 尝试列出bucket根目录的文件来测试连接
test_files = []
for i, obj in enumerate(oss2.ObjectIterator(bucket, max_keys=5)):
test_files.append(obj.key)
if i >= 4: # 只获取前5个
break
print(f"✅ OSS connection test successful, found {len(test_files)} test files")
if test_files:
print(f" Sample files: {test_files[:3]}")
except Exception as test_e:
print(f"⚠️ OSS connection test failed: {test_e}")
except Exception as e:
print(f"❌ OSS client initialization failed: {e}")
bucket = None
# 临时文件根目录
TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_ROOT, exist_ok=True)
def list_oss_files(folder_path: str) -> List[str]:
"""列出OSS文件夹中的所有文件"""
if bucket is None:
print(f"❌ OSS DEBUG: Bucket not initialized, cannot list files")
return []
files = []
try:
print(f"🔍 OSS DEBUG: Listing files with prefix: '{folder_path}'")
file_count = 0
for obj in oss2.ObjectIterator(bucket, prefix=folder_path):
if not obj.key.endswith('/'): # 排除目录本身
files.append(obj.key)
file_count += 1
if file_count <= 5: # 只输出前5个文件用于调试
print(f"🔍 OSS DEBUG: Found file: {obj.key}")
print(f"🔍 OSS DEBUG: Total files found: {len(files)}")
return sorted(files, key=lambda x: os.path.splitext(x)[0])
except Exception as e:
print(f"❌ OSS DEBUG: Error listing OSS files: {str(e)}")
print(f"❌ OSS DEBUG: Exception type: {type(e).__name__}")
print(f"❌ OSS DEBUG: folder_path was: '{folder_path}'")
return []
def download_oss_file(oss_path: str, local_path: str):
"""从OSS下载文件到本地"""
if bucket is None:
print(f"❌ OSS DEBUG: Bucket not initialized, cannot download file")
raise Exception("OSS bucket not initialized")
try:
# 确保本地目录存在
os.makedirs(os.path.dirname(local_path), exist_ok=True)
bucket.get_object_to_file(oss_path, local_path)
except Exception as e:
print(f"Error downloading file {oss_path}: {str(e)}")
raise
def oss_file_exists(oss_path: str) -> bool:
"""检查OSS文件是否存在"""
if bucket is None:
print(f"❌ OSS DEBUG: Bucket not initialized, cannot check file existence")
return False
try:
return bucket.object_exists(oss_path)
except Exception as e:
print(f"Error checking if file exists in OSS: {str(e)}")
return False
def get_user_tmp_dir(session_hash: str) -> str:
"""获取用户临时目录"""
user_dir = os.path.join(TMP_ROOT, str(session_hash))
os.makedirs(user_dir, exist_ok=True)
return user_dir
def clean_oss_result_path(result_folder: str, task_id: str) -> str:
"""统一的OSS结果路径清理函数"""
cleaned_result_folder = result_folder.strip('/')
if cleaned_result_folder.startswith('oss-waic/'):
cleaned_result_folder = cleaned_result_folder[9:] # 移除 'oss-waic/' 前缀
elif cleaned_result_folder.startswith('/oss-waic/'):
cleaned_result_folder = cleaned_result_folder[10:] # 移除 '/oss-waic/' 前缀
# 确保路径格式正确
if not cleaned_result_folder.startswith('gradio_demo/'):
cleaned_result_folder = f"gradio_demo/tasks/{task_id}"
return cleaned_result_folder
def test_oss_access(task_id: str = None):
"""测试OSS访问并查找特定任务的文件"""
if bucket is None:
print("❌ Cannot test OSS access - bucket not initialized")
return
test_paths = [
"gradio_demo/",
"gradio_demo/tasks/",
]
if task_id:
test_paths.extend([
f"gradio_demo/tasks/{task_id}/",
f"gradio_demo/tasks/{task_id}/images/",
f"gradio_demo/tasks/{task_id}/image/",
])
for path in test_paths:
try:
print(f"🔍 Testing path: {path}")
files = []
for i, obj in enumerate(oss2.ObjectIterator(bucket, prefix=path, max_keys=10)):
files.append(obj.key)
if i >= 9:
break
print(f" Found {len(files)} files")
if files:
print(f" Sample: {files[:3]}")
except Exception as e:
print(f" ❌ Error: {e}")
def cleanup_user_tmp_dir(session_hash: str):
"""清理用户临时目录"""
user_dir = os.path.join(TMP_ROOT, str(session_hash))
if os.path.exists(user_dir):
shutil.rmtree(user_dir)
|