SuriRaja's picture
Update model.py
8ef6baf verified
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import numpy as np
import cv2
processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
ROAD_LABELS = [0, 8] # class indices to consider as road
def predict_defect(image: Image.Image):
original = np.array(image)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST)
# Mark only suspicious areas (non-road) unless it's oversegmenting
defect_mask = ~np.isin(resized_mask, ROAD_LABELS)
if np.sum(defect_mask) / resized_mask.size > 0.4:
defect_mask[:] = False
overlay = original.copy()
overlay[defect_mask] = [255, 0, 0]
return Image.fromarray(overlay)