Spaces:
Sleeping
Sleeping
# grounding_dino2.py | |
# Lightweight Grounding DINO wrapper for box detection + cropping + visualization. | |
from __future__ import annotations | |
import os | |
import threading | |
from pathlib import Path | |
from typing import List, Dict, Any, Tuple, Optional | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
# ---- Writable caches (HF Spaces / containers) ---- | |
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf-cache") | |
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
os.environ.setdefault("HOME", "/tmp") | |
os.environ.setdefault("XDG_CACHE_HOME", CACHE_DIR) | |
os.environ.setdefault("HF_HOME", CACHE_DIR) | |
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", CACHE_DIR) | |
os.environ.setdefault("TRANSFORMERS_CACHE", CACHE_DIR) | |
os.environ.setdefault("HF_DATASETS_CACHE", f"{CACHE_DIR}/datasets") | |
os.environ.setdefault("TORCH_HOME", CACHE_DIR) | |
os.environ.setdefault("PYTHONPYCACHEPREFIX", "/tmp/pycache") | |
def _clamp_xyxy(box: List[float], w: int, h: int) -> Tuple[int, int, int, int]: | |
x0, y0, x1, y1 = box | |
x0 = max(0, min(int(round(x0)), w - 1)) | |
y0 = max(0, min(int(round(y0)), h - 1)) | |
x1 = max(0, min(int(round(x1)), w - 1)) | |
y1 = max(0, min(int(round(y1)), h - 1)) | |
if x1 < x0: | |
x0, x1 = x1, x0 | |
if y1 < y0: | |
y0, y1 = y1, y0 | |
return x0, y0, x1, y1 | |
def _pad_box(box: Tuple[int, int, int, int], w: int, h: int, frac: float = 0.06) -> Tuple[int, int, int, int]: | |
x0, y0, x1, y1 = box | |
bw, bh = x1 - x0, y1 - y0 | |
dx, dy = int(bw * frac), int(bh * frac) | |
return max(0, x0 - dx), max(0, y0 - dy), min(w - 1, x1 + dx), min(h - 1, y1 + dy) | |
def crop_from_box(img: Image.Image, box_xyxy: Tuple[int, int, int, int]) -> Image.Image: | |
return img.crop(box_xyxy) | |
def _parse_to_flat_labels(labels: List[str] | str) -> List[str]: | |
""" | |
Accepts a comma-separated string or a list of strings and returns a flat list of non-empty labels. | |
""" | |
if isinstance(labels, str): | |
items = [x.strip() for x in labels.split(",") if x.strip()] | |
else: | |
items = [str(x).strip() for x in labels if str(x).strip()] | |
if not items: | |
raise ValueError("No labels provided.") | |
return items | |
def _build_dot_separated_prompt(items: List[str]) -> str: | |
""" | |
Builds the recommended GroundingDINO text prompt: "a man . a dog ." | |
""" | |
return " . ".join(items) + " ." | |
class GroundingDINORunner: | |
""" | |
Minimal singleton-style wrapper for Grounding DINO zero-shot detector. | |
""" | |
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): | |
self.model_id = model_id or os.getenv("GDINO_MODEL_ID", "IDEA-Research/grounding-dino-tiny") | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self._lock = threading.Lock() | |
self.processor = AutoProcessor.from_pretrained(self.model_id, cache_dir=CACHE_DIR) | |
self.model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
self.model_id, cache_dir=CACHE_DIR | |
).to(self.device) | |
self.model.eval() | |
def detect( | |
self, | |
image: Image.Image, | |
labels: List[str] | str, | |
box_threshold: float = 0.4, | |
text_threshold: float = 0.3, | |
pad_frac: float = 0.06, | |
) -> List[Dict[str, Any]]: | |
""" | |
Runs zero-shot detection and returns: | |
[{ 'label': str, 'score': float, 'box_xyxy': (x0,y0,x1,y1), 'crop': PIL.Image }, ...] | |
""" | |
w, h = image.size | |
# ---- FIX: use dot-separated string or flat list; avoid nested lists ---- | |
items = _parse_to_flat_labels(labels) | |
text_prompt = _build_dot_separated_prompt(items) # "a man . a dog ." | |
# Prepare inputs | |
inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) | |
# Inference | |
with self._lock, torch.no_grad(): | |
outputs = self.model(**inputs) | |
# transformers>=4.51 uses "threshold", older expects "box_threshold" | |
try: | |
post = self.processor.post_process_grounded_object_detection( | |
outputs=outputs, | |
input_ids=inputs.input_ids, | |
threshold=float(box_threshold), | |
text_threshold=float(text_threshold), | |
target_sizes=[(h, w)], | |
) | |
except TypeError: | |
post = self.processor.post_process_grounded_object_detection( | |
outputs=outputs, | |
input_ids=inputs.input_ids, | |
box_threshold=float(box_threshold), | |
text_threshold=float(text_threshold), | |
target_sizes=[(h, w)], | |
) | |
det = post[0] | |
boxes = det.get("boxes", []) | |
scores = det.get("scores", []) | |
# Newer transformers populate "text_labels"; else "labels" | |
labels_out = det.get("text_labels", det.get("labels", [])) | |
results: List[Dict[str, Any]] = [] | |
for b, s, lab in zip(boxes, scores, labels_out): | |
b = b.tolist() if hasattr(b, "tolist") else list(b) | |
bx = _clamp_xyxy(b, w, h) | |
bx = _pad_box(bx, w, h, pad_frac) | |
crop = crop_from_box(image, bx) | |
score = float(s.item()) if torch.is_tensor(s) else float(s) | |
results.append({"label": lab, "score": score, "box_xyxy": bx, "crop": crop}) | |
return results | |
# --- Visualization helper ------------------------------------------------------ | |
def visualize_detections( | |
image: Image.Image, | |
detections: list[dict], | |
*, | |
box_color: tuple[int, int, int] = (0, 255, 0), | |
text_color: tuple[int, int, int] = (0, 0, 0), | |
box_width: int = 3, | |
) -> Image.Image: | |
""" | |
Draw boxes + labels on a copy of `image`. | |
Each detection item expects: {'label': str, 'score': float, 'box_xyxy': (x0,y0,x1,y1)} | |
""" | |
vis = image.copy() | |
draw = ImageDraw.Draw(vis) | |
try: | |
font = ImageFont.truetype("DejaVuSans.ttf", 16) | |
except Exception: | |
font = None | |
for det in detections: | |
x0, y0, x1, y1 = det["box_xyxy"] | |
lab = det.get("label", "") | |
sc = det.get("score", 0.0) | |
draw.rectangle((x0, y0, x1, y1), outline=box_color, width=box_width) | |
text = f"{lab} {sc:.2f}" | |
# textlength fallback | |
try: | |
text_w = draw.textlength(text, font=font) # type: ignore[attr-defined] | |
except Exception: | |
text_w = len(text) * 8 | |
pad = 4 | |
draw.rectangle((x0, max(0, y0 - 20), x0 + int(text_w) + pad * 2, y0), fill=box_color) | |
draw.text((x0 + pad, max(0, y0 - 18)), text, fill=text_color, font=font) | |
return vis | |
# convenience singleton | |
_runner_singleton: GroundingDINORunner | None = None | |
def get_runner() -> GroundingDINORunner: | |
global _runner_singleton | |
if _runner_singleton is None: | |
_runner_singleton = GroundingDINORunner() | |
return _runner_singleton | |