HF User
🚀 Fresh deploy of Magic Articulate Enhanced MVP
e7b9fb6
#!/usr/bin/env python3
"""
MagicArticulate API - Enhanced Version
支持用户上传的3D模型文件和多用户结果管理
"""
import os
import sys
import uuid
import json
import time
import shutil
import logging
import tempfile
import traceback
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple
import torch
import trimesh
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed, DistributedDataParallelKwargs
# 添加父目录到路径以正确导入模块
parent_dir = str(Path(__file__).parent.parent)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
print(f"🔍 ARTICULATE_API DEBUG: Current working directory: {os.getcwd()}")
print(f"🔍 ARTICULATE_API DEBUG: Script file path: {__file__}")
print(f"🔍 ARTICULATE_API DEBUG: Parent directory: {parent_dir}")
print(f"🔍 ARTICULATE_API DEBUG: sys.path includes:")
for i, path in enumerate(sys.path[:10]): # 只显示前10个避免太长
print(f" {i}: {path}")
# 检查目录结构
utils_path = os.path.join(parent_dir, 'utils')
skeleton_path = os.path.join(parent_dir, 'skeleton_models')
print(f"🔍 ARTICULATE_API DEBUG: utils path exists: {os.path.exists(utils_path)}")
print(f"🔍 ARTICULATE_API DEBUG: skeleton_models path exists: {os.path.exists(skeleton_path)}")
if os.path.exists(utils_path):
print(f"🔍 ARTICULATE_API DEBUG: utils contents: {os.listdir(utils_path)}")
from skeleton_models.skeletongen import SkeletonGPT
from utils.mesh_to_pc import MeshProcessor
from utils.save_utils import (
save_mesh, pred_joints_and_bones, save_skeleton_to_txt,
save_args, remove_duplicate_joints, save_skeleton_obj,
render_mesh_with_skeleton
)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelValidator:
"""3D模型验证和修复类"""
SUPPORTED_FORMATS = {'.obj', '.glb', '.gltf', '.ply', '.stl', '.fbx', '.dae'}
MAX_VERTICES = 100000 # 最大顶点数
MIN_VERTICES = 100 # 最小顶点数
MAX_FILE_SIZE_MB = 100 # 最大文件大小
@staticmethod
def validate_file(file_path: str) -> Tuple[bool, str, Dict[str, Any]]:
"""
验证3D模型文件
Returns:
(is_valid, error_message, model_info)
"""
try:
# 检查文件是否存在
if not os.path.exists(file_path):
return False, "文件不存在", {}
# 检查文件大小
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
if file_size_mb > ModelValidator.MAX_FILE_SIZE_MB:
return False, f"文件过大: {file_size_mb:.1f}MB > {ModelValidator.MAX_FILE_SIZE_MB}MB", {}
# 检查文件格式
file_ext = Path(file_path).suffix.lower()
if file_ext not in ModelValidator.SUPPORTED_FORMATS:
return False, f"不支持的文件格式: {file_ext}", {}
# 尝试加载模型
mesh = trimesh.load(file_path, force='mesh')
# 检查是否为有效网格
if not hasattr(mesh, 'vertices') or not hasattr(mesh, 'faces'):
return False, "文件不包含有效的网格数据", {}
# 检查顶点数量
vertex_count = len(mesh.vertices)
if vertex_count < ModelValidator.MIN_VERTICES:
return False, f"顶点数量过少: {vertex_count} < {ModelValidator.MIN_VERTICES}", {}
if vertex_count > ModelValidator.MAX_VERTICES:
return False, f"顶点数量过多: {vertex_count} > {ModelValidator.MAX_VERTICES}", {}
# 收集模型信息
model_info = {
'file_name': os.path.basename(file_path),
'file_size_mb': file_size_mb,
'format': file_ext,
'vertex_count': vertex_count,
'face_count': len(mesh.faces) if hasattr(mesh, 'faces') else 0,
'bounds': mesh.bounds.tolist() if hasattr(mesh, 'bounds') else None,
'is_watertight': mesh.is_watertight if hasattr(mesh, 'is_watertight') else False,
'volume': float(mesh.volume) if hasattr(mesh, 'volume') else 0.0,
'area': float(mesh.area) if hasattr(mesh, 'area') else 0.0,
}
return True, "验证通过", model_info
except Exception as e:
return False, f"模型验证失败: {str(e)}", {}
@staticmethod
def auto_repair_mesh(mesh: trimesh.Trimesh) -> Tuple[trimesh.Trimesh, List[str]]:
"""
自动修复网格问题
Returns:
(repaired_mesh, repair_log)
"""
repair_log = []
try:
# 移除重复顶点
if mesh.is_volume:
original_vertices = len(mesh.vertices)
mesh.merge_vertices()
if len(mesh.vertices) < original_vertices:
repair_log.append(f"移除了 {original_vertices - len(mesh.vertices)} 个重复顶点")
# 修复法向量
if not hasattr(mesh, 'vertex_normals') or mesh.vertex_normals is None:
mesh.fix_normals()
repair_log.append("修复了顶点法向量")
# 移除退化面
original_faces = len(mesh.faces)
mesh.remove_degenerate_faces()
if len(mesh.faces) < original_faces:
repair_log.append(f"移除了 {original_faces - len(mesh.faces)} 个退化面")
# 填充孔洞(如果需要)
if not mesh.is_watertight and hasattr(mesh, 'fill_holes'):
try:
mesh.fill_holes()
repair_log.append("填充了网格孔洞")
except:
repair_log.append("尝试填充孔洞失败,但继续处理")
return mesh, repair_log
except Exception as e:
logger.warning(f"网格修复过程中出现错误: {str(e)}")
return mesh, repair_log + [f"修复过程出错: {str(e)}"]
class ModelPreprocessor:
"""模型预处理类"""
@staticmethod
def convert_format(input_path: str, output_format: str = '.obj') -> str:
"""
转换模型格式
Args:
input_path: 输入文件路径
output_format: 输出格式 (默认为.obj)
Returns:
输出文件路径
"""
try:
mesh = trimesh.load(input_path, force='mesh')
# 生成输出路径
base_name = os.path.splitext(os.path.basename(input_path))[0]
output_path = os.path.join(
os.path.dirname(input_path),
f"{base_name}_converted{output_format}"
)
# 导出为指定格式
mesh.export(output_path)
logger.info(f"格式转换完成: {input_path} -> {output_path}")
return output_path
except Exception as e:
logger.error(f"格式转换失败: {str(e)}")
raise
@staticmethod
def simplify_mesh(mesh: trimesh.Trimesh, target_faces: int = 5000) -> trimesh.Trimesh:
"""
简化网格
Args:
mesh: 输入网格
target_faces: 目标面数
Returns:
简化后的网格
"""
try:
if len(mesh.faces) <= target_faces:
return mesh
# 使用quadric decimation进行简化
simplified = mesh.simplify_quadratic_decimation(target_faces)
logger.info(f"网格简化: {len(mesh.faces)} -> {len(simplified.faces)} 面")
return simplified
except Exception as e:
logger.warning(f"网格简化失败: {str(e)}, 使用原始网格")
return mesh
@staticmethod
def normalize_mesh(mesh: trimesh.Trimesh, scale_factor: float = 0.95) -> Tuple[trimesh.Trimesh, Dict[str, Any]]:
"""
标准化网格到标准坐标空间
Args:
mesh: 输入网格
scale_factor: 缩放因子
Returns:
(normalized_mesh, transform_info)
"""
try:
# 计算边界框
bounds = mesh.bounds
center = (bounds[0] + bounds[1]) / 2
size = bounds[1] - bounds[0]
max_size = size.max()
# 计算变换参数
scale = (2.0 * scale_factor) / max_size
translation = -center
# 应用变换
vertices = mesh.vertices.copy()
vertices = (vertices + translation) * scale
# 创建新网格
normalized_mesh = trimesh.Trimesh(vertices=vertices, faces=mesh.faces)
# 记录变换信息
transform_info = {
'original_center': center.tolist(),
'original_size': size.tolist(),
'scale': float(scale),
'translation': translation.tolist()
}
logger.info(f"网格标准化完成: scale={scale:.4f}")
return normalized_mesh, transform_info
except Exception as e:
logger.error(f"网格标准化失败: {str(e)}")
raise
class UserSessionManager:
"""用户会话管理类"""
def __init__(self, base_dir: str = "user_sessions"):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(exist_ok=True)
# 元数据文件
self.metadata_file = self.base_dir / "sessions_metadata.json"
self.load_metadata()
def load_metadata(self):
"""加载会话元数据"""
if self.metadata_file.exists():
with open(self.metadata_file, 'r', encoding='utf-8') as f:
self.sessions = json.load(f)
else:
self.sessions = {}
def save_metadata(self):
"""保存会话元数据"""
with open(self.metadata_file, 'w', encoding='utf-8') as f:
json.dump(self.sessions, f, indent=2, ensure_ascii=False)
def create_session(self, user_id: Optional[str] = None) -> str:
"""
创建新的用户会话
Args:
user_id: 用户ID(可选)
Returns:
session_id
"""
session_id = str(uuid.uuid4())
session_dir = self.base_dir / session_id
session_dir.mkdir(exist_ok=True)
# 创建子目录
(session_dir / "uploads").mkdir(exist_ok=True)
(session_dir / "outputs").mkdir(exist_ok=True)
(session_dir / "temp").mkdir(exist_ok=True)
# 记录会话信息
self.sessions[session_id] = {
'user_id': user_id,
'created_at': datetime.now().isoformat(),
'status': 'active',
'processed_models': [],
'last_activity': datetime.now().isoformat()
}
self.save_metadata()
logger.info(f"创建新会话: {session_id}")
return session_id
def get_session_dir(self, session_id: str) -> Path:
"""获取会话目录"""
session_dir = self.base_dir / session_id
if not session_dir.exists():
raise ValueError(f"会话不存在: {session_id}")
return session_dir
def update_activity(self, session_id: str):
"""更新会话活动时间"""
if session_id in self.sessions:
self.sessions[session_id]['last_activity'] = datetime.now().isoformat()
self.save_metadata()
def add_processed_model(self, session_id: str, model_info: Dict[str, Any]):
"""添加已处理模型记录"""
if session_id in self.sessions:
self.sessions[session_id]['processed_models'].append(model_info)
self.update_activity(session_id)
def cleanup_old_sessions(self, max_age_days: int = 7):
"""清理旧会话"""
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 3600)
sessions_to_remove = []
for session_id, session_info in self.sessions.items():
last_activity = datetime.fromisoformat(session_info['last_activity'])
if last_activity.timestamp() < cutoff_time:
sessions_to_remove.append(session_id)
for session_id in sessions_to_remove:
try:
session_dir = self.base_dir / session_id
if session_dir.exists():
shutil.rmtree(session_dir)
del self.sessions[session_id]
logger.info(f"清理旧会话: {session_id}")
except Exception as e:
logger.error(f"清理会话失败 {session_id}: {str(e)}")
if sessions_to_remove:
self.save_metadata()
class MagicArticulateAPI:
"""MagicArticulate API主类"""
def __init__(self,
model_weights_path: Optional[str] = None,
device: str = "auto",
session_base_dir: str = "user_sessions"):
self.device = self._setup_device(device)
self.model = None
self.accelerator = None
self.model_weights_path = model_weights_path
# 初始化会话管理器
self.session_manager = UserSessionManager(session_base_dir)
# 默认处理参数 - 匹配原始demo.py设置
self.default_args = {
'input_pc_num': 8192,
'num_beams': 1,
'n_discrete_size': 128,
'n_max_bones': 100,
'pad_id': -1,
'precision': 'fp16',
'batchsize_per_gpu': 1,
'apply_marching_cubes': False,
'octree_depth': 7,
'hier_order': False, # 匹配demo.py默认值
'save_render': False,
'llm': 'facebook/opt-350m' # 匹配demo.py默认值
}
self.initialized = False
logger.info("MagicArticulate API 初始化完成")
def _setup_device(self, device: str) -> torch.device:
"""设置计算设备"""
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
logger.info(f"使用GPU: {torch.cuda.get_device_name()}")
else:
device = "cpu"
logger.info("使用CPU")
return torch.device(device)
def initialize_model(self) -> bool:
"""初始化模型"""
try:
if self.initialized:
return True
logger.info("正在初始化MagicArticulate模型...")
# 设置加速器
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.accelerator = Accelerator(
kwargs_handlers=[kwargs],
mixed_precision=self.default_args['precision'],
)
# 创建模型
args = self._create_args_object()
self.model = SkeletonGPT(args)
if self.device.type == "cuda":
self.model = self.model.cuda()
# 加载预训练权重
if self.model_weights_path and os.path.exists(self.model_weights_path):
logger.info(f"加载模型权重: {self.model_weights_path}")
pkg = torch.load(self.model_weights_path, map_location=self.device)
self.model.load_state_dict(pkg["model"])
else:
error_msg = "预训练权重必须提供!当前使用随机初始化,结果将不准确。"
logger.error(error_msg)
# 不抛出错误,但给出强烈警告
logger.error("⚠️ WARNING: 没有预训练权重,生成的骨骼结构将不准确!")
self.model.eval()
set_seed(0)
# 准备模型
if self.accelerator:
self.model = self.accelerator.prepare(self.model)
self.initialized = True
logger.info("✅ 模型初始化成功")
return True
except Exception as e:
logger.error(f"❌ 模型初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def process_uploaded_model(self,
file_path: str,
session_id: Optional[str] = None,
user_prompt: str = "",
processing_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
处理用户上传的3D模型
Args:
file_path: 模型文件路径
session_id: 会话ID(可选)
user_prompt: 用户提示词
processing_options: 处理选项
Returns:
处理结果字典
"""
start_time = time.time()
try:
# 创建会话(如果未提供)
if not session_id:
session_id = self.session_manager.create_session()
logger.info(f"开始处理模型: {file_path}, 会话: {session_id}")
# 步骤1: 验证模型文件
is_valid, error_msg, model_info = ModelValidator.validate_file(file_path)
if not is_valid:
return self._create_error_result(error_msg, session_id, start_time)
logger.info(f"模型验证通过: {model_info}")
# 步骤2: 复制文件到会话目录
session_dir = self.session_manager.get_session_dir(session_id)
uploaded_file = session_dir / "uploads" / os.path.basename(file_path)
shutil.copy2(file_path, uploaded_file)
# 步骤3: 预处理模型
processed_mesh, preprocessing_log = self._preprocess_model(
str(uploaded_file),
processing_options or {}
)
# 步骤4: 生成骨骼
if not self.initialized:
if not self.initialize_model():
return self._create_error_result("模型初始化失败", session_id, start_time)
skeleton_result = self._generate_skeleton(
processed_mesh,
model_info['file_name'],
user_prompt
)
# 步骤5: 保存结果
output_files = self._save_results(
skeleton_result,
processed_mesh,
model_info,
session_dir,
user_prompt
)
# 步骤6: 记录处理结果
processing_time = time.time() - start_time
result = {
'success': True,
'session_id': session_id,
'processing_time': processing_time,
'model_info': model_info,
'preprocessing_log': preprocessing_log,
'skeleton_data': skeleton_result,
'output_files': output_files,
'user_prompt': user_prompt,
'timestamp': datetime.now().isoformat()
}
# 更新会话记录
self.session_manager.add_processed_model(session_id, {
'file_name': model_info['file_name'],
'processing_time': processing_time,
'timestamp': datetime.now().isoformat(),
'success': True
})
logger.info(f"✅ 模型处理完成,耗时: {processing_time:.2f}秒")
return result
except Exception as e:
processing_time = time.time() - start_time
error_msg = f"处理过程中发生错误: {str(e)}"
logger.error(f"❌ {error_msg}")
logger.error(traceback.format_exc())
return self._create_error_result(error_msg, session_id, start_time)
def _preprocess_model(self, file_path: str, options: Dict[str, Any]) -> Tuple[trimesh.Trimesh, List[str]]:
"""预处理模型"""
preprocessing_log = []
try:
# 加载模型
mesh = trimesh.load(file_path, force='mesh')
preprocessing_log.append(f"加载模型: {len(mesh.vertices)} 顶点, {len(mesh.faces)} 面")
# 自动修复
if options.get('auto_repair', True):
mesh, repair_log = ModelValidator.auto_repair_mesh(mesh)
preprocessing_log.extend(repair_log)
# 简化网格(如果需要)
target_faces = options.get('target_faces', 10000)
if len(mesh.faces) > target_faces:
mesh = ModelPreprocessor.simplify_mesh(mesh, target_faces)
preprocessing_log.append(f"简化网格到 {len(mesh.faces)} 面")
# 标准化网格
mesh, transform_info = ModelPreprocessor.normalize_mesh(mesh)
preprocessing_log.append(f"标准化网格: scale={transform_info['scale']:.4f}")
return mesh, preprocessing_log
except Exception as e:
error_msg = f"预处理失败: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _generate_skeleton(self, mesh: trimesh.Trimesh, file_name: str, user_prompt: str) -> Dict[str, Any]:
"""生成骨骼结构"""
try:
# 转换网格为点云
points_per_mesh = self.default_args['input_pc_num']
apply_marching_cubes = self.default_args['apply_marching_cubes']
octree_depth = self.default_args['octree_depth']
point_clouds = MeshProcessor.convert_meshes_to_point_clouds(
[mesh],
points_per_mesh,
apply_marching_cubes,
octree_depth
)
pc_data = point_clouds[0]
# 按照原始demo进行标准化处理
pc_coor = pc_data[:, :3]
normals = pc_data[:, 3:]
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
# 存储变换信息以便后续去标准化
trans = (bounds[0] + bounds[1])[None, :] / 2
scale = ((bounds[1] - bounds[0]).max() + 1e-5)
# 标准化坐标 - 与原始demo完全一致
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
# 组合坐标和法向量
pc_coor = pc_coor.astype(np.float32)
normals = normals.astype(np.float32)
pc_normal_data = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
# 准备输入数据
pc_normal = torch.from_numpy(pc_normal_data).unsqueeze(0)
if self.device.type == "cuda":
pc_normal = pc_normal.cuda()
# 获取mesh的变换信息
mesh_bounds = np.array([mesh.vertices.min(axis=0), mesh.vertices.max(axis=0)])
mesh_trans = (mesh_bounds[0] + mesh_bounds[1])[None, :] / 2
mesh_scale = ((mesh_bounds[1] - mesh_bounds[0]).max() + 1e-5)
batch_data = {
'pc_normal': pc_normal,
'file_name': [os.path.splitext(file_name)[0]],
'trans': torch.from_numpy(mesh_trans).unsqueeze(0),
'scale': torch.tensor(mesh_scale).unsqueeze(0),
'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0),
'faces': torch.from_numpy(mesh.faces).unsqueeze(0)
}
# 生成骨骼
with torch.no_grad():
if self.accelerator:
with self.accelerator.autocast():
pred_bone_coords = self.model.generate(batch_data)
else:
pred_bone_coords = self.model.generate(batch_data)
# 处理输出 - 完全按照原始demo的流程
trans = batch_data['trans'][0].cpu().numpy()
scale = batch_data['scale'][0].cpu().numpy()
vertices = batch_data['vertices'][0].cpu().numpy()
faces = batch_data['faces'][0].cpu().numpy()
skeleton = pred_bone_coords[0].cpu().numpy().squeeze()
pred_joints, pred_bones = pred_joints_and_bones(skeleton)
# 去重处理
if self.default_args['hier_order']:
pred_joints, pred_bones, pred_root_index = remove_duplicate_joints(
pred_joints, pred_bones, root_index=pred_bones[0][0]
)
else:
pred_joints, pred_bones = remove_duplicate_joints(pred_joints, pred_bones)
pred_root_index = 0
# 重要:去标准化骨骼关节到原始模型坐标系
pred_joints_denorm = pred_joints * scale + trans
return {
'joints': pred_joints_denorm.tolist(), # 使用去标准化后的关节
'joints_normalized': pred_joints.tolist(), # 保留标准化的关节用于可视化
'bones': pred_bones,
'root_index': pred_root_index,
'joint_count': len(pred_joints),
'bone_count': len(pred_bones),
'raw_skeleton': skeleton.tolist(),
'user_prompt': user_prompt,
'transform_info': {
'trans': trans.tolist(),
'scale': float(scale)
}
}
except Exception as e:
error_msg = f"骨骼生成失败: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _save_results(self,
skeleton_result: Dict[str, Any],
mesh: trimesh.Trimesh,
model_info: Dict[str, Any],
session_dir: Path,
user_prompt: str) -> Dict[str, str]:
"""保存处理结果"""
try:
output_dir = session_dir / "outputs"
base_name = os.path.splitext(model_info['file_name'])[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_files = {}
# 移除JSON格式输出以避免序列化问题
# 保存骨骼OBJ - 使用去标准化后的关节
obj_file = output_dir / f"{base_name}_{timestamp}_skeleton.obj"
save_skeleton_obj(
np.array(skeleton_result['joints']),
skeleton_result['bones'],
str(obj_file),
skeleton_result.get('root_index', 0),
use_cone=self.default_args['hier_order']
)
output_files['skeleton_obj'] = str(obj_file)
# 保存骨骼TXT
txt_file = output_dir / f"{base_name}_{timestamp}_rig.txt"
save_skeleton_to_txt(
np.array(skeleton_result['joints']),
skeleton_result['bones'],
skeleton_result.get('root_index', 0),
self.default_args['hier_order'],
mesh.vertices,
str(txt_file)
)
output_files['skeleton_txt'] = str(txt_file)
# 保存处理后的网格
mesh_file = output_dir / f"{base_name}_{timestamp}_processed.obj"
mesh.export(str(mesh_file))
output_files['processed_mesh'] = str(mesh_file)
# 保存处理报告(文本格式)
report_file = output_dir / f"{base_name}_{timestamp}_report.txt"
report_content = f"""MagicArticulate Processing Report
=====================================
File: {model_info['file_name']}
Processing Time: {datetime.now().isoformat()}
User Prompt: {user_prompt}
Model Information:
- Vertices: {model_info.get('vertex_count', 'N/A')}
- Faces: {model_info.get('face_count', 'N/A')}
- File Size: {model_info.get('file_size_mb', 'N/A')} MB
- Format: {model_info.get('format', 'N/A')}
Skeleton Results:
- Joints: {skeleton_result.get('joint_count', 'N/A')}
- Bones: {skeleton_result.get('bone_count', 'N/A')}
- Root Index: {skeleton_result.get('root_index', 'N/A')}
Generated Files:
- Skeleton OBJ: {base_name}_{timestamp}_skeleton.obj
- Skeleton TXT: {base_name}_{timestamp}_rig.txt
- Processed Mesh: {base_name}_{timestamp}_processed.obj
"""
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report_content)
output_files['report'] = str(report_file)
logger.info(f"结果保存完成: {len(output_files)} 个文件")
return output_files
except Exception as e:
error_msg = f"保存结果失败: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _create_error_result(self, error_message: str, session_id: str, start_time: float) -> Dict[str, Any]:
"""创建错误结果"""
processing_time = time.time() - start_time
return {
'success': False,
'session_id': session_id,
'error': error_message,
'processing_time': processing_time,
'timestamp': datetime.now().isoformat()
}
def _make_json_serializable(self, obj):
"""将对象转换为JSON可序列化格式"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, dict):
return {key: self._make_json_serializable(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
else:
return obj
def _create_args_object(self):
"""创建参数对象"""
class Args:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return Args(**self.default_args)
def get_session_info(self, session_id: str) -> Dict[str, Any]:
"""获取会话信息"""
if session_id not in self.session_manager.sessions:
raise ValueError(f"会话不存在: {session_id}")
return self.session_manager.sessions[session_id].copy()
def list_user_sessions(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""列出用户会话"""
sessions = []
for session_id, session_info in self.session_manager.sessions.items():
if user_id is None or session_info.get('user_id') == user_id:
sessions.append({
'session_id': session_id,
**session_info
})
return sorted(sessions, key=lambda x: x['created_at'], reverse=True)
def cleanup_sessions(self, max_age_days: int = 7):
"""清理旧会话"""
self.session_manager.cleanup_old_sessions(max_age_days)
# 简化的使用接口
def process_model_file(file_path: str,
user_prompt: str = "",
model_weights_path: Optional[str] = None,
output_dir: Optional[str] = None) -> Dict[str, Any]:
"""
简化的模型处理接口
Args:
file_path: 模型文件路径
user_prompt: 用户提示词
model_weights_path: 模型权重路径
output_dir: 输出目录
Returns:
处理结果
"""
api = MagicArticulateAPI(
model_weights_path=model_weights_path,
session_base_dir=output_dir or "temp_sessions"
)
result = api.process_uploaded_model(
file_path=file_path,
user_prompt=user_prompt
)
return result
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="MagicArticulate API 测试")
parser.add_argument("--input", required=True, help="输入模型文件路径")
parser.add_argument("--prompt", default="", help="用户提示词")
parser.add_argument("--weights", help="模型权重路径")
parser.add_argument("--output", default="api_outputs", help="输出目录")
args = parser.parse_args()
# 测试API
result = process_model_file(
file_path=args.input,
user_prompt=args.prompt,
model_weights_path=args.weights,
output_dir=args.output
)
if result['success']:
print("✅ 处理成功!")
print(f"会话ID: {result['session_id']}")
print(f"处理时间: {result['processing_time']:.2f}秒")
print(f"关节数量: {result['skeleton_data']['joint_count']}")
print(f"骨骼数量: {result['skeleton_data']['bone_count']}")
print(f"输出文件: {len(result['output_files'])} 个")
else:
print("❌ 处理失败!")
print(f"错误: {result['error']}")