Diffsplat / extensions /rembg_and_center.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
import os
import glob
import argparse
import logging
import numpy as np
import cv2
import rembg
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Remove background and center the image of an object"
)
parser.add_argument(
"dir_or_path",
type=str,
help="Directory or path to images (png, jpeg, webp, etc.)"
)
parser.add_argument(
"--model_name",
default="u2net", # "isnet-general-use", "birefnet-general", "birefnet-dis", "birefnet-massive"
type=str,
help="Rembg model, see https://github.com/danielgatis/rembg#models"
)
parser.add_argument(
"--size",
default=512,
type=int,
help="Output resolution"
)
parser.add_argument(
"--border_ratio",
default=0.2,
type=float,
help="Output border ratio"
)
parser.add_argument(
"--center",
action="store_true",
help="Center the object, potentially not helpful for multiview zero123"
)
# Parse the arguments
args = parser.parse_args()
# Initialize the logger
logging.basicConfig(
format="%(asctime)s - REMBG&CENTER - %(message)s",
datefmt="%Y/%m/%d %H:%M:%S",
level=logging.INFO
)
logger = logging.getLogger(__name__)
logger.propagate = True # propagate to the root logger (console)
# Create a session for rembg
session = rembg.new_session(model_name=args.model_name)
if os.path.isdir(args.dir_or_path):
logger.info(f"Processing directory [{args.dir_or_path}]...")
files = glob.glob(f"{args.dir_or_path}/*")
out_dir = args.dir_or_path
else: # single file
files = [args.dir_or_path]
out_dir = os.path.dirname(args.dir_or_path)
for file in files:
out_base = os.path.basename(file).split(".")[0]
out_rgba = os.path.join(out_dir, out_base + "_rgba.png")
# Load image and resize
logger.info(f"Loading image [{file}]...")
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
_h, _w = image.shape[:2]
scale = args.size / max(_h, _w)
_h, _w = int(_h * scale), int(_w * scale)
image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA)
# Remove background
logger.info("Removing background...")
carved_image = rembg.remove(image, session=session) # (H, W, 4)
mask = carved_image[..., -1] > 0
# Center the object
if args.center:
logger.info("Centering object...")
final_rgba = np.zeros((args.size, args.size, 4), dtype=np.uint8)
coords = np.nonzero(mask)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
h = x_max - x_min
w = y_max - y_min
desired_size = int(args.size * (1 - args.border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (args.size - h2) // 2
x2_max = x2_min + h2
y2_min = (args.size - w2) // 2
y2_max = y2_min + w2
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(
carved_image[x_min:x_max, y_min:y_max],
(w2, h2),
interpolation=cv2.INTER_AREA
)
else:
final_rgba = carved_image
# Save image
cv2.imwrite(out_rgba, final_rgba)
print() # newline after the process