Spaces:
Runtime error
Runtime error
import os | |
os.environ["HF_HOME"] = "/tmp" | |
from fastapi import FastAPI, UploadFile, File | |
from transformers import AutoImageProcessor, SwinForImageClassification | |
from PIL import Image | |
import torch | |
import io | |
app = FastAPI() | |
MODEL_ID = "OttoYu/TreeClassification" | |
processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
model = SwinForImageClassification.from_pretrained(MODEL_ID) | |
model.eval() | |
id2label = model.config.id2label | |
def root(): | |
return {"message": "TreeClassification API is running π"} | |
async def predict(file: UploadFile = File(...)): | |
image_bytes = await file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
pred_id = logits.argmax(-1).item() | |
pred_label = id2label[pred_id] | |
return { | |
"prediction": pred_label, | |
"label_id": pred_id | |
} | |