GitHub Actions
Deploy to Hugging Face Space: product-image-update-port-10
18faf97
# ----------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------
import io
import os
import time
import numpy as np
import torch
import cv2
import logging
import sys
import traceback
from PIL import Image, ImageOps, ImageEnhance
from torchvision import transforms
from scipy.ndimage import label as scipy_label, find_objects as scipy_find_objects
from typing import List
from src.utils import ProcessingContext, create_pipeline_step, LOG_LEVEL_MAP, EMOJI_MAP
# ----------------------------------------------------------------------
# GLOBAL CONSTANTS
# ----------------------------------------------------------------------
RBC_CONTRAST_FACTOR = 1.25
RBC_SHARPNESS_FACTOR = 1.15
PAD_COLOR = "#ffffff"
THRESH = 0.42
RESCUE_THRESH = 0.20
MAX_IMAGES_PER_BATCH = 4
MORPH_KERNEL_SIZE = (3, 3)
MORPH_CLOSE_ITER = 1
MORPH_OPEN_ITER = 1
EROSION_ITER = 1
GAUSSIAN_KERNEL_SIZE = (7, 7)
DO_GUIDED_FILTER = True
FILL_HOLES = False
USE_BILATERAL = True
CALIBRATION_VERSION = "v16_balanced"
BLUE_BACKGROUND_HSV_LOWER = np.array([100, 100, 80])
BLUE_BACKGROUND_HSV_UPPER = np.array([130, 255, 255])
BLUE_BACKGROUND_THRESHOLD = 0.25
SKIN_HSV_LOWER_1 = np.array([0, 20, 70])
SKIN_HSV_UPPER_1 = np.array([30, 180, 255])
SKIN_HSV_LOWER_2 = np.array([0, 10, 40])
SKIN_HSV_UPPER_2 = np.array([25, 150, 200])
SKIN_THRESHOLD = 0.10
DENIM_HSV_LOWER_1 = np.array([80, 30, 30])
DENIM_HSV_UPPER_1 = np.array([130, 220, 230])
DENIM_HSV_LOWER_2 = np.array([85, 15, 100])
DENIM_HSV_UPPER_2 = np.array([115, 130, 230])
DENIM_THRESHOLD = 0.12
OUTDOOR_GREEN_HSV_LOWER = np.array([35, 40, 40])
OUTDOOR_GREEN_HSV_UPPER = np.array([85, 255, 255])
OUTDOOR_SKY_HSV_LOWER = np.array([90, 40, 150])
OUTDOOR_SKY_HSV_UPPER = np.array([130, 120, 255])
OUTDOOR_THRESHOLD = 0.15
DIFFICULTY_CONTRAST_THRESHOLD = 50
DIFFICULTY_EDGE_THRESHOLD = 0.05
EDGE_ENHANCE_WEIGHT_ORIG = 0.7
EDGE_ENHANCE_WEIGHT_SHARP = 0.3
GUIDED_FILTER_RADIUS = 4
GUIDED_FILTER_EPSILON = 0.01
BILATERAL_FILTER_D = 9
BILATERAL_FILTER_SIGMA_COLOR = 75
BILATERAL_FILTER_SIGMA_SPACE = 75
CANNY_THRESHOLD_LOW = 100
CANNY_THRESHOLD_HIGH = 200
TARGET_SIZE_SCALING_FACTOR = 0.98
ASPECT_RATIO_THRESHOLD = 4/3
PADDING_SCALE_FACTOR = 0.98
OUTDOOR_CONTRAST_MULTIPLIER = 1.1
PERSON_CONTRAST_MULTIPLIER = 0.75
LOW_LOAD_SURFACE_THRESHOLD = 12_000_000
MEDIUM_LOAD_SURFACE_THRESHOLD = 25_000_000
HIGH_LOAD_SURFACE_THRESHOLD = 45_000_000
# ----------------------------------------------------------------------
# DYNAMIC SIZING CONFIGURATION
# ----------------------------------------------------------------------
SIZE_CONFIGS = {
"small": {
"final_side": 1024,
"rbc_scales": [1.25, 1.0]
},
"medium": {
"final_side": 1536,
"rbc_scales": [1.0, 0.75]
},
"large": {
"final_side": 2048,
"rbc_scales": [1.0]
}
}
ADAPTIVE_SCALE_CONFIGS = {
"small_fast": {"final_side": 1024, "rbc_scales": [1.0]},
"small_minimal": {"final_side": 1024, "rbc_scales": [1.0]},
"medium_fast": {"final_side": 1536, "rbc_scales": [1.0]},
"medium_minimal": {"final_side": 1536, "rbc_scales": [1.0]},
"large_fast": {"final_side": 2048, "rbc_scales": [1.0]},
"large_minimal": {"final_side": 2048, "rbc_scales": [1.0]}
}
SIZE_CONFIG_SCALE = [1024, 1536]
# ----------------------------------------------------------------------
# PRE-COMPUTED OBJECTS
# ----------------------------------------------------------------------
morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, MORPH_KERNEL_SIZE)
rmbg_trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
edge_enhance_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
# GPU-accelerated normalization if available
try:
import torchvision.transforms.v2 as v2_transforms
USE_V2_TRANSFORMS = True
except ImportError:
USE_V2_TRANSFORMS = False
# ----------------------------------------------------------------------
# UTILS
# ----------------------------------------------------------------------
def label(mask):
return scipy_label(mask)
def find_objects(lbl):
return scipy_find_objects(lbl)
def guided_filter(I, p, r, eps):
mean_I = cv2.boxFilter(I, cv2.CV_64F, (r, r))
mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r))
mean_Ip = cv2.boxFilter(I * p, cv2.CV_64F, (r, r))
cov_Ip = mean_Ip - mean_I * mean_p
mean_II = cv2.boxFilter(I * I, cv2.CV_64F, (r, r))
var_I = mean_II - mean_I * mean_I
a = cov_Ip / (var_I + eps)
b = mean_p - a * mean_I
mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r))
mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r))
return mean_a * I + mean_b
def detect_blue_background(hsv):
blue_mask = cv2.inRange(hsv, BLUE_BACKGROUND_HSV_LOWER, BLUE_BACKGROUND_HSV_UPPER)
return float(np.mean(blue_mask > 0)) > BLUE_BACKGROUND_THRESHOLD
def detect_skin_tones(hsv):
m1 = cv2.inRange(hsv, SKIN_HSV_LOWER_1, SKIN_HSV_UPPER_1)
m2 = cv2.inRange(hsv, SKIN_HSV_LOWER_2, SKIN_HSV_UPPER_2)
return float(np.mean(cv2.bitwise_or(m1, m2) > 0)) > SKIN_THRESHOLD
def detect_denim(hsv):
m1 = cv2.inRange(hsv, DENIM_HSV_LOWER_1, DENIM_HSV_UPPER_1)
m2 = cv2.inRange(hsv, DENIM_HSV_LOWER_2, DENIM_HSV_UPPER_2)
return float(np.mean(cv2.bitwise_or(m1, m2) > 0)) > DENIM_THRESHOLD
def detect_outdoor_scene(hsv):
g = cv2.inRange(hsv, OUTDOOR_GREEN_HSV_LOWER, OUTDOOR_GREEN_HSV_UPPER)
s = cv2.inRange(hsv, OUTDOOR_SKY_HSV_LOWER, OUTDOOR_SKY_HSV_UPPER)
return float(np.mean(cv2.bitwise_or(g, s) > 0)) > OUTDOOR_THRESHOLD
def analyze_image_difficulty(gray):
hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten()
hist /= hist.sum()
contrast = float(np.std(hist * np.arange(256)))
edge_ratio = float(np.mean(cv2.Canny(gray, CANNY_THRESHOLD_LOW, CANNY_THRESHOLD_HIGH) > 0))
return bool(contrast < DIFFICULTY_CONTRAST_THRESHOLD and edge_ratio < DIFFICULTY_EDGE_THRESHOLD), contrast, edge_ratio
def enhance_edges(img_rgb):
sharp = cv2.filter2D(img_rgb, -1, edge_enhance_kernel)
return cv2.addWeighted(img_rgb, EDGE_ENHANCE_WEIGHT_ORIG, sharp, EDGE_ENHANCE_WEIGHT_SHARP, 0)
def smooth_mask(mask):
if DO_GUIDED_FILTER:
f = guided_filter(mask.astype(np.float64) / 255.0,
mask.astype(np.float64) / 255.0, GUIDED_FILTER_RADIUS, GUIDED_FILTER_EPSILON)
mask = (f * 255).clip(0, 255).astype(np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, morph_kernel, iterations=MORPH_CLOSE_ITER)
if MORPH_OPEN_ITER:
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, morph_kernel, iterations=MORPH_OPEN_ITER)
if EROSION_ITER:
mask = cv2.erode(mask, morph_kernel, iterations=EROSION_ITER)
if FILL_HOLES:
contours, _ = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
cv2.drawContours(mask, [cnt], -1, 255, thickness=cv2.FILLED)
if USE_BILATERAL:
mask = cv2.bilateralFilter(mask, BILATERAL_FILTER_D, BILATERAL_FILTER_SIGMA_COLOR, BILATERAL_FILTER_SIGMA_SPACE)
return cv2.GaussianBlur(mask, GAUSSIAN_KERNEL_SIZE, 0)
def calculate_total_surface_area(contexts):
total_surface = 0
for ctx in contexts:
if ctx.skip_run or ctx.skip_processing or not hasattr(ctx, '_download_content'):
continue
try:
img = Image.open(io.BytesIO(ctx._download_content))
w, h = img.size
total_surface += w * h
except:
continue
return total_surface
def determine_adaptive_size_config(width, height, total_surface_area):
max_dim = max(width, height)
if max_dim < SIZE_CONFIG_SCALE[0]:
base_config = "small"
elif SIZE_CONFIG_SCALE[0] <= max_dim <= SIZE_CONFIG_SCALE[1]:
base_config = "medium"
else:
base_config = "large"
if total_surface_area >= HIGH_LOAD_SURFACE_THRESHOLD:
config_key = f"{base_config}_minimal"
elif total_surface_area >= MEDIUM_LOAD_SURFACE_THRESHOLD:
config_key = f"{base_config}_fast"
else:
config_key = base_config
if config_key in ADAPTIVE_SCALE_CONFIGS:
return config_key, ADAPTIVE_SCALE_CONFIGS[config_key]
else:
return base_config, SIZE_CONFIGS[base_config]
def determine_size_config(width, height):
max_dim = max(width, height)
if max_dim < SIZE_CONFIG_SCALE[0]:
return "small", SIZE_CONFIGS["small"]
elif SIZE_CONFIG_SCALE[0] <= max_dim <= SIZE_CONFIG_SCALE[1]:
return "medium", SIZE_CONFIGS["medium"]
else:
return "large", SIZE_CONFIGS["large"]
def calculate_optimal_scale(original_width, original_height, target_side):
max_original = max(original_width, original_height)
if max_original <= target_side:
return 1.0
return (target_side * TARGET_SIZE_SCALING_FACTOR) / max_original
def final_pad_sq(im):
w, h = im.size
if w / h > ASPECT_RATIO_THRESHOLD:
side = int(round(h * PADDING_SCALE_FACTOR))
l = (w - side) // 2
return im.crop((l, (h - side) // 2, l + side, (h + side) // 2))
side = max(w, h)
new_im = Image.new("RGBA", (side, side), (0, 0, 0, 0))
new_im.paste(im, ((side - w) // 2, (side - h) // 2))
return new_im
# ----------------------------------------------------------------------
# CORE IMPLEMENTATION
# ----------------------------------------------------------------------
def preprocess_images_batch(contexts, batch_logs):
function_name = "preprocess_images_batch"
processed_count = 0
skipped_count = 0
error_count = 0
total_surface_area = 0
t0 = time.perf_counter()
for ctx in contexts:
log_item = {
"image_url": ctx.url,
"function": function_name,
"data": {}
}
if ctx.skip_run or ctx.skip_processing:
log_item["status"] = "skipped"
log_item["data"]["reason"] = "marked_for_skip_or_no_downloaded_image"
batch_logs.append(log_item)
skipped_count += 1
continue
if not hasattr(ctx, '_download_content'):
log_item["status"] = "error"
log_item["exception"] = "No downloaded content found"
log_item["data"]["reason"] = "missing_download_content"
batch_logs.append(log_item)
ctx.skip_run = True
error_count += 1
continue
try:
process_start = time.perf_counter()
img = Image.open(io.BytesIO(ctx._download_content))
img = ImageOps.exif_transpose(img).convert("RGB")
ow, oh = img.size
ctx._orig_size = (ow, oh)
total_surface_area += ow * oh
size_config_name, size_config = determine_size_config(ow, oh)
final_side = size_config["final_side"]
rbc_scales = size_config["rbc_scales"]
scale = calculate_optimal_scale(ow, oh, final_side)
if scale < 1:
img = img.resize((int(ow * scale), int(oh * scale)), Image.Resampling.LANCZOS)
ow, oh = img.size
pad = Image.new("RGB", (final_side, final_side), PAD_COLOR)
pad.paste(img, ((final_side - ow) // 2, (final_side - oh) // 2))
hsv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2HSV)
gray = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
has_blue = detect_blue_background(hsv)
has_person = detect_skin_tones(hsv)
has_denim = detect_denim(hsv)
is_outdoor = detect_outdoor_scene(hsv)
is_diff, contrast, edge_ratio = analyze_image_difficulty(gray)
contrast_multiplier = 1.0
if is_outdoor:
contrast_multiplier = OUTDOOR_CONTRAST_MULTIPLIER
elif has_person:
contrast_multiplier = PERSON_CONTRAST_MULTIPLIER
enhanced = ImageEnhance.Contrast(pad).enhance(RBC_CONTRAST_FACTOR * contrast_multiplier)
enhanced = ImageEnhance.Sharpness(enhanced).enhance(RBC_SHARPNESS_FACTOR)
edge_enh = Image.fromarray(enhance_edges(np.array(pad)))
final_enhanced = edge_enh if is_diff else enhanced
ctx._rmbg_meta = {
"pad": pad,
"enhanced": final_enhanced,
"flags": (has_blue, has_person, has_denim, is_outdoor, is_diff),
"size_config": size_config_name,
"final_side": final_side,
"rbc_scales": rbc_scales
}
process_time = time.perf_counter() - process_start
log_item["status"] = "ok"
log_item["data"].update({
"orig_size": ctx._orig_size,
"processed_size": (final_side, final_side),
"size_config": size_config_name,
"scale_applied": round(scale, 4),
"contrast": contrast,
"edge_ratio": edge_ratio,
"processing_time": round(process_time, 4),
"rbc_scales": rbc_scales,
"flags": {
"has_blue": has_blue,
"has_person": has_person,
"has_denim": has_denim,
"is_outdoor": is_outdoor,
"is_difficult": is_diff
}
})
processed_count += 1
del ctx._download_content
except Exception as e:
log_item["status"] = "error"
log_item["exception"] = str(e)
log_item["data"]["processing_time"] = round(time.perf_counter() - process_start, 4) if 'process_start' in locals() else 0
ctx.skip_run = True
error_count += 1
batch_logs.append(log_item)
for ctx in contexts:
if hasattr(ctx, '_rmbg_meta'):
size_config_name, size_config = determine_adaptive_size_config(
ctx._orig_size[0], ctx._orig_size[1], total_surface_area
)
ctx._rmbg_meta["size_config"] = size_config_name
ctx._rmbg_meta["rbc_scales"] = size_config["rbc_scales"]
preprocess_summary = {
"function": "preprocess_summary",
"status": "info",
"data": {
"total_time": round(time.perf_counter() - t0, 4),
"processed_count": processed_count,
"skipped_count": skipped_count,
"error_count": error_count,
"total_surface_area": total_surface_area,
"surface_area_mb": round(total_surface_area / 1_000_000, 2),
"success_rate": f"{processed_count/(processed_count+error_count):.2%}" if (processed_count + error_count) > 0 else "N/A"
}
}
batch_logs.append(preprocess_summary)
return batch_logs
def process_background_removal(contexts, batch_logs):
import app
import contextlib
function_name = "process_background_removal"
from src.models import model_loader
RMBG_MODEL = model_loader.RMBG_MODEL
DEVICE = model_loader.DEVICE
MODELS_LOADED = model_loader.MODELS_LOADED
LOAD_ERROR = model_loader.LOAD_ERROR
RMBG_FULL_PRECISION = model_loader.RMBG_FULL_PRECISION
ENABLE_CUDA_GRAPHS = model_loader.ENABLE_CUDA_GRAPHS
logging.info(f"Checking model state in {function_name}:")
logging.info(f" MODELS_LOADED: {MODELS_LOADED}")
logging.info(f" RMBG_MODEL is None: {RMBG_MODEL is None}")
logging.info(f" DEVICE: {DEVICE}")
cuda_graph_cache = {}
stream = None
if DEVICE == "cuda" and torch.cuda.is_available():
try:
stream = torch.cuda.Stream()
except RuntimeError:
stream = None
torch.set_float32_matmul_precision('high')
if DEVICE == "cuda" and torch.cuda.is_available():
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.enabled = True
if hasattr(torch.backends.cuda, 'enable_math_sdp'):
torch.backends.cuda.enable_math_sdp(True)
if hasattr(torch.backends.cuda, 'enable_flash_sdp'):
torch.backends.cuda.enable_flash_sdp(True)
if hasattr(torch.backends.cuda, 'enable_mem_efficient_sdp'):
torch.backends.cuda.enable_mem_efficient_sdp(True)
if hasattr(torch.backends.cudnn, 'preferred_backend'):
torch.backends.cudnn.preferred_backend = 'cudnn'
except (RuntimeError, AttributeError) as e:
logging.debug(f"Some CUDA optimizations not available: {e}")
if not MODELS_LOADED and not os.getenv("SPACE_ID"):
error_msg = LOAD_ERROR or "Models not loaded"
error_trace = traceback.format_exc()
logging.error(f"CRITICAL: Model not loaded in {function_name}: {error_msg}")
logging.error(f"Traceback:\n{error_trace}")
for ctx in contexts:
ctx.skip_run = True
ctx.error = error_msg
ctx.error_traceback = error_trace
batch_logs.append({
"function": function_name,
"status": "critical_error",
"exception": error_msg,
"traceback": error_trace
})
logging.critical("Terminating due to model loading failure")
sys.exit(1)
if RMBG_MODEL is None:
logging.warning("RMBG model not available - skipping background removal")
for ctx in contexts:
if hasattr(ctx, '_download_content') and not ctx.skip_run:
img = Image.open(io.BytesIO(ctx._download_content))
img = ImageOps.exif_transpose(img).convert("RGB")
ctx.pil_img["original"] = img
ctx.pil_img["rmbg"] = img
ctx.pil_img["rmbg_size"] = img.size
ctx.color_flags["rmbg"] = "#ffffff"
batch_logs.append({
"function": function_name,
"status": "skipped",
"data": {
"reason": "rmbg_model_not_available",
"message": "Background removal skipped - RMBG model not loaded"
}
})
return batch_logs
valid_contexts = []
valid_metas = []
for ctx in contexts:
if ctx.skip_run or ctx.skip_processing or not hasattr(ctx, '_rmbg_meta'):
continue
valid_contexts.append(ctx)
valid_metas.append(ctx._rmbg_meta)
if not valid_contexts:
batch_logs.append({
"function": function_name,
"status": "error",
"data": {
"batch_size": 0,
"device": DEVICE,
"reason": "no_valid_contexts"
}
})
return batch_logs
model_device = next(RMBG_MODEL.parameters()).device if RMBG_MODEL else torch.device(DEVICE)
model_dtype = torch.float16 if not RMBG_FULL_PRECISION and DEVICE == "cuda" else torch.float32
total_pixels = sum(m['final_side'] ** 2 for m in valid_metas)
total_megapixels = total_pixels / (1024 * 1024)
logging.log(LOG_LEVEL_MAP["INFO"],
f"{EMOJI_MAP['INFO']} Model Configuration:")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Device: {model_device} | Dtype: {model_dtype}")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Matmul precision: high | TF32: {torch.backends.cuda.matmul.allow_tf32 if DEVICE == 'cuda' and torch.cuda.is_available() else 'N/A'}")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Total images: {len(valid_contexts)} | Total megapixels: {total_megapixels:.2f} MP")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Model compiled: {hasattr(RMBG_MODEL, '_dynamo_orig_callable') or hasattr(RMBG_MODEL, 'graph')}")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - CUDA available: {torch.cuda.is_available()} | Current device: {torch.cuda.current_device() if torch.cuda.is_available() else 'N/A'}")
RMBG_MODEL.eval()
if not RMBG_FULL_PRECISION and DEVICE == "cuda":
RMBG_MODEL = RMBG_MODEL.half()
logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Model converted to FP16 for faster inference")
if DEVICE == "cuda":
try:
RMBG_MODEL = RMBG_MODEL.to(memory_format=torch.channels_last)
logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Model converted to channels_last memory format")
except Exception as e:
logging.debug(f"Could not convert to channels_last: {e}")
size_groups = {}
for ctx, meta in zip(valid_contexts, valid_metas):
size_key = f"{meta['final_side']}x{meta['final_side']}"
if size_key not in size_groups:
size_groups[size_key] = []
size_groups[size_key].append((ctx, meta))
logging.log(LOG_LEVEL_MAP["INFO"],
f"{EMOJI_MAP['INFO']} Image size distribution:")
for size_key, items in size_groups.items():
megapixels = (int(size_key.split('x')[0]) ** 2) / (1024 * 1024)
logging.log(LOG_LEVEL_MAP["INFO"],
f" - {size_key}: {len(items)} images ({megapixels:.2f} MP each)")
total_processed = 0
batch_number = 1
cumulative_time = 0
cumulative_inference_time = 0
for size_key, group_items in size_groups.items():
for i in range(0, len(group_items), MAX_IMAGES_PER_BATCH):
batch_start_time = time.perf_counter()
batch_items = group_items[i:i + MAX_IMAGES_PER_BATCH]
batch_contexts = [item[0] for item in batch_items]
batch_metas = [item[1] for item in batch_items]
batch_size = len(batch_items)
rbc_scales = batch_metas[0]["rbc_scales"]
final_side = batch_metas[0]["final_side"]
batch_log_item = {
"function": f"{function_name}_batch_{batch_number}",
"status": "processing",
"data": {
"batch_number": batch_number,
"batch_size": batch_size,
"device": DEVICE,
"model_device": str(model_device),
"model_dtype": str(model_dtype),
"tensor_size": size_key,
"scales_processing": rbc_scales,
"scale_count": len(rbc_scales),
"precision": "fp16" if model_dtype == torch.float16 else "fp32",
"total_operations": batch_size * len(rbc_scales),
"megapixels_per_image": (final_side ** 2) / (1024 * 1024),
"total_megapixels": (batch_size * len(rbc_scales) * final_side ** 2) / (1024 * 1024)
}
}
batch_logs.append(batch_log_item)
batch_megapixels = (batch_size * final_side ** 2) / (1024 * 1024)
logging.log(LOG_LEVEL_MAP["PROCESSING"],
f"{EMOJI_MAP['PROCESSING']} Batch {batch_number} starting:")
logging.log(LOG_LEVEL_MAP["PROCESSING"],
f" - Images: {batch_size} | Resolution: {final_side}x{final_side} | Scales: {len(rbc_scales)}")
logging.log(LOG_LEVEL_MAP["PROCESSING"],
f" - Total operations: {batch_size * len(rbc_scales)} | Megapixels: {batch_megapixels:.2f} MP")
try:
inference_start = time.perf_counter()
scale_results = {}
with torch.no_grad():
use_amp = not RMBG_FULL_PRECISION and DEVICE == "cuda" and getattr(app, 'USE_MIXED_PRECISION', True)
autocast_context = torch.amp.autocast('cuda', dtype=torch.float16) if use_amp else contextlib.nullcontext()
with autocast_context:
for scale_idx, scale in enumerate(rbc_scales):
scale_start = time.perf_counter()
logging.log(LOG_LEVEL_MAP["PROCESSING"],
f"{EMOJI_MAP['PROCESSING']} Processing scale {scale}x for {batch_size} images...")
scale_tensors = []
for ctx_idx, (ctx, meta) in enumerate(batch_items):
enhanced_img = meta["enhanced"]
if scale == 1.0:
scaled_img = enhanced_img
else:
new_size = int(final_side * scale)
scaled_img = enhanced_img.resize(
(new_size, new_size),
Image.Resampling.LANCZOS
)
if new_size != final_side:
pad_img = Image.new("RGB", (final_side, final_side), PAD_COLOR)
offset = ((final_side - new_size) // 2,
(final_side - new_size) // 2)
pad_img.paste(scaled_img, offset)
scaled_img = pad_img
tensor = rmbg_trans(scaled_img)
scale_tensors.append(tensor)
batch_tensor = torch.stack(scale_tensors).to(model_device, dtype=model_dtype)
if DEVICE == "cuda" and torch.cuda.is_available():
try:
batch_tensor = batch_tensor.to(memory_format=torch.channels_last)
except RuntimeError:
pass
model_start = time.perf_counter()
graph_key = (batch_size, final_side)
use_cuda_graph = (ENABLE_CUDA_GRAPHS and DEVICE == "cuda" and
torch.cuda.is_available() and not RMBG_FULL_PRECISION)
if use_cuda_graph and graph_key in cuda_graph_cache and stream is not None:
graph, static_input, static_output = cuda_graph_cache[graph_key]
static_input.copy_(batch_tensor)
graph.replay()
logits = static_output.clone()
else:
if use_amp:
logits = RMBG_MODEL(batch_tensor)
else:
with torch.amp.autocast('cuda', enabled=False):
logits = RMBG_MODEL(batch_tensor)
if use_cuda_graph and len(cuda_graph_cache) < 5 and stream is not None:
try:
torch.cuda.synchronize()
static_input = batch_tensor.clone()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_output = RMBG_MODEL(static_input)
cuda_graph_cache[graph_key] = (graph, static_input, static_output)
logging.debug(f"CUDA graph cached for shape {graph_key}")
except (RuntimeError, AttributeError) as e:
logging.debug(f"CUDA graph capture failed: {e}")
if DEVICE == "cuda" and torch.cuda.is_available():
try:
torch.cuda.synchronize()
except RuntimeError:
pass
model_time = round(time.perf_counter() - model_start, 3)
if isinstance(logits, (list, tuple)):
logits = logits[-1]
if logits.shape[1] != 1:
logits = logits[:, 1:2]
probs = torch.sigmoid(logits).cpu().float().numpy()[:, 0]
for ctx_idx, prob in enumerate(probs):
if ctx_idx not in scale_results:
scale_results[ctx_idx] = {}
scale_results[ctx_idx][scale] = prob
scale_time = round(time.perf_counter() - scale_start, 3)
images_per_second = batch_size / model_time if model_time > 0 else 0
megapixels_per_second = ((batch_size * final_side ** 2) / (1024 * 1024)) / model_time if model_time > 0 else 0
logging.log(LOG_LEVEL_MAP["SUCCESS"],
f"{EMOJI_MAP['SUCCESS']} Scale {scale}x completed:")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Time: {scale_time}s (model: {model_time}s)")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Speed: {images_per_second:.2f} img/s | {megapixels_per_second:.2f} MP/s")
del batch_tensor, logits
if DEVICE == "cuda" and torch.cuda.is_available() and len(rbc_scales) > 2 and scale_idx < len(rbc_scales) - 1:
try:
torch.cuda.empty_cache()
except RuntimeError:
pass
inference_time = round(time.perf_counter() - inference_start, 3)
avg_time_per_scale = inference_time / len(rbc_scales) if len(rbc_scales) > 0 else 0
total_operations = batch_size * len(rbc_scales)
ops_per_second = total_operations / inference_time if inference_time > 0 else 0
logging.log(LOG_LEVEL_MAP["SUCCESS"],
f"{EMOJI_MAP['SUCCESS']} Inference completed:")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Total time: {inference_time}s | Avg per scale: {avg_time_per_scale:.3f}s")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Operations: {total_operations} | Speed: {ops_per_second:.2f} ops/s")
postprocess_start = time.perf_counter()
logging.log(LOG_LEVEL_MAP["PROCESSING"],
f"{EMOJI_MAP['PROCESSING']} Post-processing {batch_size} masks...")
for ctx_idx, (ctx, meta) in enumerate(batch_items):
if ctx_idx in scale_results:
scale_probs = []
scale_weights = []
for scale in sorted(rbc_scales, reverse=True):
if scale in scale_results[ctx_idx]:
scale_probs.append(scale_results[ctx_idx][scale])
if scale >= 1.25:
scale_weights.append(0.7)
elif scale >= 1.0:
scale_weights.append(0.6)
elif scale >= 0.75:
scale_weights.append(0.3)
else:
scale_weights.append(0.1)
if scale_probs and scale_weights:
weights = np.array(scale_weights)
weights = weights / weights.sum()
combined_prob = np.zeros_like(scale_probs[0])
for prob, weight in zip(scale_probs, weights):
combined_prob += prob * weight
else:
combined_prob = np.zeros((final_side, final_side))
else:
combined_prob = np.zeros((final_side, final_side))
mask = (combined_prob > THRESH).astype(np.uint8) * 255
rescue_mask = (combined_prob > RESCUE_THRESH).astype(np.uint8) * 255
if np.sum(mask) == 0 and np.sum(rescue_mask) > 0:
mask = rescue_mask
batch_log_item["data"][f"rescue_applied_{ctx_idx}"] = True
mask = smooth_mask(mask)
alpha = Image.fromarray(mask, "L")
r, g, b = meta["pad"].split()
rgba = Image.merge("RGBA", [r, g, b, alpha])
ctx.pil_img = {"original": [rgba, ctx.filename or "output.webp", None]}
total_processed += 1
postprocess_time = round(time.perf_counter() - postprocess_start, 3)
batch_processing_time = round(time.perf_counter() - batch_start_time, 4)
cumulative_time += batch_processing_time
cumulative_inference_time += inference_time
batch_images_per_second = batch_size / batch_processing_time if batch_processing_time > 0 else 0
batch_megapixels = (batch_size * final_side ** 2) / (1024 * 1024)
batch_mp_per_second = batch_megapixels / batch_processing_time if batch_processing_time > 0 else 0
batch_log_item["status"] = "completed"
batch_log_item["data"]["processed_count"] = len(batch_contexts)
batch_log_item["data"]["processing_time"] = batch_processing_time
batch_log_item["data"]["inference_time"] = inference_time
batch_log_item["data"]["postprocess_time"] = postprocess_time
batch_log_item["data"]["images_per_second"] = batch_images_per_second
batch_log_item["data"]["megapixels_per_second"] = batch_mp_per_second
logging.log(LOG_LEVEL_MAP["SUCCESS"],
f"{EMOJI_MAP['SUCCESS']} Batch {batch_number} completed:")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Total time: {batch_processing_time}s")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Breakdown: inference={inference_time}s, post={postprocess_time}s")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Performance: {batch_images_per_second:.2f} img/s | {batch_mp_per_second:.2f} MP/s")
if DEVICE == "cuda" and torch.cuda.is_available():
try:
allocated_mb = torch.cuda.memory_allocated(0) / (1024 * 1024)
reserved_mb = torch.cuda.memory_reserved(0) / (1024 * 1024)
except RuntimeError:
allocated_mb = 0
reserved_mb = 0
logging.log(LOG_LEVEL_MAP["INFO"],
f" - GPU Memory: allocated={allocated_mb:.1f}MB, reserved={reserved_mb:.1f}MB")
if DEVICE == "cuda" and torch.cuda.is_available() and batch_number % 5 == 0:
try:
torch.cuda.empty_cache()
except RuntimeError:
pass
except Exception as e:
batch_processing_time = round(time.perf_counter() - batch_start_time, 4)
error_trace = traceback.format_exc()
logging.error(f"CRITICAL: Background removal processing failed: {str(e)}")
logging.error(f"Traceback:\n{error_trace}")
for ctx in batch_contexts:
ctx.skip_run = True
ctx.error = str(e)
ctx.error_traceback = error_trace
batch_log_item["status"] = "critical_error"
batch_log_item["exception"] = str(e)
batch_log_item["traceback"] = error_trace
batch_log_item["data"]["processing_time"] = batch_processing_time
logging.critical("Terminating due to background removal processing failure")
sys.exit(1)
batch_number += 1
overall_images_per_second = total_processed / cumulative_time if cumulative_time > 0 else 0
overall_mp_per_second = total_megapixels / cumulative_time if cumulative_time > 0 else 0
avg_time_per_image = cumulative_time / total_processed if total_processed > 0 else 0
logging.log(LOG_LEVEL_MAP["SUCCESS"],
f"{EMOJI_MAP['SUCCESS']} {function_name} - Final Summary:")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Total images processed: {total_processed} / {len(valid_contexts)}")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Total time: {cumulative_time:.2f}s | Inference time: {cumulative_inference_time:.2f}s")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Overall performance: {overall_images_per_second:.2f} img/s | {overall_mp_per_second:.2f} MP/s")
logging.log(LOG_LEVEL_MAP["INFO"],
f" - Average time per image: {avg_time_per_image:.3f}s")
batch_logs.append({
"function": f"{function_name}_summary",
"status": "info",
"data": {
"total_processed": total_processed,
"total_contexts": len(valid_contexts),
"total_batches": batch_number - 1,
"size_groups": {str(k): len(v) for k, v in size_groups.items()},
"total_time": cumulative_time,
"total_inference_time": cumulative_inference_time,
"overall_images_per_second": overall_images_per_second,
"overall_megapixels_per_second": overall_mp_per_second,
"average_time_per_image": avg_time_per_image,
"optimizations_applied": {
"fp16_inference": not RMBG_FULL_PRECISION and DEVICE == "cuda",
"matmul_precision": "high",
"tf32_enabled": torch.backends.cuda.matmul.allow_tf32 if DEVICE == "cuda" and torch.cuda.is_available() else False,
"cudnn_benchmark": torch.backends.cudnn.benchmark if DEVICE == "cuda" and torch.cuda.is_available() else False
}
}
})
return batch_logs
def select_largest_component(contexts, batch_logs):
function_name = "select_largest_component"
processed_count = 0
skipped_count = 0
error_count = 0
for ctx in contexts:
if ctx.skip_run or ctx.skip_processing:
skipped_count += 1
continue
if "original" not in ctx.pil_img:
error_count += 1
continue
try:
final_rgba, filename, meta = ctx.pil_img["original"]
alpha_np = np.array(final_rgba.getchannel("A"))
labeled, num_components = label(alpha_np > 0)
if num_components <= 1:
processed_count += 1
continue
regs = find_objects(labeled)
best_component = None
best_area = 0
for reg in regs:
if reg is None:
continue
sy, ey = reg[0].start, reg[0].stop
sx, ex = reg[1].start, reg[1].stop
area = np.count_nonzero(alpha_np[sy:ey, sx:ex] > 0)
if area > best_area:
center_x = (sx + ex) / 2.0
if center_x < final_rgba.width / 2 or best_area == 0:
best_area = area
best_component = (sx, sy, ex, ey)
if best_component:
sx, sy, ex, ey = best_component
final_image = final_rgba.crop((sx, sy, ex, ey))
ctx.pil_img["original"] = [final_image, filename, meta]
processed_count += 1
except Exception as e:
error_count += 1
ctx.skip_run = True
batch_logs.append({
"function": function_name,
"status": "info",
"data": {
"processed_count": processed_count,
"skipped_count": skipped_count,
"error_count": error_count
}
})
return batch_logs
def final_pad_sq_batch(contexts, batch_logs):
function_name = "final_pad_sq_batch"
processed_count = 0
skipped_count = 0
error_count = 0
for ctx in contexts:
if ctx.skip_run or ctx.skip_processing:
skipped_count += 1
continue
if "original" not in ctx.pil_img:
error_count += 1
continue
try:
image, filename, meta = ctx.pil_img["original"]
padded_image = final_pad_sq(image)
ctx.pil_img["original"] = [padded_image, filename, meta]
processed_count += 1
except Exception as e:
error_count += 1
batch_logs.append({
"function": function_name,
"image_url": ctx.url,
"status": "error",
"exception": str(e)
})
batch_logs.append({
"function": function_name,
"status": "info",
"data": {
"processed_count": processed_count,
"skipped_count": skipped_count,
"error_count": error_count
}
})
return batch_logs
# ----------------------------------------------------------------------
# MAIN PIPELINE FUNCTION
# ----------------------------------------------------------------------
def _ensure_models_loaded():
import app
app.ensure_models_loaded()
pipeline_step = create_pipeline_step(_ensure_models_loaded)
@pipeline_step
def remove_background(
contexts: List[ProcessingContext],
batch_logs: List[dict] | None = None
):
if batch_logs is None:
batch_logs = []
pipeline_start_time = time.perf_counter()
calibration_info = {
"function": "calibration_info",
"status": "info",
"data": {
"version": CALIBRATION_VERSION,
"contrast_factor": RBC_CONTRAST_FACTOR,
"sharpness_factor": RBC_SHARPNESS_FACTOR,
"size_configs": SIZE_CONFIGS,
"adaptive_scale_configs": ADAPTIVE_SCALE_CONFIGS,
"threshold": THRESH,
"max_images_per_batch": MAX_IMAGES_PER_BATCH,
"morph_kernel_size": MORPH_KERNEL_SIZE,
"morph_close_iter": MORPH_CLOSE_ITER,
"morph_open_iter": MORPH_OPEN_ITER,
"erosion_iter": EROSION_ITER,
"gaussian_kernel_size": GAUSSIAN_KERNEL_SIZE,
"do_guided_filter": DO_GUIDED_FILTER,
"fill_holes": FILL_HOLES,
"use_bilateral": USE_BILATERAL,
"low_load_surface_threshold": LOW_LOAD_SURFACE_THRESHOLD,
"medium_load_surface_threshold": MEDIUM_LOAD_SURFACE_THRESHOLD,
"high_load_surface_threshold": HIGH_LOAD_SURFACE_THRESHOLD
}
}
batch_logs.append(calibration_info)
logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Starting remove_background pipeline for {len(contexts)} items")
valid_contexts = [ctx for ctx in contexts if not (ctx.skip_run or ctx.skip_processing) and hasattr(ctx, '_download_content')]
total_batches = (len(valid_contexts) + MAX_IMAGES_PER_BATCH - 1) // MAX_IMAGES_PER_BATCH if valid_contexts else 0
batch_processing_log = {
"function": "batch_processing_info",
"status": "info",
"data": {
"total_contexts": len(contexts),
"valid_contexts": len(valid_contexts),
"max_images_per_batch": MAX_IMAGES_PER_BATCH,
"estimated_batches": total_batches,
"adaptive_scaling_enabled": len(valid_contexts) >= MEDIUM_LOAD_SURFACE_THRESHOLD,
"next_step": "preprocess_images_batch"
}
}
batch_logs.append(batch_processing_log)
logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Processing {len(valid_contexts)} valid contexts in {total_batches} batches")
preprocess_images_batch(contexts, batch_logs)
process_background_removal(contexts, batch_logs)
batch_completion_log = {
"function": "background_removal_completed",
"status": "info",
"data": {
"completed_contexts": len([ctx for ctx in contexts if "original" in ctx.pil_img]),
"total_contexts": len(contexts),
"next_step": "select_largest_component"
}
}
batch_logs.append(batch_completion_log)
select_largest_component(contexts, batch_logs)
final_pad_sq_batch(contexts, batch_logs)
total_pipeline_time = round(time.perf_counter() - pipeline_start_time, 4)
logging.log(LOG_LEVEL_MAP["SUCCESS"], f"{EMOJI_MAP['SUCCESS']} Completed remove_background pipeline for {len(contexts)} items in {total_pipeline_time}s")
return batch_logs