Magic-plus-1 / src /utils.py
HF User
🚀 Fresh deploy of Magic Articulate Enhanced MVP
e7b9fb6
"""
工具函数
"""
import os
import shutil
import tempfile
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
import numpy as np
import trimesh
logger = logging.getLogger(__name__)
def validate_file(file_path: str, max_size_mb: int = 50) -> Tuple[bool, str]:
"""
验证上传的文件
Args:
file_path: 文件路径
max_size_mb: 最大文件大小(MB)
Returns:
(是否有效, 错误信息)
"""
try:
if not os.path.exists(file_path):
return False, "文件不存在"
# 检查文件大小
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
if file_size_mb > max_size_mb:
return False, f"文件太大: {file_size_mb:.1f}MB > {max_size_mb}MB"
# 检查文件扩展名
file_ext = Path(file_path).suffix.lower()
supported_formats = ['.obj', '.glb', '.ply', '.stl']
if file_ext not in supported_formats:
return False, f"不支持的文件格式: {file_ext}"
# 尝试加载文件
try:
mesh = trimesh.load(file_path, force='mesh')
if not hasattr(mesh, 'vertices') or len(mesh.vertices) == 0:
return False, "文件无法解析为有效的3D模型"
except Exception as e:
return False, f"文件格式错误: {str(e)}"
return True, "文件有效"
except Exception as e:
return False, f"文件验证失败: {str(e)}"
def get_model_info(file_path: str) -> Dict[str, Any]:
"""
获取模型信息
Args:
file_path: 模型文件路径
Returns:
模型信息字典
"""
try:
mesh = trimesh.load(file_path, force='mesh')
# 计算基本信息
vertex_count = len(mesh.vertices) if hasattr(mesh, 'vertices') else 0
face_count = len(mesh.faces) if hasattr(mesh, 'faces') else 0
# 计算包围盒
if vertex_count > 0:
bounds = mesh.bounds
size = bounds[1] - bounds[0]
center = (bounds[0] + bounds[1]) / 2
else:
size = np.array([0, 0, 0])
center = np.array([0, 0, 0])
# 计算表面积和体积
surface_area = mesh.area if hasattr(mesh, 'area') else 0
volume = mesh.volume if hasattr(mesh, 'volume') else 0
return {
'file_name': os.path.basename(file_path),
'file_size_mb': os.path.getsize(file_path) / (1024 * 1024),
'vertex_count': vertex_count,
'face_count': face_count,
'bounding_box': {
'min': bounds[0].tolist() if vertex_count > 0 else [0, 0, 0],
'max': bounds[1].tolist() if vertex_count > 0 else [0, 0, 0],
'size': size.tolist(),
'center': center.tolist()
},
'surface_area': float(surface_area),
'volume': float(volume),
'is_watertight': mesh.is_watertight if hasattr(mesh, 'is_watertight') else False,
'is_closed': mesh.is_closed if hasattr(mesh, 'is_closed') else False
}
except Exception as e:
logger.error(f"Failed to get model info: {str(e)}")
return {
'file_name': os.path.basename(file_path),
'error': str(e)
}
def cleanup_temp_files(temp_dir: str, keep_files: Optional[List[str]] = None):
"""
清理临时文件
Args:
temp_dir: 临时目录
keep_files: 需要保留的文件列表
"""
try:
if not os.path.exists(temp_dir):
return
for file_name in os.listdir(temp_dir):
file_path = os.path.join(temp_dir, file_name)
if keep_files and file_name in keep_files:
continue
try:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
logger.warning(f"Failed to remove {file_path}: {str(e)}")
except Exception as e:
logger.error(f"Cleanup failed: {str(e)}")
def format_processing_time(seconds: float) -> str:
"""
格式化处理时间
Args:
seconds: 秒数
Returns:
格式化的时间字符串
"""
if seconds < 60:
return f"{seconds:.1f}秒"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}分钟"
else:
hours = seconds / 3600
return f"{hours:.1f}小时"
def get_prompt_suggestions(model_info: Dict[str, Any]) -> List[str]:
"""
根据模型信息获取提示建议
Args:
model_info: 模型信息
Returns:
提示建议列表
"""
suggestions = []
# 基于文件名的建议
file_name = model_info.get('file_name', '').lower()
if any(keyword in file_name for keyword in ['human', 'person', 'character', 'boy', 'girl']):
suggestions.extend([
"realistic human skeleton for walking animations",
"character with full body rig for game animation",
"human bone structure suitable for motion capture"
])
elif any(keyword in file_name for keyword in ['dog', 'cat', 'animal', 'pet']):
suggestions.extend([
"four-legged animal with spine and tail bones",
"quadruped skeleton for natural movement",
"animal bone structure with flexible spine"
])
elif any(keyword in file_name for keyword in ['bird', 'eagle', 'chicken']):
suggestions.extend([
"bird skeleton with wing bones for flight",
"avian bone structure with hollow bones",
"bird with articulated wings and tail"
])
elif any(keyword in file_name for keyword in ['robot', 'mech', 'mechanical']):
suggestions.extend([
"mechanical robot with joint articulation",
"industrial robot with precise joint control",
"mech suit with hydraulic joint system"
])
else:
suggestions.extend([
"articulated skeleton suitable for animation",
"flexible bone structure for general movement",
"skeleton with natural joint hierarchy"
])
# 基于模型复杂度的建议
vertex_count = model_info.get('vertex_count', 0)
if vertex_count > 10000:
suggestions.append("detailed skeleton for high-poly model")
elif vertex_count < 1000:
suggestions.append("simple skeleton for low-poly model")
return suggestions[:5] # 限制建议数量
def create_processing_status(stage: str, progress: float, message: str) -> Dict[str, Any]:
"""
创建处理状态信息
Args:
stage: 处理阶段
progress: 进度 (0-1)
message: 状态消息
Returns:
状态信息字典
"""
return {
'stage': stage,
'progress': min(max(progress, 0.0), 1.0),
'message': message,
'timestamp': __import__('time').time()
}
def estimate_processing_time(model_info: Dict[str, Any]) -> float:
"""
估算处理时间
Args:
model_info: 模型信息
Returns:
估算的处理时间(秒)
"""
try:
vertex_count = model_info.get('vertex_count', 1000)
face_count = model_info.get('face_count', 1000)
# 基于模型复杂度的简单估算
complexity_factor = (vertex_count + face_count) / 10000
base_time = 30 # 基础处理时间30秒
estimated_time = base_time * (1 + complexity_factor * 0.5)
return min(estimated_time, 120) # 最多120秒
except Exception:
return 60 # 默认60秒
def generate_download_filename(original_name: str, suffix: str) -> str:
"""
生成下载文件名
Args:
original_name: 原始文件名
suffix: 后缀
Returns:
新文件名
"""
base_name = os.path.splitext(original_name)[0]
return f"{base_name}_{suffix}"
def safe_json_serialize(obj: Any) -> Any:
"""
安全的JSON序列化
Args:
obj: 要序列化的对象
Returns:
可序列化的对象
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, dict):
return {k: safe_json_serialize(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [safe_json_serialize(item) for item in obj]
else:
return obj