File size: 8,984 Bytes
7b127f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

import random
import shutil

import numpy as np
import ray
import torch
import os
from tqdm import tqdm
from load_data.interface import LoadData
import pickle
from multiprocessing import Pool, cpu_count

def read_all_data(folder_list, load_data, add_model_str=True, add_ori_name=False):
    all_data = []

    for f in folder_list:
        if add_model_str:
            result = load_data.run(os.path.join(f, 'model', 'mesh'))
        elif add_ori_name:
            result = load_data.run(os.path.join(f, f.split('/')[-1], 'mesh'))
        else:
            result = load_data.run(os.path.join(f, 'mesh'))

        all_data.append(result)
    q8_table = all_data[0][0]
    align_10 = all_data[0][1]
    dest_ArtCoeff = [r[2][np.newaxis, :] for r in all_data]
    dest_FdCoeff_q8 = [r[3][np.newaxis, :] for r in all_data]
    dest_CirCoeff_q8 = [r[4][np.newaxis, :] for r in all_data]
    dest_EccCoeff_q8 = [r[5][np.newaxis, :] for r in all_data]
    SRC_ANGLE = 10
    ANGLE = 10
    CAMNUM = 10
    ART_COEF = 35
    FD_COEF = 10
    n_shape = len(all_data)
    dest_ArtCoeff = torch.from_numpy(np.ascontiguousarray(np.concatenate(dest_ArtCoeff, axis=0))).int().cuda().reshape(n_shape, SRC_ANGLE,
                                                                                                                       CAMNUM, ART_COEF)
    dest_FdCoeff_q8 = torch.from_numpy(np.ascontiguousarray(np.concatenate(dest_FdCoeff_q8, axis=0))).int().cuda().reshape(n_shape, ANGLE,
                                                                                                                           CAMNUM, FD_COEF)
    dest_CirCoeff_q8 = torch.from_numpy(np.ascontiguousarray(np.concatenate(dest_CirCoeff_q8, axis=0))).int().cuda().reshape(n_shape, ANGLE,
                                                                                                                             CAMNUM)
    dest_EccCoeff_q8 = torch.from_numpy(np.ascontiguousarray(np.concatenate(dest_EccCoeff_q8, axis=0))).int().cuda().reshape(n_shape, ANGLE,
                                                                                                                             CAMNUM)
    q8_table = torch.from_numpy(np.ascontiguousarray(q8_table)).int().cuda().reshape(256, 256)
    align_10 = torch.from_numpy(np.ascontiguousarray(align_10)).int().cuda().reshape(60, 20)  ##
    return q8_table.contiguous(), align_10.contiguous(), dest_ArtCoeff.contiguous(), \
        dest_FdCoeff_q8.contiguous(), dest_CirCoeff_q8.contiguous(), dest_EccCoeff_q8.contiguous()


def compute_lfd_all(src_folder_list, tgt_folder_list, log):
    load_data = LoadData()

    add_ori_name = False
    add_model_str = False
    src_folder_list.sort()
    tgt_folder_list.sort()

    q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8 = read_all_data(src_folder_list, load_data,
                                                                                                       add_model_str=False)
    q8_table, align_10, tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8 = read_all_data(tgt_folder_list, load_data,
                                                                                                       add_model_str=add_model_str,
                                                                                                       add_ori_name=add_ori_name)  ###

    from lfd_all_compute.lfd import LFD
    lfd = LFD()
    lfd_matrix = lfd.forward(
            q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8,
            tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log)
    # print(lfd_matrix)
    # print(lfd_matrix.shape)
    mmd = lfd_matrix.float().min(dim=0)[0].mean()
    mmd_swp = lfd_matrix.float().min(dim=1)[0].mean()
    # print(mmd)
    # print(mmd_swp)
    return lfd_matrix.data.cpu().numpy()

def get_file_size_kb(mesh_path):
    return int(os.path.getsize(mesh_path) / 1024)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--mesh_path", type=str, required=True, help="path to the mesh folder")
    parser.add_argument("--lfd_feat", type=str, required=True, help="path to the preprocessed shapenet dataset")
    parser.add_argument("--save_root", type=str, required=True, help="path to the save resules shapenet dataset")
    parser.add_argument("--num_workers", type=int, default=1, help="number of workers to run in parallel")
    parser.add_argument("--list", type=str, default=None, help="list file in the training set")
    args = parser.parse_args()
    num_workers = args.num_workers
    listfile = args.list

    mesh_folder_path = args.mesh_path
    lfd_feat_path = args.lfd_feat
    save_root = args.save_root
    os.makedirs(save_root, exist_ok=True)


    print(f"mesh_path: {mesh_folder_path}")
    print(f"lfd_feat_path: {lfd_feat_path}")

    all_folders = os.listdir(mesh_folder_path)
    all_folders.sort()
    print("Get mesh_size")
    mesh_folder_list = []
    mesh_path_list = []
    # mesh_size_list = []
    for mesh_folder in tqdm(all_folders):
        mesh_path = os.path.join(mesh_folder_path, mesh_folder, "mesh.stl")
        mesh_folder_list.append(mesh_folder)
        mesh_path_list.append(mesh_path)
        # mesh_size_list.append(int(os.path.getsize(mesh_path) / 1024))

    with Pool(processes=cpu_count()) as pool:
        mesh_size_list = list(tqdm(pool.imap(get_file_size_kb, mesh_path_list), total=len(mesh_path_list)))

    # sort according to the size of the mesh file
    assert len(mesh_size_list) == len(mesh_folder_list)
    # mesh_folder_list = [x for _, x in sorted(zip(mesh_size_list, mesh_folder_list))]
    # mesh_size_list = sorted(mesh_size_list)
    mesh_size_list = np.array(mesh_size_list)
    print(f"Max size: {mesh_size_list.max()}")
    print(f"Min size: {mesh_size_list.min()}")
    print(f"Total {mesh_size_list.shape} mesh_folder to process")

    tgt_folder_list = mesh_folder_list

    if listfile is not None:
        valid_folders = [item.strip() for item in open(listfile, 'r').readlines()]
        tgt_folder_list = sorted(list(set(valid_folders) & set(tgt_folder_list)))
        tgt_folder_list = [os.path.join(lfd_feat_path, f) for f in tgt_folder_list]
    else:
        tgt_folder_list = [os.path.join(lfd_feat_path, f) for f in tgt_folder_list]

    src_folder_list = tgt_folder_list

    start_from_size_end = 0
    print(f"Start from size_end: {start_from_size_end}")
    print((mesh_size_list>start_from_size_end).sum()/mesh_size_list.shape[0])

    ray.init(
            num_cpus=os.cpu_count(),
            num_gpus=num_workers,
    )

    compute_lfd_all_remote = ray.remote(num_gpus=1, num_cpus=os.cpu_count() // num_workers)(compute_lfd_all)

    print("Check data")
    print(f"len of src_folder_list: {len(src_folder_list)}")
    print(f"len of tgt_folder_list: {len(tgt_folder_list)}")
    print(src_folder_list[0])
    print(tgt_folder_list[0])

    batch_size = 1
    offset = 2

    for size_start in tqdm(range(mesh_size_list.min(), mesh_size_list.max(), batch_size)):
        size_end = size_start + offset
        print(f"size_start: {size_start}, size_end: {size_end}, max_size: {mesh_size_list.max()}")
        if size_end <= start_from_size_end:
            continue
        # get the folder list for the current batch
        hitted_idx = np.where((mesh_size_list >= size_start) & (mesh_size_list <= size_end))[0]
        print(f"len of hitted folder: {len(hitted_idx)}")
        if len(hitted_idx) == 0:
            continue
        local_num_workers = min(num_workers, len(hitted_idx))
        local_tgt_folder_list = [tgt_folder_list[i] for i in hitted_idx]
        local_src_folder_list = local_tgt_folder_list
        results = []
        for i in range(local_num_workers):
            local_i_start = i * len(local_src_folder_list) // local_num_workers
            local_i_end = (i + 1) * len(local_src_folder_list) // local_num_workers
            results.append(compute_lfd_all_remote.remote(
                    local_src_folder_list[local_i_start:local_i_end],
                    local_tgt_folder_list,
                    i == 0))
        lfd_matrix = ray.get(results)
        lfd_matrix = np.concatenate(lfd_matrix, axis=0)

        save_name = os.path.join(save_root, f"lfd_{size_start:07d}kb_{size_end:07d}kb.pkl")
        pickle.dump([local_tgt_folder_list, lfd_matrix], open(save_name, 'wb'))
        print(f"pkl is saved to {save_name}\n\n")