Spaces:
Sleeping
Sleeping
# --- START OF FILE app.py (Hugging Face Space code) --- | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw # Added ImageDraw for potential visualization within Gradio UI | |
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
import traceback # For better error logging | |
# --- Check for CUDA and print status --- | |
if torch.cuda.is_available(): | |
device = "cuda" | |
print("CUDA is available. Using GPU.") | |
else: | |
device = "cpu" | |
print("CUDA not available. Using CPU.") | |
# --- End Check --- | |
# --- Model Loading (add error handling) --- | |
sam = None | |
mask_generator = None | |
clip_model = None | |
clip_processor = None | |
try: | |
print("Loading CLIP model...") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval() | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("CLIP model loaded.") | |
print("Loading SAM model...") | |
# Use a smaller/faster model if performance is an issue, e.g., vit_t, if available and checkpoint exists | |
# sam_checkpoint = "sam_vit_t.pth" # Example for tiny model | |
# sam_model_type = "vit_t" | |
sam_checkpoint = "sam_vit_b_01ec64.pth" # Original base model | |
sam_model_type = "vit_b" | |
sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint).to(device).eval() | |
# You might adjust SamAutomaticMaskGenerator parameters if needed | |
# points_per_side=32, pred_iou_thresh=0.88, stability_score_thresh=0.95, crop_n_layers=0, crop_n_points_downscale_factor=1, min_mask_region_area=0 | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
print("SAM model loaded.") | |
except Exception as e: | |
print(f"FATAL: Error loading models: {e}") | |
print(traceback.format_exc()) | |
# If models fail to load, the app shouldn't run | |
exit() | |
# --- End Model Loading --- | |
# Convert PIL image to numpy | |
def pil_to_np(pil_img): | |
return np.array(pil_img.convert("RGB")) | |
# Convert numpy array to PIL image | |
def np_to_pil(np_img): | |
if np_img.dtype != np.uint8: | |
if np_img.max() <= 1.0 and np_img.min() >= 0.0: | |
np_img = (np_img * 255).astype(np.uint8) | |
else: | |
np_img = np.clip(np_img, 0, 255).astype(np.uint8) | |
return Image.fromarray(np_img) | |
# --- EXISTING FUNCTION (for unmasking single prompt) --- | |
def clip_guided_unmask(original_img, revealed_img, text_prompt): | |
# --- (Function content remains exactly the same as before) --- | |
if original_img is None: | |
print("Error: Original image is required for unmasking.") | |
return revealed_img | |
if not isinstance(original_img, Image.Image): | |
print(f"Error: original_img is not a PIL Image, type is {type(original_img)}") | |
try: original_img = Image.open(original_img).convert("RGB") | |
except Exception: return revealed_img | |
if revealed_img is None: | |
print("No revealed image provided, creating a black canvas.") | |
revealed_img = Image.new("RGB", original_img.size, color="black") | |
elif not isinstance(revealed_img, Image.Image): | |
print(f"Error: revealed_img is not a PIL Image, type is {type(revealed_img)}") | |
try: revealed_img = Image.open(revealed_img).convert("RGB") | |
except Exception: | |
print("Falling back to black canvas for revealed image.") | |
revealed_img = Image.new("RGB", original_img.size, color="black") | |
print(f"Processing unmask request for prompt: '{text_prompt}'") | |
try: | |
np_orig = pil_to_np(original_img) | |
np_reveal = pil_to_np(revealed_img) | |
print(f"Original image shape: {np_orig.shape}, Revealed image shape: {np_reveal.shape}") | |
if np_orig.shape != np_reveal.shape: | |
print(f"Warning: Shapes mismatch. Resizing revealed {np_reveal.shape} to {np_orig.shape}") | |
revealed_img = revealed_img.resize(original_img.size) | |
np_reveal = pil_to_np(revealed_img) | |
print("Generating masks with SAM...") | |
if np_orig.dtype != np.uint8: | |
print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})") | |
np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8) | |
else: | |
np_orig_sam = np_orig | |
masks = mask_generator.generate(np_orig_sam) | |
if not masks: | |
print("SAM did not generate any masks.") | |
return revealed_img | |
print(f"Generated {len(masks)} masks.") | |
print("Processing text prompt with CLIP...") | |
prompt_for_clip = text_prompt if text_prompt else "object" # Handle empty prompt | |
text_inputs = clip_processor(text=[prompt_for_clip], return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
text_feat = clip_model.get_text_features(**text_inputs) | |
text_feat /= text_feat.norm(p=2, dim=-1, keepdim=True) | |
print("Text features generated.") | |
best_score = -float('inf') | |
best_mask_info = None | |
print("Calculating CLIP scores for masks...") | |
for i, m in enumerate(masks): | |
seg = m["segmentation"] | |
masked_for_clip = np_orig.copy() | |
masked_for_clip[~seg] = 0 | |
pil_masked_for_clip = np_to_pil(masked_for_clip) | |
inputs = clip_processor(images=pil_masked_for_clip, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
image_feat = clip_model.get_image_features(**inputs) | |
image_feat /= image_feat.norm(p=2, dim=-1, keepdim=True) | |
sim = (image_feat @ text_feat.T).item() | |
if sim > best_score: | |
best_score = sim | |
best_mask_info = m | |
print(f"Best score found: {best_score:.4f}") | |
if best_mask_info is not None: | |
bbox = best_mask_info['bbox'] | |
x, y, w, h = map(int, bbox) | |
y_start = max(0, y) | |
y_end = min(np_orig.shape[0], y + h) | |
x_start = max(0, x) | |
x_end = min(np_orig.shape[1], x + w) | |
print(f"Applying best mask's bounding box: [{x_start}:{x_end}, {y_start}:{y_end}]") | |
if y_end > y_start and x_end > x_start: | |
np_reveal[y_start:y_end, x_start:x_end] = np_orig[y_start:y_end, x_start:x_end] | |
else: | |
print(f"Warning: Invalid bounding box dimensions calculated ({w}x{h} at {x},{y}). Skipping reveal.") | |
else: | |
print("No suitable mask found based on the prompt.") | |
final_revealed_pil = np_to_pil(np_reveal) | |
print("Unmask processing complete.") | |
return final_revealed_pil | |
except Exception as e: | |
print(f"Error during clip_guided_unmask: {e}") | |
print(traceback.format_exc()) | |
return revealed_img | |
# --- NEW FUNCTION (to get all SAM bounding boxes) --- | |
def get_all_sam_bboxes(original_img): | |
""" | |
Generates all masks for an image using SAM and returns their bounding boxes. | |
Input: PIL Image | |
Output: List of bounding boxes [[x, y, w, h], ...] or None on error | |
""" | |
if original_img is None: | |
print("Error: Original image is required for getting bboxes.") | |
return None # Return None or empty list to indicate error | |
# Ensure input is PIL Image | |
if not isinstance(original_img, Image.Image): | |
print(f"Error: get_all_sam_bboxes expects a PIL Image, got {type(original_img)}") | |
try: original_img = Image.open(original_img).convert("RGB") | |
except Exception as e: | |
print(f"Error converting input to PIL Image: {e}") | |
return None | |
print("Processing request to get all SAM bounding boxes...") | |
try: | |
np_orig = pil_to_np(original_img) | |
print(f"Original image shape for bbox generation: {np_orig.shape}") | |
# Ensure uint8 for SAM | |
if np_orig.dtype != np.uint8: | |
print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})") | |
np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8) | |
else: | |
np_orig_sam = np_orig | |
print("Generating masks with SAM...") | |
masks = mask_generator.generate(np_orig_sam) | |
if not masks: | |
print("SAM did not generate any masks.") | |
return [] # Return empty list if no masks | |
print(f"Generated {len(masks)} masks.") | |
# Extract bounding boxes [x, y, w, h] | |
bboxes = [m['bbox'] for m in masks if 'bbox' in m] | |
# Ensure all elements are standard Python ints/floats for JSON serialization | |
bboxes_serializable = [[int(b[0]), int(b[1]), int(b[2]), int(b[3])] for b in bboxes] | |
print(f"Extracted {len(bboxes_serializable)} bounding boxes.") | |
return bboxes_serializable # Return the list of boxes | |
except Exception as e: | |
print(f"Error during get_all_sam_bboxes: {e}") | |
print(traceback.format_exc()) | |
return None # Indicate error | |
# --- Gradio Interface using Blocks to support multiple API endpoints --- | |
print("Setting up Gradio interface using Blocks...") | |
with gr.Blocks() as demo: | |
gr.Markdown("# CLIP-SAM Guided Unmasking and BBox Extraction") | |
with gr.Tab("Interactive Unmasking"): | |
with gr.Row(): | |
img_input_unmask = gr.Image(type="pil", label="Original Image") | |
img_revealed_input = gr.Image(type="pil", label="Current Revealed Image (leave empty on first run)") | |
img_output_unmask = gr.Image(type="pil", label="Updated Reveal") | |
prompt_input_unmask = gr.Textbox(label="Text Prompt (e.g., 'a red car', 'the dog')") | |
unmask_button = gr.Button("Unmask Prompt") | |
with gr.Tab("Get All Bounding Boxes"): | |
with gr.Row(): | |
img_input_bbox = gr.Image(type="pil", label="Original Image") | |
# Output for visualization in UI (optional) | |
img_output_bbox_viz = gr.Image(type="pil", label="Image with All BBoxes (Visualization)") | |
# Output for API call (JSON) | |
json_output_bbox = gr.JSON(label="Bounding Boxes ([x, y, w, h])") | |
bbox_button = gr.Button("Get All SAM Bounding Boxes") | |
# --- Define API endpoints --- | |
# Endpoint for the interactive unmasking function | |
unmask_button.click( | |
fn=clip_guided_unmask, | |
inputs=[img_input_unmask, img_revealed_input, prompt_input_unmask], | |
outputs=img_output_unmask, | |
api_name="predict" # Keep the original API name for compatibility | |
) | |
# Helper function to draw boxes for the UI visualization part | |
def draw_boxes_on_image_for_ui(original_img): | |
bboxes = get_all_sam_bboxes(original_img) | |
if bboxes is None or not bboxes: | |
# Return original image or error message if desired | |
print("No bounding boxes found or error occurred.") | |
return original_img, [] # Return original image and empty list | |
img_copy = original_img.copy() | |
draw = ImageDraw.Draw(img_copy) | |
print(f"Drawing {len(bboxes)} boxes for UI preview...") | |
for bbox in bboxes: | |
x, y, w, h = map(int, bbox) | |
x1, y1 = x + w, y + h | |
# Draw rectangle outline [x0, y0, x1, y1] | |
draw.rectangle([x, y, x1, y1], outline="red", width=2) | |
print("Finished drawing boxes for UI.") | |
return img_copy, bboxes # Return image with boxes AND the json data | |
# Endpoint for getting all bounding boxes (and visualizing in UI) | |
bbox_button.click( | |
fn=draw_boxes_on_image_for_ui, # Use helper that calls get_all_sam_bboxes | |
inputs=[img_input_bbox], | |
outputs=[img_output_bbox_viz, json_output_bbox], # Output to both components | |
api_name="get_boxes" # New API name for this function | |
) | |
# Launch the interface | |
print("Launching Gradio interface...") | |
# Consider adding share=False and potentially debug=True locally if needed | |
# demo.launch(share=True) # Use share=True if needed for external access like from your local script | |
demo.launch() | |
print("Interface launched. Check the output URL.") | |
# --- END OF FILE app.py --- |