ghep_image / baldhead.py
VanNguyen1214's picture
Update baldhead.py
f88f2c2 verified
# baldhead.py
import os
import cv2
import numpy as np
from PIL import Image
import tensorflow as tf
import gradio as gr
# Keras imports (note: keras-contrib must be installed)
import keras.backend as K
from keras.layers import (
Input,
Conv2D,
UpSampling2D,
LeakyReLU,
GlobalAveragePooling2D,
Dense,
Reshape,
Dropout,
Concatenate,
multiply, # ← Thêm import multiply
)
from keras.models import Model
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
# RetinaFace + skimage for face alignment
from retinaface import RetinaFace
from skimage import transform as trans
# Hugging Face Hub helper
from huggingface_hub import hf_hub_download
# --- Face‐alignment helpers (giống code gốc) ---
image_size = [256, 256]
src_landmarks = np.array([
[30.2946, 51.6963],
[65.5318, 51.5014],
[48.0252, 71.7366],
[33.5493, 92.3655],
[62.7299, 92.2041],
], dtype=np.float32)
src_landmarks[:, 0] += 8.0
src_landmarks[:, 0] += 15.0
src_landmarks[:, 1] += 30.0
src_landmarks /= 112
src_landmarks *= 200
def list2array(values):
return np.array(list(values))
def align_face(img: np.ndarray):
"""
Detect faces + landmarks in `img` via RetinaFace.
Returns lists of aligned face patches (256×256 RGB),
corresponding binary masks, and the transformation matrices.
"""
faces = RetinaFace.detect_faces(img)
bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces])
landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces])
white_canvas = np.ones(img.shape, dtype=np.uint8) * 255
aligned_faces, masks, matrices = [], [], []
if bboxes.shape[0] > 0:
for i in range(bboxes.shape[0]):
dst = landmarks[i] # detected landmarks
tform = trans.SimilarityTransform()
tform.estimate(dst, src_landmarks)
M = tform.params[0:2, :]
warped_face = cv2.warpAffine(
img, M, (image_size[1], image_size[0]), borderValue=0.0
)
warped_mask = cv2.warpAffine(
white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0
)
aligned_faces.append(warped_face)
masks.append(warped_mask)
matrices.append(tform.params[0:3, :])
return aligned_faces, masks, matrices
def put_face_back(
orig_img: np.ndarray,
processed_faces: list[np.ndarray],
masks: list[np.ndarray],
matrices: list[np.ndarray],
):
"""
Warp each processed face back onto the original `orig_img`
using the inverse of the transformation matrices.
"""
result = orig_img.copy()
h, w = orig_img.shape[:2]
for i in range(len(processed_faces)):
invM = np.linalg.inv(matrices[i])[0:2]
warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0)
mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0)
binary_mask = (mask // 255).astype(np.uint8)
# Composite: result = result * (1 - mask) + warped * mask
result = result * (1 - binary_mask)
result = result.astype(np.uint8)
result = result + warped * binary_mask
return result
# ----------------------------
# 2. GENERATOR ARCHITECTURE
# ----------------------------
def squeeze_excite_block(x, ratio=4):
"""
Squeeze-and-Excitation block: channel-wise attention.
"""
init = x
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
filters = init.shape[channel_axis]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se)
se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se)
return multiply([init, se])
def conv2d(layer_input, filters, f_size=4, bn=True, se=False):
"""
Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block)
"""
d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = InstanceNormalization()(d)
if se:
d = squeeze_excite_block(d)
return d
def atrous(layer_input, filters, f_size=4, bn=True):
"""
Atrous (dilated) convolution block with dilation rates [2,4,8].
"""
a_list = []
for rate in [2, 4, 8]:
a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input)
a_list.append(a)
a = Concatenate()(a_list)
a = LeakyReLU(alpha=0.2)(a)
if bn:
a = InstanceNormalization()(a)
return a
def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
"""
Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip)
"""
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u)
if dropout_rate:
u = Dropout(dropout_rate)(u)
u = InstanceNormalization()(u)
u = Concatenate()([u, skip_input])
return u
def build_generator():
"""
Reconstruct the generator architecture exactly as in the notebook,
then return a Keras Model object.
"""
d0 = Input(shape=(256, 256, 3))
gf = 64
# Downsampling
d1 = conv2d(d0, gf, bn=False, se=True)
d2 = conv2d(d1, gf * 2, se=True)
d3 = conv2d(d2, gf * 4, se=True)
d4 = conv2d(d3, gf * 8)
d5 = conv2d(d4, gf * 8)
# Atrous block
a1 = atrous(d5, gf * 8)
# Upsampling
u3 = deconv2d(a1, d4, gf * 8)
u4 = deconv2d(u3, d3, gf * 4)
u5 = deconv2d(u4, d2, gf * 2)
u6 = deconv2d(u5, d1, gf)
# Final upsample + conv
u7 = UpSampling2D(size=2)(u6)
output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7)
model = Model(d0, output_img)
return model
# ----------------------------
# 3. LOAD MODEL WEIGHTS
# ----------------------------
HF_REPO_ID = "VanNguyen1214/baldhead"
HF_FILENAME = "model_G_5_170.hdf5"
HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
def load_generator_from_hub():
"""
Download the .hdf5 weights from HF Hub into cache,
rebuild the generator, then load weights.
"""
local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN)
gen = build_generator()
gen.load_weights(local_path)
return gen
# Load once at startup
try:
GENERATOR = load_generator_from_hub()
print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}")
except Exception as e:
print("[ERROR] Could not load generator:", e)
GENERATOR = None
# ----------------------------
# 4. INFERENCE FUNCTION
# ----------------------------
def inference(image: Image.Image) -> Image.Image:
"""
Gradio-compatible inference function:
- Convert PIL→ numpy RGB
- Align faces
- For each face: normalize to [-1,1], run through generator, denormalize to uint8
- Put processed faces back onto original image
- Return full-image PIL
"""
if GENERATOR is None:
return image
orig = np.array(image.convert("RGB"))
faces, masks, mats = align_face(orig)
if len(faces) == 0:
return image
processed_faces = []
for face in faces:
face_input = face.astype(np.float32)
face_input = (face_input / 127.5) - 1.0 # scale to [-1,1]
face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3)
pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1]
pred = ((pred + 1.0) * 127.5).astype(np.uint8)
processed_faces.append(pred)
output_np = put_face_back(orig, processed_faces, masks, mats)
output_pil = Image.fromarray(output_np)
return output_pil