File size: 1,261 Bytes
bef5729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import trimesh
import numpy as np
import argparse
import json
import torch
from huggingface_hub import snapshot_download

from src.utils.image_utils import prepare_image
from src.models.briarmbg import BriaRMBG

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, default='preprocessed_data/scissors/scissors.png')
    parser.add_argument('--output', type=str, default='preprocessed_data')
    args = parser.parse_args()

    input_path = args.input
    output_path = args.output

    assert os.path.exists(input_path), f'{input_path} does not exist'

    mesh_name = os.path.basename(os.path.dirname(input_path))
    output_path = os.path.join(output_path, mesh_name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
    snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
    rendering_rmbg = prepare_image(input_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net, device=device)
    rendering_rmbg.save(os.path.join(output_path, f'rendering_rmbg.png'))