Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
MagicArticulate API - Enhanced Version | |
支持用户上传的3D模型文件和多用户结果管理 | |
""" | |
import os | |
import sys | |
import uuid | |
import json | |
import time | |
import shutil | |
import logging | |
import tempfile | |
import traceback | |
from pathlib import Path | |
from datetime import datetime | |
from typing import Dict, Any, List, Optional, Tuple | |
import torch | |
import trimesh | |
import numpy as np | |
from tqdm import tqdm | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed, DistributedDataParallelKwargs | |
# 添加父目录到路径以正确导入模块 | |
parent_dir = str(Path(__file__).parent.parent) | |
if parent_dir not in sys.path: | |
sys.path.insert(0, parent_dir) | |
print(f"🔍 ARTICULATE_API DEBUG: Current working directory: {os.getcwd()}") | |
print(f"🔍 ARTICULATE_API DEBUG: Script file path: {__file__}") | |
print(f"🔍 ARTICULATE_API DEBUG: Parent directory: {parent_dir}") | |
print(f"🔍 ARTICULATE_API DEBUG: sys.path includes:") | |
for i, path in enumerate(sys.path[:10]): # 只显示前10个避免太长 | |
print(f" {i}: {path}") | |
# 检查目录结构 | |
utils_path = os.path.join(parent_dir, 'utils') | |
skeleton_path = os.path.join(parent_dir, 'skeleton_models') | |
print(f"🔍 ARTICULATE_API DEBUG: utils path exists: {os.path.exists(utils_path)}") | |
print(f"🔍 ARTICULATE_API DEBUG: skeleton_models path exists: {os.path.exists(skeleton_path)}") | |
if os.path.exists(utils_path): | |
print(f"🔍 ARTICULATE_API DEBUG: utils contents: {os.listdir(utils_path)}") | |
from skeleton_models.skeletongen import SkeletonGPT | |
from utils.mesh_to_pc import MeshProcessor | |
from utils.save_utils import ( | |
save_mesh, pred_joints_and_bones, save_skeleton_to_txt, | |
save_args, remove_duplicate_joints, save_skeleton_obj, | |
render_mesh_with_skeleton | |
) | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class ModelValidator: | |
"""3D模型验证和修复类""" | |
SUPPORTED_FORMATS = {'.obj', '.glb', '.gltf', '.ply', '.stl', '.fbx', '.dae'} | |
MAX_VERTICES = 100000 # 最大顶点数 | |
MIN_VERTICES = 100 # 最小顶点数 | |
MAX_FILE_SIZE_MB = 100 # 最大文件大小 | |
def validate_file(file_path: str) -> Tuple[bool, str, Dict[str, Any]]: | |
""" | |
验证3D模型文件 | |
Returns: | |
(is_valid, error_message, model_info) | |
""" | |
try: | |
# 检查文件是否存在 | |
if not os.path.exists(file_path): | |
return False, "文件不存在", {} | |
# 检查文件大小 | |
file_size_mb = os.path.getsize(file_path) / (1024 * 1024) | |
if file_size_mb > ModelValidator.MAX_FILE_SIZE_MB: | |
return False, f"文件过大: {file_size_mb:.1f}MB > {ModelValidator.MAX_FILE_SIZE_MB}MB", {} | |
# 检查文件格式 | |
file_ext = Path(file_path).suffix.lower() | |
if file_ext not in ModelValidator.SUPPORTED_FORMATS: | |
return False, f"不支持的文件格式: {file_ext}", {} | |
# 尝试加载模型 | |
mesh = trimesh.load(file_path, force='mesh') | |
# 检查是否为有效网格 | |
if not hasattr(mesh, 'vertices') or not hasattr(mesh, 'faces'): | |
return False, "文件不包含有效的网格数据", {} | |
# 检查顶点数量 | |
vertex_count = len(mesh.vertices) | |
if vertex_count < ModelValidator.MIN_VERTICES: | |
return False, f"顶点数量过少: {vertex_count} < {ModelValidator.MIN_VERTICES}", {} | |
if vertex_count > ModelValidator.MAX_VERTICES: | |
return False, f"顶点数量过多: {vertex_count} > {ModelValidator.MAX_VERTICES}", {} | |
# 收集模型信息 | |
model_info = { | |
'file_name': os.path.basename(file_path), | |
'file_size_mb': file_size_mb, | |
'format': file_ext, | |
'vertex_count': vertex_count, | |
'face_count': len(mesh.faces) if hasattr(mesh, 'faces') else 0, | |
'bounds': mesh.bounds.tolist() if hasattr(mesh, 'bounds') else None, | |
'is_watertight': mesh.is_watertight if hasattr(mesh, 'is_watertight') else False, | |
'volume': float(mesh.volume) if hasattr(mesh, 'volume') else 0.0, | |
'area': float(mesh.area) if hasattr(mesh, 'area') else 0.0, | |
} | |
return True, "验证通过", model_info | |
except Exception as e: | |
return False, f"模型验证失败: {str(e)}", {} | |
def auto_repair_mesh(mesh: trimesh.Trimesh) -> Tuple[trimesh.Trimesh, List[str]]: | |
""" | |
自动修复网格问题 | |
Returns: | |
(repaired_mesh, repair_log) | |
""" | |
repair_log = [] | |
try: | |
# 移除重复顶点 | |
if mesh.is_volume: | |
original_vertices = len(mesh.vertices) | |
mesh.merge_vertices() | |
if len(mesh.vertices) < original_vertices: | |
repair_log.append(f"移除了 {original_vertices - len(mesh.vertices)} 个重复顶点") | |
# 修复法向量 | |
if not hasattr(mesh, 'vertex_normals') or mesh.vertex_normals is None: | |
mesh.fix_normals() | |
repair_log.append("修复了顶点法向量") | |
# 移除退化面 | |
original_faces = len(mesh.faces) | |
mesh.remove_degenerate_faces() | |
if len(mesh.faces) < original_faces: | |
repair_log.append(f"移除了 {original_faces - len(mesh.faces)} 个退化面") | |
# 填充孔洞(如果需要) | |
if not mesh.is_watertight and hasattr(mesh, 'fill_holes'): | |
try: | |
mesh.fill_holes() | |
repair_log.append("填充了网格孔洞") | |
except: | |
repair_log.append("尝试填充孔洞失败,但继续处理") | |
return mesh, repair_log | |
except Exception as e: | |
logger.warning(f"网格修复过程中出现错误: {str(e)}") | |
return mesh, repair_log + [f"修复过程出错: {str(e)}"] | |
class ModelPreprocessor: | |
"""模型预处理类""" | |
def convert_format(input_path: str, output_format: str = '.obj') -> str: | |
""" | |
转换模型格式 | |
Args: | |
input_path: 输入文件路径 | |
output_format: 输出格式 (默认为.obj) | |
Returns: | |
输出文件路径 | |
""" | |
try: | |
mesh = trimesh.load(input_path, force='mesh') | |
# 生成输出路径 | |
base_name = os.path.splitext(os.path.basename(input_path))[0] | |
output_path = os.path.join( | |
os.path.dirname(input_path), | |
f"{base_name}_converted{output_format}" | |
) | |
# 导出为指定格式 | |
mesh.export(output_path) | |
logger.info(f"格式转换完成: {input_path} -> {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"格式转换失败: {str(e)}") | |
raise | |
def simplify_mesh(mesh: trimesh.Trimesh, target_faces: int = 5000) -> trimesh.Trimesh: | |
""" | |
简化网格 | |
Args: | |
mesh: 输入网格 | |
target_faces: 目标面数 | |
Returns: | |
简化后的网格 | |
""" | |
try: | |
if len(mesh.faces) <= target_faces: | |
return mesh | |
# 使用quadric decimation进行简化 | |
simplified = mesh.simplify_quadratic_decimation(target_faces) | |
logger.info(f"网格简化: {len(mesh.faces)} -> {len(simplified.faces)} 面") | |
return simplified | |
except Exception as e: | |
logger.warning(f"网格简化失败: {str(e)}, 使用原始网格") | |
return mesh | |
def normalize_mesh(mesh: trimesh.Trimesh, scale_factor: float = 0.95) -> Tuple[trimesh.Trimesh, Dict[str, Any]]: | |
""" | |
标准化网格到标准坐标空间 | |
Args: | |
mesh: 输入网格 | |
scale_factor: 缩放因子 | |
Returns: | |
(normalized_mesh, transform_info) | |
""" | |
try: | |
# 计算边界框 | |
bounds = mesh.bounds | |
center = (bounds[0] + bounds[1]) / 2 | |
size = bounds[1] - bounds[0] | |
max_size = size.max() | |
# 计算变换参数 | |
scale = (2.0 * scale_factor) / max_size | |
translation = -center | |
# 应用变换 | |
vertices = mesh.vertices.copy() | |
vertices = (vertices + translation) * scale | |
# 创建新网格 | |
normalized_mesh = trimesh.Trimesh(vertices=vertices, faces=mesh.faces) | |
# 记录变换信息 | |
transform_info = { | |
'original_center': center.tolist(), | |
'original_size': size.tolist(), | |
'scale': float(scale), | |
'translation': translation.tolist() | |
} | |
logger.info(f"网格标准化完成: scale={scale:.4f}") | |
return normalized_mesh, transform_info | |
except Exception as e: | |
logger.error(f"网格标准化失败: {str(e)}") | |
raise | |
class UserSessionManager: | |
"""用户会话管理类""" | |
def __init__(self, base_dir: str = "user_sessions"): | |
self.base_dir = Path(base_dir) | |
self.base_dir.mkdir(exist_ok=True) | |
# 元数据文件 | |
self.metadata_file = self.base_dir / "sessions_metadata.json" | |
self.load_metadata() | |
def load_metadata(self): | |
"""加载会话元数据""" | |
if self.metadata_file.exists(): | |
with open(self.metadata_file, 'r', encoding='utf-8') as f: | |
self.sessions = json.load(f) | |
else: | |
self.sessions = {} | |
def save_metadata(self): | |
"""保存会话元数据""" | |
with open(self.metadata_file, 'w', encoding='utf-8') as f: | |
json.dump(self.sessions, f, indent=2, ensure_ascii=False) | |
def create_session(self, user_id: Optional[str] = None) -> str: | |
""" | |
创建新的用户会话 | |
Args: | |
user_id: 用户ID(可选) | |
Returns: | |
session_id | |
""" | |
session_id = str(uuid.uuid4()) | |
session_dir = self.base_dir / session_id | |
session_dir.mkdir(exist_ok=True) | |
# 创建子目录 | |
(session_dir / "uploads").mkdir(exist_ok=True) | |
(session_dir / "outputs").mkdir(exist_ok=True) | |
(session_dir / "temp").mkdir(exist_ok=True) | |
# 记录会话信息 | |
self.sessions[session_id] = { | |
'user_id': user_id, | |
'created_at': datetime.now().isoformat(), | |
'status': 'active', | |
'processed_models': [], | |
'last_activity': datetime.now().isoformat() | |
} | |
self.save_metadata() | |
logger.info(f"创建新会话: {session_id}") | |
return session_id | |
def get_session_dir(self, session_id: str) -> Path: | |
"""获取会话目录""" | |
session_dir = self.base_dir / session_id | |
if not session_dir.exists(): | |
raise ValueError(f"会话不存在: {session_id}") | |
return session_dir | |
def update_activity(self, session_id: str): | |
"""更新会话活动时间""" | |
if session_id in self.sessions: | |
self.sessions[session_id]['last_activity'] = datetime.now().isoformat() | |
self.save_metadata() | |
def add_processed_model(self, session_id: str, model_info: Dict[str, Any]): | |
"""添加已处理模型记录""" | |
if session_id in self.sessions: | |
self.sessions[session_id]['processed_models'].append(model_info) | |
self.update_activity(session_id) | |
def cleanup_old_sessions(self, max_age_days: int = 7): | |
"""清理旧会话""" | |
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 3600) | |
sessions_to_remove = [] | |
for session_id, session_info in self.sessions.items(): | |
last_activity = datetime.fromisoformat(session_info['last_activity']) | |
if last_activity.timestamp() < cutoff_time: | |
sessions_to_remove.append(session_id) | |
for session_id in sessions_to_remove: | |
try: | |
session_dir = self.base_dir / session_id | |
if session_dir.exists(): | |
shutil.rmtree(session_dir) | |
del self.sessions[session_id] | |
logger.info(f"清理旧会话: {session_id}") | |
except Exception as e: | |
logger.error(f"清理会话失败 {session_id}: {str(e)}") | |
if sessions_to_remove: | |
self.save_metadata() | |
class MagicArticulateAPI: | |
"""MagicArticulate API主类""" | |
def __init__(self, | |
model_weights_path: Optional[str] = None, | |
device: str = "auto", | |
session_base_dir: str = "user_sessions"): | |
self.device = self._setup_device(device) | |
self.model = None | |
self.accelerator = None | |
self.model_weights_path = model_weights_path | |
# 初始化会话管理器 | |
self.session_manager = UserSessionManager(session_base_dir) | |
# 默认处理参数 - 匹配原始demo.py设置 | |
self.default_args = { | |
'input_pc_num': 8192, | |
'num_beams': 1, | |
'n_discrete_size': 128, | |
'n_max_bones': 100, | |
'pad_id': -1, | |
'precision': 'fp16', | |
'batchsize_per_gpu': 1, | |
'apply_marching_cubes': False, | |
'octree_depth': 7, | |
'hier_order': False, # 匹配demo.py默认值 | |
'save_render': False, | |
'llm': 'facebook/opt-350m' # 匹配demo.py默认值 | |
} | |
self.initialized = False | |
logger.info("MagicArticulate API 初始化完成") | |
def _setup_device(self, device: str) -> torch.device: | |
"""设置计算设备""" | |
if device == "auto": | |
if torch.cuda.is_available(): | |
device = "cuda" | |
logger.info(f"使用GPU: {torch.cuda.get_device_name()}") | |
else: | |
device = "cpu" | |
logger.info("使用CPU") | |
return torch.device(device) | |
def initialize_model(self) -> bool: | |
"""初始化模型""" | |
try: | |
if self.initialized: | |
return True | |
logger.info("正在初始化MagicArticulate模型...") | |
# 设置加速器 | |
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
self.accelerator = Accelerator( | |
kwargs_handlers=[kwargs], | |
mixed_precision=self.default_args['precision'], | |
) | |
# 创建模型 | |
args = self._create_args_object() | |
self.model = SkeletonGPT(args) | |
if self.device.type == "cuda": | |
self.model = self.model.cuda() | |
# 加载预训练权重 | |
if self.model_weights_path and os.path.exists(self.model_weights_path): | |
logger.info(f"加载模型权重: {self.model_weights_path}") | |
pkg = torch.load(self.model_weights_path, map_location=self.device) | |
self.model.load_state_dict(pkg["model"]) | |
else: | |
error_msg = "预训练权重必须提供!当前使用随机初始化,结果将不准确。" | |
logger.error(error_msg) | |
# 不抛出错误,但给出强烈警告 | |
logger.error("⚠️ WARNING: 没有预训练权重,生成的骨骼结构将不准确!") | |
self.model.eval() | |
set_seed(0) | |
# 准备模型 | |
if self.accelerator: | |
self.model = self.accelerator.prepare(self.model) | |
self.initialized = True | |
logger.info("✅ 模型初始化成功") | |
return True | |
except Exception as e: | |
logger.error(f"❌ 模型初始化失败: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return False | |
def process_uploaded_model(self, | |
file_path: str, | |
session_id: Optional[str] = None, | |
user_prompt: str = "", | |
processing_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
""" | |
处理用户上传的3D模型 | |
Args: | |
file_path: 模型文件路径 | |
session_id: 会话ID(可选) | |
user_prompt: 用户提示词 | |
processing_options: 处理选项 | |
Returns: | |
处理结果字典 | |
""" | |
start_time = time.time() | |
try: | |
# 创建会话(如果未提供) | |
if not session_id: | |
session_id = self.session_manager.create_session() | |
logger.info(f"开始处理模型: {file_path}, 会话: {session_id}") | |
# 步骤1: 验证模型文件 | |
is_valid, error_msg, model_info = ModelValidator.validate_file(file_path) | |
if not is_valid: | |
return self._create_error_result(error_msg, session_id, start_time) | |
logger.info(f"模型验证通过: {model_info}") | |
# 步骤2: 复制文件到会话目录 | |
session_dir = self.session_manager.get_session_dir(session_id) | |
uploaded_file = session_dir / "uploads" / os.path.basename(file_path) | |
shutil.copy2(file_path, uploaded_file) | |
# 步骤3: 预处理模型 | |
processed_mesh, preprocessing_log = self._preprocess_model( | |
str(uploaded_file), | |
processing_options or {} | |
) | |
# 步骤4: 生成骨骼 | |
if not self.initialized: | |
if not self.initialize_model(): | |
return self._create_error_result("模型初始化失败", session_id, start_time) | |
skeleton_result = self._generate_skeleton( | |
processed_mesh, | |
model_info['file_name'], | |
user_prompt | |
) | |
# 步骤5: 保存结果 | |
output_files = self._save_results( | |
skeleton_result, | |
processed_mesh, | |
model_info, | |
session_dir, | |
user_prompt | |
) | |
# 步骤6: 记录处理结果 | |
processing_time = time.time() - start_time | |
result = { | |
'success': True, | |
'session_id': session_id, | |
'processing_time': processing_time, | |
'model_info': model_info, | |
'preprocessing_log': preprocessing_log, | |
'skeleton_data': skeleton_result, | |
'output_files': output_files, | |
'user_prompt': user_prompt, | |
'timestamp': datetime.now().isoformat() | |
} | |
# 更新会话记录 | |
self.session_manager.add_processed_model(session_id, { | |
'file_name': model_info['file_name'], | |
'processing_time': processing_time, | |
'timestamp': datetime.now().isoformat(), | |
'success': True | |
}) | |
logger.info(f"✅ 模型处理完成,耗时: {processing_time:.2f}秒") | |
return result | |
except Exception as e: | |
processing_time = time.time() - start_time | |
error_msg = f"处理过程中发生错误: {str(e)}" | |
logger.error(f"❌ {error_msg}") | |
logger.error(traceback.format_exc()) | |
return self._create_error_result(error_msg, session_id, start_time) | |
def _preprocess_model(self, file_path: str, options: Dict[str, Any]) -> Tuple[trimesh.Trimesh, List[str]]: | |
"""预处理模型""" | |
preprocessing_log = [] | |
try: | |
# 加载模型 | |
mesh = trimesh.load(file_path, force='mesh') | |
preprocessing_log.append(f"加载模型: {len(mesh.vertices)} 顶点, {len(mesh.faces)} 面") | |
# 自动修复 | |
if options.get('auto_repair', True): | |
mesh, repair_log = ModelValidator.auto_repair_mesh(mesh) | |
preprocessing_log.extend(repair_log) | |
# 简化网格(如果需要) | |
target_faces = options.get('target_faces', 10000) | |
if len(mesh.faces) > target_faces: | |
mesh = ModelPreprocessor.simplify_mesh(mesh, target_faces) | |
preprocessing_log.append(f"简化网格到 {len(mesh.faces)} 面") | |
# 标准化网格 | |
mesh, transform_info = ModelPreprocessor.normalize_mesh(mesh) | |
preprocessing_log.append(f"标准化网格: scale={transform_info['scale']:.4f}") | |
return mesh, preprocessing_log | |
except Exception as e: | |
error_msg = f"预处理失败: {str(e)}" | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
def _generate_skeleton(self, mesh: trimesh.Trimesh, file_name: str, user_prompt: str) -> Dict[str, Any]: | |
"""生成骨骼结构""" | |
try: | |
# 转换网格为点云 | |
points_per_mesh = self.default_args['input_pc_num'] | |
apply_marching_cubes = self.default_args['apply_marching_cubes'] | |
octree_depth = self.default_args['octree_depth'] | |
point_clouds = MeshProcessor.convert_meshes_to_point_clouds( | |
[mesh], | |
points_per_mesh, | |
apply_marching_cubes, | |
octree_depth | |
) | |
pc_data = point_clouds[0] | |
# 按照原始demo进行标准化处理 | |
pc_coor = pc_data[:, :3] | |
normals = pc_data[:, 3:] | |
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)]) | |
# 存储变换信息以便后续去标准化 | |
trans = (bounds[0] + bounds[1])[None, :] / 2 | |
scale = ((bounds[1] - bounds[0]).max() + 1e-5) | |
# 标准化坐标 - 与原始demo完全一致 | |
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 | |
pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995 | |
# 组合坐标和法向量 | |
pc_coor = pc_coor.astype(np.float32) | |
normals = normals.astype(np.float32) | |
pc_normal_data = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) | |
# 准备输入数据 | |
pc_normal = torch.from_numpy(pc_normal_data).unsqueeze(0) | |
if self.device.type == "cuda": | |
pc_normal = pc_normal.cuda() | |
# 获取mesh的变换信息 | |
mesh_bounds = np.array([mesh.vertices.min(axis=0), mesh.vertices.max(axis=0)]) | |
mesh_trans = (mesh_bounds[0] + mesh_bounds[1])[None, :] / 2 | |
mesh_scale = ((mesh_bounds[1] - mesh_bounds[0]).max() + 1e-5) | |
batch_data = { | |
'pc_normal': pc_normal, | |
'file_name': [os.path.splitext(file_name)[0]], | |
'trans': torch.from_numpy(mesh_trans).unsqueeze(0), | |
'scale': torch.tensor(mesh_scale).unsqueeze(0), | |
'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0), | |
'faces': torch.from_numpy(mesh.faces).unsqueeze(0) | |
} | |
# 生成骨骼 | |
with torch.no_grad(): | |
if self.accelerator: | |
with self.accelerator.autocast(): | |
pred_bone_coords = self.model.generate(batch_data) | |
else: | |
pred_bone_coords = self.model.generate(batch_data) | |
# 处理输出 - 完全按照原始demo的流程 | |
trans = batch_data['trans'][0].cpu().numpy() | |
scale = batch_data['scale'][0].cpu().numpy() | |
vertices = batch_data['vertices'][0].cpu().numpy() | |
faces = batch_data['faces'][0].cpu().numpy() | |
skeleton = pred_bone_coords[0].cpu().numpy().squeeze() | |
pred_joints, pred_bones = pred_joints_and_bones(skeleton) | |
# 去重处理 | |
if self.default_args['hier_order']: | |
pred_joints, pred_bones, pred_root_index = remove_duplicate_joints( | |
pred_joints, pred_bones, root_index=pred_bones[0][0] | |
) | |
else: | |
pred_joints, pred_bones = remove_duplicate_joints(pred_joints, pred_bones) | |
pred_root_index = 0 | |
# 重要:去标准化骨骼关节到原始模型坐标系 | |
pred_joints_denorm = pred_joints * scale + trans | |
return { | |
'joints': pred_joints_denorm.tolist(), # 使用去标准化后的关节 | |
'joints_normalized': pred_joints.tolist(), # 保留标准化的关节用于可视化 | |
'bones': pred_bones, | |
'root_index': pred_root_index, | |
'joint_count': len(pred_joints), | |
'bone_count': len(pred_bones), | |
'raw_skeleton': skeleton.tolist(), | |
'user_prompt': user_prompt, | |
'transform_info': { | |
'trans': trans.tolist(), | |
'scale': float(scale) | |
} | |
} | |
except Exception as e: | |
error_msg = f"骨骼生成失败: {str(e)}" | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
def _save_results(self, | |
skeleton_result: Dict[str, Any], | |
mesh: trimesh.Trimesh, | |
model_info: Dict[str, Any], | |
session_dir: Path, | |
user_prompt: str) -> Dict[str, str]: | |
"""保存处理结果""" | |
try: | |
output_dir = session_dir / "outputs" | |
base_name = os.path.splitext(model_info['file_name'])[0] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_files = {} | |
# 移除JSON格式输出以避免序列化问题 | |
# 保存骨骼OBJ - 使用去标准化后的关节 | |
obj_file = output_dir / f"{base_name}_{timestamp}_skeleton.obj" | |
save_skeleton_obj( | |
np.array(skeleton_result['joints']), | |
skeleton_result['bones'], | |
str(obj_file), | |
skeleton_result.get('root_index', 0), | |
use_cone=self.default_args['hier_order'] | |
) | |
output_files['skeleton_obj'] = str(obj_file) | |
# 保存骨骼TXT | |
txt_file = output_dir / f"{base_name}_{timestamp}_rig.txt" | |
save_skeleton_to_txt( | |
np.array(skeleton_result['joints']), | |
skeleton_result['bones'], | |
skeleton_result.get('root_index', 0), | |
self.default_args['hier_order'], | |
mesh.vertices, | |
str(txt_file) | |
) | |
output_files['skeleton_txt'] = str(txt_file) | |
# 保存处理后的网格 | |
mesh_file = output_dir / f"{base_name}_{timestamp}_processed.obj" | |
mesh.export(str(mesh_file)) | |
output_files['processed_mesh'] = str(mesh_file) | |
# 保存处理报告(文本格式) | |
report_file = output_dir / f"{base_name}_{timestamp}_report.txt" | |
report_content = f"""MagicArticulate Processing Report | |
===================================== | |
File: {model_info['file_name']} | |
Processing Time: {datetime.now().isoformat()} | |
User Prompt: {user_prompt} | |
Model Information: | |
- Vertices: {model_info.get('vertex_count', 'N/A')} | |
- Faces: {model_info.get('face_count', 'N/A')} | |
- File Size: {model_info.get('file_size_mb', 'N/A')} MB | |
- Format: {model_info.get('format', 'N/A')} | |
Skeleton Results: | |
- Joints: {skeleton_result.get('joint_count', 'N/A')} | |
- Bones: {skeleton_result.get('bone_count', 'N/A')} | |
- Root Index: {skeleton_result.get('root_index', 'N/A')} | |
Generated Files: | |
- Skeleton OBJ: {base_name}_{timestamp}_skeleton.obj | |
- Skeleton TXT: {base_name}_{timestamp}_rig.txt | |
- Processed Mesh: {base_name}_{timestamp}_processed.obj | |
""" | |
with open(report_file, 'w', encoding='utf-8') as f: | |
f.write(report_content) | |
output_files['report'] = str(report_file) | |
logger.info(f"结果保存完成: {len(output_files)} 个文件") | |
return output_files | |
except Exception as e: | |
error_msg = f"保存结果失败: {str(e)}" | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
def _create_error_result(self, error_message: str, session_id: str, start_time: float) -> Dict[str, Any]: | |
"""创建错误结果""" | |
processing_time = time.time() - start_time | |
return { | |
'success': False, | |
'session_id': session_id, | |
'error': error_message, | |
'processing_time': processing_time, | |
'timestamp': datetime.now().isoformat() | |
} | |
def _make_json_serializable(self, obj): | |
"""将对象转换为JSON可序列化格式""" | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
elif isinstance(obj, np.integer): | |
return int(obj) | |
elif isinstance(obj, np.floating): | |
return float(obj) | |
elif isinstance(obj, dict): | |
return {key: self._make_json_serializable(value) for key, value in obj.items()} | |
elif isinstance(obj, list): | |
return [self._make_json_serializable(item) for item in obj] | |
else: | |
return obj | |
def _create_args_object(self): | |
"""创建参数对象""" | |
class Args: | |
def __init__(self, **kwargs): | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
return Args(**self.default_args) | |
def get_session_info(self, session_id: str) -> Dict[str, Any]: | |
"""获取会话信息""" | |
if session_id not in self.session_manager.sessions: | |
raise ValueError(f"会话不存在: {session_id}") | |
return self.session_manager.sessions[session_id].copy() | |
def list_user_sessions(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]: | |
"""列出用户会话""" | |
sessions = [] | |
for session_id, session_info in self.session_manager.sessions.items(): | |
if user_id is None or session_info.get('user_id') == user_id: | |
sessions.append({ | |
'session_id': session_id, | |
**session_info | |
}) | |
return sorted(sessions, key=lambda x: x['created_at'], reverse=True) | |
def cleanup_sessions(self, max_age_days: int = 7): | |
"""清理旧会话""" | |
self.session_manager.cleanup_old_sessions(max_age_days) | |
# 简化的使用接口 | |
def process_model_file(file_path: str, | |
user_prompt: str = "", | |
model_weights_path: Optional[str] = None, | |
output_dir: Optional[str] = None) -> Dict[str, Any]: | |
""" | |
简化的模型处理接口 | |
Args: | |
file_path: 模型文件路径 | |
user_prompt: 用户提示词 | |
model_weights_path: 模型权重路径 | |
output_dir: 输出目录 | |
Returns: | |
处理结果 | |
""" | |
api = MagicArticulateAPI( | |
model_weights_path=model_weights_path, | |
session_base_dir=output_dir or "temp_sessions" | |
) | |
result = api.process_uploaded_model( | |
file_path=file_path, | |
user_prompt=user_prompt | |
) | |
return result | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="MagicArticulate API 测试") | |
parser.add_argument("--input", required=True, help="输入模型文件路径") | |
parser.add_argument("--prompt", default="", help="用户提示词") | |
parser.add_argument("--weights", help="模型权重路径") | |
parser.add_argument("--output", default="api_outputs", help="输出目录") | |
args = parser.parse_args() | |
# 测试API | |
result = process_model_file( | |
file_path=args.input, | |
user_prompt=args.prompt, | |
model_weights_path=args.weights, | |
output_dir=args.output | |
) | |
if result['success']: | |
print("✅ 处理成功!") | |
print(f"会话ID: {result['session_id']}") | |
print(f"处理时间: {result['processing_time']:.2f}秒") | |
print(f"关节数量: {result['skeleton_data']['joint_count']}") | |
print(f"骨骼数量: {result['skeleton_data']['bone_count']}") | |
print(f"输出文件: {len(result['output_files'])} 个") | |
else: | |
print("❌ 处理失败!") | |
print(f"错误: {result['error']}") |