# --- START OF FILE app.py (Hugging Face Space code) --- import torch import numpy as np from PIL import Image, ImageDraw # Added ImageDraw for potential visualization within Gradio UI import gradio as gr from transformers import CLIPProcessor, CLIPModel from segment_anything import sam_model_registry, SamAutomaticMaskGenerator import traceback # For better error logging # --- Check for CUDA and print status --- if torch.cuda.is_available(): device = "cuda" print("CUDA is available. Using GPU.") else: device = "cpu" print("CUDA not available. Using CPU.") # --- End Check --- # --- Model Loading (add error handling) --- sam = None mask_generator = None clip_model = None clip_processor = None try: print("Loading CLIP model...") clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval() clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") print("CLIP model loaded.") print("Loading SAM model...") # Use a smaller/faster model if performance is an issue, e.g., vit_t, if available and checkpoint exists # sam_checkpoint = "sam_vit_t.pth" # Example for tiny model # sam_model_type = "vit_t" sam_checkpoint = "sam_vit_b_01ec64.pth" # Original base model sam_model_type = "vit_b" sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint).to(device).eval() # You might adjust SamAutomaticMaskGenerator parameters if needed # points_per_side=32, pred_iou_thresh=0.88, stability_score_thresh=0.95, crop_n_layers=0, crop_n_points_downscale_factor=1, min_mask_region_area=0 mask_generator = SamAutomaticMaskGenerator(sam) print("SAM model loaded.") except Exception as e: print(f"FATAL: Error loading models: {e}") print(traceback.format_exc()) # If models fail to load, the app shouldn't run exit() # --- End Model Loading --- # Convert PIL image to numpy def pil_to_np(pil_img): return np.array(pil_img.convert("RGB")) # Convert numpy array to PIL image def np_to_pil(np_img): if np_img.dtype != np.uint8: if np_img.max() <= 1.0 and np_img.min() >= 0.0: np_img = (np_img * 255).astype(np.uint8) else: np_img = np.clip(np_img, 0, 255).astype(np.uint8) return Image.fromarray(np_img) # --- EXISTING FUNCTION (for unmasking single prompt) --- def clip_guided_unmask(original_img, revealed_img, text_prompt): # --- (Function content remains exactly the same as before) --- if original_img is None: print("Error: Original image is required for unmasking.") return revealed_img if not isinstance(original_img, Image.Image): print(f"Error: original_img is not a PIL Image, type is {type(original_img)}") try: original_img = Image.open(original_img).convert("RGB") except Exception: return revealed_img if revealed_img is None: print("No revealed image provided, creating a black canvas.") revealed_img = Image.new("RGB", original_img.size, color="black") elif not isinstance(revealed_img, Image.Image): print(f"Error: revealed_img is not a PIL Image, type is {type(revealed_img)}") try: revealed_img = Image.open(revealed_img).convert("RGB") except Exception: print("Falling back to black canvas for revealed image.") revealed_img = Image.new("RGB", original_img.size, color="black") print(f"Processing unmask request for prompt: '{text_prompt}'") try: np_orig = pil_to_np(original_img) np_reveal = pil_to_np(revealed_img) print(f"Original image shape: {np_orig.shape}, Revealed image shape: {np_reveal.shape}") if np_orig.shape != np_reveal.shape: print(f"Warning: Shapes mismatch. Resizing revealed {np_reveal.shape} to {np_orig.shape}") revealed_img = revealed_img.resize(original_img.size) np_reveal = pil_to_np(revealed_img) print("Generating masks with SAM...") if np_orig.dtype != np.uint8: print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})") np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8) else: np_orig_sam = np_orig masks = mask_generator.generate(np_orig_sam) if not masks: print("SAM did not generate any masks.") return revealed_img print(f"Generated {len(masks)} masks.") print("Processing text prompt with CLIP...") prompt_for_clip = text_prompt if text_prompt else "object" # Handle empty prompt text_inputs = clip_processor(text=[prompt_for_clip], return_tensors="pt", padding=True).to(device) with torch.no_grad(): text_feat = clip_model.get_text_features(**text_inputs) text_feat /= text_feat.norm(p=2, dim=-1, keepdim=True) print("Text features generated.") best_score = -float('inf') best_mask_info = None print("Calculating CLIP scores for masks...") for i, m in enumerate(masks): seg = m["segmentation"] masked_for_clip = np_orig.copy() masked_for_clip[~seg] = 0 pil_masked_for_clip = np_to_pil(masked_for_clip) inputs = clip_processor(images=pil_masked_for_clip, return_tensors="pt").to(device) with torch.no_grad(): image_feat = clip_model.get_image_features(**inputs) image_feat /= image_feat.norm(p=2, dim=-1, keepdim=True) sim = (image_feat @ text_feat.T).item() if sim > best_score: best_score = sim best_mask_info = m print(f"Best score found: {best_score:.4f}") if best_mask_info is not None: bbox = best_mask_info['bbox'] x, y, w, h = map(int, bbox) y_start = max(0, y) y_end = min(np_orig.shape[0], y + h) x_start = max(0, x) x_end = min(np_orig.shape[1], x + w) print(f"Applying best mask's bounding box: [{x_start}:{x_end}, {y_start}:{y_end}]") if y_end > y_start and x_end > x_start: np_reveal[y_start:y_end, x_start:x_end] = np_orig[y_start:y_end, x_start:x_end] else: print(f"Warning: Invalid bounding box dimensions calculated ({w}x{h} at {x},{y}). Skipping reveal.") else: print("No suitable mask found based on the prompt.") final_revealed_pil = np_to_pil(np_reveal) print("Unmask processing complete.") return final_revealed_pil except Exception as e: print(f"Error during clip_guided_unmask: {e}") print(traceback.format_exc()) return revealed_img # --- NEW FUNCTION (to get all SAM bounding boxes) --- def get_all_sam_bboxes(original_img): """ Generates all masks for an image using SAM and returns their bounding boxes. Input: PIL Image Output: List of bounding boxes [[x, y, w, h], ...] or None on error """ if original_img is None: print("Error: Original image is required for getting bboxes.") return None # Return None or empty list to indicate error # Ensure input is PIL Image if not isinstance(original_img, Image.Image): print(f"Error: get_all_sam_bboxes expects a PIL Image, got {type(original_img)}") try: original_img = Image.open(original_img).convert("RGB") except Exception as e: print(f"Error converting input to PIL Image: {e}") return None print("Processing request to get all SAM bounding boxes...") try: np_orig = pil_to_np(original_img) print(f"Original image shape for bbox generation: {np_orig.shape}") # Ensure uint8 for SAM if np_orig.dtype != np.uint8: print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})") np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8) else: np_orig_sam = np_orig print("Generating masks with SAM...") masks = mask_generator.generate(np_orig_sam) if not masks: print("SAM did not generate any masks.") return [] # Return empty list if no masks print(f"Generated {len(masks)} masks.") # Extract bounding boxes [x, y, w, h] bboxes = [m['bbox'] for m in masks if 'bbox' in m] # Ensure all elements are standard Python ints/floats for JSON serialization bboxes_serializable = [[int(b[0]), int(b[1]), int(b[2]), int(b[3])] for b in bboxes] print(f"Extracted {len(bboxes_serializable)} bounding boxes.") return bboxes_serializable # Return the list of boxes except Exception as e: print(f"Error during get_all_sam_bboxes: {e}") print(traceback.format_exc()) return None # Indicate error # --- Gradio Interface using Blocks to support multiple API endpoints --- print("Setting up Gradio interface using Blocks...") with gr.Blocks() as demo: gr.Markdown("# CLIP-SAM Guided Unmasking and BBox Extraction") with gr.Tab("Interactive Unmasking"): with gr.Row(): img_input_unmask = gr.Image(type="pil", label="Original Image") img_revealed_input = gr.Image(type="pil", label="Current Revealed Image (leave empty on first run)") img_output_unmask = gr.Image(type="pil", label="Updated Reveal") prompt_input_unmask = gr.Textbox(label="Text Prompt (e.g., 'a red car', 'the dog')") unmask_button = gr.Button("Unmask Prompt") with gr.Tab("Get All Bounding Boxes"): with gr.Row(): img_input_bbox = gr.Image(type="pil", label="Original Image") # Output for visualization in UI (optional) img_output_bbox_viz = gr.Image(type="pil", label="Image with All BBoxes (Visualization)") # Output for API call (JSON) json_output_bbox = gr.JSON(label="Bounding Boxes ([x, y, w, h])") bbox_button = gr.Button("Get All SAM Bounding Boxes") # --- Define API endpoints --- # Endpoint for the interactive unmasking function unmask_button.click( fn=clip_guided_unmask, inputs=[img_input_unmask, img_revealed_input, prompt_input_unmask], outputs=img_output_unmask, api_name="predict" # Keep the original API name for compatibility ) # Helper function to draw boxes for the UI visualization part def draw_boxes_on_image_for_ui(original_img): bboxes = get_all_sam_bboxes(original_img) if bboxes is None or not bboxes: # Return original image or error message if desired print("No bounding boxes found or error occurred.") return original_img, [] # Return original image and empty list img_copy = original_img.copy() draw = ImageDraw.Draw(img_copy) print(f"Drawing {len(bboxes)} boxes for UI preview...") for bbox in bboxes: x, y, w, h = map(int, bbox) x1, y1 = x + w, y + h # Draw rectangle outline [x0, y0, x1, y1] draw.rectangle([x, y, x1, y1], outline="red", width=2) print("Finished drawing boxes for UI.") return img_copy, bboxes # Return image with boxes AND the json data # Endpoint for getting all bounding boxes (and visualizing in UI) bbox_button.click( fn=draw_boxes_on_image_for_ui, # Use helper that calls get_all_sam_bboxes inputs=[img_input_bbox], outputs=[img_output_bbox_viz, json_output_bbox], # Output to both components api_name="get_boxes" # New API name for this function ) # Launch the interface print("Launching Gradio interface...") # Consider adding share=False and potentially debug=True locally if needed # demo.launch(share=True) # Use share=True if needed for external access like from your local script demo.launch() print("Interface launched. Check the output URL.") # --- END OF FILE app.py ---