Image_Masking / app.py
luvmelo's picture
Update app.py
0b5d529 verified
# --- START OF FILE app.py (Hugging Face Space code) ---
import torch
import numpy as np
from PIL import Image, ImageDraw # Added ImageDraw for potential visualization within Gradio UI
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import traceback # For better error logging
# --- Check for CUDA and print status ---
if torch.cuda.is_available():
device = "cuda"
print("CUDA is available. Using GPU.")
else:
device = "cpu"
print("CUDA not available. Using CPU.")
# --- End Check ---
# --- Model Loading (add error handling) ---
sam = None
mask_generator = None
clip_model = None
clip_processor = None
try:
print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("CLIP model loaded.")
print("Loading SAM model...")
# Use a smaller/faster model if performance is an issue, e.g., vit_t, if available and checkpoint exists
# sam_checkpoint = "sam_vit_t.pth" # Example for tiny model
# sam_model_type = "vit_t"
sam_checkpoint = "sam_vit_b_01ec64.pth" # Original base model
sam_model_type = "vit_b"
sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint).to(device).eval()
# You might adjust SamAutomaticMaskGenerator parameters if needed
# points_per_side=32, pred_iou_thresh=0.88, stability_score_thresh=0.95, crop_n_layers=0, crop_n_points_downscale_factor=1, min_mask_region_area=0
mask_generator = SamAutomaticMaskGenerator(sam)
print("SAM model loaded.")
except Exception as e:
print(f"FATAL: Error loading models: {e}")
print(traceback.format_exc())
# If models fail to load, the app shouldn't run
exit()
# --- End Model Loading ---
# Convert PIL image to numpy
def pil_to_np(pil_img):
return np.array(pil_img.convert("RGB"))
# Convert numpy array to PIL image
def np_to_pil(np_img):
if np_img.dtype != np.uint8:
if np_img.max() <= 1.0 and np_img.min() >= 0.0:
np_img = (np_img * 255).astype(np.uint8)
else:
np_img = np.clip(np_img, 0, 255).astype(np.uint8)
return Image.fromarray(np_img)
# --- EXISTING FUNCTION (for unmasking single prompt) ---
def clip_guided_unmask(original_img, revealed_img, text_prompt):
# --- (Function content remains exactly the same as before) ---
if original_img is None:
print("Error: Original image is required for unmasking.")
return revealed_img
if not isinstance(original_img, Image.Image):
print(f"Error: original_img is not a PIL Image, type is {type(original_img)}")
try: original_img = Image.open(original_img).convert("RGB")
except Exception: return revealed_img
if revealed_img is None:
print("No revealed image provided, creating a black canvas.")
revealed_img = Image.new("RGB", original_img.size, color="black")
elif not isinstance(revealed_img, Image.Image):
print(f"Error: revealed_img is not a PIL Image, type is {type(revealed_img)}")
try: revealed_img = Image.open(revealed_img).convert("RGB")
except Exception:
print("Falling back to black canvas for revealed image.")
revealed_img = Image.new("RGB", original_img.size, color="black")
print(f"Processing unmask request for prompt: '{text_prompt}'")
try:
np_orig = pil_to_np(original_img)
np_reveal = pil_to_np(revealed_img)
print(f"Original image shape: {np_orig.shape}, Revealed image shape: {np_reveal.shape}")
if np_orig.shape != np_reveal.shape:
print(f"Warning: Shapes mismatch. Resizing revealed {np_reveal.shape} to {np_orig.shape}")
revealed_img = revealed_img.resize(original_img.size)
np_reveal = pil_to_np(revealed_img)
print("Generating masks with SAM...")
if np_orig.dtype != np.uint8:
print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})")
np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8)
else:
np_orig_sam = np_orig
masks = mask_generator.generate(np_orig_sam)
if not masks:
print("SAM did not generate any masks.")
return revealed_img
print(f"Generated {len(masks)} masks.")
print("Processing text prompt with CLIP...")
prompt_for_clip = text_prompt if text_prompt else "object" # Handle empty prompt
text_inputs = clip_processor(text=[prompt_for_clip], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_feat = clip_model.get_text_features(**text_inputs)
text_feat /= text_feat.norm(p=2, dim=-1, keepdim=True)
print("Text features generated.")
best_score = -float('inf')
best_mask_info = None
print("Calculating CLIP scores for masks...")
for i, m in enumerate(masks):
seg = m["segmentation"]
masked_for_clip = np_orig.copy()
masked_for_clip[~seg] = 0
pil_masked_for_clip = np_to_pil(masked_for_clip)
inputs = clip_processor(images=pil_masked_for_clip, return_tensors="pt").to(device)
with torch.no_grad():
image_feat = clip_model.get_image_features(**inputs)
image_feat /= image_feat.norm(p=2, dim=-1, keepdim=True)
sim = (image_feat @ text_feat.T).item()
if sim > best_score:
best_score = sim
best_mask_info = m
print(f"Best score found: {best_score:.4f}")
if best_mask_info is not None:
bbox = best_mask_info['bbox']
x, y, w, h = map(int, bbox)
y_start = max(0, y)
y_end = min(np_orig.shape[0], y + h)
x_start = max(0, x)
x_end = min(np_orig.shape[1], x + w)
print(f"Applying best mask's bounding box: [{x_start}:{x_end}, {y_start}:{y_end}]")
if y_end > y_start and x_end > x_start:
np_reveal[y_start:y_end, x_start:x_end] = np_orig[y_start:y_end, x_start:x_end]
else:
print(f"Warning: Invalid bounding box dimensions calculated ({w}x{h} at {x},{y}). Skipping reveal.")
else:
print("No suitable mask found based on the prompt.")
final_revealed_pil = np_to_pil(np_reveal)
print("Unmask processing complete.")
return final_revealed_pil
except Exception as e:
print(f"Error during clip_guided_unmask: {e}")
print(traceback.format_exc())
return revealed_img
# --- NEW FUNCTION (to get all SAM bounding boxes) ---
def get_all_sam_bboxes(original_img):
"""
Generates all masks for an image using SAM and returns their bounding boxes.
Input: PIL Image
Output: List of bounding boxes [[x, y, w, h], ...] or None on error
"""
if original_img is None:
print("Error: Original image is required for getting bboxes.")
return None # Return None or empty list to indicate error
# Ensure input is PIL Image
if not isinstance(original_img, Image.Image):
print(f"Error: get_all_sam_bboxes expects a PIL Image, got {type(original_img)}")
try: original_img = Image.open(original_img).convert("RGB")
except Exception as e:
print(f"Error converting input to PIL Image: {e}")
return None
print("Processing request to get all SAM bounding boxes...")
try:
np_orig = pil_to_np(original_img)
print(f"Original image shape for bbox generation: {np_orig.shape}")
# Ensure uint8 for SAM
if np_orig.dtype != np.uint8:
print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})")
np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8)
else:
np_orig_sam = np_orig
print("Generating masks with SAM...")
masks = mask_generator.generate(np_orig_sam)
if not masks:
print("SAM did not generate any masks.")
return [] # Return empty list if no masks
print(f"Generated {len(masks)} masks.")
# Extract bounding boxes [x, y, w, h]
bboxes = [m['bbox'] for m in masks if 'bbox' in m]
# Ensure all elements are standard Python ints/floats for JSON serialization
bboxes_serializable = [[int(b[0]), int(b[1]), int(b[2]), int(b[3])] for b in bboxes]
print(f"Extracted {len(bboxes_serializable)} bounding boxes.")
return bboxes_serializable # Return the list of boxes
except Exception as e:
print(f"Error during get_all_sam_bboxes: {e}")
print(traceback.format_exc())
return None # Indicate error
# --- Gradio Interface using Blocks to support multiple API endpoints ---
print("Setting up Gradio interface using Blocks...")
with gr.Blocks() as demo:
gr.Markdown("# CLIP-SAM Guided Unmasking and BBox Extraction")
with gr.Tab("Interactive Unmasking"):
with gr.Row():
img_input_unmask = gr.Image(type="pil", label="Original Image")
img_revealed_input = gr.Image(type="pil", label="Current Revealed Image (leave empty on first run)")
img_output_unmask = gr.Image(type="pil", label="Updated Reveal")
prompt_input_unmask = gr.Textbox(label="Text Prompt (e.g., 'a red car', 'the dog')")
unmask_button = gr.Button("Unmask Prompt")
with gr.Tab("Get All Bounding Boxes"):
with gr.Row():
img_input_bbox = gr.Image(type="pil", label="Original Image")
# Output for visualization in UI (optional)
img_output_bbox_viz = gr.Image(type="pil", label="Image with All BBoxes (Visualization)")
# Output for API call (JSON)
json_output_bbox = gr.JSON(label="Bounding Boxes ([x, y, w, h])")
bbox_button = gr.Button("Get All SAM Bounding Boxes")
# --- Define API endpoints ---
# Endpoint for the interactive unmasking function
unmask_button.click(
fn=clip_guided_unmask,
inputs=[img_input_unmask, img_revealed_input, prompt_input_unmask],
outputs=img_output_unmask,
api_name="predict" # Keep the original API name for compatibility
)
# Helper function to draw boxes for the UI visualization part
def draw_boxes_on_image_for_ui(original_img):
bboxes = get_all_sam_bboxes(original_img)
if bboxes is None or not bboxes:
# Return original image or error message if desired
print("No bounding boxes found or error occurred.")
return original_img, [] # Return original image and empty list
img_copy = original_img.copy()
draw = ImageDraw.Draw(img_copy)
print(f"Drawing {len(bboxes)} boxes for UI preview...")
for bbox in bboxes:
x, y, w, h = map(int, bbox)
x1, y1 = x + w, y + h
# Draw rectangle outline [x0, y0, x1, y1]
draw.rectangle([x, y, x1, y1], outline="red", width=2)
print("Finished drawing boxes for UI.")
return img_copy, bboxes # Return image with boxes AND the json data
# Endpoint for getting all bounding boxes (and visualizing in UI)
bbox_button.click(
fn=draw_boxes_on_image_for_ui, # Use helper that calls get_all_sam_bboxes
inputs=[img_input_bbox],
outputs=[img_output_bbox_viz, json_output_bbox], # Output to both components
api_name="get_boxes" # New API name for this function
)
# Launch the interface
print("Launching Gradio interface...")
# Consider adding share=False and potentially debug=True locally if needed
# demo.launch(share=True) # Use share=True if needed for external access like from your local script
demo.launch()
print("Interface launched. Check the output URL.")
# --- END OF FILE app.py ---