Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import glob | |
import argparse | |
import logging | |
import numpy as np | |
import cv2 | |
import rembg | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Remove background and center the image of an object" | |
) | |
parser.add_argument( | |
"dir_or_path", | |
type=str, | |
help="Directory or path to images (png, jpeg, webp, etc.)" | |
) | |
parser.add_argument( | |
"--model_name", | |
default="u2net", # "isnet-general-use", "birefnet-general", "birefnet-dis", "birefnet-massive" | |
type=str, | |
help="Rembg model, see https://github.com/danielgatis/rembg#models" | |
) | |
parser.add_argument( | |
"--size", | |
default=512, | |
type=int, | |
help="Output resolution" | |
) | |
parser.add_argument( | |
"--border_ratio", | |
default=0.2, | |
type=float, | |
help="Output border ratio" | |
) | |
parser.add_argument( | |
"--center", | |
action="store_true", | |
help="Center the object, potentially not helpful for multiview zero123" | |
) | |
# Parse the arguments | |
args = parser.parse_args() | |
# Initialize the logger | |
logging.basicConfig( | |
format="%(asctime)s - REMBG&CENTER - %(message)s", | |
datefmt="%Y/%m/%d %H:%M:%S", | |
level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
logger.propagate = True # propagate to the root logger (console) | |
# Create a session for rembg | |
session = rembg.new_session(model_name=args.model_name) | |
if os.path.isdir(args.dir_or_path): | |
logger.info(f"Processing directory [{args.dir_or_path}]...") | |
files = glob.glob(f"{args.dir_or_path}/*") | |
out_dir = args.dir_or_path | |
else: # single file | |
files = [args.dir_or_path] | |
out_dir = os.path.dirname(args.dir_or_path) | |
for file in files: | |
out_base = os.path.basename(file).split(".")[0] | |
out_rgba = os.path.join(out_dir, out_base + "_rgba.png") | |
# Load image and resize | |
logger.info(f"Loading image [{file}]...") | |
image = cv2.imread(file, cv2.IMREAD_UNCHANGED) | |
_h, _w = image.shape[:2] | |
scale = args.size / max(_h, _w) | |
_h, _w = int(_h * scale), int(_w * scale) | |
image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA) | |
# Remove background | |
logger.info("Removing background...") | |
carved_image = rembg.remove(image, session=session) # (H, W, 4) | |
mask = carved_image[..., -1] > 0 | |
# Center the object | |
if args.center: | |
logger.info("Centering object...") | |
final_rgba = np.zeros((args.size, args.size, 4), dtype=np.uint8) | |
coords = np.nonzero(mask) | |
x_min, x_max = coords[0].min(), coords[0].max() | |
y_min, y_max = coords[1].min(), coords[1].max() | |
h = x_max - x_min | |
w = y_max - y_min | |
desired_size = int(args.size * (1 - args.border_ratio)) | |
scale = desired_size / max(h, w) | |
h2 = int(h * scale) | |
w2 = int(w * scale) | |
x2_min = (args.size - h2) // 2 | |
x2_max = x2_min + h2 | |
y2_min = (args.size - w2) // 2 | |
y2_max = y2_min + w2 | |
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize( | |
carved_image[x_min:x_max, y_min:y_max], | |
(w2, h2), | |
interpolation=cv2.INTER_AREA | |
) | |
else: | |
final_rgba = carved_image | |
# Save image | |
cv2.imwrite(out_rgba, final_rgba) | |
print() # newline after the process | |