# ui.py import os import io import json import base64 import requests import streamlit as st from PIL import Image st.set_page_config(page_title="SmolVLM UI", layout="wide") st.title("SmolVLM Grounding") API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000") def show_metrics(metrics: dict): if not metrics: return info = metrics cols = st.columns(4) tt = info.get("timings_ms", {}).get("total") it = info.get("timings_ms", {}).get("inference") tps = info.get("throughput", {}).get("tokens_per_sec_inference") vram = info.get("gpu_memory_mb", {}).get("max_reserved") cols[0].metric("Total (ms)", f"{tt:.0f}" if tt is not None else "—") cols[1].metric("Inference (ms)", f"{it:.0f}" if it is not None else "—") cols[2].metric("Tok/s (infer)", f"{tps:.1f}" if tps is not None else "—") cols[3].metric("GPU reserved (MB)", f"{vram:.0f}" if vram is not None else "—") st.expander("All metrics").json(info) tab_upload, tab_detect = st.tabs(["SmolVLM Detection", "Grounded Detection"]) # -------------------- Tab 1: uploads -> /generate -------------------- with tab_upload: st.subheader("Upload an image") files = st.file_uploader("Images", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=True) prompt = st.text_area("Prompt", "Can you describe the image?", height=80) run = st.button("Generate", type="primary", use_container_width=True, key="run_files") max_new_tokens = st.slider("max_new_tokens", 1, 1024, 300, step=1) temperature_on = st.toggle("Set temperature?", value=False) temperature = st.slider("temperature", 0.0, 2.0, 0.2, step=0.05) if temperature_on else None topp_on = st.toggle("Set top_p?", value=False) top_p = st.slider("top_p", 0.05, 1.0, 0.95, step=0.05) if topp_on else None st.caption("API base: " + API_BASE) if run: if not files or not prompt.strip(): st.error("Please add at least one image and a prompt.") else: with st.spinner("Calling FastAPI…"): data = { "prompt": prompt, "max_new_tokens": str(max_new_tokens), # form fields are strings } if temperature is not None: data["temperature"] = str(temperature) if top_p is not None: data["top_p"] = str(top_p) multipart = [] previews = [] for f in files: content = f.read() multipart.append(("images", (f.name, content, f.type or "application/octet-stream"))) try: previews.append(Image.open(io.BytesIO(content))) except Exception: pass try: r = requests.post(f"{API_BASE}/generate", data=data, files=multipart, timeout=300) r.raise_for_status() out = r.json() st.success("Done!") if previews: # keep existing behavior for uploads (can change to width=... if you prefer) st.image(previews, caption=[f.name for f in files]) st.subheader("Answer") st.write(out.get("text", "")) show_metrics(out.get("metrics", {})) except requests.RequestException as e: st.error(f"Request failed: {e}") if hasattr(e, "response") and e.response is not None: try: st.code(e.response.text, language="json") except Exception: st.write(e.response.text) # -------------------- Tab 2: Detect & Describe -> /detect_describe -------------------- with tab_detect: st.subheader("SmolVLM Grounded Detection") # Upload + labels det_image = st.file_uploader("Image", type=["jpg", "jpeg", "png", "webp"], accept_multiple_files=False) det_labels = st.text_input("Labels (comma-separated)", "a man,a dog") # ---- Preview + description placeholders shown ABOVE sliders/button ---- preview_placeholder = st.empty() desc_placeholder = st.empty() det_bytes = None if det_image: det_bytes = det_image.getvalue() # Small preview (fixed width) preview_placeholder.image(det_bytes, caption=det_image.name, width=480) # Controls det_box_thr = st.slider("box_threshold", 0.05, 0.95, 0.40, 0.01) det_text_thr = st.slider("text_threshold", 0.05, 0.95, 0.30, 0.01) det_pad = st.slider("crop padding (fraction)", 0.0, 0.2, 0.06, 0.01) det_max_new = st.slider("max_new_tokens", 1, 512, 160, 1) run_det = st.button("Detect", type="primary", use_container_width=True) if run_det: if not det_bytes or not det_labels.strip(): st.error("Please provide an image and at least one label.") else: with st.spinner("Calling FastAPI…"): data = { "labels": det_labels, "box_threshold": str(det_box_thr), "text_threshold": str(det_text_thr), "pad_frac": str(det_pad), "max_new_tokens": str(det_max_new), "return_overlay": "true", } files = [("image", (det_image.name, det_bytes, det_image.type or "application/octet-stream"))] try: r = requests.post(f"{API_BASE}/detect_describe", data=data, files=files, timeout=300) r.raise_for_status() out = r.json() # Replace preview with the overlay, still small b64 = out.get("overlay_png_b64") if b64: overlay_bytes = base64.b64decode(b64) preview_placeholder.image(overlay_bytes, caption=f"Detections: {det_image.name}", width=480) # Show descriptions right here (above controls) dets = out.get("detections", []) if not dets: desc_placeholder.info("No detections at current thresholds.") else: lines = [] for i, d in enumerate(dets, 1): lines.append(f"**{i}. {d['label']}** (score={d['score']:.2f}, box={d['box_xyxy']})\n\n{d['description']}") desc_placeholder.markdown("\n\n---\n\n".join(lines), unsafe_allow_html=False) except requests.RequestException as e: st.error(f"Request failed: {e}") if hasattr(e, "response") and e.response is not None: try: st.code(e.response.text, language="json") except Exception: st.write(e.response.text)