|
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
|
|
processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") |
|
model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") |
|
|
|
ROAD_LABELS = [0, 8] |
|
|
|
def predict_defect(image: Image.Image): |
|
original = np.array(image) |
|
inputs = processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy() |
|
resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
defect_mask = ~np.isin(resized_mask, ROAD_LABELS) |
|
if np.sum(defect_mask) / resized_mask.size > 0.4: |
|
defect_mask[:] = False |
|
|
|
overlay = original.copy() |
|
overlay[defect_mask] = [255, 0, 0] |
|
return Image.fromarray(overlay) |