File size: 3,572 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import glob
import argparse
import logging

import numpy as np
import cv2
import rembg


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Remove background and center the image of an object"
    )

    parser.add_argument(
        "dir_or_path",
        type=str,
        help="Directory or path to images (png, jpeg, webp, etc.)"
    )
    parser.add_argument(
        "--model_name",
        default="u2net",  # "isnet-general-use", "birefnet-general", "birefnet-dis", "birefnet-massive"
        type=str,
        help="Rembg model, see https://github.com/danielgatis/rembg#models"
    )
    parser.add_argument(
        "--size",
        default=512,
        type=int,
        help="Output resolution"
    )
    parser.add_argument(
        "--border_ratio",
        default=0.2,
        type=float,
        help="Output border ratio"
    )
    parser.add_argument(
        "--center",
        action="store_true",
        help="Center the object, potentially not helpful for multiview zero123"
    )

    # Parse the arguments
    args = parser.parse_args()

    # Initialize the logger
    logging.basicConfig(
        format="%(asctime)s - REMBG&CENTER - %(message)s",
        datefmt="%Y/%m/%d %H:%M:%S",
        level=logging.INFO
    )
    logger = logging.getLogger(__name__)
    logger.propagate = True  # propagate to the root logger (console)

    # Create a session for rembg
    session = rembg.new_session(model_name=args.model_name)

    if os.path.isdir(args.dir_or_path):
        logger.info(f"Processing directory [{args.dir_or_path}]...")
        files = glob.glob(f"{args.dir_or_path}/*")
        out_dir = args.dir_or_path
    else:  # single file
        files = [args.dir_or_path]
        out_dir = os.path.dirname(args.dir_or_path)

    for file in files:
        out_base = os.path.basename(file).split(".")[0]
        out_rgba = os.path.join(out_dir, out_base + "_rgba.png")

        # Load image and resize
        logger.info(f"Loading image [{file}]...")
        image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
        _h, _w = image.shape[:2]
        scale = args.size / max(_h, _w)
        _h, _w = int(_h * scale), int(_w * scale)
        image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA)

        # Remove background
        logger.info("Removing background...")
        carved_image = rembg.remove(image, session=session) # (H, W, 4)
        mask = carved_image[..., -1] > 0

        # Center the object
        if args.center:
            logger.info("Centering object...")
            final_rgba = np.zeros((args.size, args.size, 4), dtype=np.uint8)

            coords = np.nonzero(mask)
            x_min, x_max = coords[0].min(), coords[0].max()
            y_min, y_max = coords[1].min(), coords[1].max()
            h = x_max - x_min
            w = y_max - y_min
            desired_size = int(args.size * (1 - args.border_ratio))
            scale = desired_size / max(h, w)
            h2 = int(h * scale)
            w2 = int(w * scale)
            x2_min = (args.size - h2) // 2
            x2_max = x2_min + h2
            y2_min = (args.size - w2) // 2
            y2_max = y2_min + w2
            final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(
                carved_image[x_min:x_max, y_min:y_max],
                (w2, h2),
                interpolation=cv2.INTER_AREA
            )
        else:
            final_rgba = carved_image
        
        # Save image
        cv2.imwrite(out_rgba, final_rgba)

    print()  # newline after the process