import os import glob import argparse import logging import numpy as np import cv2 import rembg if __name__ == "__main__": parser = argparse.ArgumentParser( description="Remove background and center the image of an object" ) parser.add_argument( "dir_or_path", type=str, help="Directory or path to images (png, jpeg, webp, etc.)" ) parser.add_argument( "--model_name", default="u2net", # "isnet-general-use", "birefnet-general", "birefnet-dis", "birefnet-massive" type=str, help="Rembg model, see https://github.com/danielgatis/rembg#models" ) parser.add_argument( "--size", default=512, type=int, help="Output resolution" ) parser.add_argument( "--border_ratio", default=0.2, type=float, help="Output border ratio" ) parser.add_argument( "--center", action="store_true", help="Center the object, potentially not helpful for multiview zero123" ) # Parse the arguments args = parser.parse_args() # Initialize the logger logging.basicConfig( format="%(asctime)s - REMBG&CENTER - %(message)s", datefmt="%Y/%m/%d %H:%M:%S", level=logging.INFO ) logger = logging.getLogger(__name__) logger.propagate = True # propagate to the root logger (console) # Create a session for rembg session = rembg.new_session(model_name=args.model_name) if os.path.isdir(args.dir_or_path): logger.info(f"Processing directory [{args.dir_or_path}]...") files = glob.glob(f"{args.dir_or_path}/*") out_dir = args.dir_or_path else: # single file files = [args.dir_or_path] out_dir = os.path.dirname(args.dir_or_path) for file in files: out_base = os.path.basename(file).split(".")[0] out_rgba = os.path.join(out_dir, out_base + "_rgba.png") # Load image and resize logger.info(f"Loading image [{file}]...") image = cv2.imread(file, cv2.IMREAD_UNCHANGED) _h, _w = image.shape[:2] scale = args.size / max(_h, _w) _h, _w = int(_h * scale), int(_w * scale) image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA) # Remove background logger.info("Removing background...") carved_image = rembg.remove(image, session=session) # (H, W, 4) mask = carved_image[..., -1] > 0 # Center the object if args.center: logger.info("Centering object...") final_rgba = np.zeros((args.size, args.size, 4), dtype=np.uint8) coords = np.nonzero(mask) x_min, x_max = coords[0].min(), coords[0].max() y_min, y_max = coords[1].min(), coords[1].max() h = x_max - x_min w = y_max - y_min desired_size = int(args.size * (1 - args.border_ratio)) scale = desired_size / max(h, w) h2 = int(h * scale) w2 = int(w * scale) x2_min = (args.size - h2) // 2 x2_max = x2_min + h2 y2_min = (args.size - w2) // 2 y2_max = y2_min + w2 final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize( carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA ) else: final_rgba = carved_image # Save image cv2.imwrite(out_rgba, final_rgba) print() # newline after the process