Spaces:
Sleeping
Sleeping
File size: 7,021 Bytes
eee3392 be6e716 eee3392 b6e71f1 eee3392 b6e71f1 b7bc425 b6e71f1 eee3392 b6e71f1 5b9a7b6 eee3392 be6e716 eee3392 b6e71f1 b7bc425 b6e71f1 be6e716 b7bc425 be6e716 b7bc425 b6e71f1 b7bc425 be6e716 b7bc425 be6e716 b7bc425 be6e716 b7bc425 be6e716 b7bc425 be6e716 b7bc425 be6e716 b7bc425 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# ui.py
import os
import io
import json
import base64
import requests
import streamlit as st
from PIL import Image
st.set_page_config(page_title="SmolVLM UI", layout="wide")
st.title("SmolVLM Grounding")
API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000")
def show_metrics(metrics: dict):
if not metrics:
return
info = metrics
cols = st.columns(4)
tt = info.get("timings_ms", {}).get("total")
it = info.get("timings_ms", {}).get("inference")
tps = info.get("throughput", {}).get("tokens_per_sec_inference")
vram = info.get("gpu_memory_mb", {}).get("max_reserved")
cols[0].metric("Total (ms)", f"{tt:.0f}" if tt is not None else "—")
cols[1].metric("Inference (ms)", f"{it:.0f}" if it is not None else "—")
cols[2].metric("Tok/s (infer)", f"{tps:.1f}" if tps is not None else "—")
cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—")
st.expander("All metrics").json(info)
tab_upload, tab_detect = st.tabs(["SmolVLM Detection", "Grounded Detection"])
# -------------------- Tab 1: uploads -> /generate --------------------
with tab_upload:
st.subheader("Upload an image")
files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True)
prompt = st.text_area("Prompt", "Can you describe the image?", height=80)
run = st.button("Generate", type="primary", use_container_width=True, key="run_files")
max_new_tokens = st.slider("max_new_tokens", 1, 1024, 300, step=1)
temperature_on = st.toggle("Set temperature?", value=False)
temperature = st.slider("temperature", 0.0, 2.0, 0.2, step=0.05) if temperature_on else None
topp_on = st.toggle("Set top_p?", value=False)
top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None
st.caption("API base: " + API_BASE)
if run:
if not files or not prompt.strip():
st.error("Please add at least one image and a prompt.")
else:
with st.spinner("Calling FastAPI…"):
data = {
"prompt": prompt,
"max_new_tokens": str(max_new_tokens), # form fields are strings
}
if temperature is not None:
data["temperature"] = str(temperature)
if top_p is not None:
data["top_p"] = str(top_p)
multipart = []
previews = []
for f in files:
content = f.read()
multipart.append(("images", (f.name, content, f.type or "application/octet-stream")))
try:
previews.append(Image.open(io.BytesIO(content)))
except Exception:
pass
try:
r = requests.post(f"{API_BASE}/generate", data=data, files=multipart, timeout=300)
r.raise_for_status()
out = r.json()
st.success("Done!")
if previews:
# keep existing behavior for uploads (can change to width=... if you prefer)
st.image(previews, caption=[f.name for f in files])
st.subheader("Answer")
st.write(out.get("text", ""))
show_metrics(out.get("metrics", {}))
except requests.RequestException as e:
st.error(f"Request failed: {e}")
if hasattr(e, "response") and e.response is not None:
try:
st.code(e.response.text, language="json")
except Exception:
st.write(e.response.text)
# -------------------- Tab 2: Detect & Describe -> /detect_describe --------------------
with tab_detect:
st.subheader("SmolVLM Grounded Detection")
# Upload + labels
det_image = st.file_uploader("Image", type=["jpg", "jpeg", "png", "webp"], accept_multiple_files=False)
det_labels = st.text_input("Labels (comma-separated)", "a man,a dog")
# ---- Preview + description placeholders shown ABOVE sliders/button ----
preview_placeholder = st.empty()
desc_placeholder = st.empty()
det_bytes = None
if det_image:
det_bytes = det_image.getvalue()
# Small preview (fixed width)
preview_placeholder.image(det_bytes, caption=det_image.name, width=480)
# Controls
det_box_thr = st.slider("box_threshold", 0.05, 0.95, 0.40, 0.01)
det_text_thr = st.slider("text_threshold", 0.05, 0.95, 0.30, 0.01)
det_pad = st.slider("crop padding (fraction)", 0.0, 0.2, 0.06, 0.01)
det_max_new = st.slider("max_new_tokens", 1, 512, 160, 1)
run_det = st.button("Detect", type="primary", use_container_width=True)
if run_det:
if not det_bytes or not det_labels.strip():
st.error("Please provide an image and at least one label.")
else:
with st.spinner("Calling FastAPI…"):
data = {
"labels": det_labels,
"box_threshold": str(det_box_thr),
"text_threshold": str(det_text_thr),
"pad_frac": str(det_pad),
"max_new_tokens": str(det_max_new),
"return_overlay": "true",
}
files = [("image", (det_image.name, det_bytes, det_image.type or "application/octet-stream"))]
try:
r = requests.post(f"{API_BASE}/detect_describe", data=data, files=files, timeout=300)
r.raise_for_status()
out = r.json()
# Replace preview with the overlay, still small
b64 = out.get("overlay_png_b64")
if b64:
overlay_bytes = base64.b64decode(b64)
preview_placeholder.image(overlay_bytes, caption=f"Detections: {det_image.name}", width=480)
# Show descriptions right here (above controls)
dets = out.get("detections", [])
if not dets:
desc_placeholder.info("No detections at current thresholds.")
else:
lines = []
for i, d in enumerate(dets, 1):
lines.append(f"**{i}. {d['label']}** (score={d['score']:.2f}, box={d['box_xyxy']})\n\n{d['description']}")
desc_placeholder.markdown("\n\n---\n\n".join(lines), unsafe_allow_html=False)
except requests.RequestException as e:
st.error(f"Request failed: {e}")
if hasattr(e, "response") and e.response is not None:
try:
st.code(e.response.text, language="json")
except Exception:
st.write(e.response.text)
|