Spaces:
Sleeping
Sleeping
# ui.py | |
import os | |
import io | |
import json | |
import base64 | |
import requests | |
import streamlit as st | |
from PIL import Image | |
st.set_page_config(page_title="SmolVLM UI", layout="wide") | |
st.title("SmolVLM Grounding") | |
API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000") | |
def show_metrics(metrics: dict): | |
if not metrics: | |
return | |
info = metrics | |
cols = st.columns(4) | |
tt = info.get("timings_ms", {}).get("total") | |
it = info.get("timings_ms", {}).get("inference") | |
tps = info.get("throughput", {}).get("tokens_per_sec_inference") | |
vram = info.get("gpu_memory_mb", {}).get("max_reserved") | |
cols[0].metric("Total (ms)", f"{tt:.0f}" if tt is not None else "—") | |
cols[1].metric("Inference (ms)", f"{it:.0f}" if it is not None else "—") | |
cols[2].metric("Tok/s (infer)", f"{tps:.1f}" if tps is not None else "—") | |
cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—") | |
st.expander("All metrics").json(info) | |
tab_upload, tab_detect = st.tabs(["SmolVLM Detection", "Grounded Detection"]) | |
# -------------------- Tab 1: uploads -> /generate -------------------- | |
with tab_upload: | |
st.subheader("Upload an image") | |
files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True) | |
prompt = st.text_area("Prompt", "Can you describe the image?", height=80) | |
run = st.button("Generate", type="primary", use_container_width=True, key="run_files") | |
max_new_tokens = st.slider("max_new_tokens", 1, 1024, 300, step=1) | |
temperature_on = st.toggle("Set temperature?", value=False) | |
temperature = st.slider("temperature", 0.0, 2.0, 0.2, step=0.05) if temperature_on else None | |
topp_on = st.toggle("Set top_p?", value=False) | |
top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None | |
st.caption("API base: " + API_BASE) | |
if run: | |
if not files or not prompt.strip(): | |
st.error("Please add at least one image and a prompt.") | |
else: | |
with st.spinner("Calling FastAPI…"): | |
data = { | |
"prompt": prompt, | |
"max_new_tokens": str(max_new_tokens), # form fields are strings | |
} | |
if temperature is not None: | |
data["temperature"] = str(temperature) | |
if top_p is not None: | |
data["top_p"] = str(top_p) | |
multipart = [] | |
previews = [] | |
for f in files: | |
content = f.read() | |
multipart.append(("images", (f.name, content, f.type or "application/octet-stream"))) | |
try: | |
previews.append(Image.open(io.BytesIO(content))) | |
except Exception: | |
pass | |
try: | |
r = requests.post(f"{API_BASE}/generate", data=data, files=multipart, timeout=300) | |
r.raise_for_status() | |
out = r.json() | |
st.success("Done!") | |
if previews: | |
# keep existing behavior for uploads (can change to width=... if you prefer) | |
st.image(previews, caption=[f.name for f in files]) | |
st.subheader("Answer") | |
st.write(out.get("text", "")) | |
show_metrics(out.get("metrics", {})) | |
except requests.RequestException as e: | |
st.error(f"Request failed: {e}") | |
if hasattr(e, "response") and e.response is not None: | |
try: | |
st.code(e.response.text, language="json") | |
except Exception: | |
st.write(e.response.text) | |
# -------------------- Tab 2: Detect & Describe -> /detect_describe -------------------- | |
with tab_detect: | |
st.subheader("SmolVLM Grounded Detection") | |
# Upload + labels | |
det_image = st.file_uploader("Image", type=["jpg", "jpeg", "png", "webp"], accept_multiple_files=False) | |
det_labels = st.text_input("Labels (comma-separated)", "a man,a dog") | |
# ---- Preview + description placeholders shown ABOVE sliders/button ---- | |
preview_placeholder = st.empty() | |
desc_placeholder = st.empty() | |
det_bytes = None | |
if det_image: | |
det_bytes = det_image.getvalue() | |
# Small preview (fixed width) | |
preview_placeholder.image(det_bytes, caption=det_image.name, width=480) | |
# Controls | |
det_box_thr = st.slider("box_threshold", 0.05, 0.95, 0.40, 0.01) | |
det_text_thr = st.slider("text_threshold", 0.05, 0.95, 0.30, 0.01) | |
det_pad = st.slider("crop padding (fraction)", 0.0, 0.2, 0.06, 0.01) | |
det_max_new = st.slider("max_new_tokens", 1, 512, 160, 1) | |
run_det = st.button("Detect", type="primary", use_container_width=True) | |
if run_det: | |
if not det_bytes or not det_labels.strip(): | |
st.error("Please provide an image and at least one label.") | |
else: | |
with st.spinner("Calling FastAPI…"): | |
data = { | |
"labels": det_labels, | |
"box_threshold": str(det_box_thr), | |
"text_threshold": str(det_text_thr), | |
"pad_frac": str(det_pad), | |
"max_new_tokens": str(det_max_new), | |
"return_overlay": "true", | |
} | |
files = [("image", (det_image.name, det_bytes, det_image.type or "application/octet-stream"))] | |
try: | |
r = requests.post(f"{API_BASE}/detect_describe", data=data, files=files, timeout=300) | |
r.raise_for_status() | |
out = r.json() | |
# Replace preview with the overlay, still small | |
b64 = out.get("overlay_png_b64") | |
if b64: | |
overlay_bytes = base64.b64decode(b64) | |
preview_placeholder.image(overlay_bytes, caption=f"Detections: {det_image.name}", width=480) | |
# Show descriptions right here (above controls) | |
dets = out.get("detections", []) | |
if not dets: | |
desc_placeholder.info("No detections at current thresholds.") | |
else: | |
lines = [] | |
for i, d in enumerate(dets, 1): | |
lines.append(f"**{i}. {d['label']}** (score={d['score']:.2f}, box={d['box_xyxy']})\n\n{d['description']}") | |
desc_placeholder.markdown("\n\n---\n\n".join(lines), unsafe_allow_html=False) | |
except requests.RequestException as e: | |
st.error(f"Request failed: {e}") | |
if hasattr(e, "response") and e.response is not None: | |
try: | |
st.code(e.response.text, language="json") | |
except Exception: | |
st.write(e.response.text) | |