File size: 4,021 Bytes
e7b9fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
import os
import argparse
from omegaconf import OmegaConf, DictConfig, ListConfig
import numpy as np
import torch
from .michelangelo.utils.misc import instantiate_from_config

def load_surface(fp):
    
    with np.load(fp) as input_pc:
        surface = input_pc['points']
        normal = input_pc['normals']
    
    rng = np.random.default_rng()
    ind = rng.choice(surface.shape[0], 4096, replace=False)
    surface = torch.FloatTensor(surface[ind])
    normal = torch.FloatTensor(normal[ind])
    
    surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
    
    return surface

def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):

    surface = load_surface(args.pointcloud_path)
    # old_surface = surface.clone()

    # surface[0,:,0]*=-1
    # surface[0,:,1]*=-1
    surface[0,:,2]*=-1

    # encoding
    shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)    
    shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)

    # decoding
    latents = model.model.shape_model.decode(shape_zq)
    # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
    
    return 0

def load_model(ckpt_path="third_party/Michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt"):
    import urllib.request
    from pathlib import Path
    
    # 自动下载checkpoint文件如果不存在
    if not os.path.exists(ckpt_path):
        print(f"Downloading checkpoint to {ckpt_path}...")
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        
        # HuggingFace直接下载链接
        download_url = "https://huggingface.co/Maikou/Michelangelo/resolve/main/checkpoints/aligned_shape_latents/shapevae-256.ckpt"
        
        try:
            print("正在从HuggingFace下载模型文件...")
            urllib.request.urlretrieve(download_url, ckpt_path)
            print(f"✅ 模型文件下载完成: {ckpt_path}")
        except Exception as e:
            print(f"❌ 模型文件下载失败: {e}")
            # 如果下载失败,返回一个简化的模型
            import torch.nn as nn
            class DummyModel(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.dummy = nn.Linear(1, 1)
                def forward(self, x):
                    return x
                def encode(self, x):
                    return torch.randn(1, 768)  # 返回期望的特征维度
            print("⚠️ 使用简化模型替代")
            return DummyModel()
    
    model_config = OmegaConf.load("third_party/Michelangelo/configs/shapevae-256.yaml")
    if hasattr(model_config, "model"):
        model_config = model_config.model

    model = instantiate_from_config(model_config, ckpt_path=ckpt_path)

    return model
if __name__ == "__main__":
    '''
    1. Reconstruct point cloud
    2. Image-conditioned generation
    3. Text-conditioned generation
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--ckpt_path", type=str, required=True)
    parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
    parser.add_argument("--image_path", type=str, help='Path to the input image')
    parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
    parser.add_argument("--output_dir", type=str, default='./output')
    parser.add_argument("-s", "--seed", type=int, default=0)
    args = parser.parse_args()
    
    print(f'-----------------------------------------------------------------------------')
    print(f'>>> Output directory: {args.output_dir}')
    print(f'-----------------------------------------------------------------------------')
    
    reconstruction(args, load_model(args))