Spaces:
Sleeping
Sleeping
# app.py | |
from time import perf_counter | |
from io import BytesIO | |
from typing import List, Optional, Union | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from pydantic import BaseModel, Field, HttpUrl | |
from PIL import Image | |
import uvicorn | |
from util import get_runner, SmolVLMRunner | |
app = FastAPI(title="SmolVLM Inference API", version="1.2.0") | |
_runner: Optional[SmolVLMRunner] = None | |
# ----------------------- Pydantic models ----------------------- | |
class URLRequest(BaseModel): | |
prompt: str = Field(..., description="Text prompt to accompany the images.") | |
image_urls: List[HttpUrl] = Field(..., description="List of image URLs.") | |
max_new_tokens: int = Field(300, ge=1, le=1024) | |
temperature: Optional[float] = Field(None, ge=0.0, le=2.0) | |
top_p: Optional[float] = Field(None, gt=0.0, le=1.0) | |
class DetectDescribeURLRequest(BaseModel): | |
image_url: HttpUrl | |
labels: Union[str, List[str]] | |
box_threshold: float = 0.40 | |
text_threshold: float = 0.30 | |
pad_frac: float = 0.06 | |
max_new_tokens: int = 160 | |
return_overlay: bool = True | |
temperature: Optional[float] = None | |
top_p: Optional[float] = None | |
# ----------------------- Startup / health ----------------------- | |
async def _load_model_on_startup(): | |
global _runner | |
_runner = get_runner() | |
def health(): | |
return {"status": "ok", "model": _runner.model_id if _runner else None} | |
# ----------------------- Core VLM endpoints ----------------------- | |
async def generate_from_files( | |
prompt: str = Form(...), | |
images: List[UploadFile] = File(..., description="One or more image files."), | |
max_new_tokens: int = Form(300), | |
temperature: Optional[float] = Form(None), | |
top_p: Optional[float] = Form(None), | |
): | |
if not images: | |
raise HTTPException(status_code=400, detail="At least one image must be provided.") | |
t_req_start = perf_counter() | |
# Read files | |
t_load_start = perf_counter() | |
blobs = [] | |
for f in images: | |
if not f.content_type or not f.content_type.startswith("image/"): | |
raise HTTPException(status_code=415, detail=f"Unsupported file type: {f.content_type}") | |
blobs.append(await f.read()) | |
pil_images = _runner.load_pil_from_bytes(blobs) | |
t_load_end = perf_counter() | |
text, inner_metrics = _runner.generate( | |
prompt=prompt, | |
images=pil_images, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
return_stats=True, | |
) | |
t_req_end = perf_counter() | |
metrics = { | |
**inner_metrics, | |
"request_ms": { | |
"image_load": round((t_load_end - t_load_start) * 1000.0, 2), | |
"end_to_end": round((t_req_end - t_req_start) * 1000.0, 2), | |
}, | |
} | |
return {"text": text, "metrics": metrics} | |
async def generate_from_urls(req: URLRequest): | |
t_req_start = perf_counter() | |
if len(req.image_urls) == 0: | |
raise HTTPException(status_code=400, detail="At least one image URL is required.") | |
t_load_start = perf_counter() | |
pil_images = _runner.load_pil_from_urls([str(u) for u in req.image_urls]) | |
t_load_end = perf_counter() | |
text, inner_metrics = _runner.generate( | |
prompt=req.prompt, | |
images=pil_images, | |
max_new_tokens=req.max_new_tokens, | |
temperature=req.temperature, | |
top_p=req.top_p, | |
return_stats=True, | |
) | |
t_req_end = perf_counter() | |
metrics = { | |
**inner_metrics, | |
"request_ms": { | |
"image_load": round((t_load_end - t_load_start) * 1000.0, 2), | |
"end_to_end": round((t_req_end - t_req_start) * 1000.0, 2), | |
}, | |
} | |
return {"text": text, "metrics": metrics} | |
# ----------------------- Detect & Describe endpoints ----------------------- | |
async def detect_describe( | |
image: UploadFile = File(..., description="One image file (image/*)"), | |
labels: str = Form(..., description='Comma-separated phrases, e.g. "a man,a dog"'), | |
box_threshold: float = Form(0.40), | |
text_threshold: float = Form(0.30), | |
pad_frac: float = Form(0.06), | |
max_new_tokens: int = Form(160), | |
temperature: Optional[float] = Form(None), | |
top_p: Optional[float] = Form(None), | |
return_overlay: bool = Form(True), | |
): | |
if not image.content_type or not image.content_type.startswith("image/"): | |
raise HTTPException(status_code=415, detail=f"Unsupported file type: {image.content_type}") | |
try: | |
raw = await image.read() | |
pil = Image.open(BytesIO(raw)).convert("RGB") | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Failed to read image: {e}") | |
out = _runner.detect_and_describe( | |
image=pil, | |
labels=labels, # comma-separated string OK | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
pad_frac=pad_frac, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
return_overlay=return_overlay, | |
) | |
return out | |
async def detect_describe_url(req: DetectDescribeURLRequest): | |
try: | |
pil = _runner.load_pil_from_urls([str(req.image_url)])[0] | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Failed to fetch image: {e}") | |
out = _runner.detect_and_describe( | |
image=pil, | |
labels=req.labels, | |
box_threshold=req.box_threshold, | |
text_threshold=req.text_threshold, | |
pad_frac=req.pad_frac, | |
max_new_tokens=req.max_new_tokens, | |
temperature=req.temperature, | |
top_p=req.top_p, | |
return_overlay=req.return_overlay, | |
) | |
return out | |
# ----------------------- Entrypoint ----------------------- | |
if __name__ == "__main__": | |
# Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000) | |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False) | |