DurgaDeepak commited on
Commit
52aa226
·
verified ·
1 Parent(s): 03c9511

Update models/segmentation/segmenter.py

Browse files
Files changed (1) hide show
  1. models/segmentation/segmenter.py +75 -40
models/segmentation/segmenter.py CHANGED
@@ -4,86 +4,121 @@ from PIL import Image
4
  import numpy as np
5
  from torchvision import transforms
6
  from torchvision.models.segmentation import deeplabv3_resnet50
7
- from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
 
 
 
 
 
8
 
9
  logger = logging.getLogger(__name__)
 
10
 
11
  class Segmenter:
12
  """
13
- Generalized Semantic Segmentation Wrapper for SegFormer and DeepLabV3.
14
  """
15
 
16
  def __init__(self, model_key="nvidia/segformer-b0-finetuned-ade-512-512", device="cpu"):
17
  """
18
- Initialize the segmentation model.
19
-
20
  Args:
21
- model_key (str): Model identifier, e.g., Hugging Face model id or 'deeplabv3_resnet50'.
22
- device (str): Inference device ("cpu" or "cuda").
23
  """
24
- logger.info(f"Initializing segmenter with model: {model_key}")
 
25
  self.device = device
26
- self.model_key = model_key
27
- self.model, self.processor = self._load_model()
28
 
29
  def _load_model(self):
30
  """
31
- Load the segmentation model and processor.
32
-
33
- Returns:
34
- Tuple[torch.nn.Module, Optional[Processor]]
35
  """
 
 
 
 
36
  if "segformer" in self.model_key:
37
- model = SegformerForSemanticSegmentation.from_pretrained(self.model_key).to(self.device)
38
- processor = SegformerFeatureExtractor.from_pretrained(self.model_key)
39
- return model, processor
 
40
  elif self.model_key == "deeplabv3_resnet50":
41
- model = deeplabv3_resnet50(pretrained=True).to(self.device).eval()
42
- return model, None
 
 
 
 
 
 
43
  else:
44
- raise ValueError(f"Unsupported model key: {self.model_key}")
 
 
45
 
46
- def predict(self, image: Image.Image, **kwargs):
47
  """
48
- Perform segmentation on the input image.
49
 
50
  Args:
51
- image (PIL.Image.Image): Input image.
52
-
53
  Returns:
54
- np.ndarray: Segmentation mask.
55
  """
56
- logger.info("Running segmentation")
57
 
 
58
  if "segformer" in self.model_key:
59
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
60
  outputs = self.model(**inputs)
61
  mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
62
  return mask
63
 
64
- elif self.model_key == "deeplabv3_resnet50":
65
- transform = transforms.Compose([
66
- transforms.ToTensor(),
67
- ])
68
- inputs = transform(image).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
69
  with torch.no_grad():
70
- outputs = self.model(inputs)["out"]
71
- mask = outputs.argmax(1).squeeze().cpu().numpy()
72
- return mask
 
 
 
73
 
74
- def draw(self, image: Image.Image, mask: np.ndarray, alpha=0.5):
 
 
75
  """
76
  Overlay the segmentation mask on the input image.
77
 
78
  Args:
79
- image (PIL.Image.Image): Original image.
80
  mask (np.ndarray): Segmentation mask.
81
  alpha (float): Blend strength.
82
-
83
  Returns:
84
- PIL.Image.Image: Image with mask overlay.
85
  """
86
  logger.info("Drawing segmentation overlay")
87
- mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)).convert("L").resize(image.size)
88
- mask_colored = Image.merge("RGB", (mask_img, mask_img, mask_img))
89
- return Image.blend(image, mask_colored, alpha)
 
 
 
 
4
  import numpy as np
5
  from torchvision import transforms
6
  from torchvision.models.segmentation import deeplabv3_resnet50
7
+ from transformers import (
8
+ SegformerForSemanticSegmentation,
9
+ SegformerFeatureExtractor,
10
+ AutoProcessor,
11
+ CLIPSegForImageSegmentation,
12
+ )
13
 
14
  logger = logging.getLogger(__name__)
15
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
 
17
  class Segmenter:
18
  """
19
+ Generalized Semantic Segmentation Wrapper for SegFormer, DeepLabV3, and CLIPSeg.
20
  """
21
 
22
  def __init__(self, model_key="nvidia/segformer-b0-finetuned-ade-512-512", device="cpu"):
23
  """
 
 
24
  Args:
25
+ model_key (str): HF model identifier or 'deeplabv3_resnet50'.
26
+ device (str): 'cpu' or 'cuda'.
27
  """
28
+ logger.info(f"Initializing Segmenter for model '{model_key}' on {device}")
29
+ self.model_key = model_key.lower()
30
  self.device = device
31
+ self.model = None
32
+ self.processor = None # for transformers-based models
33
 
34
  def _load_model(self):
35
  """
36
+ Lazy-load the model & processor based on model_key.
 
 
 
37
  """
38
+ if self.model is not None:
39
+ return
40
+
41
+ # SegFormer
42
  if "segformer" in self.model_key:
43
+ self.model = SegformerForSemanticSegmentation.from_pretrained(self.model_key).to(self.device).eval()
44
+ self.processor = SegformerFeatureExtractor.from_pretrained(self.model_key)
45
+
46
+ # DeepLabV3
47
  elif self.model_key == "deeplabv3_resnet50":
48
+ self.model = deeplabv3_resnet50(pretrained=True).to(self.device).eval()
49
+ self.processor = None
50
+
51
+ # CLIPSeg
52
+ elif "clipseg" in self.model_key:
53
+ self.model = CLIPSegForImageSegmentation.from_pretrained(self.model_key).to(self.device).eval()
54
+ self.processor = AutoProcessor.from_pretrained(self.model_key)
55
+
56
  else:
57
+ raise ValueError(f"Unsupported segmentation model key: '{self.model_key}'")
58
+
59
+ logger.info(f"Loaded segmentation model '{self.model_key}'")
60
 
61
+ def predict(self, image: Image.Image, prompt: str = "", **kwargs) -> np.ndarray:
62
  """
63
+ Perform segmentation.
64
 
65
  Args:
66
+ image (PIL.Image.Image): Input.
67
+ prompt (str): Only used for CLIPSeg.
68
  Returns:
69
+ np.ndarray: Segmentation mask (H×W).
70
  """
71
+ self._load_model()
72
 
73
+ # SegFormer path
74
  if "segformer" in self.model_key:
75
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
76
  outputs = self.model(**inputs)
77
  mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
78
  return mask
79
 
80
+ # DeepLabV3 path
81
+ if self.model_key == "deeplabv3_resnet50":
82
+ tf = transforms.ToTensor()
83
+ inp = tf(image).unsqueeze(0).to(self.device)
84
+ with torch.no_grad():
85
+ out = self.model(inp)["out"]
86
+ mask = out.argmax(1).squeeze().cpu().numpy()
87
+ return mask
88
+
89
+ # CLIPSeg path
90
+ if "clipseg" in self.model_key:
91
+ # CLIPSeg expects both text and image
92
+ inputs = self.processor(
93
+ text=[prompt], # list of prompts
94
+ images=[image], # list of images
95
+ return_tensors="pt"
96
+ ).to(self.device)
97
  with torch.no_grad():
98
+ outputs = self.model(**inputs)
99
+ # outputs.logits shape: (batch=1, height, width)
100
+ mask = outputs.logits.squeeze(0).cpu().numpy()
101
+ # Optionally threshold to binary:
102
+ # mask = (mask > kwargs.get("threshold", 0.5)).astype(np.uint8)
103
+ return mask
104
 
105
+ raise RuntimeError("Unreachable segmentation branch")
106
+
107
+ def draw(self, image: Image.Image, mask: np.ndarray, alpha=0.5) -> Image.Image:
108
  """
109
  Overlay the segmentation mask on the input image.
110
 
111
  Args:
112
+ image (PIL.Image.Image): Original.
113
  mask (np.ndarray): Segmentation mask.
114
  alpha (float): Blend strength.
 
115
  Returns:
116
+ PIL.Image.Image: Blended output.
117
  """
118
  logger.info("Drawing segmentation overlay")
119
+ # Normalize mask to 0–255
120
+ gray = ((mask - mask.min()) / (mask.ptp()) * 255).astype(np.uint8)
121
+ mask_img = Image.fromarray(gray).convert("L").resize(image.size)
122
+ # Make it RGB
123
+ color_mask = Image.merge("RGB", (mask_img, mask_img, mask_img))
124
+ return Image.blend(image, color_mask, alpha)