from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation from PIL import Image import torch import numpy as np import cv2 processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") ROAD_LABELS = [0, 8] # class indices to consider as road def predict_defect(image: Image.Image): original = np.array(image) inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy() resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST) # Mark only suspicious areas (non-road) unless it's oversegmenting defect_mask = ~np.isin(resized_mask, ROAD_LABELS) if np.sum(defect_mask) / resized_mask.size > 0.4: defect_mask[:] = False overlay = original.copy() overlay[defect_mask] = [255, 0, 0] return Image.fromarray(overlay)