# app.py import gradio as gr from PIL import Image, ImageDraw import torch import numpy as np from transformers import SamModel, SamProcessor from diffusers import StableDiffusionInpaintPipeline # Constants IMG_SIZE = 512 # Initialize SAM model and processor on CPU sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") # Initialize Inpainting pipeline on CPU with a compatible model inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32 ).to("cpu") # No need for model_cpu_offload on CPU # Global variables to store points and the original image input_points = [] input_image = None def mask_to_rgba(mask): """ Converts a binary mask to an RGBA image for visualization. """ bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency return bg_transparent def generate_mask(image, input_points): """ Generates a binary mask using SAM based on input points. Args: image (PIL.Image): The input image. input_points (list of lists): List of points selected by the user. Returns: np.ndarray: Binary mask where the object is marked with 1s. """ if not input_points: return None # Convert image to RGB if not already image = image.convert("RGB") # Flatten the list of points points = [tuple(point) for point in input_points] # Prepare inputs for SAM inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu") with torch.no_grad(): outputs = sam_model(**inputs) # Post-process masks masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) if len(masks) == 0: return None # Select the mask with the highest IoU score best_mask = masks[0][0][outputs.iou_scores.argmax()] # Invert mask: object=1, background=0 binary_mask = ~best_mask.numpy().astype(bool).astype(int) return binary_mask def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale): """ Replaces the selected object in the image based on the prompt. Args: image (PIL.Image): The original image. mask (np.ndarray): Binary mask of the selected object. prompt (str): Text prompt describing the replacement. negative_prompt (str): Negative text prompt to refine generation. seed (int): Random seed for reproducibility. guidance_scale (float): Guidance scale for the inpainting model. Returns: PIL.Image: The augmented image with the object replaced. """ if mask is None: return image mask_image = Image.fromarray((mask * 255).astype(np.uint8)) generator = torch.Generator("cpu").manual_seed(seed) try: result = inpaint_pipeline( prompt=prompt, image=image, mask_image=mask_image, negative_prompt=negative_prompt if negative_prompt else None, generator=generator, guidance_scale=guidance_scale ).images[0] return result except Exception as e: print(f"Inpainting error: {e}") return image def visualize_mask(image, mask): """ Overlays the mask on the image for visualization. Args: image (PIL.Image): The original image. mask (np.ndarray): Binary mask of the selected object. Returns: PIL.Image: Image with mask overlay. """ if mask is None: return image mask_rgba = mask_to_rgba(mask) mask_pil = Image.fromarray(mask_rgba) overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil) return overlay.convert("RGB") def get_points(img, evt: gr.SelectData): """ Captures points selected by the user on the image. Args: img (PIL.Image): The uploaded image. evt (gr.SelectData): Event data containing the point coordinates. Returns: Tuple: (Updated mask visualization, Updated image with crossmarks) """ global input_points global input_image # The first time this is called, save the untouched input image if len(input_points) == 0: input_image = img.copy() x = evt.index[0] y = evt.index[1] input_points.append([x, y]) # Run SAM to generate mask mask = generate_mask(input_image, input_points) # Mark selected points with a green crossmark draw = ImageDraw.Draw(img) size = 10 for point in input_points: px, py = point draw.line((px - size, py, px + size, py), fill="green", width=5) draw.line((px, py - size, px, py + size), fill="green", width=5) # Visualize the mask overlay masked_image = visualize_mask(input_image, mask) return masked_image, img def run_inpaint(prompt, negative_prompt, cfg, seed, invert): """ Runs the inpainting process based on user inputs. Args: prompt (str): Prompt for infill. negative_prompt (str): Negative prompt. cfg (float): Classifier-Free Guidance Scale. seed (int): Random seed. invert (bool): Whether to infill the subject instead of the background. Returns: PIL.Image: The inpainted image. """ global input_image global input_points if input_image is None or len(input_points) == 0: raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.") mask = generate_mask(input_image, input_points) if invert: what = 'subject' mask = ~mask else: what = 'background' try: inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg) except Exception as e: raise gr.Error(str(e)) return inpainted.resize((IMG_SIZE, IMG_SIZE)) def reset_points_func(): """ Resets the selected points and the input image. Returns: Tuple: (Reset mask visualization, Reset image, Empty inpainted image) """ global input_points global input_image input_points = [] input_image = None return None, None, None def preprocess(input_img): """ Preprocesses the uploaded image to ensure it is square and resized. Args: input_img (PIL.Image): The uploaded image. Returns: PIL.Image: The preprocessed image. """ if input_img is None: return None # Make sure the image is square width, height = input_img.size if width != height: # Add white padding to make the image square new_size = max(width, height) new_image = Image.new("RGB", (new_size, new_size), 'white') left = (new_size - width) // 2 top = (new_size - height) // 2 new_image.paste(input_img, (left, top)) input_img = new_image return input_img.resize((IMG_SIZE, IMG_SIZE)) def build_app(get_processed_inputs, inpaint): """ Builds and launches the Gradio app. Args: get_processed_inputs (function): Function to process inputs for SAM. inpaint (function): Function to perform inpainting. Returns: None """ with gr.Blocks() as demo: gr.Markdown( """ # Object Replacement App Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image. **Instructions:** 1. **Upload Image:** Click on the first image box to upload your image. 2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy. 3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output. 4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed. 5. **Replace Object:** Click the "Replace Object" button to generate the augmented image. 6. **Reset:** Click the "Reset" button to clear selections and start over. """) with gr.Row(): with gr.Column(): # Image upload and point selection upload_image = gr.Image(label="Upload Image", type="pil", interactive=True) mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False) selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False) # Capture points using the select event upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image]) # Preprocess image on change upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image]) # Text inputs and settings prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2) negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2) cfg = gr.Slider( label="Classifier-Free Guidance Scale", minimum=1.0, maximum=20.0, value=7.5, step=0.5 ) seed = gr.Number(label="Seed", value=42, precision=0) invert = gr.Checkbox(label="Infill subject instead of background") # Buttons replace_button = gr.Button("Replace Object") reset_button = gr.Button("Reset") with gr.Column(): # Output images augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False) # Define button actions replace_button.click( fn=run_inpaint, inputs=[prompt, negative_prompt, cfg, seed, invert], outputs=[augmented_image] ) reset_button.click( fn=reset_points_func, inputs=[], outputs=[mask_visualization, selected_image, augmented_image] ) # Examples (optional) gr.Markdown( """ ## EXAMPLES Click on an example to load it. Then, follow the instructions above. """) with gr.Row(): examples = gr.Examples( examples=[ ["car.png", "a red sports car", "blurry, low quality", 42], ["house.jpg", "a modern villa", "dark, overexposed", 123], ["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999] ], inputs=[ upload_image, prompt, negative_prompt, seed ], label="Click to load examples", cache_examples=True ) demo.queue(max_size=10).launch() # Launch the app build_app(None, None)