File size: 1,006 Bytes
e85c9d4
3519028
e85c9d4
dd5fd0f
7345a5a
dd5fd0f
7345a5a
dd5fd0f
 
 
 
7345a5a
dd5fd0f
7345a5a
 
 
dd5fd0f
7345a5a
dd5fd0f
7345a5a
 
 
dd5fd0f
7345a5a
 
 
 
 
dd5fd0f
7345a5a
 
 
 
 
dd5fd0f
7345a5a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
os.environ["HF_HOME"] = "/tmp"

from fastapi import FastAPI, UploadFile, File
from transformers import AutoImageProcessor, SwinForImageClassification
from PIL import Image
import torch
import io

app = FastAPI()

MODEL_ID = "OttoYu/TreeClassification"

processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = SwinForImageClassification.from_pretrained(MODEL_ID)
model.eval()

id2label = model.config.id2label

@app.get("/")
def root():
    return {"message": "TreeClassification API is running 🚀"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        pred_id = logits.argmax(-1).item()
        pred_label = id2label[pred_id]

    return {
        "prediction": pred_label,
        "label_id": pred_id
    }