Spaces:
Sleeping
Sleeping
""" | |
工具函数 | |
""" | |
import os | |
import shutil | |
import tempfile | |
import logging | |
from pathlib import Path | |
from typing import Optional, Dict, Any, List, Tuple | |
import numpy as np | |
import trimesh | |
logger = logging.getLogger(__name__) | |
def validate_file(file_path: str, max_size_mb: int = 50) -> Tuple[bool, str]: | |
""" | |
验证上传的文件 | |
Args: | |
file_path: 文件路径 | |
max_size_mb: 最大文件大小(MB) | |
Returns: | |
(是否有效, 错误信息) | |
""" | |
try: | |
if not os.path.exists(file_path): | |
return False, "文件不存在" | |
# 检查文件大小 | |
file_size_mb = os.path.getsize(file_path) / (1024 * 1024) | |
if file_size_mb > max_size_mb: | |
return False, f"文件太大: {file_size_mb:.1f}MB > {max_size_mb}MB" | |
# 检查文件扩展名 | |
file_ext = Path(file_path).suffix.lower() | |
supported_formats = ['.obj', '.glb', '.ply', '.stl'] | |
if file_ext not in supported_formats: | |
return False, f"不支持的文件格式: {file_ext}" | |
# 尝试加载文件 | |
try: | |
mesh = trimesh.load(file_path, force='mesh') | |
if not hasattr(mesh, 'vertices') or len(mesh.vertices) == 0: | |
return False, "文件无法解析为有效的3D模型" | |
except Exception as e: | |
return False, f"文件格式错误: {str(e)}" | |
return True, "文件有效" | |
except Exception as e: | |
return False, f"文件验证失败: {str(e)}" | |
def get_model_info(file_path: str) -> Dict[str, Any]: | |
""" | |
获取模型信息 | |
Args: | |
file_path: 模型文件路径 | |
Returns: | |
模型信息字典 | |
""" | |
try: | |
mesh = trimesh.load(file_path, force='mesh') | |
# 计算基本信息 | |
vertex_count = len(mesh.vertices) if hasattr(mesh, 'vertices') else 0 | |
face_count = len(mesh.faces) if hasattr(mesh, 'faces') else 0 | |
# 计算包围盒 | |
if vertex_count > 0: | |
bounds = mesh.bounds | |
size = bounds[1] - bounds[0] | |
center = (bounds[0] + bounds[1]) / 2 | |
else: | |
size = np.array([0, 0, 0]) | |
center = np.array([0, 0, 0]) | |
# 计算表面积和体积 | |
surface_area = mesh.area if hasattr(mesh, 'area') else 0 | |
volume = mesh.volume if hasattr(mesh, 'volume') else 0 | |
return { | |
'file_name': os.path.basename(file_path), | |
'file_size_mb': os.path.getsize(file_path) / (1024 * 1024), | |
'vertex_count': vertex_count, | |
'face_count': face_count, | |
'bounding_box': { | |
'min': bounds[0].tolist() if vertex_count > 0 else [0, 0, 0], | |
'max': bounds[1].tolist() if vertex_count > 0 else [0, 0, 0], | |
'size': size.tolist(), | |
'center': center.tolist() | |
}, | |
'surface_area': float(surface_area), | |
'volume': float(volume), | |
'is_watertight': mesh.is_watertight if hasattr(mesh, 'is_watertight') else False, | |
'is_closed': mesh.is_closed if hasattr(mesh, 'is_closed') else False | |
} | |
except Exception as e: | |
logger.error(f"Failed to get model info: {str(e)}") | |
return { | |
'file_name': os.path.basename(file_path), | |
'error': str(e) | |
} | |
def cleanup_temp_files(temp_dir: str, keep_files: Optional[List[str]] = None): | |
""" | |
清理临时文件 | |
Args: | |
temp_dir: 临时目录 | |
keep_files: 需要保留的文件列表 | |
""" | |
try: | |
if not os.path.exists(temp_dir): | |
return | |
for file_name in os.listdir(temp_dir): | |
file_path = os.path.join(temp_dir, file_name) | |
if keep_files and file_name in keep_files: | |
continue | |
try: | |
if os.path.isfile(file_path): | |
os.remove(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
logger.warning(f"Failed to remove {file_path}: {str(e)}") | |
except Exception as e: | |
logger.error(f"Cleanup failed: {str(e)}") | |
def format_processing_time(seconds: float) -> str: | |
""" | |
格式化处理时间 | |
Args: | |
seconds: 秒数 | |
Returns: | |
格式化的时间字符串 | |
""" | |
if seconds < 60: | |
return f"{seconds:.1f}秒" | |
elif seconds < 3600: | |
minutes = seconds / 60 | |
return f"{minutes:.1f}分钟" | |
else: | |
hours = seconds / 3600 | |
return f"{hours:.1f}小时" | |
def get_prompt_suggestions(model_info: Dict[str, Any]) -> List[str]: | |
""" | |
根据模型信息获取提示建议 | |
Args: | |
model_info: 模型信息 | |
Returns: | |
提示建议列表 | |
""" | |
suggestions = [] | |
# 基于文件名的建议 | |
file_name = model_info.get('file_name', '').lower() | |
if any(keyword in file_name for keyword in ['human', 'person', 'character', 'boy', 'girl']): | |
suggestions.extend([ | |
"realistic human skeleton for walking animations", | |
"character with full body rig for game animation", | |
"human bone structure suitable for motion capture" | |
]) | |
elif any(keyword in file_name for keyword in ['dog', 'cat', 'animal', 'pet']): | |
suggestions.extend([ | |
"four-legged animal with spine and tail bones", | |
"quadruped skeleton for natural movement", | |
"animal bone structure with flexible spine" | |
]) | |
elif any(keyword in file_name for keyword in ['bird', 'eagle', 'chicken']): | |
suggestions.extend([ | |
"bird skeleton with wing bones for flight", | |
"avian bone structure with hollow bones", | |
"bird with articulated wings and tail" | |
]) | |
elif any(keyword in file_name for keyword in ['robot', 'mech', 'mechanical']): | |
suggestions.extend([ | |
"mechanical robot with joint articulation", | |
"industrial robot with precise joint control", | |
"mech suit with hydraulic joint system" | |
]) | |
else: | |
suggestions.extend([ | |
"articulated skeleton suitable for animation", | |
"flexible bone structure for general movement", | |
"skeleton with natural joint hierarchy" | |
]) | |
# 基于模型复杂度的建议 | |
vertex_count = model_info.get('vertex_count', 0) | |
if vertex_count > 10000: | |
suggestions.append("detailed skeleton for high-poly model") | |
elif vertex_count < 1000: | |
suggestions.append("simple skeleton for low-poly model") | |
return suggestions[:5] # 限制建议数量 | |
def create_processing_status(stage: str, progress: float, message: str) -> Dict[str, Any]: | |
""" | |
创建处理状态信息 | |
Args: | |
stage: 处理阶段 | |
progress: 进度 (0-1) | |
message: 状态消息 | |
Returns: | |
状态信息字典 | |
""" | |
return { | |
'stage': stage, | |
'progress': min(max(progress, 0.0), 1.0), | |
'message': message, | |
'timestamp': __import__('time').time() | |
} | |
def estimate_processing_time(model_info: Dict[str, Any]) -> float: | |
""" | |
估算处理时间 | |
Args: | |
model_info: 模型信息 | |
Returns: | |
估算的处理时间(秒) | |
""" | |
try: | |
vertex_count = model_info.get('vertex_count', 1000) | |
face_count = model_info.get('face_count', 1000) | |
# 基于模型复杂度的简单估算 | |
complexity_factor = (vertex_count + face_count) / 10000 | |
base_time = 30 # 基础处理时间30秒 | |
estimated_time = base_time * (1 + complexity_factor * 0.5) | |
return min(estimated_time, 120) # 最多120秒 | |
except Exception: | |
return 60 # 默认60秒 | |
def generate_download_filename(original_name: str, suffix: str) -> str: | |
""" | |
生成下载文件名 | |
Args: | |
original_name: 原始文件名 | |
suffix: 后缀 | |
Returns: | |
新文件名 | |
""" | |
base_name = os.path.splitext(original_name)[0] | |
return f"{base_name}_{suffix}" | |
def safe_json_serialize(obj: Any) -> Any: | |
""" | |
安全的JSON序列化 | |
Args: | |
obj: 要序列化的对象 | |
Returns: | |
可序列化的对象 | |
""" | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
elif isinstance(obj, np.floating): | |
return float(obj) | |
elif isinstance(obj, np.integer): | |
return int(obj) | |
elif isinstance(obj, dict): | |
return {k: safe_json_serialize(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [safe_json_serialize(item) for item in obj] | |
else: | |
return obj |