import torch from PIL import Image, ImageDraw, ImageFont from transformers import GroundingDinoProcessor from modeling_grounding_dino import GroundingDinoForObjectDetection from PIL import Image, ImageDraw, ImageFont from itertools import cycle import os from datetime import datetime import gradio as gr import tempfile # Load model and processor model_id = "fushh7/llmdet_swin_large_hf" model_id = "fushh7/llmdet_swin_tiny_hf" DEVICE = "cpu" print(f"[INFO] Using device: {DEVICE}") print(f"[INFO] Loading model from {model_id}...") processor = GroundingDinoProcessor.from_pretrained(model_id) model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE) model.eval() print("[INFO] Model loaded successfully.") # Pre-defined palette (extend or tweak as you like) BOX_COLORS = [ "deepskyblue", "red", "lime", "dodgerblue", "cyan", "magenta", "yellow", "orange", "chartreuse" ] def save_cropped_images(original_image, boxes, labels, scores): """ Salva ogni regione ritagliata definita dalle bounding box in file temporanei. :param original_image: Immagine PIL originale :param boxes: Lista di bounding box [x_min, y_min, x_max, y_max] :param labels: Lista di etichette per ogni box :param scores: Lista di punteggi di confidenza :return: Lista dei percorsi dei file temporanei salvati """ saved_paths = [] for i, (box, label, score) in enumerate(zip(boxes, labels, scores)): # Crea un file temporaneo with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: filepath = tmp_file.name # Ritaglia la regione dall'immagine originale cropped_img = original_image.crop(box) # Salva l'immagine ritagliata cropped_img.save(filepath) saved_paths.append(filepath) return saved_paths def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16): """ Draw bounding boxes and labels on a PIL Image. :param image: PIL Image object :param boxes: Iterable of [x_min, y_min, x_max, y_max] :param labels: Iterable of label strings :param scores: Iterable of scalar confidences (0-1) :param colors: List/tuple of colour names or RGB tuples :param font_path: Path to a TTF font for labels :param font_size: Int size of font to use, default 16 :return: PIL Image with drawn boxes """ # Ensure we can iterate colours indefinitely colour_cycle = cycle(colors) draw = ImageDraw.Draw(image) # Pick a font (fallback to default if missing) try: font = ImageFont.truetype(font_path, size=font_size) except IOError: font = ImageFont.load_default(size=font_size) # Assign a consistent colour per label (optional) label_to_colour = {} for box, label, score in zip(boxes, labels, scores): # Reuse colour if label seen before, else take next from cycle colour = label_to_colour.setdefault(label, next(colour_cycle)) x_min, y_min, x_max, y_max = map(int, box) # Draw rectangle draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2) # Compose text text = f"{label} ({score:.3f})" text_size = draw.textbbox((0, 0), text, font=font)[2:] # Draw text background for legibility bg_coords = [x_min, y_min - text_size[1] - 4, x_min + text_size[0] + 4, y_min] draw.rectangle(bg_coords, fill=colour) # Draw text draw.text((x_min + 2, y_min - text_size[1] - 2), text, fill="black", font=font) return image def resize_image_max_dimension(image, max_size=4096): """ Resize an image so that the longest side is at most max_size pixels, while maintaining the aspect ratio. :param image: PIL Image object :param max_size: Maximum dimension in pixels (default: 1024) :return: PIL Image object (resized) """ width, height = image.size # Check if resizing is needed if max(width, height) <= max_size: return image # Calculate new dimensions maintaining aspect ratio ratio = max_size / max(width, height) new_width = int(width * ratio) new_height = int(height * ratio) # Resize the image using high-quality resampling return image.resize((new_width, new_height), Image.Resampling.LANCZOS) def detect_and_draw( img: Image.Image, text_query: str, box_threshold: float = 0.14, text_threshold: float = 0.13, save_crops: bool = True ): """ Detect objects described in `text_query`, draw boxes, return the image and crops. Note: `text_query` must be lowercase and each concept ends with a dot (e.g. 'a cat. a remote control.') """ # Make sure text is lowered text_query = text_query.lower() # If the image size is too large, we make it smaller img = resize_image_max_dimension(img, max_size=4096) # Preprocess the image inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=[img.size[::-1]] )[0] img_out = img.copy() img_out = draw_boxes( img_out, boxes = results["boxes"].cpu().numpy(), labels = results.get("text_labels", results.get("labels", [])), scores = results["scores"] ) # Lista per i percorsi dei crop crop_paths = [] if save_crops: crop_paths = save_cropped_images( img, boxes=results["boxes"].cpu().numpy(), labels=results.get("text_labels", results.get("labels", [])), scores=results["scores"] ) print(f"Generated {len(crop_paths)} cropped images") return img_out, crop_paths # Create example list examples = [ ["examples/stickers(1).jpg", "stickers. labels.", 0.24, 0.23], ] # Funzione per pulire i file temporanei dopo l'uso def cleanup_temp_files(crop_paths): for path in crop_paths: try: os.unlink(path) except: pass # Create Gradio demo with gr.Blocks(title="Stikkiers", css=".gradio-container {max-width: 100% !important}") as demo: gr.Markdown("# Sticker Finder") gr.Markdown("Upload an image and adjust thresholds to see detections.") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") text_query = gr.Textbox( value="stickers. labels. postcards.", label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')" ) box_threshold = gr.Slider(0.0, 1.0, 0.14, step=0.05, label="Box Threshold") text_threshold = gr.Slider(0.0, 1.0, 0.13, step=0.05, label="Text Threshold") submit_btn = gr.Button("Detect") with gr.Column(): image_output = gr.Image(type="pil", label="Detections") # Galleria per i crop gallery = gr.Gallery( label="Detected Crops", columns=[4], rows=[2], object_fit="contain", height="auto" ) # Esempi gr.Examples( examples=examples, inputs=[image_input, text_query, box_threshold, text_threshold], outputs=[image_output, gallery], fn=detect_and_draw, cache_examples=True ) # Pulsante di submit submit_btn.click( fn=detect_and_draw, inputs=[image_input, text_query, box_threshold, text_threshold], outputs=[image_output, gallery] ) # Pulisci i file temporanei quando viene caricato un nuovo esempio demo.load( fn=lambda: None, inputs=None, outputs=None, ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", share=False)