Spaces:
Running
Running
Update segmentation.py
Browse files- segmentation.py +98 -29
segmentation.py
CHANGED
@@ -1,87 +1,156 @@
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
import cv2
|
5 |
from PIL import Image
|
6 |
-
|
|
|
7 |
import mediapipe as mp
|
|
|
8 |
|
9 |
-
# SegFormer
|
10 |
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
|
11 |
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
|
|
|
|
|
12 |
mp_face_mesh = mp.solutions.face_mesh
|
13 |
|
14 |
def remove_hair_from_image(image: Image.Image) -> Image.Image:
|
|
|
|
|
|
|
15 |
rgb = image.convert("RGB")
|
16 |
-
arr = np.array(rgb)
|
|
|
17 |
|
|
|
18 |
inputs = processor(images=rgb, return_tensors="pt")
|
19 |
with torch.no_grad():
|
20 |
logits = model(**inputs).logits.cpu()
|
21 |
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
|
22 |
-
|
23 |
-
hair_mask = (
|
|
|
|
|
|
|
24 |
|
25 |
-
alpha = np.full((h, w), 255, np.uint8)
|
26 |
-
alpha[hair_mask > 0] = 0
|
27 |
rgba = np.dstack([arr, alpha])
|
28 |
return Image.fromarray(rgba)
|
29 |
|
30 |
def get_facemesh_mask(image: Image.Image) -> np.ndarray:
|
|
|
|
|
|
|
31 |
img = np.array(image.convert("RGB"))
|
32 |
h, w = img.shape[:2]
|
33 |
-
mask = np.zeros((h, w), np.uint8)
|
34 |
with mp_face_mesh.FaceMesh(
|
35 |
-
static_image_mode=True,
|
36 |
-
|
|
|
|
|
37 |
) as mesh:
|
38 |
res = mesh.process(img)
|
39 |
if res.multi_face_landmarks:
|
40 |
-
pts = [(int(lm.x*w), int(lm.y*h))
|
41 |
-
for lm in res.multi_face_landmarks[0].landmark]
|
42 |
hull = cv2.convexHull(np.array(pts, np.int32))
|
43 |
cv2.fillConvexPoly(mask, hull, 1)
|
44 |
return mask
|
45 |
|
46 |
-
def expand_forehead_mask(face_mask: np.ndarray, pct: float=0.2) -> np.ndarray:
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
y0, y1 = ys.min(), ys.max()
|
50 |
-
exp = int((y1-y0)*pct)
|
51 |
top = max(y0 - exp, 0)
|
52 |
out = np.zeros_like(face_mask)
|
53 |
-
out[top:top+(y1-y0)] = face_mask[y0:y1]
|
54 |
return out
|
55 |
|
56 |
def extract_face_and_forehead_no_hair(img: Image.Image) -> Image.Image:
|
|
|
|
|
|
|
57 |
rgb = img.convert("RGB")
|
58 |
-
arr = np.array(rgb)
|
|
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
with torch.no_grad():
|
63 |
-
logits = model(**
|
64 |
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
|
65 |
seg = up.argmax(dim=1)[0].numpy()
|
66 |
hair_mask = (seg == 2).astype(np.uint8)
|
67 |
|
68 |
-
#
|
69 |
fm = get_facemesh_mask(img)
|
70 |
fm_exp = expand_forehead_mask(fm, 0.2)
|
71 |
fore = cv2.bitwise_and(fm_exp, 1 - fm)
|
72 |
fore = cv2.bitwise_and(fore, 1 - hair_mask)
|
73 |
-
cm = ((fm + fore)>0).astype(np.uint8)
|
74 |
-
cm = cv2.GaussianBlur(cm.astype(np.float32),(3,3),0)
|
75 |
-
cm = (cm>0.5).astype(np.uint8)
|
76 |
|
77 |
alpha = (cm * 255).astype(np.uint8)
|
78 |
rgba = np.dstack([arr, alpha])
|
79 |
return Image.fromarray(rgba)
|
80 |
|
81 |
def remove_face_using_segmentation(img: Image.Image) -> Image.Image:
|
|
|
|
|
|
|
82 |
ff = extract_face_and_forehead_no_hair(img)
|
83 |
-
mask = np.array(ff)[...,3] > 0
|
84 |
ori = img.convert("RGBA")
|
85 |
arr = np.array(ori)
|
86 |
-
arr[...,3][mask] = 0
|
87 |
return Image.fromarray(arr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# segmentation.py
|
2 |
+
|
3 |
import numpy as np
|
|
|
|
|
4 |
import cv2
|
5 |
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
import mediapipe as mp
|
9 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
10 |
|
11 |
+
# SegFormer setup for hair segmentation
|
12 |
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
|
13 |
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
|
14 |
+
|
15 |
+
# MediaPipe FaceMesh for face+forehead mask
|
16 |
mp_face_mesh = mp.solutions.face_mesh
|
17 |
|
18 |
def remove_hair_from_image(image: Image.Image) -> Image.Image:
|
19 |
+
"""
|
20 |
+
Remove hair: return RGBA with hair area transparent.
|
21 |
+
"""
|
22 |
rgb = image.convert("RGB")
|
23 |
+
arr = np.array(rgb)
|
24 |
+
h, w = arr.shape[:2]
|
25 |
|
26 |
+
# SegFormer hair mask
|
27 |
inputs = processor(images=rgb, return_tensors="pt")
|
28 |
with torch.no_grad():
|
29 |
logits = model(**inputs).logits.cpu()
|
30 |
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
|
31 |
+
mask = up.argmax(dim=1)[0].numpy()
|
32 |
+
hair_mask = (mask == 2).astype(np.uint8)
|
33 |
+
|
34 |
+
alpha = np.full((h, w), 255, dtype=np.uint8)
|
35 |
+
alpha[hair_mask == 1] = 0
|
36 |
|
|
|
|
|
37 |
rgba = np.dstack([arr, alpha])
|
38 |
return Image.fromarray(rgba)
|
39 |
|
40 |
def get_facemesh_mask(image: Image.Image) -> np.ndarray:
|
41 |
+
"""
|
42 |
+
Return binary mask of face+forehead (no hair) using MediaPipe.
|
43 |
+
"""
|
44 |
img = np.array(image.convert("RGB"))
|
45 |
h, w = img.shape[:2]
|
46 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
47 |
with mp_face_mesh.FaceMesh(
|
48 |
+
static_image_mode=True,
|
49 |
+
max_num_faces=1,
|
50 |
+
refine_landmarks=True,
|
51 |
+
min_detection_confidence=0.5
|
52 |
) as mesh:
|
53 |
res = mesh.process(img)
|
54 |
if res.multi_face_landmarks:
|
55 |
+
pts = [(int(lm.x * w), int(lm.y * h)) for lm in res.multi_face_landmarks[0].landmark]
|
|
|
56 |
hull = cv2.convexHull(np.array(pts, np.int32))
|
57 |
cv2.fillConvexPoly(mask, hull, 1)
|
58 |
return mask
|
59 |
|
60 |
+
def expand_forehead_mask(face_mask: np.ndarray, pct: float = 0.2) -> np.ndarray:
|
61 |
+
"""
|
62 |
+
Expand face mask upward to include forehead region.
|
63 |
+
"""
|
64 |
+
ys = np.where(face_mask > 0)[0]
|
65 |
+
if ys.size == 0:
|
66 |
+
return face_mask
|
67 |
y0, y1 = ys.min(), ys.max()
|
68 |
+
exp = int((y1 - y0) * pct)
|
69 |
top = max(y0 - exp, 0)
|
70 |
out = np.zeros_like(face_mask)
|
71 |
+
out[top:top + (y1 - y0)] = face_mask[y0:y1]
|
72 |
return out
|
73 |
|
74 |
def extract_face_and_forehead_no_hair(img: Image.Image) -> Image.Image:
|
75 |
+
"""
|
76 |
+
Return RGBA where alpha=255 for face+forehead (no hair), alpha=0 elsewhere.
|
77 |
+
"""
|
78 |
rgb = img.convert("RGB")
|
79 |
+
arr = np.array(rgb)
|
80 |
+
h, w = arr.shape[:2]
|
81 |
|
82 |
+
# hair mask
|
83 |
+
inputs = processor(images=rgb, return_tensors="pt")
|
84 |
with torch.no_grad():
|
85 |
+
logits = model(**inputs).logits.cpu()
|
86 |
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
|
87 |
seg = up.argmax(dim=1)[0].numpy()
|
88 |
hair_mask = (seg == 2).astype(np.uint8)
|
89 |
|
90 |
+
# face+forehead mask
|
91 |
fm = get_facemesh_mask(img)
|
92 |
fm_exp = expand_forehead_mask(fm, 0.2)
|
93 |
fore = cv2.bitwise_and(fm_exp, 1 - fm)
|
94 |
fore = cv2.bitwise_and(fore, 1 - hair_mask)
|
95 |
+
cm = ((fm + fore) > 0).astype(np.uint8)
|
96 |
+
cm = cv2.GaussianBlur(cm.astype(np.float32), (3, 3), 0)
|
97 |
+
cm = (cm > 0.5).astype(np.uint8)
|
98 |
|
99 |
alpha = (cm * 255).astype(np.uint8)
|
100 |
rgba = np.dstack([arr, alpha])
|
101 |
return Image.fromarray(rgba)
|
102 |
|
103 |
def remove_face_using_segmentation(img: Image.Image) -> Image.Image:
|
104 |
+
"""
|
105 |
+
Remove face+forehead: return RGBA with hair-only (alpha=255 hair, alpha=0 face).
|
106 |
+
"""
|
107 |
ff = extract_face_and_forehead_no_hair(img)
|
108 |
+
mask = np.array(ff)[..., 3] > 0
|
109 |
ori = img.convert("RGBA")
|
110 |
arr = np.array(ori)
|
111 |
+
arr[..., 3][mask] = 0
|
112 |
return Image.fromarray(arr)
|
113 |
+
|
114 |
+
def get_bbox_from_alpha(rgba: Image.Image):
|
115 |
+
"""
|
116 |
+
Compute bounding box from alpha channel: (x1, y1, x2, y2) or None.
|
117 |
+
"""
|
118 |
+
arr = np.array(rgba)
|
119 |
+
alpha = arr[..., 3]
|
120 |
+
ys, xs = np.where(alpha > 0)
|
121 |
+
if ys.size == 0:
|
122 |
+
return None
|
123 |
+
x1, x2 = xs.min(), xs.max()
|
124 |
+
y1, y2 = ys.min(), ys.max()
|
125 |
+
return x1, y1, x2, y2
|
126 |
+
|
127 |
+
def compute_scale(w_bg, h_bg, w_src, h_src):
|
128 |
+
return ((w_bg / w_src) + (h_bg / h_src)) / 2
|
129 |
+
|
130 |
+
def compute_offset(bbox_bg, bbox_src, scale):
|
131 |
+
x1, y1, x2, y2 = bbox_bg
|
132 |
+
bg_cx = x1 + (x2 - x1) // 2
|
133 |
+
bg_cy = y1 + (y2 - y1) // 2
|
134 |
+
sx1, sy1, sx2, sy2 = bbox_src
|
135 |
+
src_cx = int((sx1 + (sx2 - sx1) // 2) * scale)
|
136 |
+
src_cy = int((sy1 + (sy2 - sy1) // 2) * scale)
|
137 |
+
return bg_cx - src_cx, bg_cy - src_cy
|
138 |
+
|
139 |
+
def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int, int]) -> Image.Image:
|
140 |
+
res = bg.copy()
|
141 |
+
x, y = offset
|
142 |
+
h, w = src.shape[:2]
|
143 |
+
x1, y1 = max(x, 0), max(y, 0)
|
144 |
+
x2 = min(x + w, bg.shape[1])
|
145 |
+
y2 = min(y + h, bg.shape[0])
|
146 |
+
if x1 >= x2 or y1 >= y2:
|
147 |
+
return Image.fromarray(res)
|
148 |
+
cs = src[y1 - y:y2 - y, x1 - x:x2 - x]
|
149 |
+
cd = res[y1:y2, x1:x2]
|
150 |
+
mask = cs[..., 3] > 0
|
151 |
+
if cd.shape[2] == 3:
|
152 |
+
cd[mask] = cs[mask][..., :3]
|
153 |
+
else:
|
154 |
+
cd[mask] = cs[mask]
|
155 |
+
res[y1:y2, x1:x2] = cd
|
156 |
+
return Image.fromarray(res)
|