VanNguyen1214 commited on
Commit
a3f5fed
·
verified ·
1 Parent(s): 5f52a7e

Update segmentation.py

Browse files
Files changed (1) hide show
  1. segmentation.py +98 -29
segmentation.py CHANGED
@@ -1,87 +1,156 @@
 
 
1
  import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
  import cv2
5
  from PIL import Image
6
- from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
 
7
  import mediapipe as mp
 
8
 
9
- # SegFormer để phân vùng tóc
10
  processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
11
  model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
 
 
12
  mp_face_mesh = mp.solutions.face_mesh
13
 
14
  def remove_hair_from_image(image: Image.Image) -> Image.Image:
 
 
 
15
  rgb = image.convert("RGB")
16
- arr = np.array(rgb); h, w = arr.shape[:2]
 
17
 
 
18
  inputs = processor(images=rgb, return_tensors="pt")
19
  with torch.no_grad():
20
  logits = model(**inputs).logits.cpu()
21
  up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
22
- pred = up.argmax(dim=1)[0].numpy()
23
- hair_mask = (pred == 2).astype(np.uint8)
 
 
 
24
 
25
- alpha = np.full((h, w), 255, np.uint8)
26
- alpha[hair_mask > 0] = 0
27
  rgba = np.dstack([arr, alpha])
28
  return Image.fromarray(rgba)
29
 
30
  def get_facemesh_mask(image: Image.Image) -> np.ndarray:
 
 
 
31
  img = np.array(image.convert("RGB"))
32
  h, w = img.shape[:2]
33
- mask = np.zeros((h, w), np.uint8)
34
  with mp_face_mesh.FaceMesh(
35
- static_image_mode=True, max_num_faces=1,
36
- refine_landmarks=True, min_detection_confidence=0.5
 
 
37
  ) as mesh:
38
  res = mesh.process(img)
39
  if res.multi_face_landmarks:
40
- pts = [(int(lm.x*w), int(lm.y*h))
41
- for lm in res.multi_face_landmarks[0].landmark]
42
  hull = cv2.convexHull(np.array(pts, np.int32))
43
  cv2.fillConvexPoly(mask, hull, 1)
44
  return mask
45
 
46
- def expand_forehead_mask(face_mask: np.ndarray, pct: float=0.2) -> np.ndarray:
47
- ys = np.where(face_mask>0)[0]
48
- if ys.size==0: return face_mask
 
 
 
 
49
  y0, y1 = ys.min(), ys.max()
50
- exp = int((y1-y0)*pct)
51
  top = max(y0 - exp, 0)
52
  out = np.zeros_like(face_mask)
53
- out[top:top+(y1-y0)] = face_mask[y0:y1]
54
  return out
55
 
56
  def extract_face_and_forehead_no_hair(img: Image.Image) -> Image.Image:
 
 
 
57
  rgb = img.convert("RGB")
58
- arr = np.array(rgb); h, w = arr.shape[:2]
 
59
 
60
- # tóc
61
- inp = processor(images=rgb, return_tensors="pt")
62
  with torch.no_grad():
63
- logits = model(**inp).logits.cpu()
64
  up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
65
  seg = up.argmax(dim=1)[0].numpy()
66
  hair_mask = (seg == 2).astype(np.uint8)
67
 
68
- # mặt+trán
69
  fm = get_facemesh_mask(img)
70
  fm_exp = expand_forehead_mask(fm, 0.2)
71
  fore = cv2.bitwise_and(fm_exp, 1 - fm)
72
  fore = cv2.bitwise_and(fore, 1 - hair_mask)
73
- cm = ((fm + fore)>0).astype(np.uint8)
74
- cm = cv2.GaussianBlur(cm.astype(np.float32),(3,3),0)
75
- cm = (cm>0.5).astype(np.uint8)
76
 
77
  alpha = (cm * 255).astype(np.uint8)
78
  rgba = np.dstack([arr, alpha])
79
  return Image.fromarray(rgba)
80
 
81
  def remove_face_using_segmentation(img: Image.Image) -> Image.Image:
 
 
 
82
  ff = extract_face_and_forehead_no_hair(img)
83
- mask = np.array(ff)[...,3] > 0
84
  ori = img.convert("RGBA")
85
  arr = np.array(ori)
86
- arr[...,3][mask] = 0
87
  return Image.fromarray(arr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # segmentation.py
2
+
3
  import numpy as np
 
 
4
  import cv2
5
  from PIL import Image
6
+ import torch
7
+ import torch.nn.functional as F
8
  import mediapipe as mp
9
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
10
 
11
+ # SegFormer setup for hair segmentation
12
  processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
13
  model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
14
+
15
+ # MediaPipe FaceMesh for face+forehead mask
16
  mp_face_mesh = mp.solutions.face_mesh
17
 
18
  def remove_hair_from_image(image: Image.Image) -> Image.Image:
19
+ """
20
+ Remove hair: return RGBA with hair area transparent.
21
+ """
22
  rgb = image.convert("RGB")
23
+ arr = np.array(rgb)
24
+ h, w = arr.shape[:2]
25
 
26
+ # SegFormer hair mask
27
  inputs = processor(images=rgb, return_tensors="pt")
28
  with torch.no_grad():
29
  logits = model(**inputs).logits.cpu()
30
  up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
31
+ mask = up.argmax(dim=1)[0].numpy()
32
+ hair_mask = (mask == 2).astype(np.uint8)
33
+
34
+ alpha = np.full((h, w), 255, dtype=np.uint8)
35
+ alpha[hair_mask == 1] = 0
36
 
 
 
37
  rgba = np.dstack([arr, alpha])
38
  return Image.fromarray(rgba)
39
 
40
  def get_facemesh_mask(image: Image.Image) -> np.ndarray:
41
+ """
42
+ Return binary mask of face+forehead (no hair) using MediaPipe.
43
+ """
44
  img = np.array(image.convert("RGB"))
45
  h, w = img.shape[:2]
46
+ mask = np.zeros((h, w), dtype=np.uint8)
47
  with mp_face_mesh.FaceMesh(
48
+ static_image_mode=True,
49
+ max_num_faces=1,
50
+ refine_landmarks=True,
51
+ min_detection_confidence=0.5
52
  ) as mesh:
53
  res = mesh.process(img)
54
  if res.multi_face_landmarks:
55
+ pts = [(int(lm.x * w), int(lm.y * h)) for lm in res.multi_face_landmarks[0].landmark]
 
56
  hull = cv2.convexHull(np.array(pts, np.int32))
57
  cv2.fillConvexPoly(mask, hull, 1)
58
  return mask
59
 
60
+ def expand_forehead_mask(face_mask: np.ndarray, pct: float = 0.2) -> np.ndarray:
61
+ """
62
+ Expand face mask upward to include forehead region.
63
+ """
64
+ ys = np.where(face_mask > 0)[0]
65
+ if ys.size == 0:
66
+ return face_mask
67
  y0, y1 = ys.min(), ys.max()
68
+ exp = int((y1 - y0) * pct)
69
  top = max(y0 - exp, 0)
70
  out = np.zeros_like(face_mask)
71
+ out[top:top + (y1 - y0)] = face_mask[y0:y1]
72
  return out
73
 
74
  def extract_face_and_forehead_no_hair(img: Image.Image) -> Image.Image:
75
+ """
76
+ Return RGBA where alpha=255 for face+forehead (no hair), alpha=0 elsewhere.
77
+ """
78
  rgb = img.convert("RGB")
79
+ arr = np.array(rgb)
80
+ h, w = arr.shape[:2]
81
 
82
+ # hair mask
83
+ inputs = processor(images=rgb, return_tensors="pt")
84
  with torch.no_grad():
85
+ logits = model(**inputs).logits.cpu()
86
  up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
87
  seg = up.argmax(dim=1)[0].numpy()
88
  hair_mask = (seg == 2).astype(np.uint8)
89
 
90
+ # face+forehead mask
91
  fm = get_facemesh_mask(img)
92
  fm_exp = expand_forehead_mask(fm, 0.2)
93
  fore = cv2.bitwise_and(fm_exp, 1 - fm)
94
  fore = cv2.bitwise_and(fore, 1 - hair_mask)
95
+ cm = ((fm + fore) > 0).astype(np.uint8)
96
+ cm = cv2.GaussianBlur(cm.astype(np.float32), (3, 3), 0)
97
+ cm = (cm > 0.5).astype(np.uint8)
98
 
99
  alpha = (cm * 255).astype(np.uint8)
100
  rgba = np.dstack([arr, alpha])
101
  return Image.fromarray(rgba)
102
 
103
  def remove_face_using_segmentation(img: Image.Image) -> Image.Image:
104
+ """
105
+ Remove face+forehead: return RGBA with hair-only (alpha=255 hair, alpha=0 face).
106
+ """
107
  ff = extract_face_and_forehead_no_hair(img)
108
+ mask = np.array(ff)[..., 3] > 0
109
  ori = img.convert("RGBA")
110
  arr = np.array(ori)
111
+ arr[..., 3][mask] = 0
112
  return Image.fromarray(arr)
113
+
114
+ def get_bbox_from_alpha(rgba: Image.Image):
115
+ """
116
+ Compute bounding box from alpha channel: (x1, y1, x2, y2) or None.
117
+ """
118
+ arr = np.array(rgba)
119
+ alpha = arr[..., 3]
120
+ ys, xs = np.where(alpha > 0)
121
+ if ys.size == 0:
122
+ return None
123
+ x1, x2 = xs.min(), xs.max()
124
+ y1, y2 = ys.min(), ys.max()
125
+ return x1, y1, x2, y2
126
+
127
+ def compute_scale(w_bg, h_bg, w_src, h_src):
128
+ return ((w_bg / w_src) + (h_bg / h_src)) / 2
129
+
130
+ def compute_offset(bbox_bg, bbox_src, scale):
131
+ x1, y1, x2, y2 = bbox_bg
132
+ bg_cx = x1 + (x2 - x1) // 2
133
+ bg_cy = y1 + (y2 - y1) // 2
134
+ sx1, sy1, sx2, sy2 = bbox_src
135
+ src_cx = int((sx1 + (sx2 - sx1) // 2) * scale)
136
+ src_cy = int((sy1 + (sy2 - sy1) // 2) * scale)
137
+ return bg_cx - src_cx, bg_cy - src_cy
138
+
139
+ def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int, int]) -> Image.Image:
140
+ res = bg.copy()
141
+ x, y = offset
142
+ h, w = src.shape[:2]
143
+ x1, y1 = max(x, 0), max(y, 0)
144
+ x2 = min(x + w, bg.shape[1])
145
+ y2 = min(y + h, bg.shape[0])
146
+ if x1 >= x2 or y1 >= y2:
147
+ return Image.fromarray(res)
148
+ cs = src[y1 - y:y2 - y, x1 - x:x2 - x]
149
+ cd = res[y1:y2, x1:x2]
150
+ mask = cs[..., 3] > 0
151
+ if cd.shape[2] == 3:
152
+ cd[mask] = cs[mask][..., :3]
153
+ else:
154
+ cd[mask] = cs[mask]
155
+ res[y1:y2, x1:x2] = cd
156
+ return Image.fromarray(res)