# ---------------------------------------------------------------------- # IMPORTS # ---------------------------------------------------------------------- import io import os import time import numpy as np import torch import cv2 import logging import sys import traceback from PIL import Image, ImageOps, ImageEnhance from torchvision import transforms from scipy.ndimage import label as scipy_label, find_objects as scipy_find_objects from typing import List from src.utils import ProcessingContext, create_pipeline_step, LOG_LEVEL_MAP, EMOJI_MAP # ---------------------------------------------------------------------- # GLOBAL CONSTANTS # ---------------------------------------------------------------------- RBC_CONTRAST_FACTOR = 1.25 RBC_SHARPNESS_FACTOR = 1.15 PAD_COLOR = "#ffffff" THRESH = 0.42 RESCUE_THRESH = 0.20 MAX_IMAGES_PER_BATCH = 4 MORPH_KERNEL_SIZE = (3, 3) MORPH_CLOSE_ITER = 1 MORPH_OPEN_ITER = 1 EROSION_ITER = 1 GAUSSIAN_KERNEL_SIZE = (7, 7) DO_GUIDED_FILTER = True FILL_HOLES = False USE_BILATERAL = True CALIBRATION_VERSION = "v16_balanced" BLUE_BACKGROUND_HSV_LOWER = np.array([100, 100, 80]) BLUE_BACKGROUND_HSV_UPPER = np.array([130, 255, 255]) BLUE_BACKGROUND_THRESHOLD = 0.25 SKIN_HSV_LOWER_1 = np.array([0, 20, 70]) SKIN_HSV_UPPER_1 = np.array([30, 180, 255]) SKIN_HSV_LOWER_2 = np.array([0, 10, 40]) SKIN_HSV_UPPER_2 = np.array([25, 150, 200]) SKIN_THRESHOLD = 0.10 DENIM_HSV_LOWER_1 = np.array([80, 30, 30]) DENIM_HSV_UPPER_1 = np.array([130, 220, 230]) DENIM_HSV_LOWER_2 = np.array([85, 15, 100]) DENIM_HSV_UPPER_2 = np.array([115, 130, 230]) DENIM_THRESHOLD = 0.12 OUTDOOR_GREEN_HSV_LOWER = np.array([35, 40, 40]) OUTDOOR_GREEN_HSV_UPPER = np.array([85, 255, 255]) OUTDOOR_SKY_HSV_LOWER = np.array([90, 40, 150]) OUTDOOR_SKY_HSV_UPPER = np.array([130, 120, 255]) OUTDOOR_THRESHOLD = 0.15 DIFFICULTY_CONTRAST_THRESHOLD = 50 DIFFICULTY_EDGE_THRESHOLD = 0.05 EDGE_ENHANCE_WEIGHT_ORIG = 0.7 EDGE_ENHANCE_WEIGHT_SHARP = 0.3 GUIDED_FILTER_RADIUS = 4 GUIDED_FILTER_EPSILON = 0.01 BILATERAL_FILTER_D = 9 BILATERAL_FILTER_SIGMA_COLOR = 75 BILATERAL_FILTER_SIGMA_SPACE = 75 CANNY_THRESHOLD_LOW = 100 CANNY_THRESHOLD_HIGH = 200 TARGET_SIZE_SCALING_FACTOR = 0.98 ASPECT_RATIO_THRESHOLD = 4/3 PADDING_SCALE_FACTOR = 0.98 OUTDOOR_CONTRAST_MULTIPLIER = 1.1 PERSON_CONTRAST_MULTIPLIER = 0.75 LOW_LOAD_SURFACE_THRESHOLD = 12_000_000 MEDIUM_LOAD_SURFACE_THRESHOLD = 25_000_000 HIGH_LOAD_SURFACE_THRESHOLD = 45_000_000 # ---------------------------------------------------------------------- # DYNAMIC SIZING CONFIGURATION # ---------------------------------------------------------------------- SIZE_CONFIGS = { "small": { "final_side": 1024, "rbc_scales": [1.25, 1.0] }, "medium": { "final_side": 1536, "rbc_scales": [1.0, 0.75] }, "large": { "final_side": 2048, "rbc_scales": [1.0] } } ADAPTIVE_SCALE_CONFIGS = { "small_fast": {"final_side": 1024, "rbc_scales": [1.0]}, "small_minimal": {"final_side": 1024, "rbc_scales": [1.0]}, "medium_fast": {"final_side": 1536, "rbc_scales": [1.0]}, "medium_minimal": {"final_side": 1536, "rbc_scales": [1.0]}, "large_fast": {"final_side": 2048, "rbc_scales": [1.0]}, "large_minimal": {"final_side": 2048, "rbc_scales": [1.0]} } SIZE_CONFIG_SCALE = [1024, 1536] # ---------------------------------------------------------------------- # PRE-COMPUTED OBJECTS # ---------------------------------------------------------------------- morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, MORPH_KERNEL_SIZE) rmbg_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) edge_enhance_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) # GPU-accelerated normalization if available try: import torchvision.transforms.v2 as v2_transforms USE_V2_TRANSFORMS = True except ImportError: USE_V2_TRANSFORMS = False # ---------------------------------------------------------------------- # UTILS # ---------------------------------------------------------------------- def label(mask): return scipy_label(mask) def find_objects(lbl): return scipy_find_objects(lbl) def guided_filter(I, p, r, eps): mean_I = cv2.boxFilter(I, cv2.CV_64F, (r, r)) mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r)) mean_Ip = cv2.boxFilter(I * p, cv2.CV_64F, (r, r)) cov_Ip = mean_Ip - mean_I * mean_p mean_II = cv2.boxFilter(I * I, cv2.CV_64F, (r, r)) var_I = mean_II - mean_I * mean_I a = cov_Ip / (var_I + eps) b = mean_p - a * mean_I mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r)) mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r)) return mean_a * I + mean_b def detect_blue_background(hsv): blue_mask = cv2.inRange(hsv, BLUE_BACKGROUND_HSV_LOWER, BLUE_BACKGROUND_HSV_UPPER) return float(np.mean(blue_mask > 0)) > BLUE_BACKGROUND_THRESHOLD def detect_skin_tones(hsv): m1 = cv2.inRange(hsv, SKIN_HSV_LOWER_1, SKIN_HSV_UPPER_1) m2 = cv2.inRange(hsv, SKIN_HSV_LOWER_2, SKIN_HSV_UPPER_2) return float(np.mean(cv2.bitwise_or(m1, m2) > 0)) > SKIN_THRESHOLD def detect_denim(hsv): m1 = cv2.inRange(hsv, DENIM_HSV_LOWER_1, DENIM_HSV_UPPER_1) m2 = cv2.inRange(hsv, DENIM_HSV_LOWER_2, DENIM_HSV_UPPER_2) return float(np.mean(cv2.bitwise_or(m1, m2) > 0)) > DENIM_THRESHOLD def detect_outdoor_scene(hsv): g = cv2.inRange(hsv, OUTDOOR_GREEN_HSV_LOWER, OUTDOOR_GREEN_HSV_UPPER) s = cv2.inRange(hsv, OUTDOOR_SKY_HSV_LOWER, OUTDOOR_SKY_HSV_UPPER) return float(np.mean(cv2.bitwise_or(g, s) > 0)) > OUTDOOR_THRESHOLD def analyze_image_difficulty(gray): hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten() hist /= hist.sum() contrast = float(np.std(hist * np.arange(256))) edge_ratio = float(np.mean(cv2.Canny(gray, CANNY_THRESHOLD_LOW, CANNY_THRESHOLD_HIGH) > 0)) return bool(contrast < DIFFICULTY_CONTRAST_THRESHOLD and edge_ratio < DIFFICULTY_EDGE_THRESHOLD), contrast, edge_ratio def enhance_edges(img_rgb): sharp = cv2.filter2D(img_rgb, -1, edge_enhance_kernel) return cv2.addWeighted(img_rgb, EDGE_ENHANCE_WEIGHT_ORIG, sharp, EDGE_ENHANCE_WEIGHT_SHARP, 0) def smooth_mask(mask): if DO_GUIDED_FILTER: f = guided_filter(mask.astype(np.float64) / 255.0, mask.astype(np.float64) / 255.0, GUIDED_FILTER_RADIUS, GUIDED_FILTER_EPSILON) mask = (f * 255).clip(0, 255).astype(np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, morph_kernel, iterations=MORPH_CLOSE_ITER) if MORPH_OPEN_ITER: mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, morph_kernel, iterations=MORPH_OPEN_ITER) if EROSION_ITER: mask = cv2.erode(mask, morph_kernel, iterations=EROSION_ITER) if FILL_HOLES: contours, _ = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) for cnt in contours: cv2.drawContours(mask, [cnt], -1, 255, thickness=cv2.FILLED) if USE_BILATERAL: mask = cv2.bilateralFilter(mask, BILATERAL_FILTER_D, BILATERAL_FILTER_SIGMA_COLOR, BILATERAL_FILTER_SIGMA_SPACE) return cv2.GaussianBlur(mask, GAUSSIAN_KERNEL_SIZE, 0) def calculate_total_surface_area(contexts): total_surface = 0 for ctx in contexts: if ctx.skip_run or ctx.skip_processing or not hasattr(ctx, '_download_content'): continue try: img = Image.open(io.BytesIO(ctx._download_content)) w, h = img.size total_surface += w * h except: continue return total_surface def determine_adaptive_size_config(width, height, total_surface_area): max_dim = max(width, height) if max_dim < SIZE_CONFIG_SCALE[0]: base_config = "small" elif SIZE_CONFIG_SCALE[0] <= max_dim <= SIZE_CONFIG_SCALE[1]: base_config = "medium" else: base_config = "large" if total_surface_area >= HIGH_LOAD_SURFACE_THRESHOLD: config_key = f"{base_config}_minimal" elif total_surface_area >= MEDIUM_LOAD_SURFACE_THRESHOLD: config_key = f"{base_config}_fast" else: config_key = base_config if config_key in ADAPTIVE_SCALE_CONFIGS: return config_key, ADAPTIVE_SCALE_CONFIGS[config_key] else: return base_config, SIZE_CONFIGS[base_config] def determine_size_config(width, height): max_dim = max(width, height) if max_dim < SIZE_CONFIG_SCALE[0]: return "small", SIZE_CONFIGS["small"] elif SIZE_CONFIG_SCALE[0] <= max_dim <= SIZE_CONFIG_SCALE[1]: return "medium", SIZE_CONFIGS["medium"] else: return "large", SIZE_CONFIGS["large"] def calculate_optimal_scale(original_width, original_height, target_side): max_original = max(original_width, original_height) if max_original <= target_side: return 1.0 return (target_side * TARGET_SIZE_SCALING_FACTOR) / max_original def final_pad_sq(im): w, h = im.size if w / h > ASPECT_RATIO_THRESHOLD: side = int(round(h * PADDING_SCALE_FACTOR)) l = (w - side) // 2 return im.crop((l, (h - side) // 2, l + side, (h + side) // 2)) side = max(w, h) new_im = Image.new("RGBA", (side, side), (0, 0, 0, 0)) new_im.paste(im, ((side - w) // 2, (side - h) // 2)) return new_im # ---------------------------------------------------------------------- # CORE IMPLEMENTATION # ---------------------------------------------------------------------- def preprocess_images_batch(contexts, batch_logs): function_name = "preprocess_images_batch" processed_count = 0 skipped_count = 0 error_count = 0 total_surface_area = 0 t0 = time.perf_counter() for ctx in contexts: log_item = { "image_url": ctx.url, "function": function_name, "data": {} } if ctx.skip_run or ctx.skip_processing: log_item["status"] = "skipped" log_item["data"]["reason"] = "marked_for_skip_or_no_downloaded_image" batch_logs.append(log_item) skipped_count += 1 continue if not hasattr(ctx, '_download_content'): log_item["status"] = "error" log_item["exception"] = "No downloaded content found" log_item["data"]["reason"] = "missing_download_content" batch_logs.append(log_item) ctx.skip_run = True error_count += 1 continue try: process_start = time.perf_counter() img = Image.open(io.BytesIO(ctx._download_content)) img = ImageOps.exif_transpose(img).convert("RGB") ow, oh = img.size ctx._orig_size = (ow, oh) total_surface_area += ow * oh size_config_name, size_config = determine_size_config(ow, oh) final_side = size_config["final_side"] rbc_scales = size_config["rbc_scales"] scale = calculate_optimal_scale(ow, oh, final_side) if scale < 1: img = img.resize((int(ow * scale), int(oh * scale)), Image.Resampling.LANCZOS) ow, oh = img.size pad = Image.new("RGB", (final_side, final_side), PAD_COLOR) pad.paste(img, ((final_side - ow) // 2, (final_side - oh) // 2)) hsv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2HSV) gray = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY) has_blue = detect_blue_background(hsv) has_person = detect_skin_tones(hsv) has_denim = detect_denim(hsv) is_outdoor = detect_outdoor_scene(hsv) is_diff, contrast, edge_ratio = analyze_image_difficulty(gray) contrast_multiplier = 1.0 if is_outdoor: contrast_multiplier = OUTDOOR_CONTRAST_MULTIPLIER elif has_person: contrast_multiplier = PERSON_CONTRAST_MULTIPLIER enhanced = ImageEnhance.Contrast(pad).enhance(RBC_CONTRAST_FACTOR * contrast_multiplier) enhanced = ImageEnhance.Sharpness(enhanced).enhance(RBC_SHARPNESS_FACTOR) edge_enh = Image.fromarray(enhance_edges(np.array(pad))) final_enhanced = edge_enh if is_diff else enhanced ctx._rmbg_meta = { "pad": pad, "enhanced": final_enhanced, "flags": (has_blue, has_person, has_denim, is_outdoor, is_diff), "size_config": size_config_name, "final_side": final_side, "rbc_scales": rbc_scales } process_time = time.perf_counter() - process_start log_item["status"] = "ok" log_item["data"].update({ "orig_size": ctx._orig_size, "processed_size": (final_side, final_side), "size_config": size_config_name, "scale_applied": round(scale, 4), "contrast": contrast, "edge_ratio": edge_ratio, "processing_time": round(process_time, 4), "rbc_scales": rbc_scales, "flags": { "has_blue": has_blue, "has_person": has_person, "has_denim": has_denim, "is_outdoor": is_outdoor, "is_difficult": is_diff } }) processed_count += 1 del ctx._download_content except Exception as e: log_item["status"] = "error" log_item["exception"] = str(e) log_item["data"]["processing_time"] = round(time.perf_counter() - process_start, 4) if 'process_start' in locals() else 0 ctx.skip_run = True error_count += 1 batch_logs.append(log_item) for ctx in contexts: if hasattr(ctx, '_rmbg_meta'): size_config_name, size_config = determine_adaptive_size_config( ctx._orig_size[0], ctx._orig_size[1], total_surface_area ) ctx._rmbg_meta["size_config"] = size_config_name ctx._rmbg_meta["rbc_scales"] = size_config["rbc_scales"] preprocess_summary = { "function": "preprocess_summary", "status": "info", "data": { "total_time": round(time.perf_counter() - t0, 4), "processed_count": processed_count, "skipped_count": skipped_count, "error_count": error_count, "total_surface_area": total_surface_area, "surface_area_mb": round(total_surface_area / 1_000_000, 2), "success_rate": f"{processed_count/(processed_count+error_count):.2%}" if (processed_count + error_count) > 0 else "N/A" } } batch_logs.append(preprocess_summary) return batch_logs def process_background_removal(contexts, batch_logs): import app import contextlib function_name = "process_background_removal" from src.models import model_loader RMBG_MODEL = model_loader.RMBG_MODEL DEVICE = model_loader.DEVICE MODELS_LOADED = model_loader.MODELS_LOADED LOAD_ERROR = model_loader.LOAD_ERROR RMBG_FULL_PRECISION = model_loader.RMBG_FULL_PRECISION ENABLE_CUDA_GRAPHS = model_loader.ENABLE_CUDA_GRAPHS logging.info(f"Checking model state in {function_name}:") logging.info(f" MODELS_LOADED: {MODELS_LOADED}") logging.info(f" RMBG_MODEL is None: {RMBG_MODEL is None}") logging.info(f" DEVICE: {DEVICE}") cuda_graph_cache = {} stream = None if DEVICE == "cuda" and torch.cuda.is_available(): try: stream = torch.cuda.Stream() except RuntimeError: stream = None torch.set_float32_matmul_precision('high') if DEVICE == "cuda" and torch.cuda.is_available(): try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.enabled = True if hasattr(torch.backends.cuda, 'enable_math_sdp'): torch.backends.cuda.enable_math_sdp(True) if hasattr(torch.backends.cuda, 'enable_flash_sdp'): torch.backends.cuda.enable_flash_sdp(True) if hasattr(torch.backends.cuda, 'enable_mem_efficient_sdp'): torch.backends.cuda.enable_mem_efficient_sdp(True) if hasattr(torch.backends.cudnn, 'preferred_backend'): torch.backends.cudnn.preferred_backend = 'cudnn' except (RuntimeError, AttributeError) as e: logging.debug(f"Some CUDA optimizations not available: {e}") if not MODELS_LOADED and not os.getenv("SPACE_ID"): error_msg = LOAD_ERROR or "Models not loaded" error_trace = traceback.format_exc() logging.error(f"CRITICAL: Model not loaded in {function_name}: {error_msg}") logging.error(f"Traceback:\n{error_trace}") for ctx in contexts: ctx.skip_run = True ctx.error = error_msg ctx.error_traceback = error_trace batch_logs.append({ "function": function_name, "status": "critical_error", "exception": error_msg, "traceback": error_trace }) logging.critical("Terminating due to model loading failure") sys.exit(1) if RMBG_MODEL is None: logging.warning("RMBG model not available - skipping background removal") for ctx in contexts: if hasattr(ctx, '_download_content') and not ctx.skip_run: img = Image.open(io.BytesIO(ctx._download_content)) img = ImageOps.exif_transpose(img).convert("RGB") ctx.pil_img["original"] = img ctx.pil_img["rmbg"] = img ctx.pil_img["rmbg_size"] = img.size ctx.color_flags["rmbg"] = "#ffffff" batch_logs.append({ "function": function_name, "status": "skipped", "data": { "reason": "rmbg_model_not_available", "message": "Background removal skipped - RMBG model not loaded" } }) return batch_logs valid_contexts = [] valid_metas = [] for ctx in contexts: if ctx.skip_run or ctx.skip_processing or not hasattr(ctx, '_rmbg_meta'): continue valid_contexts.append(ctx) valid_metas.append(ctx._rmbg_meta) if not valid_contexts: batch_logs.append({ "function": function_name, "status": "error", "data": { "batch_size": 0, "device": DEVICE, "reason": "no_valid_contexts" } }) return batch_logs model_device = next(RMBG_MODEL.parameters()).device if RMBG_MODEL else torch.device(DEVICE) model_dtype = torch.float16 if not RMBG_FULL_PRECISION and DEVICE == "cuda" else torch.float32 total_pixels = sum(m['final_side'] ** 2 for m in valid_metas) total_megapixels = total_pixels / (1024 * 1024) logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Model Configuration:") logging.log(LOG_LEVEL_MAP["INFO"], f" - Device: {model_device} | Dtype: {model_dtype}") logging.log(LOG_LEVEL_MAP["INFO"], f" - Matmul precision: high | TF32: {torch.backends.cuda.matmul.allow_tf32 if DEVICE == 'cuda' and torch.cuda.is_available() else 'N/A'}") logging.log(LOG_LEVEL_MAP["INFO"], f" - Total images: {len(valid_contexts)} | Total megapixels: {total_megapixels:.2f} MP") logging.log(LOG_LEVEL_MAP["INFO"], f" - Model compiled: {hasattr(RMBG_MODEL, '_dynamo_orig_callable') or hasattr(RMBG_MODEL, 'graph')}") logging.log(LOG_LEVEL_MAP["INFO"], f" - CUDA available: {torch.cuda.is_available()} | Current device: {torch.cuda.current_device() if torch.cuda.is_available() else 'N/A'}") RMBG_MODEL.eval() if not RMBG_FULL_PRECISION and DEVICE == "cuda": RMBG_MODEL = RMBG_MODEL.half() logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Model converted to FP16 for faster inference") if DEVICE == "cuda": try: RMBG_MODEL = RMBG_MODEL.to(memory_format=torch.channels_last) logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Model converted to channels_last memory format") except Exception as e: logging.debug(f"Could not convert to channels_last: {e}") size_groups = {} for ctx, meta in zip(valid_contexts, valid_metas): size_key = f"{meta['final_side']}x{meta['final_side']}" if size_key not in size_groups: size_groups[size_key] = [] size_groups[size_key].append((ctx, meta)) logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Image size distribution:") for size_key, items in size_groups.items(): megapixels = (int(size_key.split('x')[0]) ** 2) / (1024 * 1024) logging.log(LOG_LEVEL_MAP["INFO"], f" - {size_key}: {len(items)} images ({megapixels:.2f} MP each)") total_processed = 0 batch_number = 1 cumulative_time = 0 cumulative_inference_time = 0 for size_key, group_items in size_groups.items(): for i in range(0, len(group_items), MAX_IMAGES_PER_BATCH): batch_start_time = time.perf_counter() batch_items = group_items[i:i + MAX_IMAGES_PER_BATCH] batch_contexts = [item[0] for item in batch_items] batch_metas = [item[1] for item in batch_items] batch_size = len(batch_items) rbc_scales = batch_metas[0]["rbc_scales"] final_side = batch_metas[0]["final_side"] batch_log_item = { "function": f"{function_name}_batch_{batch_number}", "status": "processing", "data": { "batch_number": batch_number, "batch_size": batch_size, "device": DEVICE, "model_device": str(model_device), "model_dtype": str(model_dtype), "tensor_size": size_key, "scales_processing": rbc_scales, "scale_count": len(rbc_scales), "precision": "fp16" if model_dtype == torch.float16 else "fp32", "total_operations": batch_size * len(rbc_scales), "megapixels_per_image": (final_side ** 2) / (1024 * 1024), "total_megapixels": (batch_size * len(rbc_scales) * final_side ** 2) / (1024 * 1024) } } batch_logs.append(batch_log_item) batch_megapixels = (batch_size * final_side ** 2) / (1024 * 1024) logging.log(LOG_LEVEL_MAP["PROCESSING"], f"{EMOJI_MAP['PROCESSING']} Batch {batch_number} starting:") logging.log(LOG_LEVEL_MAP["PROCESSING"], f" - Images: {batch_size} | Resolution: {final_side}x{final_side} | Scales: {len(rbc_scales)}") logging.log(LOG_LEVEL_MAP["PROCESSING"], f" - Total operations: {batch_size * len(rbc_scales)} | Megapixels: {batch_megapixels:.2f} MP") try: inference_start = time.perf_counter() scale_results = {} with torch.no_grad(): use_amp = not RMBG_FULL_PRECISION and DEVICE == "cuda" and getattr(app, 'USE_MIXED_PRECISION', True) autocast_context = torch.amp.autocast('cuda', dtype=torch.float16) if use_amp else contextlib.nullcontext() with autocast_context: for scale_idx, scale in enumerate(rbc_scales): scale_start = time.perf_counter() logging.log(LOG_LEVEL_MAP["PROCESSING"], f"{EMOJI_MAP['PROCESSING']} Processing scale {scale}x for {batch_size} images...") scale_tensors = [] for ctx_idx, (ctx, meta) in enumerate(batch_items): enhanced_img = meta["enhanced"] if scale == 1.0: scaled_img = enhanced_img else: new_size = int(final_side * scale) scaled_img = enhanced_img.resize( (new_size, new_size), Image.Resampling.LANCZOS ) if new_size != final_side: pad_img = Image.new("RGB", (final_side, final_side), PAD_COLOR) offset = ((final_side - new_size) // 2, (final_side - new_size) // 2) pad_img.paste(scaled_img, offset) scaled_img = pad_img tensor = rmbg_trans(scaled_img) scale_tensors.append(tensor) batch_tensor = torch.stack(scale_tensors).to(model_device, dtype=model_dtype) if DEVICE == "cuda" and torch.cuda.is_available(): try: batch_tensor = batch_tensor.to(memory_format=torch.channels_last) except RuntimeError: pass model_start = time.perf_counter() graph_key = (batch_size, final_side) use_cuda_graph = (ENABLE_CUDA_GRAPHS and DEVICE == "cuda" and torch.cuda.is_available() and not RMBG_FULL_PRECISION) if use_cuda_graph and graph_key in cuda_graph_cache and stream is not None: graph, static_input, static_output = cuda_graph_cache[graph_key] static_input.copy_(batch_tensor) graph.replay() logits = static_output.clone() else: if use_amp: logits = RMBG_MODEL(batch_tensor) else: with torch.amp.autocast('cuda', enabled=False): logits = RMBG_MODEL(batch_tensor) if use_cuda_graph and len(cuda_graph_cache) < 5 and stream is not None: try: torch.cuda.synchronize() static_input = batch_tensor.clone() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): static_output = RMBG_MODEL(static_input) cuda_graph_cache[graph_key] = (graph, static_input, static_output) logging.debug(f"CUDA graph cached for shape {graph_key}") except (RuntimeError, AttributeError) as e: logging.debug(f"CUDA graph capture failed: {e}") if DEVICE == "cuda" and torch.cuda.is_available(): try: torch.cuda.synchronize() except RuntimeError: pass model_time = round(time.perf_counter() - model_start, 3) if isinstance(logits, (list, tuple)): logits = logits[-1] if logits.shape[1] != 1: logits = logits[:, 1:2] probs = torch.sigmoid(logits).cpu().float().numpy()[:, 0] for ctx_idx, prob in enumerate(probs): if ctx_idx not in scale_results: scale_results[ctx_idx] = {} scale_results[ctx_idx][scale] = prob scale_time = round(time.perf_counter() - scale_start, 3) images_per_second = batch_size / model_time if model_time > 0 else 0 megapixels_per_second = ((batch_size * final_side ** 2) / (1024 * 1024)) / model_time if model_time > 0 else 0 logging.log(LOG_LEVEL_MAP["SUCCESS"], f"{EMOJI_MAP['SUCCESS']} Scale {scale}x completed:") logging.log(LOG_LEVEL_MAP["INFO"], f" - Time: {scale_time}s (model: {model_time}s)") logging.log(LOG_LEVEL_MAP["INFO"], f" - Speed: {images_per_second:.2f} img/s | {megapixels_per_second:.2f} MP/s") del batch_tensor, logits if DEVICE == "cuda" and torch.cuda.is_available() and len(rbc_scales) > 2 and scale_idx < len(rbc_scales) - 1: try: torch.cuda.empty_cache() except RuntimeError: pass inference_time = round(time.perf_counter() - inference_start, 3) avg_time_per_scale = inference_time / len(rbc_scales) if len(rbc_scales) > 0 else 0 total_operations = batch_size * len(rbc_scales) ops_per_second = total_operations / inference_time if inference_time > 0 else 0 logging.log(LOG_LEVEL_MAP["SUCCESS"], f"{EMOJI_MAP['SUCCESS']} Inference completed:") logging.log(LOG_LEVEL_MAP["INFO"], f" - Total time: {inference_time}s | Avg per scale: {avg_time_per_scale:.3f}s") logging.log(LOG_LEVEL_MAP["INFO"], f" - Operations: {total_operations} | Speed: {ops_per_second:.2f} ops/s") postprocess_start = time.perf_counter() logging.log(LOG_LEVEL_MAP["PROCESSING"], f"{EMOJI_MAP['PROCESSING']} Post-processing {batch_size} masks...") for ctx_idx, (ctx, meta) in enumerate(batch_items): if ctx_idx in scale_results: scale_probs = [] scale_weights = [] for scale in sorted(rbc_scales, reverse=True): if scale in scale_results[ctx_idx]: scale_probs.append(scale_results[ctx_idx][scale]) if scale >= 1.25: scale_weights.append(0.7) elif scale >= 1.0: scale_weights.append(0.6) elif scale >= 0.75: scale_weights.append(0.3) else: scale_weights.append(0.1) if scale_probs and scale_weights: weights = np.array(scale_weights) weights = weights / weights.sum() combined_prob = np.zeros_like(scale_probs[0]) for prob, weight in zip(scale_probs, weights): combined_prob += prob * weight else: combined_prob = np.zeros((final_side, final_side)) else: combined_prob = np.zeros((final_side, final_side)) mask = (combined_prob > THRESH).astype(np.uint8) * 255 rescue_mask = (combined_prob > RESCUE_THRESH).astype(np.uint8) * 255 if np.sum(mask) == 0 and np.sum(rescue_mask) > 0: mask = rescue_mask batch_log_item["data"][f"rescue_applied_{ctx_idx}"] = True mask = smooth_mask(mask) alpha = Image.fromarray(mask, "L") r, g, b = meta["pad"].split() rgba = Image.merge("RGBA", [r, g, b, alpha]) ctx.pil_img = {"original": [rgba, ctx.filename or "output.webp", None]} total_processed += 1 postprocess_time = round(time.perf_counter() - postprocess_start, 3) batch_processing_time = round(time.perf_counter() - batch_start_time, 4) cumulative_time += batch_processing_time cumulative_inference_time += inference_time batch_images_per_second = batch_size / batch_processing_time if batch_processing_time > 0 else 0 batch_megapixels = (batch_size * final_side ** 2) / (1024 * 1024) batch_mp_per_second = batch_megapixels / batch_processing_time if batch_processing_time > 0 else 0 batch_log_item["status"] = "completed" batch_log_item["data"]["processed_count"] = len(batch_contexts) batch_log_item["data"]["processing_time"] = batch_processing_time batch_log_item["data"]["inference_time"] = inference_time batch_log_item["data"]["postprocess_time"] = postprocess_time batch_log_item["data"]["images_per_second"] = batch_images_per_second batch_log_item["data"]["megapixels_per_second"] = batch_mp_per_second logging.log(LOG_LEVEL_MAP["SUCCESS"], f"{EMOJI_MAP['SUCCESS']} Batch {batch_number} completed:") logging.log(LOG_LEVEL_MAP["INFO"], f" - Total time: {batch_processing_time}s") logging.log(LOG_LEVEL_MAP["INFO"], f" - Breakdown: inference={inference_time}s, post={postprocess_time}s") logging.log(LOG_LEVEL_MAP["INFO"], f" - Performance: {batch_images_per_second:.2f} img/s | {batch_mp_per_second:.2f} MP/s") if DEVICE == "cuda" and torch.cuda.is_available(): try: allocated_mb = torch.cuda.memory_allocated(0) / (1024 * 1024) reserved_mb = torch.cuda.memory_reserved(0) / (1024 * 1024) except RuntimeError: allocated_mb = 0 reserved_mb = 0 logging.log(LOG_LEVEL_MAP["INFO"], f" - GPU Memory: allocated={allocated_mb:.1f}MB, reserved={reserved_mb:.1f}MB") if DEVICE == "cuda" and torch.cuda.is_available() and batch_number % 5 == 0: try: torch.cuda.empty_cache() except RuntimeError: pass except Exception as e: batch_processing_time = round(time.perf_counter() - batch_start_time, 4) error_trace = traceback.format_exc() logging.error(f"CRITICAL: Background removal processing failed: {str(e)}") logging.error(f"Traceback:\n{error_trace}") for ctx in batch_contexts: ctx.skip_run = True ctx.error = str(e) ctx.error_traceback = error_trace batch_log_item["status"] = "critical_error" batch_log_item["exception"] = str(e) batch_log_item["traceback"] = error_trace batch_log_item["data"]["processing_time"] = batch_processing_time logging.critical("Terminating due to background removal processing failure") sys.exit(1) batch_number += 1 overall_images_per_second = total_processed / cumulative_time if cumulative_time > 0 else 0 overall_mp_per_second = total_megapixels / cumulative_time if cumulative_time > 0 else 0 avg_time_per_image = cumulative_time / total_processed if total_processed > 0 else 0 logging.log(LOG_LEVEL_MAP["SUCCESS"], f"{EMOJI_MAP['SUCCESS']} {function_name} - Final Summary:") logging.log(LOG_LEVEL_MAP["INFO"], f" - Total images processed: {total_processed} / {len(valid_contexts)}") logging.log(LOG_LEVEL_MAP["INFO"], f" - Total time: {cumulative_time:.2f}s | Inference time: {cumulative_inference_time:.2f}s") logging.log(LOG_LEVEL_MAP["INFO"], f" - Overall performance: {overall_images_per_second:.2f} img/s | {overall_mp_per_second:.2f} MP/s") logging.log(LOG_LEVEL_MAP["INFO"], f" - Average time per image: {avg_time_per_image:.3f}s") batch_logs.append({ "function": f"{function_name}_summary", "status": "info", "data": { "total_processed": total_processed, "total_contexts": len(valid_contexts), "total_batches": batch_number - 1, "size_groups": {str(k): len(v) for k, v in size_groups.items()}, "total_time": cumulative_time, "total_inference_time": cumulative_inference_time, "overall_images_per_second": overall_images_per_second, "overall_megapixels_per_second": overall_mp_per_second, "average_time_per_image": avg_time_per_image, "optimizations_applied": { "fp16_inference": not RMBG_FULL_PRECISION and DEVICE == "cuda", "matmul_precision": "high", "tf32_enabled": torch.backends.cuda.matmul.allow_tf32 if DEVICE == "cuda" and torch.cuda.is_available() else False, "cudnn_benchmark": torch.backends.cudnn.benchmark if DEVICE == "cuda" and torch.cuda.is_available() else False } } }) return batch_logs def select_largest_component(contexts, batch_logs): function_name = "select_largest_component" processed_count = 0 skipped_count = 0 error_count = 0 for ctx in contexts: if ctx.skip_run or ctx.skip_processing: skipped_count += 1 continue if "original" not in ctx.pil_img: error_count += 1 continue try: final_rgba, filename, meta = ctx.pil_img["original"] alpha_np = np.array(final_rgba.getchannel("A")) labeled, num_components = label(alpha_np > 0) if num_components <= 1: processed_count += 1 continue regs = find_objects(labeled) best_component = None best_area = 0 for reg in regs: if reg is None: continue sy, ey = reg[0].start, reg[0].stop sx, ex = reg[1].start, reg[1].stop area = np.count_nonzero(alpha_np[sy:ey, sx:ex] > 0) if area > best_area: center_x = (sx + ex) / 2.0 if center_x < final_rgba.width / 2 or best_area == 0: best_area = area best_component = (sx, sy, ex, ey) if best_component: sx, sy, ex, ey = best_component final_image = final_rgba.crop((sx, sy, ex, ey)) ctx.pil_img["original"] = [final_image, filename, meta] processed_count += 1 except Exception as e: error_count += 1 ctx.skip_run = True batch_logs.append({ "function": function_name, "status": "info", "data": { "processed_count": processed_count, "skipped_count": skipped_count, "error_count": error_count } }) return batch_logs def final_pad_sq_batch(contexts, batch_logs): function_name = "final_pad_sq_batch" processed_count = 0 skipped_count = 0 error_count = 0 for ctx in contexts: if ctx.skip_run or ctx.skip_processing: skipped_count += 1 continue if "original" not in ctx.pil_img: error_count += 1 continue try: image, filename, meta = ctx.pil_img["original"] padded_image = final_pad_sq(image) ctx.pil_img["original"] = [padded_image, filename, meta] processed_count += 1 except Exception as e: error_count += 1 batch_logs.append({ "function": function_name, "image_url": ctx.url, "status": "error", "exception": str(e) }) batch_logs.append({ "function": function_name, "status": "info", "data": { "processed_count": processed_count, "skipped_count": skipped_count, "error_count": error_count } }) return batch_logs # ---------------------------------------------------------------------- # MAIN PIPELINE FUNCTION # ---------------------------------------------------------------------- def _ensure_models_loaded(): import app app.ensure_models_loaded() pipeline_step = create_pipeline_step(_ensure_models_loaded) @pipeline_step def remove_background( contexts: List[ProcessingContext], batch_logs: List[dict] | None = None ): if batch_logs is None: batch_logs = [] pipeline_start_time = time.perf_counter() calibration_info = { "function": "calibration_info", "status": "info", "data": { "version": CALIBRATION_VERSION, "contrast_factor": RBC_CONTRAST_FACTOR, "sharpness_factor": RBC_SHARPNESS_FACTOR, "size_configs": SIZE_CONFIGS, "adaptive_scale_configs": ADAPTIVE_SCALE_CONFIGS, "threshold": THRESH, "max_images_per_batch": MAX_IMAGES_PER_BATCH, "morph_kernel_size": MORPH_KERNEL_SIZE, "morph_close_iter": MORPH_CLOSE_ITER, "morph_open_iter": MORPH_OPEN_ITER, "erosion_iter": EROSION_ITER, "gaussian_kernel_size": GAUSSIAN_KERNEL_SIZE, "do_guided_filter": DO_GUIDED_FILTER, "fill_holes": FILL_HOLES, "use_bilateral": USE_BILATERAL, "low_load_surface_threshold": LOW_LOAD_SURFACE_THRESHOLD, "medium_load_surface_threshold": MEDIUM_LOAD_SURFACE_THRESHOLD, "high_load_surface_threshold": HIGH_LOAD_SURFACE_THRESHOLD } } batch_logs.append(calibration_info) logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Starting remove_background pipeline for {len(contexts)} items") valid_contexts = [ctx for ctx in contexts if not (ctx.skip_run or ctx.skip_processing) and hasattr(ctx, '_download_content')] total_batches = (len(valid_contexts) + MAX_IMAGES_PER_BATCH - 1) // MAX_IMAGES_PER_BATCH if valid_contexts else 0 batch_processing_log = { "function": "batch_processing_info", "status": "info", "data": { "total_contexts": len(contexts), "valid_contexts": len(valid_contexts), "max_images_per_batch": MAX_IMAGES_PER_BATCH, "estimated_batches": total_batches, "adaptive_scaling_enabled": len(valid_contexts) >= MEDIUM_LOAD_SURFACE_THRESHOLD, "next_step": "preprocess_images_batch" } } batch_logs.append(batch_processing_log) logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Processing {len(valid_contexts)} valid contexts in {total_batches} batches") preprocess_images_batch(contexts, batch_logs) process_background_removal(contexts, batch_logs) batch_completion_log = { "function": "background_removal_completed", "status": "info", "data": { "completed_contexts": len([ctx for ctx in contexts if "original" in ctx.pil_img]), "total_contexts": len(contexts), "next_step": "select_largest_component" } } batch_logs.append(batch_completion_log) select_largest_component(contexts, batch_logs) final_pad_sq_batch(contexts, batch_logs) total_pipeline_time = round(time.perf_counter() - pipeline_start_time, 4) logging.log(LOG_LEVEL_MAP["SUCCESS"], f"{EMOJI_MAP['SUCCESS']} Completed remove_background pipeline for {len(contexts)} items in {total_pipeline_time}s") return batch_logs