import os os.environ["HF_HOME"] = "/tmp" from fastapi import FastAPI, UploadFile, File from transformers import AutoImageProcessor, SwinForImageClassification from PIL import Image import torch import io app = FastAPI() MODEL_ID = "OttoYu/TreeClassification" processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = SwinForImageClassification.from_pretrained(MODEL_ID) model.eval() id2label = model.config.id2label @app.get("/") def root(): return {"message": "TreeClassification API is running 🚀"} @app.post("/predict") async def predict(file: UploadFile = File(...)): image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits pred_id = logits.argmax(-1).item() pred_label = id2label[pred_id] return { "prediction": pred_label, "label_id": pred_id }