DurgaDeepak commited on
Commit
08d991a
·
verified ·
1 Parent(s): 343283f

Update models/segmentation/segmenter.py

Browse files
Files changed (1) hide show
  1. models/segmentation/segmenter.py +89 -89
models/segmentation/segmenter.py CHANGED
@@ -1,89 +1,89 @@
1
- import logging
2
- import torch
3
- from PIL import Image
4
- import numpy as np
5
- from torchvision import transforms
6
- from torchvision.models.segmentation import deeplabv3_resnet50
7
- from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class Segmenter:
12
- """
13
- Generalized Semantic Segmentation Wrapper for SegFormer and DeepLabV3.
14
- """
15
-
16
- def __init__(self, model_key="nvidia/segformer-b0-finetuned-ade-512-512", device="cpu"):
17
- """
18
- Initialize the segmentation model.
19
-
20
- Args:
21
- model_key (str): Model identifier, e.g., Hugging Face model id or 'deeplabv3_resnet50'.
22
- device (str): Inference device ("cpu" or "cuda").
23
- """
24
- logger.info(f"Initializing segmenter with model: {model_key}")
25
- self.device = device
26
- self.model_key = model_key
27
- self.model, self.processor = self._load_model()
28
-
29
- def _load_model(self):
30
- """
31
- Load the segmentation model and processor.
32
-
33
- Returns:
34
- Tuple[torch.nn.Module, Optional[Processor]]
35
- """
36
- if "segformer" in self.model_key:
37
- model = SegformerForSemanticSegmentation.from_pretrained(self.model_key).to(self.device)
38
- processor = SegformerFeatureExtractor.from_pretrained(self.model_key)
39
- return model, processor
40
- elif self.model_key == "deeplabv3_resnet50":
41
- model = deeplabv3_resnet50(pretrained=True).to(self.device).eval()
42
- return model, None
43
- else:
44
- raise ValueError(f"Unsupported model key: {self.model_key}")
45
-
46
- def predict(self, image: Image.Image):
47
- """
48
- Perform segmentation on the input image.
49
-
50
- Args:
51
- image (PIL.Image.Image): Input image.
52
-
53
- Returns:
54
- np.ndarray: Segmentation mask.
55
- """
56
- logger.info("Running segmentation")
57
-
58
- if "segformer" in self.model_key:
59
- inputs = self.processor(images=image, return_tensors="pt").to(self.device)
60
- outputs = self.model(**inputs)
61
- mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
62
- return mask
63
-
64
- elif self.model_key == "deeplabv3_resnet50":
65
- transform = transforms.Compose([
66
- transforms.ToTensor(),
67
- ])
68
- inputs = transform(image).unsqueeze(0).to(self.device)
69
- with torch.no_grad():
70
- outputs = self.model(inputs)["out"]
71
- mask = outputs.argmax(1).squeeze().cpu().numpy()
72
- return mask
73
-
74
- def draw(self, image: Image.Image, mask: np.ndarray, alpha=0.5):
75
- """
76
- Overlay the segmentation mask on the input image.
77
-
78
- Args:
79
- image (PIL.Image.Image): Original image.
80
- mask (np.ndarray): Segmentation mask.
81
- alpha (float): Blend strength.
82
-
83
- Returns:
84
- PIL.Image.Image: Image with mask overlay.
85
- """
86
- logger.info("Drawing segmentation overlay")
87
- mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)).convert("L").resize(image.size)
88
- mask_colored = Image.merge("RGB", (mask_img, mask_img, mask_img))
89
- return Image.blend(image, mask_colored, alpha)
 
1
+ import logging
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from torchvision import transforms
6
+ from torchvision.models.segmentation import deeplabv3_resnet50
7
+ from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class Segmenter:
12
+ """
13
+ Generalized Semantic Segmentation Wrapper for SegFormer and DeepLabV3.
14
+ """
15
+
16
+ def __init__(self, model_key="nvidia/segformer-b0-finetuned-ade-512-512", device="cpu"):
17
+ """
18
+ Initialize the segmentation model.
19
+
20
+ Args:
21
+ model_key (str): Model identifier, e.g., Hugging Face model id or 'deeplabv3_resnet50'.
22
+ device (str): Inference device ("cpu" or "cuda").
23
+ """
24
+ logger.info(f"Initializing segmenter with model: {model_key}")
25
+ self.device = device
26
+ self.model_key = model_key
27
+ self.model, self.processor = self._load_model()
28
+
29
+ def _load_model(self):
30
+ """
31
+ Load the segmentation model and processor.
32
+
33
+ Returns:
34
+ Tuple[torch.nn.Module, Optional[Processor]]
35
+ """
36
+ if "segformer" in self.model_key:
37
+ model = SegformerForSemanticSegmentation.from_pretrained(self.model_key).to(self.device)
38
+ processor = SegformerFeatureExtractor.from_pretrained(self.model_key)
39
+ return model, processor
40
+ elif self.model_key == "deeplabv3_resnet50":
41
+ model = deeplabv3_resnet50(pretrained=True).to(self.device).eval()
42
+ return model, None
43
+ else:
44
+ raise ValueError(f"Unsupported model key: {self.model_key}")
45
+
46
+ def predict(self, image: Image.Image, **kwargs):
47
+ """
48
+ Perform segmentation on the input image.
49
+
50
+ Args:
51
+ image (PIL.Image.Image): Input image.
52
+
53
+ Returns:
54
+ np.ndarray: Segmentation mask.
55
+ """
56
+ logger.info("Running segmentation")
57
+
58
+ if "segformer" in self.model_key:
59
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
60
+ outputs = self.model(**inputs)
61
+ mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
62
+ return mask
63
+
64
+ elif self.model_key == "deeplabv3_resnet50":
65
+ transform = transforms.Compose([
66
+ transforms.ToTensor(),
67
+ ])
68
+ inputs = transform(image).unsqueeze(0).to(self.device)
69
+ with torch.no_grad():
70
+ outputs = self.model(inputs)["out"]
71
+ mask = outputs.argmax(1).squeeze().cpu().numpy()
72
+ return mask
73
+
74
+ def draw(self, image: Image.Image, mask: np.ndarray, alpha=0.5):
75
+ """
76
+ Overlay the segmentation mask on the input image.
77
+
78
+ Args:
79
+ image (PIL.Image.Image): Original image.
80
+ mask (np.ndarray): Segmentation mask.
81
+ alpha (float): Blend strength.
82
+
83
+ Returns:
84
+ PIL.Image.Image: Image with mask overlay.
85
+ """
86
+ logger.info("Drawing segmentation overlay")
87
+ mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)).convert("L").resize(image.size)
88
+ mask_colored = Image.merge("RGB", (mask_img, mask_img, mask_img))
89
+ return Image.blend(image, mask_colored, alpha)