# ---------------------------------------------------------------------- # IMPORTS # ---------------------------------------------------------------------- import os import cv2 import logging import numpy as np import torch import torch.nn.functional as F from typing import List, Optional, Tuple from PIL import Image, ImageDraw from transformers import AutoProcessor, AutoModelForMaskGeneration from simple_lama_inpainting import SimpleLama # ---------------------------------------------------------------------- # MODEL REPOSITORY IDENTIFIERS # ---------------------------------------------------------------------- SAM_REPO = "facebook/sam-vit-huge" # ---------------------------------------------------------------------- # MODEL PRECISION SETTINGS # ---------------------------------------------------------------------- SAM_FULL_PRECISION = True LAMA_FULL_PRECISION = True # ---------------------------------------------------------------------- # GLOBAL MODEL INSTANCES # ---------------------------------------------------------------------- SAM_PROCESSOR = None SAM_MODEL = None SIMPLE_LAMA = None # ---------------------------------------------------------------------- # INITIALIZE MODELS # ---------------------------------------------------------------------- def initialize_sam_and_lama(device="cuda"): global SAM_PROCESSOR, SAM_MODEL, SIMPLE_LAMA if SAM_PROCESSOR is None or SAM_MODEL is None or SIMPLE_LAMA is None: logging.info("Loading SAM model...") SAM_PROCESSOR = AutoProcessor.from_pretrained(SAM_REPO) SAM_MODEL = load_sam_model(SAM_REPO, SAM_FULL_PRECISION) logging.info("Loading LaMa inpainting model...") lama_device = "cpu" logging.info("LAMA will use CPU - this is intentional for compatibility") SIMPLE_LAMA = SimpleLama(device=lama_device) logging.info(f"Successfully loaded LAMA model on {lama_device.upper()}") def load_sam_model(repo_id: str, full_precision: bool): try: torch.cuda.empty_cache() model = AutoModelForMaskGeneration.from_pretrained( repo_id, device_map="auto", torch_dtype=torch.float32 if full_precision else torch.float16 ) if not hasattr(model, 'hf_device_map'): model = model.cuda() if not full_precision: model = model.half() model.eval() with torch.no_grad(): logging.info(f"Verifying SAM model is on CUDA") param = next(model.parameters()) if not param.is_cuda: model = model.cuda() logging.warning(f"Forced SAM model to CUDA") logging.info(f"SAM model device: {param.device}") return model except Exception as e: logging.error(f"Failed to load SAM model: {e}") raise # ---------------------------------------------------------------------- # ARTIFACT UTILITIES # ---------------------------------------------------------------------- ARTIFACTS_LIST = ["jewelry", "necklace", "bracelet", "ring", "earrings", "watch", "glasses"] # ---------------------------------------------------------------------- # UNDER DEVELOPMENT # ---------------------------------------------------------------------- @pipeline_step def remove_object_batch(contexts: List[ProcessingContext], batch_logs: List[dict]) -> None: initialize_sam_and_lama() logging.info(f"[DEBUG] remove_object_batch => Starting with {len(contexts)} contexts.") for ctx_idx, ctx in enumerate(contexts): step_log = { "function": "remove_object_batch", "image_url": getattr(ctx, "url", "unknown"), "status": None, "artifacts_found": [], "image_dimensions": None, "artifact_boxes": [] } if ctx.skip_run or ctx.skip_processing: step_log["status"] = "skipped" batch_logs.append(step_log) continue if "original" not in ctx.pil_img: logging.debug(f"(Context #{ctx_idx}) => RBC 'original' missing => {ctx.url}") step_log["status"] = "error" step_log["exception"] = "No RBC 'original' in ctx" ctx.skip_run = True batch_logs.append(step_log) continue dr = ctx.detection_result if not dr or dr.get("status") != "ok": logging.debug(f"(Context #{ctx_idx}) => No valid detection => {ctx.url}") step_log["status"] = "no_detection" batch_logs.append(step_log) continue boxes = dr.get("boxes", []) kws = dr.get("final_keywords", []) if len(boxes) != len(kws) or not boxes: logging.debug(f"(Context #{ctx_idx}) => mismatch or no boxes => {ctx.url}") step_log["status"] = "no_boxes_in_detection" batch_logs.append(step_log) continue artifact_indices = [i for i, kw_ in enumerate(kws) if kw_ in ARTIFACTS_LIST] if not artifact_indices: logging.debug(f"(Context #{ctx_idx}) => No artifacts found => {ctx.url}. Skipping flatten.") step_log["status"] = "no_artifacts_found" batch_logs.append(step_log) continue pil_rgba, orig_fname, _ = ctx.pil_img["original"] logging.debug(f"(Context #{ctx_idx}) Flattening RBC image to white background (since artifacts exist).") flattened = Image.new("RGB", pil_rgba.size, (255, 255, 255)) flattened.paste(pil_rgba.convert("RGB"), mask=pil_rgba.getchannel("A")) logging.debug(f"(Context #{ctx_idx}) Background conversion complete.") updated_img = flattened found_labels = [] for art_i in artifact_indices: box_ = boxes[art_i] kw_ = kws[art_i] step_log["artifact_boxes"].append({ "original_box": box_, "label": kw_ }) w_img, h_img = updated_img.size expanded = expand_bbox(box_, w_img, h_img, pad=24) logging.debug(f"(Context #{ctx_idx}) Artifact {art_i}: Expanded box from {box_} to {expanded}.") step_log["artifact_boxes"][-1]["expanded_box"] = expanded logging.debug(f"(Context #{ctx_idx}) Removing object in region {expanded}.") try: updated_img = remove_object_inplace( updated_img, expanded, SAM_PROCESSOR, SAM_MODEL, SIMPLE_LAMA, device="cuda", debug_save_prefix=f"dbg_ctx{ctx_idx}_artifact{art_i}", dilate_mask=True, dilate_kernel_size=40 ) logging.debug(f"(Context #{ctx_idx}) Object removal complete for artifact {art_i}.") found_labels.append(kw_) except RuntimeError as re: logging.warning(f"[WARNING] TorchScript inpainting failed for artifact {art_i}, skipping removal.\n{re}") step_log["artifact_boxes"][-1]["skipped_inpainting"] = True ctx.pil_img["original"] = [updated_img, orig_fname, None] step_log["artifacts_found"] = found_labels step_log["status"] = "artifacts_removed" step_log["image_dimensions"] = (updated_img.width, updated_img.height) logging.debug(f"(Context #{ctx_idx}) => Artifacts removed => {ctx.url}") batch_logs.append(step_log) logging.debug("[DEBUG] remove_object_batch => Finished.\n") def expand_bbox(box, w, h, pad=24): x1, y1, x2, y2 = box expanded_box = [ max(0, x1 - pad), max(0, y1 - pad), min(w, x2 + pad), min(h, y2 + pad) ] logging.debug(f"expand_bbox => Original: {box}, Expanded: {expanded_box}") return expanded_box def remove_object_inplace( pil_rgb: Image.Image, bbox: List[int], sam_processor, sam_model, lama_model_jit, device="cuda", debug_save_prefix=None, dilate_mask=False, dilate_kernel_size=15 ) -> Image.Image: logging.debug(f"remove_object_inplace => Processing bbox {bbox} on image size {pil_rgb.size}") image_rgb = pil_rgb.convert("RGB") inputs = sam_processor( images=image_rgb, input_boxes=[[[bbox[0], bbox[1], bbox[2], bbox[3]]]], return_tensors="pt" ).to(device) if not SAM_FULL_PRECISION and sam_model.dtype == torch.float16: inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} with torch.no_grad(): out_sam = sam_model(**inputs) pred_masks = out_sam.pred_masks if pred_masks.ndim == 5 and pred_masks.shape[2] == 3: pred_masks = pred_masks[:, 0, 0, :, :] elif pred_masks.ndim == 4 and pred_masks.shape[1] == 3: pred_masks = pred_masks[:, 0, :, :] if pred_masks.ndim == 3: pred_masks = pred_masks.unsqueeze(1) if "reshaped_input_sizes" in inputs: t_h, t_w = inputs["reshaped_input_sizes"][0].tolist() pred_masks = F.interpolate( pred_masks, size=(t_h, t_w), mode="bilinear", align_corners=False ) mask_bin = (pred_masks[0, 0] > 0.5).cpu().numpy().astype(np.uint8) if dilate_mask: kernel = np.ones((dilate_kernel_size, dilate_kernel_size), dtype=np.uint8) mask_bin = cv2.dilate(mask_bin, kernel, iterations=1) logging.debug(f"remove_object_inplace => Dilated mask mean: {mask_bin.mean():.6f}") updated_crop = inpaint_region_with_lama_multi_fallback( image_rgb, mask_bin, bbox, lama_model_jit ) logging.debug(f"remove_object_inplace => Inpainting complete for bbox {bbox}") return updated_crop def inpaint_region_with_lama_multi_fallback( image_rgb: Image.Image, mask_bin: np.ndarray, bbox: List[int], lama_model_jit ) -> Image.Image: x1, y1, x2, y2 = bbox subregion = image_rgb.crop((x1, y1, x2, y2)) mask_sub = mask_bin[y1:y2, x1:x2].copy() orig_w, orig_h = subregion.size logging.debug(f"inpaint_region_with_lama_multi_fallback => Cropped region: w={orig_w}, h={orig_h}") if orig_w < 2 or orig_h < 2: logging.warning("Subregion too small for inpainting. Filling with white instead.") return fill_white(image_rgb, bbox) max_dim = max(orig_w, orig_h) target_size = 1024 scale = 1.0 if max_dim > target_size: scale = target_size / float(max_dim) new_w = max(1, int(round(orig_w * scale))) new_h = max(1, int(round(orig_h * scale))) subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS) mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST) logging.debug(f"inpaint_region_with_lama_multi_fallback => scaled to {new_w}x{new_h} (factor={scale:.3f})") else: new_w, new_h = orig_w, orig_h pad_w = (32 - (new_w % 32)) % 32 pad_h = (32 - (new_h % 32)) % 32 logging.debug(f"inpaint_region_with_lama_multi_fallback => pad_w={pad_w}, pad_h={pad_h}") sub_tensor = ( torch.from_numpy(np.array(subregion)) .permute(2, 0, 1) .unsqueeze(0) .float() / 255.0 ) mask_tensor = torch.from_numpy(mask_sub.astype(np.float32)).unsqueeze(0).unsqueeze(0) original_f_pad = F.pad original_torch_pad = getattr(torch, "pad", None) original_reflection = None if hasattr(torch._C._nn, "reflection_pad2d"): original_reflection = torch._C._nn.reflection_pad2d def custom_f_pad(inp, pad_vals, mode="constant", value=0): if mode == "reflect": mode = "replicate" return original_f_pad(inp, pad_vals, mode=mode, value=value) def custom_torch_pad(inp, pad_vals, mode="constant", value=0): if mode == "reflect": mode = "replicate" return original_torch_pad(inp, pad_vals, mode=mode, value=value) def replicate_pad2d(*args, **kwargs): return F.replication_pad2d(*args, **kwargs) try: F.pad = custom_f_pad if original_torch_pad is not None: torch.pad = custom_torch_pad if original_reflection is not None: torch._C._nn.reflection_pad2d = replicate_pad2d sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect') mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) result_tensor = None try: with torch.no_grad(): sub_tensor_gpu = sub_tensor_padded.to("cuda") mask_tensor_gpu = mask_tensor_padded.to("cuda") result_tensor = lama_model_jit.model.forward(sub_tensor_gpu, mask_tensor_gpu) except RuntimeError as re_gpu: logging.warning(f"TorchScript GPU inpainting failed => {re_gpu}\nAttempting CPU fallback...") try: result_tensor = inpaint_torchscript_cpu_fallback(sub_tensor_padded, mask_tensor_padded, lama_model_jit) except RuntimeError as re_cpu: logging.warning(f"TorchScript CPU fallback also failed => {re_cpu}\nFilling with white region.") return fill_white(image_rgb, bbox) finally: F.pad = original_f_pad if original_torch_pad is not None: torch.pad = original_torch_pad if original_reflection is not None: torch._C._nn.reflection_pad2d = original_reflection if result_tensor is None: logging.warning("Result is None after fallback => filling with white region.") return fill_white(image_rgb, bbox) result_tensor_cropped = result_tensor[:, :, :new_h, :new_w] out_np = ( result_tensor_cropped.squeeze(0) .permute(1, 2, 0) .mul(255) .clamp(0, 255) .byte() .cpu() .numpy() ) inpainted_pil = Image.fromarray(out_np) if scale != 1.0: inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS) final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255)) final_sub.paste(inpainted_pil, (0, 0)) out_img = image_rgb.copy() out_img.paste(final_sub, (x1, y1)) logging.debug(f"inpaint_region_with_lama_multi_fallback => done for region {bbox}") return out_img def inpaint_torchscript_cpu_fallback( sub_tensor_padded: torch.Tensor, mask_tensor_padded: torch.Tensor, lama_model_jit ) -> torch.Tensor: orig_device = next(lama_model_jit.model.parameters()).device lama_model_jit.model.to("cpu") sub_cpu = sub_tensor_padded.cpu() mask_cpu = mask_tensor_padded.cpu() with torch.no_grad(): result_cpu = lama_model_jit.model.forward(sub_cpu, mask_cpu) lama_model_jit.model.to(orig_device) return result_cpu def fill_white(image_rgb: Image.Image, bbox: List[int]) -> Image.Image: x1, y1, x2, y2 = bbox ret_img = image_rgb.copy() draw = ImageDraw.Draw(ret_img) draw.rectangle([x1, y1, x2, y2], fill=(255, 255, 255)) return ret_img def inpaint_region_with_lama_gpu_only( image_rgb: Image.Image, mask_bin: np.ndarray, bbox: List[int], lama_model, debug_save_prefix: Optional[str] = None ) -> Image.Image: x1, y1, x2, y2 = bbox subregion = image_rgb.crop((x1, y1, x2, y2)) mask_sub = mask_bin[y1:y2, x1:x2].copy() orig_w, orig_h = subregion.size if orig_w < 2 or orig_h < 2: return image_rgb target_size = 1024 scale = 1.0 max_dim = max(orig_w, orig_h) if max_dim > target_size: scale = target_size / float(max_dim) new_w = max(1, int(round(orig_w * scale))) new_h = max(1, int(round(orig_h * scale))) subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS) mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST) else: new_w, new_h = orig_w, orig_h pad_w = (32 - (new_w % 32)) % 32 pad_h = (32 - (new_h % 32)) % 32 sub_np = np.array(subregion) sub_tensor = ( torch.from_numpy(sub_np) .permute(2, 0, 1) .unsqueeze(0) .float() .to("cuda") / 255.0 ).contiguous() mask_tensor = ( torch.from_numpy((mask_sub > 0).astype(np.float32)) .unsqueeze(0) .unsqueeze(0) .float() .to("cuda") ).contiguous() original_F_pad = F.pad original_torch_pad = getattr(torch, "pad", None) def custom_F_pad(input, pad_vals, mode="constant", value=0): if mode == "reflect": mode = "replicate" return original_F_pad(input, pad_vals, mode=mode, value=value) def custom_torch_pad(input, pad_vals, mode="constant", value=0): if mode == "reflect": mode = "replicate" return original_torch_pad(input, pad_vals, mode=mode, value=value) original_reflection_pad2d = None if hasattr(torch._C._nn, 'reflection_pad2d'): original_reflection_pad2d = torch._C._nn.reflection_pad2d def no_reflection_pad2d(*args, **kwargs): return F.replication_pad2d(*args, **kwargs) try: F.pad = custom_F_pad if original_torch_pad is not None: torch.pad = custom_torch_pad if original_reflection_pad2d is not None: torch._C._nn.reflection_pad2d = no_reflection_pad2d sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect') mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) try: with torch.no_grad(): result_tensor = lama_model.model.forward(sub_tensor_padded, mask_tensor_padded) except RuntimeError as e: result_tensor = run_lama_on_cpu_fallback( sub_tensor_padded.cpu(), mask_tensor_padded.cpu(), lama_model ) finally: F.pad = original_F_pad if original_torch_pad is not None: torch.pad = original_torch_pad if original_reflection_pad2d is not None: torch._C._nn.reflection_pad2d = original_reflection_pad2d result_tensor_cropped = result_tensor[:, :, :new_h, :new_w] result_np = ( result_tensor_cropped.squeeze(0) .permute(1, 2, 0) .mul(255) .clamp(0, 255) .cpu() .numpy() .astype(np.uint8) ) inpainted_pil = Image.fromarray(result_np) if scale != 1.0: inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS) final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255)) final_sub.paste(inpainted_pil, (0, 0)) out_img = image_rgb.copy() out_img.paste(final_sub, (x1, y1)) torch.cuda.empty_cache() return out_img.convert("RGB") def run_lama_on_cpu_fallback( sub_tensor_padded_cpu: torch.Tensor, mask_tensor_padded_cpu: torch.Tensor, lama_model ) -> torch.Tensor: with torch.no_grad(): orig_device = next(lama_model.model.parameters()).device lama_model.model.to("cpu") sub_t = sub_tensor_padded_cpu mask_t = mask_tensor_padded_cpu result = lama_model.model.forward(sub_t, mask_t) lama_model.model.to(orig_device) return result # ---------------------------------------------------------------------- # END UNDER DEVELOPMENT # ----------------------------------------------------------------------