Spaces:
Running
on
Zero
Running
on
Zero
# ---------------------------------------------------------------------- | |
# IMPORTS | |
# ---------------------------------------------------------------------- | |
import os | |
import cv2 | |
import logging | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from typing import List, Optional, Tuple | |
from PIL import Image, ImageDraw | |
from transformers import AutoProcessor, AutoModelForMaskGeneration | |
from simple_lama_inpainting import SimpleLama | |
# ---------------------------------------------------------------------- | |
# MODEL REPOSITORY IDENTIFIERS | |
# ---------------------------------------------------------------------- | |
SAM_REPO = "facebook/sam-vit-huge" | |
# ---------------------------------------------------------------------- | |
# MODEL PRECISION SETTINGS | |
# ---------------------------------------------------------------------- | |
SAM_FULL_PRECISION = True | |
LAMA_FULL_PRECISION = True | |
# ---------------------------------------------------------------------- | |
# GLOBAL MODEL INSTANCES | |
# ---------------------------------------------------------------------- | |
SAM_PROCESSOR = None | |
SAM_MODEL = None | |
SIMPLE_LAMA = None | |
# ---------------------------------------------------------------------- | |
# INITIALIZE MODELS | |
# ---------------------------------------------------------------------- | |
def initialize_sam_and_lama(device="cuda"): | |
global SAM_PROCESSOR, SAM_MODEL, SIMPLE_LAMA | |
if SAM_PROCESSOR is None or SAM_MODEL is None or SIMPLE_LAMA is None: | |
logging.info("Loading SAM model...") | |
SAM_PROCESSOR = AutoProcessor.from_pretrained(SAM_REPO) | |
SAM_MODEL = load_sam_model(SAM_REPO, SAM_FULL_PRECISION) | |
logging.info("Loading LaMa inpainting model...") | |
lama_device = "cpu" | |
logging.info("LAMA will use CPU - this is intentional for compatibility") | |
SIMPLE_LAMA = SimpleLama(device=lama_device) | |
logging.info(f"Successfully loaded LAMA model on {lama_device.upper()}") | |
def load_sam_model(repo_id: str, full_precision: bool): | |
try: | |
torch.cuda.empty_cache() | |
model = AutoModelForMaskGeneration.from_pretrained( | |
repo_id, | |
device_map="auto", | |
torch_dtype=torch.float32 if full_precision else torch.float16 | |
) | |
if not hasattr(model, 'hf_device_map'): | |
model = model.cuda() | |
if not full_precision: | |
model = model.half() | |
model.eval() | |
with torch.no_grad(): | |
logging.info(f"Verifying SAM model is on CUDA") | |
param = next(model.parameters()) | |
if not param.is_cuda: | |
model = model.cuda() | |
logging.warning(f"Forced SAM model to CUDA") | |
logging.info(f"SAM model device: {param.device}") | |
return model | |
except Exception as e: | |
logging.error(f"Failed to load SAM model: {e}") | |
raise | |
# ---------------------------------------------------------------------- | |
# ARTIFACT UTILITIES | |
# ---------------------------------------------------------------------- | |
ARTIFACTS_LIST = ["jewelry", "necklace", "bracelet", "ring", "earrings", "watch", "glasses"] | |
# ---------------------------------------------------------------------- | |
# UNDER DEVELOPMENT | |
# ---------------------------------------------------------------------- | |
def remove_object_batch(contexts: List[ProcessingContext], batch_logs: List[dict]) -> None: | |
initialize_sam_and_lama() | |
logging.info(f"[DEBUG] remove_object_batch => Starting with {len(contexts)} contexts.") | |
for ctx_idx, ctx in enumerate(contexts): | |
step_log = { | |
"function": "remove_object_batch", | |
"image_url": getattr(ctx, "url", "unknown"), | |
"status": None, | |
"artifacts_found": [], | |
"image_dimensions": None, | |
"artifact_boxes": [] | |
} | |
if ctx.skip_run or ctx.skip_processing: | |
step_log["status"] = "skipped" | |
batch_logs.append(step_log) | |
continue | |
if "original" not in ctx.pil_img: | |
logging.debug(f"(Context #{ctx_idx}) => RBC 'original' missing => {ctx.url}") | |
step_log["status"] = "error" | |
step_log["exception"] = "No RBC 'original' in ctx" | |
ctx.skip_run = True | |
batch_logs.append(step_log) | |
continue | |
dr = ctx.detection_result | |
if not dr or dr.get("status") != "ok": | |
logging.debug(f"(Context #{ctx_idx}) => No valid detection => {ctx.url}") | |
step_log["status"] = "no_detection" | |
batch_logs.append(step_log) | |
continue | |
boxes = dr.get("boxes", []) | |
kws = dr.get("final_keywords", []) | |
if len(boxes) != len(kws) or not boxes: | |
logging.debug(f"(Context #{ctx_idx}) => mismatch or no boxes => {ctx.url}") | |
step_log["status"] = "no_boxes_in_detection" | |
batch_logs.append(step_log) | |
continue | |
artifact_indices = [i for i, kw_ in enumerate(kws) if kw_ in ARTIFACTS_LIST] | |
if not artifact_indices: | |
logging.debug(f"(Context #{ctx_idx}) => No artifacts found => {ctx.url}. Skipping flatten.") | |
step_log["status"] = "no_artifacts_found" | |
batch_logs.append(step_log) | |
continue | |
pil_rgba, orig_fname, _ = ctx.pil_img["original"] | |
logging.debug(f"(Context #{ctx_idx}) Flattening RBC image to white background (since artifacts exist).") | |
flattened = Image.new("RGB", pil_rgba.size, (255, 255, 255)) | |
flattened.paste(pil_rgba.convert("RGB"), mask=pil_rgba.getchannel("A")) | |
logging.debug(f"(Context #{ctx_idx}) Background conversion complete.") | |
updated_img = flattened | |
found_labels = [] | |
for art_i in artifact_indices: | |
box_ = boxes[art_i] | |
kw_ = kws[art_i] | |
step_log["artifact_boxes"].append({ | |
"original_box": box_, | |
"label": kw_ | |
}) | |
w_img, h_img = updated_img.size | |
expanded = expand_bbox(box_, w_img, h_img, pad=24) | |
logging.debug(f"(Context #{ctx_idx}) Artifact {art_i}: Expanded box from {box_} to {expanded}.") | |
step_log["artifact_boxes"][-1]["expanded_box"] = expanded | |
logging.debug(f"(Context #{ctx_idx}) Removing object in region {expanded}.") | |
try: | |
updated_img = remove_object_inplace( | |
updated_img, | |
expanded, | |
SAM_PROCESSOR, | |
SAM_MODEL, | |
SIMPLE_LAMA, | |
device="cuda", | |
debug_save_prefix=f"dbg_ctx{ctx_idx}_artifact{art_i}", | |
dilate_mask=True, | |
dilate_kernel_size=40 | |
) | |
logging.debug(f"(Context #{ctx_idx}) Object removal complete for artifact {art_i}.") | |
found_labels.append(kw_) | |
except RuntimeError as re: | |
logging.warning(f"[WARNING] TorchScript inpainting failed for artifact {art_i}, skipping removal.\n{re}") | |
step_log["artifact_boxes"][-1]["skipped_inpainting"] = True | |
ctx.pil_img["original"] = [updated_img, orig_fname, None] | |
step_log["artifacts_found"] = found_labels | |
step_log["status"] = "artifacts_removed" | |
step_log["image_dimensions"] = (updated_img.width, updated_img.height) | |
logging.debug(f"(Context #{ctx_idx}) => Artifacts removed => {ctx.url}") | |
batch_logs.append(step_log) | |
logging.debug("[DEBUG] remove_object_batch => Finished.\n") | |
def expand_bbox(box, w, h, pad=24): | |
x1, y1, x2, y2 = box | |
expanded_box = [ | |
max(0, x1 - pad), | |
max(0, y1 - pad), | |
min(w, x2 + pad), | |
min(h, y2 + pad) | |
] | |
logging.debug(f"expand_bbox => Original: {box}, Expanded: {expanded_box}") | |
return expanded_box | |
def remove_object_inplace( | |
pil_rgb: Image.Image, | |
bbox: List[int], | |
sam_processor, | |
sam_model, | |
lama_model_jit, | |
device="cuda", | |
debug_save_prefix=None, | |
dilate_mask=False, | |
dilate_kernel_size=15 | |
) -> Image.Image: | |
logging.debug(f"remove_object_inplace => Processing bbox {bbox} on image size {pil_rgb.size}") | |
image_rgb = pil_rgb.convert("RGB") | |
inputs = sam_processor( | |
images=image_rgb, | |
input_boxes=[[[bbox[0], bbox[1], bbox[2], bbox[3]]]], | |
return_tensors="pt" | |
).to(device) | |
if not SAM_FULL_PRECISION and sam_model.dtype == torch.float16: | |
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} | |
with torch.no_grad(): | |
out_sam = sam_model(**inputs) | |
pred_masks = out_sam.pred_masks | |
if pred_masks.ndim == 5 and pred_masks.shape[2] == 3: | |
pred_masks = pred_masks[:, 0, 0, :, :] | |
elif pred_masks.ndim == 4 and pred_masks.shape[1] == 3: | |
pred_masks = pred_masks[:, 0, :, :] | |
if pred_masks.ndim == 3: | |
pred_masks = pred_masks.unsqueeze(1) | |
if "reshaped_input_sizes" in inputs: | |
t_h, t_w = inputs["reshaped_input_sizes"][0].tolist() | |
pred_masks = F.interpolate( | |
pred_masks, | |
size=(t_h, t_w), | |
mode="bilinear", | |
align_corners=False | |
) | |
mask_bin = (pred_masks[0, 0] > 0.5).cpu().numpy().astype(np.uint8) | |
if dilate_mask: | |
kernel = np.ones((dilate_kernel_size, dilate_kernel_size), dtype=np.uint8) | |
mask_bin = cv2.dilate(mask_bin, kernel, iterations=1) | |
logging.debug(f"remove_object_inplace => Dilated mask mean: {mask_bin.mean():.6f}") | |
updated_crop = inpaint_region_with_lama_multi_fallback( | |
image_rgb, | |
mask_bin, | |
bbox, | |
lama_model_jit | |
) | |
logging.debug(f"remove_object_inplace => Inpainting complete for bbox {bbox}") | |
return updated_crop | |
def inpaint_region_with_lama_multi_fallback( | |
image_rgb: Image.Image, | |
mask_bin: np.ndarray, | |
bbox: List[int], | |
lama_model_jit | |
) -> Image.Image: | |
x1, y1, x2, y2 = bbox | |
subregion = image_rgb.crop((x1, y1, x2, y2)) | |
mask_sub = mask_bin[y1:y2, x1:x2].copy() | |
orig_w, orig_h = subregion.size | |
logging.debug(f"inpaint_region_with_lama_multi_fallback => Cropped region: w={orig_w}, h={orig_h}") | |
if orig_w < 2 or orig_h < 2: | |
logging.warning("Subregion too small for inpainting. Filling with white instead.") | |
return fill_white(image_rgb, bbox) | |
max_dim = max(orig_w, orig_h) | |
target_size = 1024 | |
scale = 1.0 | |
if max_dim > target_size: | |
scale = target_size / float(max_dim) | |
new_w = max(1, int(round(orig_w * scale))) | |
new_h = max(1, int(round(orig_h * scale))) | |
subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
logging.debug(f"inpaint_region_with_lama_multi_fallback => scaled to {new_w}x{new_h} (factor={scale:.3f})") | |
else: | |
new_w, new_h = orig_w, orig_h | |
pad_w = (32 - (new_w % 32)) % 32 | |
pad_h = (32 - (new_h % 32)) % 32 | |
logging.debug(f"inpaint_region_with_lama_multi_fallback => pad_w={pad_w}, pad_h={pad_h}") | |
sub_tensor = ( | |
torch.from_numpy(np.array(subregion)) | |
.permute(2, 0, 1) | |
.unsqueeze(0) | |
.float() / 255.0 | |
) | |
mask_tensor = torch.from_numpy(mask_sub.astype(np.float32)).unsqueeze(0).unsqueeze(0) | |
original_f_pad = F.pad | |
original_torch_pad = getattr(torch, "pad", None) | |
original_reflection = None | |
if hasattr(torch._C._nn, "reflection_pad2d"): | |
original_reflection = torch._C._nn.reflection_pad2d | |
def custom_f_pad(inp, pad_vals, mode="constant", value=0): | |
if mode == "reflect": | |
mode = "replicate" | |
return original_f_pad(inp, pad_vals, mode=mode, value=value) | |
def custom_torch_pad(inp, pad_vals, mode="constant", value=0): | |
if mode == "reflect": | |
mode = "replicate" | |
return original_torch_pad(inp, pad_vals, mode=mode, value=value) | |
def replicate_pad2d(*args, **kwargs): | |
return F.replication_pad2d(*args, **kwargs) | |
try: | |
F.pad = custom_f_pad | |
if original_torch_pad is not None: | |
torch.pad = custom_torch_pad | |
if original_reflection is not None: | |
torch._C._nn.reflection_pad2d = replicate_pad2d | |
sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect') | |
mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) | |
result_tensor = None | |
try: | |
with torch.no_grad(): | |
sub_tensor_gpu = sub_tensor_padded.to("cuda") | |
mask_tensor_gpu = mask_tensor_padded.to("cuda") | |
result_tensor = lama_model_jit.model.forward(sub_tensor_gpu, mask_tensor_gpu) | |
except RuntimeError as re_gpu: | |
logging.warning(f"TorchScript GPU inpainting failed => {re_gpu}\nAttempting CPU fallback...") | |
try: | |
result_tensor = inpaint_torchscript_cpu_fallback(sub_tensor_padded, mask_tensor_padded, lama_model_jit) | |
except RuntimeError as re_cpu: | |
logging.warning(f"TorchScript CPU fallback also failed => {re_cpu}\nFilling with white region.") | |
return fill_white(image_rgb, bbox) | |
finally: | |
F.pad = original_f_pad | |
if original_torch_pad is not None: | |
torch.pad = original_torch_pad | |
if original_reflection is not None: | |
torch._C._nn.reflection_pad2d = original_reflection | |
if result_tensor is None: | |
logging.warning("Result is None after fallback => filling with white region.") | |
return fill_white(image_rgb, bbox) | |
result_tensor_cropped = result_tensor[:, :, :new_h, :new_w] | |
out_np = ( | |
result_tensor_cropped.squeeze(0) | |
.permute(1, 2, 0) | |
.mul(255) | |
.clamp(0, 255) | |
.byte() | |
.cpu() | |
.numpy() | |
) | |
inpainted_pil = Image.fromarray(out_np) | |
if scale != 1.0: | |
inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS) | |
final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255)) | |
final_sub.paste(inpainted_pil, (0, 0)) | |
out_img = image_rgb.copy() | |
out_img.paste(final_sub, (x1, y1)) | |
logging.debug(f"inpaint_region_with_lama_multi_fallback => done for region {bbox}") | |
return out_img | |
def inpaint_torchscript_cpu_fallback( | |
sub_tensor_padded: torch.Tensor, | |
mask_tensor_padded: torch.Tensor, | |
lama_model_jit | |
) -> torch.Tensor: | |
orig_device = next(lama_model_jit.model.parameters()).device | |
lama_model_jit.model.to("cpu") | |
sub_cpu = sub_tensor_padded.cpu() | |
mask_cpu = mask_tensor_padded.cpu() | |
with torch.no_grad(): | |
result_cpu = lama_model_jit.model.forward(sub_cpu, mask_cpu) | |
lama_model_jit.model.to(orig_device) | |
return result_cpu | |
def fill_white(image_rgb: Image.Image, bbox: List[int]) -> Image.Image: | |
x1, y1, x2, y2 = bbox | |
ret_img = image_rgb.copy() | |
draw = ImageDraw.Draw(ret_img) | |
draw.rectangle([x1, y1, x2, y2], fill=(255, 255, 255)) | |
return ret_img | |
def inpaint_region_with_lama_gpu_only( | |
image_rgb: Image.Image, | |
mask_bin: np.ndarray, | |
bbox: List[int], | |
lama_model, | |
debug_save_prefix: Optional[str] = None | |
) -> Image.Image: | |
x1, y1, x2, y2 = bbox | |
subregion = image_rgb.crop((x1, y1, x2, y2)) | |
mask_sub = mask_bin[y1:y2, x1:x2].copy() | |
orig_w, orig_h = subregion.size | |
if orig_w < 2 or orig_h < 2: | |
return image_rgb | |
target_size = 1024 | |
scale = 1.0 | |
max_dim = max(orig_w, orig_h) | |
if max_dim > target_size: | |
scale = target_size / float(max_dim) | |
new_w = max(1, int(round(orig_w * scale))) | |
new_h = max(1, int(round(orig_h * scale))) | |
subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
else: | |
new_w, new_h = orig_w, orig_h | |
pad_w = (32 - (new_w % 32)) % 32 | |
pad_h = (32 - (new_h % 32)) % 32 | |
sub_np = np.array(subregion) | |
sub_tensor = ( | |
torch.from_numpy(sub_np) | |
.permute(2, 0, 1) | |
.unsqueeze(0) | |
.float() | |
.to("cuda") | |
/ 255.0 | |
).contiguous() | |
mask_tensor = ( | |
torch.from_numpy((mask_sub > 0).astype(np.float32)) | |
.unsqueeze(0) | |
.unsqueeze(0) | |
.float() | |
.to("cuda") | |
).contiguous() | |
original_F_pad = F.pad | |
original_torch_pad = getattr(torch, "pad", None) | |
def custom_F_pad(input, pad_vals, mode="constant", value=0): | |
if mode == "reflect": | |
mode = "replicate" | |
return original_F_pad(input, pad_vals, mode=mode, value=value) | |
def custom_torch_pad(input, pad_vals, mode="constant", value=0): | |
if mode == "reflect": | |
mode = "replicate" | |
return original_torch_pad(input, pad_vals, mode=mode, value=value) | |
original_reflection_pad2d = None | |
if hasattr(torch._C._nn, 'reflection_pad2d'): | |
original_reflection_pad2d = torch._C._nn.reflection_pad2d | |
def no_reflection_pad2d(*args, **kwargs): | |
return F.replication_pad2d(*args, **kwargs) | |
try: | |
F.pad = custom_F_pad | |
if original_torch_pad is not None: | |
torch.pad = custom_torch_pad | |
if original_reflection_pad2d is not None: | |
torch._C._nn.reflection_pad2d = no_reflection_pad2d | |
sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect') | |
mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) | |
try: | |
with torch.no_grad(): | |
result_tensor = lama_model.model.forward(sub_tensor_padded, mask_tensor_padded) | |
except RuntimeError as e: | |
result_tensor = run_lama_on_cpu_fallback( | |
sub_tensor_padded.cpu(), | |
mask_tensor_padded.cpu(), | |
lama_model | |
) | |
finally: | |
F.pad = original_F_pad | |
if original_torch_pad is not None: | |
torch.pad = original_torch_pad | |
if original_reflection_pad2d is not None: | |
torch._C._nn.reflection_pad2d = original_reflection_pad2d | |
result_tensor_cropped = result_tensor[:, :, :new_h, :new_w] | |
result_np = ( | |
result_tensor_cropped.squeeze(0) | |
.permute(1, 2, 0) | |
.mul(255) | |
.clamp(0, 255) | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
inpainted_pil = Image.fromarray(result_np) | |
if scale != 1.0: | |
inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS) | |
final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255)) | |
final_sub.paste(inpainted_pil, (0, 0)) | |
out_img = image_rgb.copy() | |
out_img.paste(final_sub, (x1, y1)) | |
torch.cuda.empty_cache() | |
return out_img.convert("RGB") | |
def run_lama_on_cpu_fallback( | |
sub_tensor_padded_cpu: torch.Tensor, | |
mask_tensor_padded_cpu: torch.Tensor, | |
lama_model | |
) -> torch.Tensor: | |
with torch.no_grad(): | |
orig_device = next(lama_model.model.parameters()).device | |
lama_model.model.to("cpu") | |
sub_t = sub_tensor_padded_cpu | |
mask_t = mask_tensor_padded_cpu | |
result = lama_model.model.forward(sub_t, mask_t) | |
lama_model.model.to(orig_device) | |
return result | |
# ---------------------------------------------------------------------- | |
# END UNDER DEVELOPMENT | |
# ---------------------------------------------------------------------- | |