SuriRaja commited on
Commit
2326974
·
verified ·
1 Parent(s): 87686ae

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -11
model.py CHANGED
@@ -1,30 +1,32 @@
1
  from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
2
- from PIL import Image
3
  import torch
4
  import numpy as np
5
  import cv2
 
 
 
 
 
6
 
7
- # 🔁 Replace this with a model trained for road defects
8
- processor = AutoImageProcessor.from_pretrained("segments/DeepLabV3")
9
- model = AutoModelForSemanticSegmentation.from_pretrained("segments/DeepLabV3")
10
 
11
  def predict_defect(image: Image.Image):
12
  original = np.array(image)
13
  inputs = processor(images=image, return_tensors="pt")
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
-
17
  logits = outputs.logits
18
  segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
19
 
20
- # Resize mask to original image size
21
  resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST)
22
 
23
- # 📌 NOTE: Update the label index below based on your dataset
24
- road_defect_label_index = 1 # Assume 1 represents cracks/potholes
25
- mask = (resized_mask == road_defect_label_index)
26
 
27
- # Overlay red where defect detected
28
  overlay = original.copy()
29
- overlay[mask] = [255, 0, 0]
 
30
  return Image.fromarray(overlay)
 
1
  from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ from PIL import Image
6
+
7
+ # Load model
8
+ processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
9
+ model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
10
 
11
+ # SegFormer ADE20K label IDs: road is class 0 or 8 depending on the mapping
12
+ ROAD_LABELS = [0, 8] # adjust if needed based on actual mapping
 
13
 
14
  def predict_defect(image: Image.Image):
15
  original = np.array(image)
16
  inputs = processor(images=image, return_tensors="pt")
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
+
20
  logits = outputs.logits
21
  segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
22
 
23
+ # Resize mask to match original image size
24
  resized_mask = cv2.resize(segmentation.astype(np.uint8), (original.shape[1], original.shape[0]), interpolation=cv2.INTER_NEAREST)
25
 
26
+ # Highlight anything that's NOT road
27
+ mask = ~np.isin(resized_mask, ROAD_LABELS)
 
28
 
 
29
  overlay = original.copy()
30
+ overlay[mask] = [255, 0, 0] # red highlight for all non-road anomalies
31
+
32
  return Image.fromarray(overlay)