Spaces:
Sleeping
Sleeping
# util.py (patched cache handling for HF Spaces) | |
import os | |
from pathlib import Path | |
# Put every cache under /tmp (always writable in Spaces) | |
CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf-cache") | |
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
# Make sure libraries don't fall back to "~/.cache" -> "/.cache" | |
os.environ.setdefault("HF_HOME", CACHE_DIR) | |
os.environ.setdefault("TRANSFORMERS_CACHE", CACHE_DIR) | |
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", CACHE_DIR) | |
os.environ.setdefault("XDG_CACHE_HOME", CACHE_DIR) | |
os.environ.setdefault("TORCH_HOME", CACHE_DIR) | |
# util.py (Spaces-safe + metrics) | |
from time import perf_counter | |
import threading | |
from io import BytesIO | |
from typing import List, Sequence, Tuple, Dict, Any | |
import io | |
import base64 | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from transformers.image_utils import load_image as hf_load_image | |
from grounding_dino2 import get_runner as get_gdino_runner, visualize_detections | |
def _has_flash_attn() -> bool: | |
try: | |
import flash_attn # noqa: F401 | |
return True | |
except Exception: | |
return False | |
def _pick_backend_and_dtype(): | |
if not torch.cuda.is_available(): | |
return "eager", torch.float32, "cpu" | |
major, _ = torch.cuda.get_device_capability() | |
dev = "cuda" | |
bf16_ok = torch.cuda.is_bf16_supported() | |
dtype = torch.bfloat16 if bf16_ok else torch.float16 | |
if major >= 8: # Ampere+ | |
attn = "flash_attention_2" if _has_flash_attn() else "eager" | |
else: | |
attn = "eager" | |
return attn, dtype, dev | |
class SmolVLMRunner: | |
"""Portable wrapper with per-call metrics.""" | |
def __init__(self, model_id: str | None = None, device: str | None = None): | |
self.model_id = model_id or os.getenv("SMOLVLM_MODEL_ID", "HuggingFaceTB/SmolVLM-Instruct") | |
attn_impl, dtype, dev = _pick_backend_and_dtype() | |
attn_impl = os.getenv("SMOLVLM_ATTN", attn_impl) # optional override | |
self.device = device or dev | |
self.dtype = dtype | |
self.attn_impl = attn_impl | |
if self.device == "cuda" and self.attn_impl == "sdpa": | |
try: | |
from torch.backends.cuda import sdp_kernel | |
sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=True) | |
except Exception: | |
pass | |
self.processor = AutoProcessor.from_pretrained(self.model_id, cache_dir=CACHE_DIR) | |
self.model = AutoModelForVision2Seq.from_pretrained( | |
self.model_id, | |
torch_dtype=self.dtype, | |
_attn_implementation=self.attn_impl, | |
cache_dir=CACHE_DIR, | |
).to(self.device) | |
try: | |
self.model.config._attn_implementation = self.attn_impl | |
except Exception: | |
pass | |
self.model.eval() | |
self._lock = threading.Lock() | |
# ---------- Image utils ---------- | |
def _ensure_rgb(img: Image.Image) -> Image.Image: | |
return img.convert("RGB") if img.mode != "RGB" else img | |
def load_pil_from_urls(cls, urls: Sequence[str]) -> List[Image.Image]: | |
return [cls._ensure_rgb(hf_load_image(u)) for u in urls] | |
def load_pil_from_bytes(cls, blobs: Sequence[bytes]) -> List[Image.Image]: | |
return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs] | |
# ---------- Inference ---------- | |
def detect_and_describe( | |
self, | |
image: Image.Image, | |
labels: list[str] | str, | |
*, | |
box_threshold: float = 0.4, | |
text_threshold: float = 0.3, | |
pad_frac: float = 0.06, | |
max_new_tokens: int = 160, | |
temperature: float | None = None, | |
top_p: float | None = None, | |
return_overlay: bool = False, | |
) -> list[dict] | dict: | |
""" | |
Uses Grounding DINO to detect boxes for `labels`, then asks SmolVLM to | |
describe each cropped box. | |
If return_overlay=False (default): returns a list of dicts: | |
[{ 'label','score','box_xyxy','description' }, ...] | |
If return_overlay=True: returns a dict: | |
{ 'detections': [...], 'overlay_png_b64': '<base64 PNG>' } | |
""" | |
gdino = get_gdino_runner() | |
detections = gdino.detect( | |
image=image, | |
labels=labels, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
pad_frac=pad_frac, | |
) | |
if not detections: | |
return [] if not return_overlay else {"detections": [], "overlay_png_b64": None} | |
results: list[dict] = [] | |
for det in detections: | |
crop = det["crop"] | |
prompt_txt = f"The image gets the label: '{det['label']}'. Describe the object inside this crop in detail." | |
content = [{"type": "image"}, {"type": "text", "text": prompt_txt}] | |
messages = [{"role": "user", "content": content}] | |
chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = self.processor(text=chat_prompt, images=[crop], return_tensors="pt") | |
inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
gen_kwargs = dict(max_new_tokens=max_new_tokens) | |
if temperature is not None: | |
gen_kwargs["temperature"] = float(temperature) | |
if top_p is not None: | |
gen_kwargs["top_p"] = float(top_p) | |
with self._lock, torch.inference_mode(): | |
out_ids = self.model.generate(**inputs, **gen_kwargs) | |
text = self.processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip() | |
if text.startswith("Assistant:"): | |
text = text[len("Assistant:"):].strip() | |
results.append({ | |
"label": det["label"], | |
"score": det["score"], | |
"box_xyxy": det["box_xyxy"], | |
"description": text, | |
}) | |
if not return_overlay: | |
return results | |
# Build overlay image (PNG -> base64 string) | |
overlay = visualize_detections(image, detections) | |
buf = io.BytesIO() | |
overlay.save(buf, format="PNG") | |
b64 = base64.b64encode(buf.getvalue()).decode("ascii") | |
return {"detections": results, "overlay_png_b64": b64} | |
def generate( | |
self, | |
prompt: str, | |
images: Sequence[Image.Image], | |
max_new_tokens: int = 300, | |
temperature: float | None = None, | |
top_p: float | None = None, | |
return_stats: bool = False, | |
) -> str | Tuple[str, Dict[str, Any]]: | |
""" | |
Returns str by default. | |
If return_stats=True, returns (text, metrics_dict). | |
""" | |
meta = { | |
"model_id": self.model_id, | |
"device": self.device, | |
"dtype": str(self.dtype).replace("torch.", ""), | |
"attn_backend": self.attn_impl, | |
"image_count": len(images), | |
"max_new_tokens": int(max_new_tokens), | |
"temperature": None if temperature is None else float(temperature), | |
"top_p": None if top_p is None else float(top_p), | |
} | |
t0 = perf_counter() | |
content = [{"type": "image"} for _ in images] + [{"type": "text", "text": prompt}] | |
messages = [{"role": "user", "content": content}] | |
chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) | |
# Preprocess (tokenize + vision) | |
inputs = self.processor(text=chat_prompt, images=list(images), return_tensors="pt") | |
inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
t_pre_end = perf_counter() | |
# Inference (generate) | |
gen_kwargs = dict(max_new_tokens=max_new_tokens) | |
if temperature is not None: | |
gen_kwargs["temperature"] = float(temperature) | |
if top_p is not None: | |
gen_kwargs["top_p"] = float(top_p) | |
if self.device == "cuda": | |
torch.cuda.synchronize() | |
torch.cuda.reset_peak_memory_stats() | |
with self._lock, torch.inference_mode(): | |
t_inf_start = perf_counter() | |
out_ids = self.model.generate(**inputs, **gen_kwargs) | |
if self.device == "cuda": | |
torch.cuda.synchronize() | |
t_inf_end = perf_counter() | |
# Decode | |
text = self.processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip() | |
if text.startswith("Assistant:"): | |
text = text[len("Assistant:"):].strip() | |
t_dec_end = perf_counter() | |
# Stats | |
input_tokens = int(inputs["input_ids"].shape[-1]) if "input_ids" in inputs else None | |
total_tokens = int(out_ids.shape[-1]) # includes prompt + generated | |
output_tokens = int(total_tokens - (input_tokens or 0)) if input_tokens is not None else None | |
pre_ms = (t_pre_end - t0) * 1000.0 | |
infer_ms = (t_inf_end - t_inf_start) * 1000.0 | |
decode_ms = (t_dec_end - t_inf_end) * 1000.0 | |
total_ms = (t_dec_end - t0) * 1000.0 | |
tps_infer = (output_tokens / ((t_inf_end - t_inf_start) + 1e-9)) if output_tokens else None | |
tps_total = ( | |
(output_tokens / ((t_dec_end - t0) + 1e-9)) if output_tokens else None | |
) | |
gpu_mem_alloc_mb = gpu_mem_resv_mb = None | |
gpu_name = None | |
if self.device == "cuda": | |
try: | |
gpu_mem_alloc_mb = round(torch.cuda.max_memory_allocated() / (1024**2), 2) | |
gpu_mem_resv_mb = round(torch.cuda.max_memory_reserved() / (1024**2), 2) | |
gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
except Exception: | |
pass | |
metrics: Dict[str, Any] = { | |
**meta, | |
"gpu_name": gpu_name, | |
"timings_ms": { | |
"preprocess": round(pre_ms, 2), | |
"inference": round(infer_ms, 2), | |
"decode": round(decode_ms, 2), | |
"total": round(total_ms, 2), | |
}, | |
"tokens": { | |
"input": input_tokens, | |
"output": output_tokens, | |
"total": total_tokens, | |
}, | |
"throughput": { | |
"tokens_per_sec_inference": None if tps_infer is None else round(tps_infer, 2), | |
"tokens_per_sec_end_to_end": None if tps_total is None else round(tps_total, 2), | |
}, | |
"gpu_memory_mb": { | |
"max_allocated": gpu_mem_alloc_mb, | |
"max_reserved": gpu_mem_resv_mb, | |
}, | |
} | |
return (text, metrics) if return_stats else text | |
# Convenience singleton | |
_runner_singleton = None | |
def get_runner(): | |
global _runner_singleton | |
if _runner_singleton is None: | |
_runner_singleton = SmolVLMRunner() | |
return _runner_singleton | |