File size: 12,044 Bytes
d0d40b6
 
74dc3aa
 
0b5d529
bf792db
74dc3aa
 
d0d40b6
74dc3aa
 
 
 
 
 
 
 
 
 
 
0b5d529
 
 
 
74dc3aa
 
 
 
 
 
 
0b5d529
 
 
 
 
 
 
d0d40b6
0b5d529
74dc3aa
 
 
d0d40b6
 
 
 
74dc3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b5d529
74dc3aa
0b5d529
74dc3aa
0b5d529
d0d40b6
 
 
 
0b5d529
 
 
74dc3aa
 
 
d0d40b6
 
0b5d529
d0d40b6
 
 
 
0b5d529
d0d40b6
 
74dc3aa
d0d40b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b5d529
d0d40b6
 
 
 
 
 
 
0b5d529
d0d40b6
 
 
0b5d529
d0d40b6
0b5d529
d0d40b6
 
 
 
 
 
 
 
0b5d529
d0d40b6
 
 
 
 
0b5d529
d0d40b6
 
 
 
 
0b5d529
d0d40b6
 
 
 
 
74dc3aa
d0d40b6
0b5d529
d0d40b6
74dc3aa
d0d40b6
 
 
 
74dc3aa
0b5d529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9d255
d0d40b6
74dc3aa
d0d40b6
0b5d529
 
74dc3aa
d0d40b6
0b5d529
d0d40b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# --- START OF FILE app.py (Hugging Face Space code) ---

import torch
import numpy as np
from PIL import Image, ImageDraw # Added ImageDraw for potential visualization within Gradio UI
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import traceback # For better error logging

# --- Check for CUDA and print status ---
if torch.cuda.is_available():
    device = "cuda"
    print("CUDA is available. Using GPU.")
else:
    device = "cpu"
    print("CUDA not available. Using CPU.")
# --- End Check ---

# --- Model Loading (add error handling) ---
sam = None
mask_generator = None
clip_model = None
clip_processor = None
try:
    print("Loading CLIP model...")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    print("CLIP model loaded.")

    print("Loading SAM model...")
    # Use a smaller/faster model if performance is an issue, e.g., vit_t, if available and checkpoint exists
    # sam_checkpoint = "sam_vit_t.pth" # Example for tiny model
    # sam_model_type = "vit_t"
    sam_checkpoint = "sam_vit_b_01ec64.pth" # Original base model
    sam_model_type = "vit_b"

    sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint).to(device).eval()
    # You might adjust SamAutomaticMaskGenerator parameters if needed
    # points_per_side=32, pred_iou_thresh=0.88, stability_score_thresh=0.95, crop_n_layers=0, crop_n_points_downscale_factor=1, min_mask_region_area=0
    mask_generator = SamAutomaticMaskGenerator(sam)
    print("SAM model loaded.")
except Exception as e:
    print(f"FATAL: Error loading models: {e}")
    print(traceback.format_exc())
    # If models fail to load, the app shouldn't run
    exit()
# --- End Model Loading ---

# Convert PIL image to numpy
def pil_to_np(pil_img):
    return np.array(pil_img.convert("RGB"))

# Convert numpy array to PIL image
def np_to_pil(np_img):
    if np_img.dtype != np.uint8:
        if np_img.max() <= 1.0 and np_img.min() >= 0.0:
            np_img = (np_img * 255).astype(np.uint8)
        else:
            np_img = np.clip(np_img, 0, 255).astype(np.uint8)
    return Image.fromarray(np_img)


# --- EXISTING FUNCTION (for unmasking single prompt) ---
def clip_guided_unmask(original_img, revealed_img, text_prompt):
    # --- (Function content remains exactly the same as before) ---
    if original_img is None:
        print("Error: Original image is required for unmasking.")
        return revealed_img

    if not isinstance(original_img, Image.Image):
        print(f"Error: original_img is not a PIL Image, type is {type(original_img)}")
        try: original_img = Image.open(original_img).convert("RGB")
        except Exception: return revealed_img

    if revealed_img is None:
        print("No revealed image provided, creating a black canvas.")
        revealed_img = Image.new("RGB", original_img.size, color="black")
    elif not isinstance(revealed_img, Image.Image):
         print(f"Error: revealed_img is not a PIL Image, type is {type(revealed_img)}")
         try: revealed_img = Image.open(revealed_img).convert("RGB")
         except Exception:
              print("Falling back to black canvas for revealed image.")
              revealed_img = Image.new("RGB", original_img.size, color="black")

    print(f"Processing unmask request for prompt: '{text_prompt}'")
    try:
        np_orig = pil_to_np(original_img)
        np_reveal = pil_to_np(revealed_img)
        print(f"Original image shape: {np_orig.shape}, Revealed image shape: {np_reveal.shape}")

        if np_orig.shape != np_reveal.shape:
            print(f"Warning: Shapes mismatch. Resizing revealed {np_reveal.shape} to {np_orig.shape}")
            revealed_img = revealed_img.resize(original_img.size)
            np_reveal = pil_to_np(revealed_img)

        print("Generating masks with SAM...")
        if np_orig.dtype != np.uint8:
            print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})")
            np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8)
        else:
            np_orig_sam = np_orig

        masks = mask_generator.generate(np_orig_sam)
        if not masks:
            print("SAM did not generate any masks.")
            return revealed_img
        print(f"Generated {len(masks)} masks.")

        print("Processing text prompt with CLIP...")
        prompt_for_clip = text_prompt if text_prompt else "object" # Handle empty prompt
        text_inputs = clip_processor(text=[prompt_for_clip], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            text_feat = clip_model.get_text_features(**text_inputs)
            text_feat /= text_feat.norm(p=2, dim=-1, keepdim=True)
        print("Text features generated.")

        best_score = -float('inf')
        best_mask_info = None

        print("Calculating CLIP scores for masks...")
        for i, m in enumerate(masks):
            seg = m["segmentation"]
            masked_for_clip = np_orig.copy()
            masked_for_clip[~seg] = 0
            pil_masked_for_clip = np_to_pil(masked_for_clip)
            inputs = clip_processor(images=pil_masked_for_clip, return_tensors="pt").to(device)
            with torch.no_grad():
                image_feat = clip_model.get_image_features(**inputs)
                image_feat /= image_feat.norm(p=2, dim=-1, keepdim=True)
                sim = (image_feat @ text_feat.T).item()
            if sim > best_score:
                best_score = sim
                best_mask_info = m

        print(f"Best score found: {best_score:.4f}")

        if best_mask_info is not None:
            bbox = best_mask_info['bbox']
            x, y, w, h = map(int, bbox)
            y_start = max(0, y)
            y_end = min(np_orig.shape[0], y + h)
            x_start = max(0, x)
            x_end = min(np_orig.shape[1], x + w)
            print(f"Applying best mask's bounding box: [{x_start}:{x_end}, {y_start}:{y_end}]")
            if y_end > y_start and x_end > x_start:
                np_reveal[y_start:y_end, x_start:x_end] = np_orig[y_start:y_end, x_start:x_end]
            else:
                 print(f"Warning: Invalid bounding box dimensions calculated ({w}x{h} at {x},{y}). Skipping reveal.")
        else:
            print("No suitable mask found based on the prompt.")

        final_revealed_pil = np_to_pil(np_reveal)
        print("Unmask processing complete.")
        return final_revealed_pil

    except Exception as e:
        print(f"Error during clip_guided_unmask: {e}")
        print(traceback.format_exc())
        return revealed_img

# --- NEW FUNCTION (to get all SAM bounding boxes) ---
def get_all_sam_bboxes(original_img):
    """
    Generates all masks for an image using SAM and returns their bounding boxes.
    Input: PIL Image
    Output: List of bounding boxes [[x, y, w, h], ...] or None on error
    """
    if original_img is None:
        print("Error: Original image is required for getting bboxes.")
        return None # Return None or empty list to indicate error

    # Ensure input is PIL Image
    if not isinstance(original_img, Image.Image):
        print(f"Error: get_all_sam_bboxes expects a PIL Image, got {type(original_img)}")
        try: original_img = Image.open(original_img).convert("RGB")
        except Exception as e:
             print(f"Error converting input to PIL Image: {e}")
             return None

    print("Processing request to get all SAM bounding boxes...")
    try:
        np_orig = pil_to_np(original_img)
        print(f"Original image shape for bbox generation: {np_orig.shape}")

        # Ensure uint8 for SAM
        if np_orig.dtype != np.uint8:
            print(f"Warning: Converting original image to uint8 for SAM (original type: {np_orig.dtype})")
            np_orig_sam = np.clip(np_orig, 0, 255).astype(np.uint8)
        else:
            np_orig_sam = np_orig

        print("Generating masks with SAM...")
        masks = mask_generator.generate(np_orig_sam)

        if not masks:
            print("SAM did not generate any masks.")
            return [] # Return empty list if no masks

        print(f"Generated {len(masks)} masks.")

        # Extract bounding boxes [x, y, w, h]
        bboxes = [m['bbox'] for m in masks if 'bbox' in m]
        # Ensure all elements are standard Python ints/floats for JSON serialization
        bboxes_serializable = [[int(b[0]), int(b[1]), int(b[2]), int(b[3])] for b in bboxes]


        print(f"Extracted {len(bboxes_serializable)} bounding boxes.")
        return bboxes_serializable # Return the list of boxes

    except Exception as e:
        print(f"Error during get_all_sam_bboxes: {e}")
        print(traceback.format_exc())
        return None # Indicate error


# --- Gradio Interface using Blocks to support multiple API endpoints ---
print("Setting up Gradio interface using Blocks...")

with gr.Blocks() as demo:
    gr.Markdown("# CLIP-SAM Guided Unmasking and BBox Extraction")

    with gr.Tab("Interactive Unmasking"):
        with gr.Row():
            img_input_unmask = gr.Image(type="pil", label="Original Image")
            img_revealed_input = gr.Image(type="pil", label="Current Revealed Image (leave empty on first run)")
            img_output_unmask = gr.Image(type="pil", label="Updated Reveal")
        prompt_input_unmask = gr.Textbox(label="Text Prompt (e.g., 'a red car', 'the dog')")
        unmask_button = gr.Button("Unmask Prompt")

    with gr.Tab("Get All Bounding Boxes"):
        with gr.Row():
            img_input_bbox = gr.Image(type="pil", label="Original Image")
            # Output for visualization in UI (optional)
            img_output_bbox_viz = gr.Image(type="pil", label="Image with All BBoxes (Visualization)")
            # Output for API call (JSON)
            json_output_bbox = gr.JSON(label="Bounding Boxes ([x, y, w, h])")
        bbox_button = gr.Button("Get All SAM Bounding Boxes")

    # --- Define API endpoints ---

    # Endpoint for the interactive unmasking function
    unmask_button.click(
        fn=clip_guided_unmask,
        inputs=[img_input_unmask, img_revealed_input, prompt_input_unmask],
        outputs=img_output_unmask,
        api_name="predict" # Keep the original API name for compatibility
    )

    # Helper function to draw boxes for the UI visualization part
    def draw_boxes_on_image_for_ui(original_img):
        bboxes = get_all_sam_bboxes(original_img)
        if bboxes is None or not bboxes:
             # Return original image or error message if desired
             print("No bounding boxes found or error occurred.")
             return original_img, [] # Return original image and empty list
        img_copy = original_img.copy()
        draw = ImageDraw.Draw(img_copy)
        print(f"Drawing {len(bboxes)} boxes for UI preview...")
        for bbox in bboxes:
            x, y, w, h = map(int, bbox)
            x1, y1 = x + w, y + h
            # Draw rectangle outline [x0, y0, x1, y1]
            draw.rectangle([x, y, x1, y1], outline="red", width=2)
        print("Finished drawing boxes for UI.")
        return img_copy, bboxes # Return image with boxes AND the json data

    # Endpoint for getting all bounding boxes (and visualizing in UI)
    bbox_button.click(
        fn=draw_boxes_on_image_for_ui, # Use helper that calls get_all_sam_bboxes
        inputs=[img_input_bbox],
        outputs=[img_output_bbox_viz, json_output_bbox], # Output to both components
        api_name="get_boxes" # New API name for this function
    )


# Launch the interface
print("Launching Gradio interface...")
# Consider adding share=False and potentially debug=True locally if needed
# demo.launch(share=True) # Use share=True if needed for external access like from your local script
demo.launch()
print("Interface launched. Check the output URL.")


# --- END OF FILE app.py ---