import spaces import os import pickle from time import perf_counter import tempfile import cv2 import gradio as gr import numpy as np import torch from PIL import Image from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler from utils.drag import bi_warp __all__ = [ 'clear_all', 'resize', 'visualize_user_drag', 'preview_out_image', 'inpaint', 'add_point', 'undo_point', 'clear_point', ] # Global variables for lazy loading pipe = None # UI functions def clear_all(length): """Reset UI by clearing all input images and parameters.""" return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 5, None) def resize(canvas, gen_length, canvas_length): """Resize canvas while maintaining aspect ratio.""" if not canvas: return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 result = process_canvas(canvas) if result[0] is None: # Check if image is None return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 image = result[0] aspect_ratio = image.shape[1] / image.shape[0] is_landscape = aspect_ratio >= 1 new_dims = ( (gen_length, round(gen_length / aspect_ratio / 8) * 8) if is_landscape else (round(gen_length * aspect_ratio / 8) * 8, gen_length) ) canvas_dims = ( (canvas_length, round(canvas_length / aspect_ratio)) if is_landscape else (round(canvas_length * aspect_ratio), canvas_length) ) return (gr.Image(value=cv2.resize(image, new_dims), width=canvas_dims[0], height=canvas_dims[1]),) * 3 def process_canvas(canvas): """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object.""" # Handle None canvas if canvas is None: return None, None # Handle new ImageEditor format if isinstance(canvas, dict): if 'background' in canvas and 'layers' in canvas: # New ImageEditor format if canvas["background"] is None: return None, None image = canvas["background"].copy() # Ensure image is 3-channel RGB if len(image.shape) == 3 and image.shape[2] == 4: image = image[:, :, :3] # Remove alpha channel elif len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # Try to extract mask from layers mask = np.zeros(image.shape[:2], dtype=np.uint8) if canvas["layers"]: for layer in canvas["layers"]: if isinstance(layer, np.ndarray) and len(layer.shape) >= 2: layer_mask = np.uint8(layer[:, :, 0] > 0) if len(layer.shape) == 3 else np.uint8(layer > 0) mask = np.logical_or(mask, layer_mask).astype(np.uint8) elif 'image' in canvas and 'mask' in canvas: # Old format if canvas["image"] is None: return None, None image = canvas["image"].copy() # Ensure image is 3-channel RGB if len(image.shape) == 3 and image.shape[2] == 4: image = image[:, :, :3] # Remove alpha channel elif len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy() if canvas["mask"] is not None else np.zeros(image.shape[:2], dtype=np.uint8) else: # Fallback return None, None else: # Direct numpy array if canvas is None: return None, None image = canvas.copy() if isinstance(canvas, np.ndarray) else np.array(canvas) # Ensure image is 3-channel RGB if len(image.shape) == 3 and image.shape[2] == 4: image = image[:, :, :3] # Remove alpha channel elif len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) mask = np.zeros(image.shape[:2], dtype=np.uint8) return image, mask # Point manipulation functions def add_point(canvas, points, inpaint_ks, evt: gr.SelectData): """Add selected point to points list and update image.""" if canvas is None: return None points.append(evt.index) return visualize_user_drag(canvas, points) def undo_point(canvas, points, inpaint_ks): """Remove last point and update image.""" if canvas is None: return None if len(points) > 0: points.pop() return visualize_user_drag(canvas, points) def clear_point(canvas, points, inpaint_ks): """Clear all points and update image.""" if canvas is None: return None points.clear() return visualize_user_drag(canvas, points) # Visualization tools def visualize_user_drag(canvas, points): """Visualize control points and motion vectors on the input image.""" if canvas is None: return None result = process_canvas(canvas) if result[0] is None: # Check if image is None return None image, mask = result # Ensure image is uint8 and 3-channel if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) if len(image.shape) != 3 or image.shape[2] != 3: return None # Apply colored mask overlay result_img = image.copy() if np.any(mask == 1): result_img[mask == 1] = [255, 0, 0] # Red color image = cv2.addWeighted(result_img, 0.3, image, 0.7, 0) # Draw mask outline if np.any(mask > 0): contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image, contours, -1, (255, 255, 255), 2) # Draw control points and motion vectors prev_point = None for idx, point in enumerate(points, 1): if idx % 2 == 0: cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point if prev_point is not None: cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5) else: cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point prev_point = point return image def preview_out_image(canvas, points, inpaint_ks): """Preview warped image result and generate inpainting mask.""" if canvas is None: return None, None result = process_canvas(canvas) if result[0] is None: # Check if image is None return None, None image, mask = result # Ensure image is uint8 and 3-channel if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) if len(image.shape) != 3 or image.shape[2] != 3: return image, None if len(points) < 2: return image, None # ensure H, W divisible by 8 and longer edge 512 shapes_valid = all(s % 8 == 0 for s in mask.shape + image.shape[:2]) size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask)) if not (shapes_valid and size_valid): gr.Warning('Click Resize Image Button first.') return image, None try: handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks) image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]] # Add grid pattern to highlight inpainting regions background = np.ones_like(mask) * 255 background[::10] = background[:, ::10] = 0 image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image) return image, (inpaint_mask * 255).astype(np.uint8) except Exception as e: gr.Warning(f"Preview failed: {str(e)}") return image, None # Inpaint tools @spaces.GPU def setup_pipeline(device='cuda', model_version='v1-5'): """Initialize optimized inpainting pipeline with specified model configuration.""" MODEL_CONFIGS = { 'v1-5': ('runwayml/stable-diffusion-inpainting', 'latent-consistency/lcm-lora-sdv1-5', 'madebyollin/taesd'), 'xl': ('diffusers/stable-diffusion-xl-1.0-inpainting-0.1', 'latent-consistency/lcm-lora-sdxl', 'madebyollin/taesdxl') } model_id, lora_id, vae_id = MODEL_CONFIGS[model_version] # Check if CUDA is available, fallback to CPU if not torch.cuda.is_available(): device = 'cpu' torch_dtype = torch.float32 variant = None else: torch_dtype = torch.float16 variant = "fp16" gr.Info('Loading inpainting pipeline...') pipe = AutoPipelineForInpainting.from_pretrained( model_id, torch_dtype=torch_dtype, variant=variant, safety_checker=None ) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(lora_id) pipe.fuse_lora() pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch_dtype) pipe = pipe.to(device) # Pre-compute prompt embeddings during setup if model_version == 'v1-5': pipe.cached_prompt_embeds = pipe.encode_prompt( '', device=device, num_images_per_prompt=1, do_classifier_free_guidance=False)[0] else: pipe.cached_prompt_embeds, pipe.cached_pooled_prompt_embeds = pipe.encode_prompt( '', device=device, num_images_per_prompt=1, do_classifier_free_guidance=False)[0::2] return pipe def get_pipeline(): """Lazy load pipeline only when needed.""" global pipe if pipe is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' pipe = setup_pipeline(device=device, model_version='v1-5') if device == 'cuda': pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0] else: pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cpu', 1, False)[0] return pipe @spaces.GPU def inpaint(image, inpaint_mask): """Perform efficient inpainting on masked regions using Stable Diffusion.""" if image is None: return None if inpaint_mask is None: return image start = perf_counter() # Get pipeline (lazy loading) pipe = get_pipeline() pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5' inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0 # Convert inputs to PIL image_pil = Image.fromarray(image) inpaint_mask_pil = Image.fromarray(inpaint_mask) width, height = inpaint_mask_pil.size if width % 8 != 0 or height % 8 != 0: width, height = round(width / 8) * 8, round(height / 8) * 8 image_pil = image_pil.resize((width, height)) image = np.array(image_pil) inpaint_mask_pil = inpaint_mask_pil.resize((width, height), Image.NEAREST) inpaint_mask = np.array(inpaint_mask_pil) # Common pipeline parameters common_params = { 'image': image_pil, 'mask_image': inpaint_mask_pil, 'height': height, 'width': width, 'guidance_scale': 1.0, 'num_inference_steps': 8, 'strength': inpaint_strength, 'output_type': 'np' } # Run pipeline try: if pipe_id == 'v1-5': inpainted = pipe( prompt_embeds=pipe.cached_prompt_embeds, **common_params ).images[0] else: inpainted = pipe( prompt_embeds=pipe.cached_prompt_embeds, pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds, **common_params ).images[0] except Exception as e: gr.Warning(f"Inpainting failed: {str(e)}") return image # Post-process results inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8) return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)