vlm_grounding / ui.py
reygml's picture
haha
f3e066f
# ui.py
import os
import io
import json
import base64
import requests
import streamlit as st
from PIL import Image
st.set_page_config(page_title="SmolVLM UI", layout="wide")
st.title("SmolVLM Grounding")
API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000")
def show_metrics(metrics: dict):
if not metrics:
return
info = metrics
cols = st.columns(4)
tt = info.get("timings_ms", {}).get("total")
it = info.get("timings_ms", {}).get("inference")
tps = info.get("throughput", {}).get("tokens_per_sec_inference")
vram = info.get("gpu_memory_mb", {}).get("max_reserved")
cols[0].metric("Total (ms)", f"{tt:.0f}" if tt is not None else "—")
cols[1].metric("Inference (ms)", f"{it:.0f}" if it is not None else "—")
cols[2].metric("Tok/s (infer)", f"{tps:.1f}" if tps is not None else "—")
cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—")
st.expander("All metrics").json(info)
tab_upload, tab_detect = st.tabs(["SmolVLM Detection", "Grounded Detection"])
# -------------------- Tab 1: uploads -> /generate --------------------
with tab_upload:
st.subheader("Upload an image")
files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True)
prompt = st.text_area("Prompt", "Can you describe the image?", height=80)
run = st.button("Generate", type="primary", use_container_width=True, key="run_files")
max_new_tokens = st.slider("max_new_tokens", 1, 1024, 300, step=1)
temperature_on = st.toggle("Set temperature?", value=False)
temperature = st.slider("temperature", 0.0, 2.0, 0.2, step=0.05) if temperature_on else None
topp_on = st.toggle("Set top_p?", value=False)
top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None
st.caption("API base: " + API_BASE)
if run:
if not files or not prompt.strip():
st.error("Please add at least one image and a prompt.")
else:
with st.spinner("Calling FastAPI…"):
data = {
"prompt": prompt,
"max_new_tokens": str(max_new_tokens), # form fields are strings
}
if temperature is not None:
data["temperature"] = str(temperature)
if top_p is not None:
data["top_p"] = str(top_p)
multipart = []
previews = []
for f in files:
content = f.read()
multipart.append(("images", (f.name, content, f.type or "application/octet-stream")))
try:
previews.append(Image.open(io.BytesIO(content)))
except Exception:
pass
try:
r = requests.post(f"{API_BASE}/generate", data=data, files=multipart, timeout=300)
r.raise_for_status()
out = r.json()
st.success("Done!")
if previews:
# keep existing behavior for uploads (can change to width=... if you prefer)
st.image(previews, caption=[f.name for f in files])
st.subheader("Answer")
st.write(out.get("text", ""))
show_metrics(out.get("metrics", {}))
except requests.RequestException as e:
st.error(f"Request failed: {e}")
if hasattr(e, "response") and e.response is not None:
try:
st.code(e.response.text, language="json")
except Exception:
st.write(e.response.text)
# -------------------- Tab 2: Detect & Describe -> /detect_describe --------------------
with tab_detect:
st.subheader("SmolVLM Grounded Detection")
# Upload + labels
det_image = st.file_uploader("Image", type=["jpg", "jpeg", "png", "webp"], accept_multiple_files=False)
det_labels = st.text_input("Labels (comma-separated)", "a man,a dog")
# ---- Preview + description placeholders shown ABOVE sliders/button ----
preview_placeholder = st.empty()
desc_placeholder = st.empty()
det_bytes = None
if det_image:
det_bytes = det_image.getvalue()
# Small preview (fixed width)
preview_placeholder.image(det_bytes, caption=det_image.name, width=480)
# Controls
det_box_thr = st.slider("box_threshold", 0.05, 0.95, 0.40, 0.01)
det_text_thr = st.slider("text_threshold", 0.05, 0.95, 0.30, 0.01)
det_pad = st.slider("crop padding (fraction)", 0.0, 0.2, 0.06, 0.01)
det_max_new = st.slider("max_new_tokens", 1, 512, 160, 1)
run_det = st.button("Detect", type="primary", use_container_width=True)
if run_det:
if not det_bytes or not det_labels.strip():
st.error("Please provide an image and at least one label.")
else:
with st.spinner("Calling FastAPI…"):
data = {
"labels": det_labels,
"box_threshold": str(det_box_thr),
"text_threshold": str(det_text_thr),
"pad_frac": str(det_pad),
"max_new_tokens": str(det_max_new),
"return_overlay": "true",
}
files = [("image", (det_image.name, det_bytes, det_image.type or "application/octet-stream"))]
try:
r = requests.post(f"{API_BASE}/detect_describe", data=data, files=files, timeout=300)
r.raise_for_status()
out = r.json()
# Replace preview with the overlay, still small
b64 = out.get("overlay_png_b64")
if b64:
overlay_bytes = base64.b64decode(b64)
preview_placeholder.image(overlay_bytes, caption=f"Detections: {det_image.name}", width=480)
# Show descriptions right here (above controls)
dets = out.get("detections", [])
if not dets:
desc_placeholder.info("No detections at current thresholds.")
else:
lines = []
for i, d in enumerate(dets, 1):
lines.append(f"**{i}. {d['label']}** (score={d['score']:.2f}, box={d['box_xyxy']})\n\n{d['description']}")
desc_placeholder.markdown("\n\n---\n\n".join(lines), unsafe_allow_html=False)
except requests.RequestException as e:
st.error(f"Request failed: {e}")
if hasattr(e, "response") and e.response is not None:
try:
st.code(e.response.text, language="json")
except Exception:
st.write(e.response.text)