File size: 1,145 Bytes
8470e33
8ef6baf
8470e33
 
46cc277
2326974
 
 
8470e33
8ef6baf
8470e33
 
46cc277
8470e33
 
 
 
 
46cc277
f07ca33
8ef6baf
 
 
 
2233dcb
 
8ef6baf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import numpy as np
import cv2

processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

ROAD_LABELS = [0, 8]  # class indices to consider as road

def predict_defect(image: Image.Image):
    original = np.array(image)
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
    resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Mark only suspicious areas (non-road) unless it's oversegmenting
    defect_mask = ~np.isin(resized_mask, ROAD_LABELS)
    if np.sum(defect_mask) / resized_mask.size > 0.4:
        defect_mask[:] = False

    overlay = original.copy()
    overlay[defect_mask] = [255, 0, 0]
    return Image.fromarray(overlay)