| | from pathlib import Path |
| | from tqdm import tqdm |
| | import pandas as pd |
| | from datasets import load_dataset |
| | import os |
| | import json |
| | import gc |
| |
|
| | from utils import empty_solution |
| | from predict import predict_wireframe |
| |
|
| | from fast_pointnet_v2 import load_pointnet_model |
| | from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model |
| | import torch |
| |
|
| | if __name__ == "__main__": |
| | print ("------------ Loading dataset------------ ") |
| | param_path = Path('params.json') |
| | print(param_path) |
| | with param_path.open() as f: |
| | params = json.load(f) |
| | print(params) |
| | import os |
| | |
| | print('pwd:') |
| | os.system('pwd') |
| | print(os.system('ls -lahtr')) |
| | print('/generic/path/to/data_dir/') |
| | print(os.system('ls -lahtr /generic/path/to/data_dir/')) |
| | print('/generic/path/to/data_dir/data') |
| | print(os.system('ls -lahtrR /generic/path/to/data_dir/data')) |
| |
|
| | data_path_test_server = Path('/generic/path/to/data_dir') |
| | data_path_local = Path("/generic/path/to/user_home") / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' |
| |
|
| | if data_path_test_server.exists(): |
| | |
| | TEST_ENV = True |
| | else: |
| | |
| | TEST_ENV = False |
| | from huggingface_hub import snapshot_download |
| | _ = snapshot_download( |
| | repo_id=params['dataset'], |
| | local_dir="/generic/path/to/data_dir", |
| | repo_type="dataset", |
| | ) |
| | data_path = data_path_test_server |
| | |
| | |
| | print(data_path) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | data_files = { |
| | "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
| | "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
| | } |
| | print(data_files) |
| | dataset = load_dataset( |
| | str(data_path / 'hoho25k_test_x.py'), |
| | data_files=data_files, |
| | trust_remote_code=True, |
| | writer_batch_size=100 |
| | ) |
| |
|
| | print('load with webdataset') |
| |
|
| |
|
| | print(dataset, flush=True) |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) |
| |
|
| | pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device) |
| |
|
| | voxel_model = None |
| |
|
| | config = {'vertex_threshold': 0.59, 'edge_threshold': 0.65, 'only_predicted_connections': True} |
| |
|
| | print('------------ Now you can do your solution ---------------') |
| | solution = [] |
| |
|
| | def process_sample(sample, i): |
| | try: |
| | pred_vertices, pred_edges = predict_wireframe(sample, pnet_model, voxel_model, pnet_class_model, config) |
| | except: |
| | pred_vertices, pred_edges = empty_solution() |
| | if i %10 == 0: |
| | gc.collect() |
| | return { |
| | 'order_id': sample['order_id'], |
| | 'wf_vertices': pred_vertices.tolist(), |
| | 'wf_edges': pred_edges |
| | } |
| | |
| | num_cores = 4 |
| | |
| | for subset_name in dataset.keys(): |
| | print (f"Predicting {subset_name}") |
| | for i, sample in enumerate(tqdm(dataset[subset_name])): |
| | res = process_sample(sample, i) |
| | solution.append(res) |
| |
|
| | print('------------ Saving results ---------------') |
| | sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) |
| | sub.to_parquet("submission.parquet") |
| | print("------------ Done ------------ ") |