Spaces:
Runtime error
Runtime error
transfer color for grayscale inputs.
Browse files- facelib/utils/face_restoration_helper.py +6 -4
- facelib/utils/misc.py +28 -0
- inference_codeformer.py +1 -1
facelib/utils/face_restoration_helper.py
CHANGED
|
@@ -6,7 +6,7 @@ from torchvision.transforms.functional import normalize
|
|
| 6 |
|
| 7 |
from facelib.detection import init_detection_model
|
| 8 |
from facelib.parsing import init_parsing_model
|
| 9 |
-
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
|
| 10 |
from basicsr.utils.misc import get_device
|
| 11 |
|
| 12 |
|
|
@@ -300,10 +300,12 @@ class FaceRestoreHelper(object):
|
|
| 300 |
torch.save(inverse_affine, save_path)
|
| 301 |
|
| 302 |
|
| 303 |
-
def add_restored_face(self,
|
| 304 |
if self.is_gray:
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
|
| 308 |
|
| 309 |
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
|
|
|
| 6 |
|
| 7 |
from facelib.detection import init_detection_model
|
| 8 |
from facelib.parsing import init_parsing_model
|
| 9 |
+
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
|
| 10 |
from basicsr.utils.misc import get_device
|
| 11 |
|
| 12 |
|
|
|
|
| 300 |
torch.save(inverse_affine, save_path)
|
| 301 |
|
| 302 |
|
| 303 |
+
def add_restored_face(self, restored_face, input_face=None):
|
| 304 |
if self.is_gray:
|
| 305 |
+
restored_face = bgr2gray(restored_face) # convert img into grayscale
|
| 306 |
+
if input_face is not None:
|
| 307 |
+
restored_face = adain_npy(restored_face, input_face) # transfer the color
|
| 308 |
+
self.restored_faces.append(restored_face)
|
| 309 |
|
| 310 |
|
| 311 |
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
facelib/utils/misc.py
CHANGED
|
@@ -172,3 +172,31 @@ def bgr2gray(img, out_channel=3):
|
|
| 172 |
if out_channel == 3:
|
| 173 |
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 174 |
return gray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
if out_channel == 3:
|
| 173 |
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 174 |
return gray
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
feat (numpy): 3D [w h c]s
|
| 181 |
+
"""
|
| 182 |
+
size = feat.shape
|
| 183 |
+
assert len(size) == 3, 'The input feature should be 3D tensor.'
|
| 184 |
+
c = size[2]
|
| 185 |
+
feat_var = feat.reshape(-1, c).var(axis=0) + eps
|
| 186 |
+
feat_std = np.sqrt(feat_var).reshape(1, 1, c)
|
| 187 |
+
feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
|
| 188 |
+
return feat_mean, feat_std
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def adain_npy(content_feat, style_feat):
|
| 192 |
+
"""Adaptive instance normalization for numpy.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
content_feat (numpy): The input feature.
|
| 196 |
+
style_feat (numpy): The reference feature.
|
| 197 |
+
"""
|
| 198 |
+
size = content_feat.shape
|
| 199 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
| 200 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
| 201 |
+
normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
|
| 202 |
+
return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
|
inference_codeformer.py
CHANGED
|
@@ -205,7 +205,7 @@ if __name__ == '__main__':
|
|
| 205 |
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
| 206 |
|
| 207 |
restored_face = restored_face.astype('uint8')
|
| 208 |
-
face_helper.add_restored_face(restored_face)
|
| 209 |
|
| 210 |
# paste_back
|
| 211 |
if not args.has_aligned:
|
|
|
|
| 205 |
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
| 206 |
|
| 207 |
restored_face = restored_face.astype('uint8')
|
| 208 |
+
face_helper.add_restored_face(restored_face, cropped_face)
|
| 209 |
|
| 210 |
# paste_back
|
| 211 |
if not args.has_aligned:
|