Spaces:
Sleeping
Sleeping
""" | |
MagicArticulate MVP - 增强版Gradio应用 | |
支持多格式文件下载和预览 | |
""" | |
import os | |
import sys | |
import time | |
import logging | |
import tempfile | |
import traceback | |
from pathlib import Path | |
from typing import Optional, Dict, Any, List, Tuple | |
import shutil | |
import zipfile | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
# 添加src目录到路径 | |
sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) | |
from enhanced_magic_wrapper import EnhancedMagicWrapper | |
from config import get_config, DEMO_PROMPTS, EXAMPLE_MODELS | |
from src.utils import ( | |
validate_file, get_model_info, cleanup_temp_files, | |
format_processing_time, get_prompt_suggestions, | |
create_processing_status, estimate_processing_time, | |
generate_download_filename, safe_json_serialize | |
) | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# 获取配置 | |
config = get_config() | |
# 全局变量 | |
magic_wrapper = None | |
processing_status = {} | |
session_results = {} # 存储处理结果 | |
def initialize_app(): | |
"""初始化应用""" | |
global magic_wrapper | |
try: | |
logger.info("🚀 Initializing MagicArticulate MVP LFG...") | |
logger.info(f"🔍 Current working directory: {os.getcwd()}") | |
logger.info(f"🔍 Script directory: {os.path.dirname(__file__)}") | |
# 检查关键目录结构 | |
directories = ['src', 'utils', 'skeleton_models', 'magic_articulate_plus', 'third_party'] | |
for dir_name in directories: | |
dir_path = os.path.join(os.getcwd(), dir_name) | |
exists = os.path.exists(dir_path) | |
logger.info(f"🔍 Directory {dir_name}: exists={exists}") | |
if exists and os.path.isdir(dir_path): | |
try: | |
contents = os.listdir(dir_path)[:5] # 只显示前5个文件 | |
logger.info(f"🔍 Contents (first 5): {contents}") | |
except Exception as e: | |
logger.warning(f"🔍 Could not list contents: {e}") | |
# 首先下载所需的模型文件 | |
try: | |
logger.info("📥 开始下载模型文件...") | |
from download_models import download_models | |
download_models() | |
except Exception as e: | |
logger.warning(f"⚠️ 模型下载过程中出现问题: {e}") | |
import traceback | |
logger.warning(f"⚠️ Download traceback: {traceback.format_exc()}") | |
# 创建增强版包装器实例(支持真实3D模型处理) | |
logger.info("🔧 Creating EnhancedMagicWrapper instance...") | |
magic_wrapper = EnhancedMagicWrapper() | |
# 初始化包装器 | |
logger.info("🔧 Initializing wrapper...") | |
if magic_wrapper.initialize(): | |
logger.info("✅ MagicArticulate MVP initialized successfully") | |
return True | |
else: | |
logger.error("❌ Failed to initialize MagicArticulate wrapper") | |
return False | |
except Exception as e: | |
logger.error(f"💥 App initialization failed: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return False | |
def create_download_package(output_files: Dict[str, str], session_id: str) -> str: | |
""" | |
创建包含所有输出文件的ZIP包 | |
Args: | |
output_files: 输出文件路径字典 | |
session_id: 会话ID | |
Returns: | |
ZIP文件路径 | |
""" | |
try: | |
# 创建临时目录 | |
temp_dir = Path(tempfile.mkdtemp()) | |
zip_path = temp_dir / f"skeleton_results_{session_id}.zip" | |
# 创建ZIP文件 | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for file_type, file_path in output_files.items(): | |
if os.path.exists(file_path): | |
# 使用描述性的文件名 | |
if 'skeleton_json' in file_type: | |
arcname = "skeleton_data.json" | |
elif 'skeleton_obj' in file_type: | |
arcname = "skeleton_model.obj" | |
elif 'skeleton_txt' in file_type: | |
arcname = "skeleton_rig.txt" | |
elif 'processed_mesh' in file_type: | |
arcname = "processed_mesh.obj" | |
elif 'report' in file_type: | |
arcname = "processing_report.json" | |
else: | |
arcname = os.path.basename(file_path) | |
zipf.write(file_path, arcname) | |
logger.info(f"Added {arcname} to ZIP") | |
logger.info(f"Created download package: {zip_path}") | |
return str(zip_path) | |
except Exception as e: | |
logger.error(f"Failed to create download package: {str(e)}") | |
return None | |
def process_3d_model_gpu( | |
model_file: gr.File, | |
prompt: str, | |
confidence_threshold: float, | |
generate_preview: bool | |
) -> Tuple[str, str, Any, Any, Any, Any, str, str]: | |
""" | |
GPU处理函数 - 使用ZeroGPU | |
返回多个文件供下载 | |
Returns: | |
(状态, 文本展示, OBJ下载, TXT下载, ZIP下载, 处理信息, 错误信息, 骨骼数据) | |
""" | |
global magic_wrapper, session_results | |
start_time = time.time() | |
session_id = f"session_{int(start_time)}" | |
try: | |
logger.info(f"🔄 Starting GPU processing for session: {session_id}") | |
# 验证输入 | |
if model_file is None: | |
return "❌ 错误", "", None, None, None, None, "", "请上传3D模型文件" | |
if not prompt.strip(): | |
prompt = DEMO_PROMPTS['generic'] | |
logger.info(f"Using default prompt: {prompt}") | |
# 验证文件 | |
file_path = model_file.name | |
is_valid, error_msg = validate_file(file_path, config['file_limits']['max_size_mb']) | |
if not is_valid: | |
return "❌ 错误", "", None, None, None, None, "", f"文件验证失败: {error_msg}" | |
# 获取模型信息 | |
model_info = get_model_info(file_path) | |
logger.info(f"📊 Model info: {model_info}") | |
# 估算处理时间 | |
estimated_time = estimate_processing_time(model_info) | |
logger.info(f"⏱️ Estimated processing time: {estimated_time:.1f}s") | |
# 更新处理状态 | |
processing_status[session_id] = create_processing_status( | |
"preparing", 0.1, "准备处理3D模型..." | |
) | |
# 调用MagicArticulate处理 | |
if magic_wrapper is None: | |
logger.error("MagicArticulate wrapper not initialized") | |
return "❌ 错误", "", None, None, None, None, "", "AI模型未初始化" | |
processing_status[session_id] = create_processing_status( | |
"processing", 0.3, "正在生成骨骼结构..." | |
) | |
# 执行处理 | |
result = magic_wrapper.process_3d_model( | |
model_file_path=file_path, | |
prompt=prompt, | |
confidence_threshold=confidence_threshold, | |
generate_preview=generate_preview | |
) | |
processing_status[session_id] = create_processing_status( | |
"finalizing", 0.9, "正在准备输出文件..." | |
) | |
# 处理结果 | |
if not result['success']: | |
error_msg = result.get('error', 'Unknown error') | |
logger.error(f"Processing failed: {error_msg}") | |
return "❌ 处理失败", "", None, None, None, None, "", error_msg | |
# 保存结果到会话 | |
session_results[session_id] = result | |
# 准备输出数据 | |
skeleton_data = result['skeleton_data'] | |
output_files = result['output_files'] | |
processing_info = result['processing_info'] | |
# 格式化骨骼数据为文本显示 | |
skeleton_json = f"""骨骼结构数据预览 | |
=================== | |
关节数量: {skeleton_data.get('joint_count', 0)} | |
骨骼数量: {skeleton_data.get('bone_count', 0)} | |
根节点索引: {skeleton_data.get('root_index', 0)} | |
关节坐标 (前10个): | |
{str(skeleton_data.get('joints', [])[:10])} | |
骨骼连接 (前10个): | |
{str(skeleton_data.get('bones', [])[:10])} | |
用户提示: {skeleton_data.get('user_prompt', 'N/A')} | |
""" | |
# 准备各个文件供下载 | |
obj_file = output_files.get('skeleton_obj', None) | |
txt_file = output_files.get('skeleton_txt', None) | |
# 创建ZIP包含所有文件 | |
zip_file = create_download_package(output_files, session_id) | |
# 处理时间 | |
processing_time = time.time() - start_time | |
# 准备处理信息 | |
info_text = f""" | |
## 处理完成! ✅ | |
### 📊 处理统计 | |
- **输入文件**: {processing_info.get('input_file', 'Unknown')} | |
- **处理时间**: {format_processing_time(processing_time)} | |
- **提示词**: {processing_info.get('prompt', 'None')} | |
### 🦴 骨骼数据 | |
- **关节数量**: {processing_info.get('joint_count', 0)} | |
- **骨骼数量**: {processing_info.get('bone_count', 0)} | |
- **根节点索引**: {skeleton_data.get('root_index', 0)} | |
### 📁 可下载文件 | |
1. **骨骼模型 (OBJ)** - 3D骨骼的几何表示,可在3D软件中查看 | |
2. **绑定数据 (TXT)** - 传统的骨骼绑定格式,适合导入到动画软件 | |
3. **完整包 (ZIP)** - 包含所有输出文件的压缩包 | |
### 💡 使用建议 | |
- OBJ格式可以直接在Blender、Maya等3D软件中查看 | |
- TXT格式符合传统骨骼绑定标准,便于集成到现有工作流程 | |
- ZIP包含所有文件和处理报告,方便归档和分享 | |
""" | |
processing_status[session_id] = create_processing_status( | |
"completed", 1.0, "处理完成!" | |
) | |
logger.info(f"✅ Processing completed successfully in {processing_time:.1f}s") | |
return ( | |
"✅ 处理完成", | |
skeleton_json, | |
obj_file, | |
txt_file, | |
zip_file, | |
info_text, | |
"", | |
skeleton_data # 添加原始skeleton_data用于3D预览 | |
) | |
except Exception as e: | |
processing_time = time.time() - start_time | |
error_msg = f"处理过程中发生错误: {str(e)}" | |
logger.error(f"💥 Processing error: {error_msg}") | |
logger.error(traceback.format_exc()) | |
processing_status[session_id] = create_processing_status( | |
"error", 0.0, error_msg | |
) | |
return ( | |
"❌ 处理失败", | |
"", | |
None, | |
None, | |
None, | |
f"处理时间: {format_processing_time(processing_time)}", | |
error_msg, | |
None # 空的skeleton_data | |
) | |
def create_visualization_html(skeleton_data: Dict[str, Any]) -> str: | |
""" | |
创建骨骼可视化的HTML | |
使用Three.js进行简单的3D展示 | |
""" | |
joints = skeleton_data.get('joints', []) | |
bones = skeleton_data.get('bones', []) | |
html_content = f""" | |
<div id="skeleton-viewer" style="width: 100%; height: 400px; border: 1px solid #ddd;"> | |
<canvas id="three-canvas" style="width: 100%; height: 100%;"></canvas> | |
</div> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script> | |
<script> | |
// 简单的Three.js骨骼可视化 | |
const scene = new THREE.Scene(); | |
scene.background = new THREE.Color(0xf0f0f0); | |
const camera = new THREE.PerspectiveCamera(75, 1, 0.1, 1000); | |
camera.position.set(2, 2, 2); | |
camera.lookAt(0, 0, 0); | |
const renderer = new THREE.WebGLRenderer({{canvas: document.getElementById('three-canvas')}}); | |
renderer.setSize(400, 400); | |
// 添加光源 | |
const light = new THREE.DirectionalLight(0xffffff, 1); | |
light.position.set(1, 1, 1); | |
scene.add(light); | |
// 添加网格 | |
const gridHelper = new THREE.GridHelper(4, 10); | |
scene.add(gridHelper); | |
// 骨骼数据 | |
const joints = {json.dumps(joints)}; | |
const bones = {json.dumps(bones)}; | |
// 创建关节球体 | |
joints.forEach((joint, index) => {{ | |
const geometry = new THREE.SphereGeometry(0.05); | |
const material = new THREE.MeshPhongMaterial({{color: 0xff0000}}); | |
const sphere = new THREE.Mesh(geometry, material); | |
sphere.position.set(joint[0], joint[1], joint[2]); | |
scene.add(sphere); | |
}}); | |
// 创建骨骼线条 | |
bones.forEach(bone => {{ | |
const start = joints[bone[0]]; | |
const end = joints[bone[1]]; | |
const points = []; | |
points.push(new THREE.Vector3(start[0], start[1], start[2])); | |
points.push(new THREE.Vector3(end[0], end[1], end[2])); | |
const geometry = new THREE.BufferGeometry().setFromPoints(points); | |
const material = new THREE.LineBasicMaterial({{color: 0x0000ff}}); | |
const line = new THREE.Line(geometry, material); | |
scene.add(line); | |
}}); | |
// 动画循环 | |
function animate() {{ | |
requestAnimationFrame(animate); | |
scene.rotation.y += 0.01; | |
renderer.render(scene, camera); | |
}} | |
animate(); | |
</script> | |
""" | |
return html_content | |
def create_demo_interface(): | |
"""创建增强版Gradio界面""" | |
# 自定义CSS | |
custom_css = """ | |
.gradio-container { | |
max-width: 1400px; | |
margin: 0 auto; | |
} | |
.download-section { | |
border: 2px solid #e0e0e0; | |
border-radius: 10px; | |
padding: 20px; | |
margin: 10px 0; | |
background-color: #f9f9f9; | |
} | |
.status-box { | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
background-color: #f8f9fa; | |
} | |
.success-status { | |
border-color: #28a745; | |
background-color: #d4edda; | |
} | |
.error-status { | |
border-color: #dc3545; | |
background-color: #f8d7da; | |
} | |
.info-panel { | |
font-family: monospace; | |
font-size: 14px; | |
line-height: 1.4; | |
} | |
.file-download-btn { | |
margin: 5px; | |
min-width: 200px; | |
} | |
""" | |
# 创建界面 | |
with gr.Blocks( | |
title=config['ui']['title'] + " Enhanced", | |
theme=gr.themes.Soft(), | |
css=custom_css | |
) as demo: | |
# 标题和描述 | |
gr.Markdown(f""" | |
# {config['ui']['title']} - 增强版 | |
{config['ui']['description']} | |
### ✨ 增强功能 | |
- 📁 **多格式下载** - OBJ, TXT, ZIP | |
- 👁️ **骨骼预览** - 3D可视化展示 | |
- 📊 **详细统计** - 完整的处理信息 | |
- 🚀 **批量下载** - 一键下载所有文件 | |
""") | |
# 主界面 | |
with gr.Row(): | |
# 左侧 - 输入 | |
with gr.Column(scale=1): | |
gr.Markdown("### 📤 输入设置") | |
# 文件上传 | |
model_file = gr.File( | |
label="上传3D模型", | |
file_types=['.obj', '.glb', '.ply', '.stl'], | |
file_count="single" | |
) | |
# 提示词输入 | |
prompt_input = gr.Textbox( | |
label="提示词", | |
placeholder="描述你想要的骨骼类型,例如:realistic human skeleton for animation", | |
lines=3, | |
value=DEMO_PROMPTS['generic'] | |
) | |
# 提示词建议 | |
with gr.Accordion("💡 提示词建议", open=False): | |
for key, prompt in DEMO_PROMPTS.items(): | |
gr.Button( | |
f"{key.title()}: {prompt}", | |
size="sm" | |
).click( | |
lambda p=prompt: p, | |
outputs=prompt_input | |
) | |
# 高级选项 | |
with gr.Accordion("⚙️ 高级选项", open=False): | |
confidence_threshold = gr.Slider( | |
label="置信度阈值", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.8, | |
step=0.1 | |
) | |
generate_preview = gr.Checkbox( | |
label="生成预览图", | |
value=True | |
) | |
# 处理按钮 | |
process_btn = gr.Button( | |
"🎯 生成骨骼", | |
variant="primary", | |
size="lg" | |
) | |
# 右侧 - 输出 | |
with gr.Column(scale=2): | |
gr.Markdown("### 📥 处理结果") | |
# 状态显示 | |
status_text = gr.Textbox( | |
label="处理状态", | |
value="等待处理...", | |
interactive=False | |
) | |
# 标签页组织输出 | |
with gr.Tabs(): | |
# 数据展示标签 | |
with gr.TabItem("📊 骨骼数据"): | |
skeleton_data_json = gr.Textbox( | |
label="骨骼数据预览", | |
lines=15, | |
interactive=False, | |
show_copy_button=True | |
) | |
# 3D预览标签 | |
with gr.TabItem("👁️ 3D预览"): | |
skeleton_preview = gr.HTML( | |
label="骨骼可视化", | |
value="<p>等待处理...</p>" | |
) | |
# 下载标签 | |
with gr.TabItem("📁 文件下载"): | |
gr.Markdown("### 下载骨骼文件") | |
with gr.Row(): | |
download_obj = gr.File( | |
label="🎨 OBJ格式", | |
visible=True | |
) | |
download_txt = gr.File( | |
label="📝 TXT格式", | |
visible=True | |
) | |
with gr.Row(): | |
download_zip = gr.File( | |
label="📦 完整包(ZIP)", | |
visible=True | |
) | |
# 处理信息标签 | |
with gr.TabItem("ℹ️ 处理信息"): | |
processing_info = gr.Markdown( | |
value="等待处理..." | |
) | |
# 错误信息(通常隐藏) | |
error_info = gr.Textbox( | |
label="错误信息", | |
visible=False, | |
interactive=False | |
) | |
# 示例模型已移除 - 直接上传您的3D模型开始使用 | |
# 使用说明 | |
with gr.Accordion("📖 使用说明", open=False): | |
gr.Markdown(""" | |
## 🎯 如何使用 | |
1. **上传模型** - 支持OBJ, GLB, PLY, STL格式 | |
2. **输入提示词** - 描述期望的骨骼类型 | |
3. **点击生成** - 等待30-120秒 | |
4. **查看结果** - 在不同标签页查看数据、预览和下载 | |
## 📁 输出文件说明 | |
- **OBJ** - 可在3D软件中查看的骨骼模型 | |
- **TXT** - 传统骨骼绑定格式 | |
- **ZIP** - 包含所有文件的压缩包 | |
## 💡 提示 | |
- 模型应该是封闭的网格以获得最佳效果 | |
- 复杂模型可能需要更长处理时间 | |
- 使用具体的提示词可以获得更好的结果 | |
""") | |
# 事件绑定 | |
def process_and_update_ui(model_file, prompt, confidence, preview): | |
# 处理模型 | |
status, json_data, obj_file, txt_file, zip_file, info, error, skeleton_data = process_3d_model_gpu( | |
model_file, prompt, confidence, preview | |
) | |
# 生成3D预览 | |
preview_html = "<p>暂无预览</p>" | |
if status == "✅ 处理完成" and skeleton_data: | |
try: | |
preview_html = create_visualization_html(skeleton_data) | |
except Exception as e: | |
preview_html = f"<p>预览生成失败: {str(e)}</p>" | |
# 更新可见性 | |
error_visible = status.startswith("❌") | |
return ( | |
status, # 状态 | |
json_data, # JSON展示 | |
obj_file, # OBJ下载 | |
txt_file, # TXT下载 | |
zip_file, # ZIP下载 | |
preview_html, # 3D预览 | |
info, # 处理信息 | |
error, # 错误信息 | |
gr.update(visible=error_visible) # 错误框可见性 | |
) | |
# 绑定处理按钮 | |
process_btn.click( | |
fn=process_and_update_ui, | |
inputs=[ | |
model_file, | |
prompt_input, | |
confidence_threshold, | |
generate_preview | |
], | |
api_name="predict", | |
outputs=[ | |
status_text, | |
skeleton_data_json, | |
download_obj, | |
download_txt, | |
download_zip, | |
skeleton_preview, | |
processing_info, | |
error_info, | |
error_info # 控制可见性 | |
] | |
) | |
# 页脚 | |
gr.Markdown(""" | |
--- | |
## 🔗 相关链接 | |
- [MagicArticulate Paper](https://github.com/Seed3D/MagicArticulate) | |
- [ArticulateHub Project](https://github.com/your-repo) | |
- [Hugging Face Spaces](https://huggingface.co/spaces) | |
**Made with ❤️ using Gradio and ZeroGPU** | |
""") | |
return demo | |
def main(): | |
"""主函数""" | |
try: | |
logger.info("🚀 Starting Enhanced MagicArticulate MVP...") | |
# 初始化应用 | |
if not initialize_app(): | |
logger.error("❌ Failed to initialize app") | |
return | |
# 创建界面 | |
demo = create_demo_interface() | |
# 启动应用 | |
logger.info("🌟 Launching Enhanced Gradio interface...") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_api=True, | |
share=False, | |
debug=False | |
) | |
except Exception as e: | |
logger.error(f"💥 Main function failed: {str(e)}") | |
logger.error(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() |