Spaces:
Running
Running
# baldhead.py | |
import os | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
import gradio as gr | |
# Keras imports (note: keras-contrib must be installed) | |
import keras.backend as K | |
from keras.layers import ( | |
Input, | |
Conv2D, | |
UpSampling2D, | |
LeakyReLU, | |
GlobalAveragePooling2D, | |
Dense, | |
Reshape, | |
Dropout, | |
Concatenate, | |
multiply, # ← Thêm import multiply | |
) | |
from keras.models import Model | |
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization | |
# RetinaFace + skimage for face alignment | |
from retinaface import RetinaFace | |
from skimage import transform as trans | |
# Hugging Face Hub helper | |
from huggingface_hub import hf_hub_download | |
# --- Face‐alignment helpers (giống code gốc) --- | |
image_size = [256, 256] | |
src_landmarks = np.array([ | |
[30.2946, 51.6963], | |
[65.5318, 51.5014], | |
[48.0252, 71.7366], | |
[33.5493, 92.3655], | |
[62.7299, 92.2041], | |
], dtype=np.float32) | |
src_landmarks[:, 0] += 8.0 | |
src_landmarks[:, 0] += 15.0 | |
src_landmarks[:, 1] += 30.0 | |
src_landmarks /= 112 | |
src_landmarks *= 200 | |
def list2array(values): | |
return np.array(list(values)) | |
def align_face(img: np.ndarray): | |
""" | |
Detect faces + landmarks in `img` via RetinaFace. | |
Returns lists of aligned face patches (256×256 RGB), | |
corresponding binary masks, and the transformation matrices. | |
""" | |
faces = RetinaFace.detect_faces(img) | |
bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces]) | |
landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces]) | |
white_canvas = np.ones(img.shape, dtype=np.uint8) * 255 | |
aligned_faces, masks, matrices = [], [], [] | |
if bboxes.shape[0] > 0: | |
for i in range(bboxes.shape[0]): | |
dst = landmarks[i] # detected landmarks | |
tform = trans.SimilarityTransform() | |
tform.estimate(dst, src_landmarks) | |
M = tform.params[0:2, :] | |
warped_face = cv2.warpAffine( | |
img, M, (image_size[1], image_size[0]), borderValue=0.0 | |
) | |
warped_mask = cv2.warpAffine( | |
white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0 | |
) | |
aligned_faces.append(warped_face) | |
masks.append(warped_mask) | |
matrices.append(tform.params[0:3, :]) | |
return aligned_faces, masks, matrices | |
def put_face_back( | |
orig_img: np.ndarray, | |
processed_faces: list[np.ndarray], | |
masks: list[np.ndarray], | |
matrices: list[np.ndarray], | |
): | |
""" | |
Warp each processed face back onto the original `orig_img` | |
using the inverse of the transformation matrices. | |
""" | |
result = orig_img.copy() | |
h, w = orig_img.shape[:2] | |
for i in range(len(processed_faces)): | |
invM = np.linalg.inv(matrices[i])[0:2] | |
warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0) | |
mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0) | |
binary_mask = (mask // 255).astype(np.uint8) | |
# Composite: result = result * (1 - mask) + warped * mask | |
result = result * (1 - binary_mask) | |
result = result.astype(np.uint8) | |
result = result + warped * binary_mask | |
return result | |
# ---------------------------- | |
# 2. GENERATOR ARCHITECTURE | |
# ---------------------------- | |
def squeeze_excite_block(x, ratio=4): | |
""" | |
Squeeze-and-Excitation block: channel-wise attention. | |
""" | |
init = x | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
filters = init.shape[channel_axis] | |
se_shape = (1, 1, filters) | |
se = GlobalAveragePooling2D()(init) | |
se = Reshape(se_shape)(se) | |
se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se) | |
se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se) | |
return multiply([init, se]) | |
def conv2d(layer_input, filters, f_size=4, bn=True, se=False): | |
""" | |
Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block) | |
""" | |
d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input) | |
d = LeakyReLU(alpha=0.2)(d) | |
if bn: | |
d = InstanceNormalization()(d) | |
if se: | |
d = squeeze_excite_block(d) | |
return d | |
def atrous(layer_input, filters, f_size=4, bn=True): | |
""" | |
Atrous (dilated) convolution block with dilation rates [2,4,8]. | |
""" | |
a_list = [] | |
for rate in [2, 4, 8]: | |
a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input) | |
a_list.append(a) | |
a = Concatenate()(a_list) | |
a = LeakyReLU(alpha=0.2)(a) | |
if bn: | |
a = InstanceNormalization()(a) | |
return a | |
def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0): | |
""" | |
Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip) | |
""" | |
u = UpSampling2D(size=2)(layer_input) | |
u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u) | |
if dropout_rate: | |
u = Dropout(dropout_rate)(u) | |
u = InstanceNormalization()(u) | |
u = Concatenate()([u, skip_input]) | |
return u | |
def build_generator(): | |
""" | |
Reconstruct the generator architecture exactly as in the notebook, | |
then return a Keras Model object. | |
""" | |
d0 = Input(shape=(256, 256, 3)) | |
gf = 64 | |
# Downsampling | |
d1 = conv2d(d0, gf, bn=False, se=True) | |
d2 = conv2d(d1, gf * 2, se=True) | |
d3 = conv2d(d2, gf * 4, se=True) | |
d4 = conv2d(d3, gf * 8) | |
d5 = conv2d(d4, gf * 8) | |
# Atrous block | |
a1 = atrous(d5, gf * 8) | |
# Upsampling | |
u3 = deconv2d(a1, d4, gf * 8) | |
u4 = deconv2d(u3, d3, gf * 4) | |
u5 = deconv2d(u4, d2, gf * 2) | |
u6 = deconv2d(u5, d1, gf) | |
# Final upsample + conv | |
u7 = UpSampling2D(size=2)(u6) | |
output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7) | |
model = Model(d0, output_img) | |
return model | |
# ---------------------------- | |
# 3. LOAD MODEL WEIGHTS | |
# ---------------------------- | |
HF_REPO_ID = "VanNguyen1214/baldhead" | |
HF_FILENAME = "model_G_5_170.hdf5" | |
HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
def load_generator_from_hub(): | |
""" | |
Download the .hdf5 weights from HF Hub into cache, | |
rebuild the generator, then load weights. | |
""" | |
local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN) | |
gen = build_generator() | |
gen.load_weights(local_path) | |
return gen | |
# Load once at startup | |
try: | |
GENERATOR = load_generator_from_hub() | |
print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}") | |
except Exception as e: | |
print("[ERROR] Could not load generator:", e) | |
GENERATOR = None | |
# ---------------------------- | |
# 4. INFERENCE FUNCTION | |
# ---------------------------- | |
def inference(image: Image.Image) -> Image.Image: | |
""" | |
Gradio-compatible inference function: | |
- Convert PIL→ numpy RGB | |
- Align faces | |
- For each face: normalize to [-1,1], run through generator, denormalize to uint8 | |
- Put processed faces back onto original image | |
- Return full-image PIL | |
""" | |
if GENERATOR is None: | |
return image | |
orig = np.array(image.convert("RGB")) | |
faces, masks, mats = align_face(orig) | |
if len(faces) == 0: | |
return image | |
processed_faces = [] | |
for face in faces: | |
face_input = face.astype(np.float32) | |
face_input = (face_input / 127.5) - 1.0 # scale to [-1,1] | |
face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3) | |
pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1] | |
pred = ((pred + 1.0) * 127.5).astype(np.uint8) | |
processed_faces.append(pred) | |
output_np = put_face_back(orig, processed_faces, masks, mats) | |
output_pil = Image.fromarray(output_np) | |
return output_pil | |