Spaces:
Runtime error
Runtime error
File size: 1,006 Bytes
e85c9d4 3519028 e85c9d4 dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a dd5fd0f 7345a5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import os
os.environ["HF_HOME"] = "/tmp"
from fastapi import FastAPI, UploadFile, File
from transformers import AutoImageProcessor, SwinForImageClassification
from PIL import Image
import torch
import io
app = FastAPI()
MODEL_ID = "OttoYu/TreeClassification"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = SwinForImageClassification.from_pretrained(MODEL_ID)
model.eval()
id2label = model.config.id2label
@app.get("/")
def root():
return {"message": "TreeClassification API is running 🚀"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_id = logits.argmax(-1).item()
pred_label = id2label[pred_id]
return {
"prediction": pred_label,
"label_id": pred_id
}
|