deeptree / app.py
jkcoolkidz's picture
change cache
3519028
import os
os.environ["HF_HOME"] = "/tmp"
from fastapi import FastAPI, UploadFile, File
from transformers import AutoImageProcessor, SwinForImageClassification
from PIL import Image
import torch
import io
app = FastAPI()
MODEL_ID = "OttoYu/TreeClassification"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = SwinForImageClassification.from_pretrained(MODEL_ID)
model.eval()
id2label = model.config.id2label
@app.get("/")
def root():
return {"message": "TreeClassification API is running πŸš€"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_id = logits.argmax(-1).item()
pred_label = id2label[pred_id]
return {
"prediction": pred_label,
"label_id": pred_id
}