GitHub Actions
Deploy to Hugging Face Space: product-image-update-port-10
18faf97
# ----------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------
import os
import cv2
import logging
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForMaskGeneration
from simple_lama_inpainting import SimpleLama
# ----------------------------------------------------------------------
# MODEL REPOSITORY IDENTIFIERS
# ----------------------------------------------------------------------
SAM_REPO = "facebook/sam-vit-huge"
# ----------------------------------------------------------------------
# MODEL PRECISION SETTINGS
# ----------------------------------------------------------------------
SAM_FULL_PRECISION = True
LAMA_FULL_PRECISION = True
# ----------------------------------------------------------------------
# GLOBAL MODEL INSTANCES
# ----------------------------------------------------------------------
SAM_PROCESSOR = None
SAM_MODEL = None
SIMPLE_LAMA = None
# ----------------------------------------------------------------------
# INITIALIZE MODELS
# ----------------------------------------------------------------------
def initialize_sam_and_lama(device="cuda"):
global SAM_PROCESSOR, SAM_MODEL, SIMPLE_LAMA
if SAM_PROCESSOR is None or SAM_MODEL is None or SIMPLE_LAMA is None:
logging.info("Loading SAM model...")
SAM_PROCESSOR = AutoProcessor.from_pretrained(SAM_REPO)
SAM_MODEL = load_sam_model(SAM_REPO, SAM_FULL_PRECISION)
logging.info("Loading LaMa inpainting model...")
lama_device = "cpu"
logging.info("LAMA will use CPU - this is intentional for compatibility")
SIMPLE_LAMA = SimpleLama(device=lama_device)
logging.info(f"Successfully loaded LAMA model on {lama_device.upper()}")
def load_sam_model(repo_id: str, full_precision: bool):
try:
torch.cuda.empty_cache()
model = AutoModelForMaskGeneration.from_pretrained(
repo_id,
device_map="auto",
torch_dtype=torch.float32 if full_precision else torch.float16
)
if not hasattr(model, 'hf_device_map'):
model = model.cuda()
if not full_precision:
model = model.half()
model.eval()
with torch.no_grad():
logging.info(f"Verifying SAM model is on CUDA")
param = next(model.parameters())
if not param.is_cuda:
model = model.cuda()
logging.warning(f"Forced SAM model to CUDA")
logging.info(f"SAM model device: {param.device}")
return model
except Exception as e:
logging.error(f"Failed to load SAM model: {e}")
raise
# ----------------------------------------------------------------------
# ARTIFACT UTILITIES
# ----------------------------------------------------------------------
ARTIFACTS_LIST = ["jewelry", "necklace", "bracelet", "ring", "earrings", "watch", "glasses"]
# ----------------------------------------------------------------------
# UNDER DEVELOPMENT
# ----------------------------------------------------------------------
@pipeline_step
def remove_object_batch(contexts: List[ProcessingContext], batch_logs: List[dict]) -> None:
initialize_sam_and_lama()
logging.info(f"[DEBUG] remove_object_batch => Starting with {len(contexts)} contexts.")
for ctx_idx, ctx in enumerate(contexts):
step_log = {
"function": "remove_object_batch",
"image_url": getattr(ctx, "url", "unknown"),
"status": None,
"artifacts_found": [],
"image_dimensions": None,
"artifact_boxes": []
}
if ctx.skip_run or ctx.skip_processing:
step_log["status"] = "skipped"
batch_logs.append(step_log)
continue
if "original" not in ctx.pil_img:
logging.debug(f"(Context #{ctx_idx}) => RBC 'original' missing => {ctx.url}")
step_log["status"] = "error"
step_log["exception"] = "No RBC 'original' in ctx"
ctx.skip_run = True
batch_logs.append(step_log)
continue
dr = ctx.detection_result
if not dr or dr.get("status") != "ok":
logging.debug(f"(Context #{ctx_idx}) => No valid detection => {ctx.url}")
step_log["status"] = "no_detection"
batch_logs.append(step_log)
continue
boxes = dr.get("boxes", [])
kws = dr.get("final_keywords", [])
if len(boxes) != len(kws) or not boxes:
logging.debug(f"(Context #{ctx_idx}) => mismatch or no boxes => {ctx.url}")
step_log["status"] = "no_boxes_in_detection"
batch_logs.append(step_log)
continue
artifact_indices = [i for i, kw_ in enumerate(kws) if kw_ in ARTIFACTS_LIST]
if not artifact_indices:
logging.debug(f"(Context #{ctx_idx}) => No artifacts found => {ctx.url}. Skipping flatten.")
step_log["status"] = "no_artifacts_found"
batch_logs.append(step_log)
continue
pil_rgba, orig_fname, _ = ctx.pil_img["original"]
logging.debug(f"(Context #{ctx_idx}) Flattening RBC image to white background (since artifacts exist).")
flattened = Image.new("RGB", pil_rgba.size, (255, 255, 255))
flattened.paste(pil_rgba.convert("RGB"), mask=pil_rgba.getchannel("A"))
logging.debug(f"(Context #{ctx_idx}) Background conversion complete.")
updated_img = flattened
found_labels = []
for art_i in artifact_indices:
box_ = boxes[art_i]
kw_ = kws[art_i]
step_log["artifact_boxes"].append({
"original_box": box_,
"label": kw_
})
w_img, h_img = updated_img.size
expanded = expand_bbox(box_, w_img, h_img, pad=24)
logging.debug(f"(Context #{ctx_idx}) Artifact {art_i}: Expanded box from {box_} to {expanded}.")
step_log["artifact_boxes"][-1]["expanded_box"] = expanded
logging.debug(f"(Context #{ctx_idx}) Removing object in region {expanded}.")
try:
updated_img = remove_object_inplace(
updated_img,
expanded,
SAM_PROCESSOR,
SAM_MODEL,
SIMPLE_LAMA,
device="cuda",
debug_save_prefix=f"dbg_ctx{ctx_idx}_artifact{art_i}",
dilate_mask=True,
dilate_kernel_size=40
)
logging.debug(f"(Context #{ctx_idx}) Object removal complete for artifact {art_i}.")
found_labels.append(kw_)
except RuntimeError as re:
logging.warning(f"[WARNING] TorchScript inpainting failed for artifact {art_i}, skipping removal.\n{re}")
step_log["artifact_boxes"][-1]["skipped_inpainting"] = True
ctx.pil_img["original"] = [updated_img, orig_fname, None]
step_log["artifacts_found"] = found_labels
step_log["status"] = "artifacts_removed"
step_log["image_dimensions"] = (updated_img.width, updated_img.height)
logging.debug(f"(Context #{ctx_idx}) => Artifacts removed => {ctx.url}")
batch_logs.append(step_log)
logging.debug("[DEBUG] remove_object_batch => Finished.\n")
def expand_bbox(box, w, h, pad=24):
x1, y1, x2, y2 = box
expanded_box = [
max(0, x1 - pad),
max(0, y1 - pad),
min(w, x2 + pad),
min(h, y2 + pad)
]
logging.debug(f"expand_bbox => Original: {box}, Expanded: {expanded_box}")
return expanded_box
def remove_object_inplace(
pil_rgb: Image.Image,
bbox: List[int],
sam_processor,
sam_model,
lama_model_jit,
device="cuda",
debug_save_prefix=None,
dilate_mask=False,
dilate_kernel_size=15
) -> Image.Image:
logging.debug(f"remove_object_inplace => Processing bbox {bbox} on image size {pil_rgb.size}")
image_rgb = pil_rgb.convert("RGB")
inputs = sam_processor(
images=image_rgb,
input_boxes=[[[bbox[0], bbox[1], bbox[2], bbox[3]]]],
return_tensors="pt"
).to(device)
if not SAM_FULL_PRECISION and sam_model.dtype == torch.float16:
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
with torch.no_grad():
out_sam = sam_model(**inputs)
pred_masks = out_sam.pred_masks
if pred_masks.ndim == 5 and pred_masks.shape[2] == 3:
pred_masks = pred_masks[:, 0, 0, :, :]
elif pred_masks.ndim == 4 and pred_masks.shape[1] == 3:
pred_masks = pred_masks[:, 0, :, :]
if pred_masks.ndim == 3:
pred_masks = pred_masks.unsqueeze(1)
if "reshaped_input_sizes" in inputs:
t_h, t_w = inputs["reshaped_input_sizes"][0].tolist()
pred_masks = F.interpolate(
pred_masks,
size=(t_h, t_w),
mode="bilinear",
align_corners=False
)
mask_bin = (pred_masks[0, 0] > 0.5).cpu().numpy().astype(np.uint8)
if dilate_mask:
kernel = np.ones((dilate_kernel_size, dilate_kernel_size), dtype=np.uint8)
mask_bin = cv2.dilate(mask_bin, kernel, iterations=1)
logging.debug(f"remove_object_inplace => Dilated mask mean: {mask_bin.mean():.6f}")
updated_crop = inpaint_region_with_lama_multi_fallback(
image_rgb,
mask_bin,
bbox,
lama_model_jit
)
logging.debug(f"remove_object_inplace => Inpainting complete for bbox {bbox}")
return updated_crop
def inpaint_region_with_lama_multi_fallback(
image_rgb: Image.Image,
mask_bin: np.ndarray,
bbox: List[int],
lama_model_jit
) -> Image.Image:
x1, y1, x2, y2 = bbox
subregion = image_rgb.crop((x1, y1, x2, y2))
mask_sub = mask_bin[y1:y2, x1:x2].copy()
orig_w, orig_h = subregion.size
logging.debug(f"inpaint_region_with_lama_multi_fallback => Cropped region: w={orig_w}, h={orig_h}")
if orig_w < 2 or orig_h < 2:
logging.warning("Subregion too small for inpainting. Filling with white instead.")
return fill_white(image_rgb, bbox)
max_dim = max(orig_w, orig_h)
target_size = 1024
scale = 1.0
if max_dim > target_size:
scale = target_size / float(max_dim)
new_w = max(1, int(round(orig_w * scale)))
new_h = max(1, int(round(orig_h * scale)))
subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS)
mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
logging.debug(f"inpaint_region_with_lama_multi_fallback => scaled to {new_w}x{new_h} (factor={scale:.3f})")
else:
new_w, new_h = orig_w, orig_h
pad_w = (32 - (new_w % 32)) % 32
pad_h = (32 - (new_h % 32)) % 32
logging.debug(f"inpaint_region_with_lama_multi_fallback => pad_w={pad_w}, pad_h={pad_h}")
sub_tensor = (
torch.from_numpy(np.array(subregion))
.permute(2, 0, 1)
.unsqueeze(0)
.float() / 255.0
)
mask_tensor = torch.from_numpy(mask_sub.astype(np.float32)).unsqueeze(0).unsqueeze(0)
original_f_pad = F.pad
original_torch_pad = getattr(torch, "pad", None)
original_reflection = None
if hasattr(torch._C._nn, "reflection_pad2d"):
original_reflection = torch._C._nn.reflection_pad2d
def custom_f_pad(inp, pad_vals, mode="constant", value=0):
if mode == "reflect":
mode = "replicate"
return original_f_pad(inp, pad_vals, mode=mode, value=value)
def custom_torch_pad(inp, pad_vals, mode="constant", value=0):
if mode == "reflect":
mode = "replicate"
return original_torch_pad(inp, pad_vals, mode=mode, value=value)
def replicate_pad2d(*args, **kwargs):
return F.replication_pad2d(*args, **kwargs)
try:
F.pad = custom_f_pad
if original_torch_pad is not None:
torch.pad = custom_torch_pad
if original_reflection is not None:
torch._C._nn.reflection_pad2d = replicate_pad2d
sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect')
mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0)
result_tensor = None
try:
with torch.no_grad():
sub_tensor_gpu = sub_tensor_padded.to("cuda")
mask_tensor_gpu = mask_tensor_padded.to("cuda")
result_tensor = lama_model_jit.model.forward(sub_tensor_gpu, mask_tensor_gpu)
except RuntimeError as re_gpu:
logging.warning(f"TorchScript GPU inpainting failed => {re_gpu}\nAttempting CPU fallback...")
try:
result_tensor = inpaint_torchscript_cpu_fallback(sub_tensor_padded, mask_tensor_padded, lama_model_jit)
except RuntimeError as re_cpu:
logging.warning(f"TorchScript CPU fallback also failed => {re_cpu}\nFilling with white region.")
return fill_white(image_rgb, bbox)
finally:
F.pad = original_f_pad
if original_torch_pad is not None:
torch.pad = original_torch_pad
if original_reflection is not None:
torch._C._nn.reflection_pad2d = original_reflection
if result_tensor is None:
logging.warning("Result is None after fallback => filling with white region.")
return fill_white(image_rgb, bbox)
result_tensor_cropped = result_tensor[:, :, :new_h, :new_w]
out_np = (
result_tensor_cropped.squeeze(0)
.permute(1, 2, 0)
.mul(255)
.clamp(0, 255)
.byte()
.cpu()
.numpy()
)
inpainted_pil = Image.fromarray(out_np)
if scale != 1.0:
inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS)
final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255))
final_sub.paste(inpainted_pil, (0, 0))
out_img = image_rgb.copy()
out_img.paste(final_sub, (x1, y1))
logging.debug(f"inpaint_region_with_lama_multi_fallback => done for region {bbox}")
return out_img
def inpaint_torchscript_cpu_fallback(
sub_tensor_padded: torch.Tensor,
mask_tensor_padded: torch.Tensor,
lama_model_jit
) -> torch.Tensor:
orig_device = next(lama_model_jit.model.parameters()).device
lama_model_jit.model.to("cpu")
sub_cpu = sub_tensor_padded.cpu()
mask_cpu = mask_tensor_padded.cpu()
with torch.no_grad():
result_cpu = lama_model_jit.model.forward(sub_cpu, mask_cpu)
lama_model_jit.model.to(orig_device)
return result_cpu
def fill_white(image_rgb: Image.Image, bbox: List[int]) -> Image.Image:
x1, y1, x2, y2 = bbox
ret_img = image_rgb.copy()
draw = ImageDraw.Draw(ret_img)
draw.rectangle([x1, y1, x2, y2], fill=(255, 255, 255))
return ret_img
def inpaint_region_with_lama_gpu_only(
image_rgb: Image.Image,
mask_bin: np.ndarray,
bbox: List[int],
lama_model,
debug_save_prefix: Optional[str] = None
) -> Image.Image:
x1, y1, x2, y2 = bbox
subregion = image_rgb.crop((x1, y1, x2, y2))
mask_sub = mask_bin[y1:y2, x1:x2].copy()
orig_w, orig_h = subregion.size
if orig_w < 2 or orig_h < 2:
return image_rgb
target_size = 1024
scale = 1.0
max_dim = max(orig_w, orig_h)
if max_dim > target_size:
scale = target_size / float(max_dim)
new_w = max(1, int(round(orig_w * scale)))
new_h = max(1, int(round(orig_h * scale)))
subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS)
mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
else:
new_w, new_h = orig_w, orig_h
pad_w = (32 - (new_w % 32)) % 32
pad_h = (32 - (new_h % 32)) % 32
sub_np = np.array(subregion)
sub_tensor = (
torch.from_numpy(sub_np)
.permute(2, 0, 1)
.unsqueeze(0)
.float()
.to("cuda")
/ 255.0
).contiguous()
mask_tensor = (
torch.from_numpy((mask_sub > 0).astype(np.float32))
.unsqueeze(0)
.unsqueeze(0)
.float()
.to("cuda")
).contiguous()
original_F_pad = F.pad
original_torch_pad = getattr(torch, "pad", None)
def custom_F_pad(input, pad_vals, mode="constant", value=0):
if mode == "reflect":
mode = "replicate"
return original_F_pad(input, pad_vals, mode=mode, value=value)
def custom_torch_pad(input, pad_vals, mode="constant", value=0):
if mode == "reflect":
mode = "replicate"
return original_torch_pad(input, pad_vals, mode=mode, value=value)
original_reflection_pad2d = None
if hasattr(torch._C._nn, 'reflection_pad2d'):
original_reflection_pad2d = torch._C._nn.reflection_pad2d
def no_reflection_pad2d(*args, **kwargs):
return F.replication_pad2d(*args, **kwargs)
try:
F.pad = custom_F_pad
if original_torch_pad is not None:
torch.pad = custom_torch_pad
if original_reflection_pad2d is not None:
torch._C._nn.reflection_pad2d = no_reflection_pad2d
sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect')
mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0)
try:
with torch.no_grad():
result_tensor = lama_model.model.forward(sub_tensor_padded, mask_tensor_padded)
except RuntimeError as e:
result_tensor = run_lama_on_cpu_fallback(
sub_tensor_padded.cpu(),
mask_tensor_padded.cpu(),
lama_model
)
finally:
F.pad = original_F_pad
if original_torch_pad is not None:
torch.pad = original_torch_pad
if original_reflection_pad2d is not None:
torch._C._nn.reflection_pad2d = original_reflection_pad2d
result_tensor_cropped = result_tensor[:, :, :new_h, :new_w]
result_np = (
result_tensor_cropped.squeeze(0)
.permute(1, 2, 0)
.mul(255)
.clamp(0, 255)
.cpu()
.numpy()
.astype(np.uint8)
)
inpainted_pil = Image.fromarray(result_np)
if scale != 1.0:
inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS)
final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255))
final_sub.paste(inpainted_pil, (0, 0))
out_img = image_rgb.copy()
out_img.paste(final_sub, (x1, y1))
torch.cuda.empty_cache()
return out_img.convert("RGB")
def run_lama_on_cpu_fallback(
sub_tensor_padded_cpu: torch.Tensor,
mask_tensor_padded_cpu: torch.Tensor,
lama_model
) -> torch.Tensor:
with torch.no_grad():
orig_device = next(lama_model.model.parameters()).device
lama_model.model.to("cpu")
sub_t = sub_tensor_padded_cpu
mask_t = mask_tensor_padded_cpu
result = lama_model.model.forward(sub_t, mask_t)
lama_model.model.to(orig_device)
return result
# ----------------------------------------------------------------------
# END UNDER DEVELOPMENT
# ----------------------------------------------------------------------