Spaces:
Running
Running
Added face aging model and interface with Gradio
Browse files- .gitignore +5 -0
- app.py +36 -4
- assets/mask1024.jpg +0 -0
- assets/mask512.jpg +0 -0
- models.py +98 -0
- requirements.txt +66 -0
- test_functions.py +100 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Virtual Environment
|
2 |
+
venv
|
3 |
+
|
4 |
+
__pycache__
|
5 |
+
|
app.py
CHANGED
@@ -1,7 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from models import UNet
|
4 |
+
from test_functions import process_image
|
5 |
+
from PIL import Image
|
6 |
import gradio as gr
|
7 |
|
8 |
+
from huggingface_hub import hf_hub_download
|
|
|
9 |
|
10 |
+
MODEL_PATH = hf_hub_download(repo_id="Robys01/face-aging", filename="best_unet_model.pth")
|
11 |
+
print(f"Model downloaded to {MODEL_PATH}")
|
12 |
+
|
13 |
+
model = UNet()
|
14 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False))
|
15 |
+
model.eval()
|
16 |
+
|
17 |
+
def age_image(image: Image.Image, source_age: int, target_age: int) -> Image.Image:
|
18 |
+
# Ensure the image is in RGB or grayscale; if not, convert it.
|
19 |
+
if image.mode not in ["RGB", "L"]:
|
20 |
+
print(f"Converting image from {image.mode} to RGB")
|
21 |
+
image = image.convert("RGB")
|
22 |
+
|
23 |
+
processed_image = process_image(model, image, source_age, target_age)
|
24 |
+
return processed_image
|
25 |
+
|
26 |
+
iface = gr.Interface(
|
27 |
+
fn=age_image,
|
28 |
+
inputs=[
|
29 |
+
gr.Image(type="pil", label="Input Image"),
|
30 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"),
|
31 |
+
gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age")
|
32 |
+
],
|
33 |
+
outputs=gr.Image(type="pil", label="Aged Image"),
|
34 |
+
title="Face Aging Demo",
|
35 |
+
description="Upload an image along with a source age approximation and a target age to generate an aged version of the face."
|
36 |
+
)
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
iface.launch()
|
assets/mask1024.jpg
ADDED
![]() |
assets/mask512.jpg
ADDED
![]() |
models.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import antialiased_cnns
|
4 |
+
|
5 |
+
class DownLayer(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels):
|
7 |
+
super(DownLayer, self).__init__()
|
8 |
+
self.layer = nn.Sequential(
|
9 |
+
nn.MaxPool2d(kernel_size=2, stride=1),
|
10 |
+
antialiased_cnns.BlurPool(in_channels, stride=2),
|
11 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
12 |
+
nn.LeakyReLU(inplace=True),
|
13 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
14 |
+
nn.LeakyReLU(inplace=True)
|
15 |
+
)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.layer(x)
|
19 |
+
|
20 |
+
|
21 |
+
class UpLayer(nn.Module):
|
22 |
+
def __init__(self, in_channels, out_channels):
|
23 |
+
super(UpLayer, self).__init__()
|
24 |
+
# Conv transpose upsampling
|
25 |
+
|
26 |
+
self.blur_upsample = nn.Sequential(
|
27 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
|
28 |
+
antialiased_cnns.BlurPool(out_channels, stride=1)
|
29 |
+
)
|
30 |
+
|
31 |
+
self.layer = nn.Sequential(
|
32 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
33 |
+
nn.LeakyReLU(inplace=True),
|
34 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
35 |
+
nn.LeakyReLU(inplace=True)
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x, skip):
|
39 |
+
x = self.blur_upsample(x)
|
40 |
+
x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
|
41 |
+
return self.layer(x)
|
42 |
+
|
43 |
+
|
44 |
+
class UNet(nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super(UNet, self).__init__()
|
47 |
+
self.init_conv = nn.Sequential(
|
48 |
+
nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
49 |
+
nn.LeakyReLU(inplace=True),
|
50 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
51 |
+
nn.LeakyReLU(inplace=True)
|
52 |
+
)
|
53 |
+
|
54 |
+
self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
|
55 |
+
self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
|
56 |
+
self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
|
57 |
+
self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
|
58 |
+
self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
|
59 |
+
self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
|
60 |
+
self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
|
61 |
+
self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
|
62 |
+
self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x0 = self.init_conv(x)
|
66 |
+
x1 = self.down1(x0)
|
67 |
+
x2 = self.down2(x1)
|
68 |
+
x3 = self.down3(x2)
|
69 |
+
x4 = self.down4(x3)
|
70 |
+
x = self.up1(x4, x3)
|
71 |
+
x = self.up2(x, x2)
|
72 |
+
x = self.up3(x, x1)
|
73 |
+
x = self.up4(x, x0)
|
74 |
+
x = self.final_conv(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class PatchGANDiscriminator(nn.Module):
|
79 |
+
def __init__(self, input_channels=3):
|
80 |
+
super(PatchGANDiscriminator, self).__init__()
|
81 |
+
self.model = nn.Sequential(
|
82 |
+
nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
|
83 |
+
nn.LeakyReLU(0.2, inplace=True),
|
84 |
+
|
85 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
86 |
+
nn.BatchNorm2d(128),
|
87 |
+
nn.LeakyReLU(0.2, inplace=True),
|
88 |
+
|
89 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
90 |
+
nn.BatchNorm2d(256),
|
91 |
+
nn.LeakyReLU(0.2, inplace=True),
|
92 |
+
|
93 |
+
nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
|
94 |
+
# Output layer with 1 channel for binary classification
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
return self.model(x)
|
requirements.txt
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://download.pytorch.org/whl/cpu/torch_stable.html
|
2 |
+
torch==2.3.1+cpu
|
3 |
+
torchvision==0.18.1+cpu
|
4 |
+
aiofiles==23.2.1
|
5 |
+
annotated-types==0.7.0
|
6 |
+
antialiased-cnns==0.3
|
7 |
+
anyio==4.8.0
|
8 |
+
certifi==2025.1.31
|
9 |
+
charset-normalizer==3.4.1
|
10 |
+
click==8.1.8
|
11 |
+
dlib==19.24.6
|
12 |
+
face-recognition==1.3.0
|
13 |
+
face_recognition_models==0.3.0
|
14 |
+
fastapi==0.115.8
|
15 |
+
ffmpy==0.5.0
|
16 |
+
filelock==3.17.0
|
17 |
+
fsspec==2025.2.0
|
18 |
+
gradio==5.15.0
|
19 |
+
gradio_client==1.7.0
|
20 |
+
h11==0.14.0
|
21 |
+
httpcore==1.0.7
|
22 |
+
httptools==0.6.4
|
23 |
+
httpx==0.28.1
|
24 |
+
huggingface-hub==0.28.1
|
25 |
+
idna==3.10
|
26 |
+
Jinja2==3.1.5
|
27 |
+
markdown-it-py==3.0.0
|
28 |
+
MarkupSafe==2.1.5
|
29 |
+
mdurl==0.1.2
|
30 |
+
mpmath==1.3.0
|
31 |
+
networkx==3.4.2
|
32 |
+
numpy==2.2.2
|
33 |
+
orjson==3.10.15
|
34 |
+
packaging==24.2
|
35 |
+
pandas==2.2.3
|
36 |
+
pillow==11.1.0
|
37 |
+
pydantic==2.10.6
|
38 |
+
pydantic_core==2.27.2
|
39 |
+
pydub==0.25.1
|
40 |
+
Pygments==2.19.1
|
41 |
+
python-dateutil==2.9.0.post0
|
42 |
+
python-dotenv==1.0.1
|
43 |
+
python-multipart==0.0.20
|
44 |
+
pytz==2025.1
|
45 |
+
PyYAML==6.0.2
|
46 |
+
requests==2.32.3
|
47 |
+
rich==13.9.4
|
48 |
+
ruff==0.9.6
|
49 |
+
safehttpx==0.1.6
|
50 |
+
semantic-version==2.10.0
|
51 |
+
setuptools==75.8.0
|
52 |
+
shellingham==1.5.4
|
53 |
+
six==1.17.0
|
54 |
+
sniffio==1.3.1
|
55 |
+
starlette==0.45.3
|
56 |
+
sympy==1.13.1
|
57 |
+
tomlkit==0.13.2
|
58 |
+
tqdm==4.67.1
|
59 |
+
typer==0.15.1
|
60 |
+
typing_extensions==4.12.2
|
61 |
+
tzdata==2025.1
|
62 |
+
urllib3==2.3.0
|
63 |
+
uvicorn==0.34.0
|
64 |
+
uvloop==0.21.0
|
65 |
+
watchfiles==1.0.4
|
66 |
+
websockets==14.2
|
test_functions.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import face_recognition
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from torchvision import transforms
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
|
9 |
+
small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
|
10 |
+
|
11 |
+
def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
|
12 |
+
"""
|
13 |
+
Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
|
14 |
+
"""
|
15 |
+
|
16 |
+
input_tensor = input_tensor.to(next(your_model.parameters()).device)
|
17 |
+
mask = mask.to(next(your_model.parameters()).device)
|
18 |
+
small_mask = small_mask.to(next(your_model.parameters()).device)
|
19 |
+
|
20 |
+
n, c, h, w = input_tensor.size()
|
21 |
+
output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
|
22 |
+
|
23 |
+
count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
|
24 |
+
|
25 |
+
add = 2 if window_size % stride != 0 else 1
|
26 |
+
|
27 |
+
for y in range(0, h - window_size + add, stride):
|
28 |
+
for x in range(0, w - window_size + add, stride):
|
29 |
+
window = input_tensor[:, :, y:y + window_size, x:x + window_size]
|
30 |
+
|
31 |
+
# Apply the same preprocessing as during training
|
32 |
+
input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
|
33 |
+
|
34 |
+
# Forward pass
|
35 |
+
with torch.no_grad():
|
36 |
+
output = your_model(input_variable)
|
37 |
+
|
38 |
+
output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
|
39 |
+
count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
|
40 |
+
|
41 |
+
count_tensor = torch.clamp(count_tensor, min=1.0)
|
42 |
+
|
43 |
+
# Average the overlapping regions
|
44 |
+
output_tensor /= count_tensor
|
45 |
+
|
46 |
+
# Apply mask
|
47 |
+
output_tensor *= mask
|
48 |
+
|
49 |
+
return output_tensor.cpu()
|
50 |
+
|
51 |
+
|
52 |
+
def process_image(your_model, image, source_age, target_age=0,
|
53 |
+
window_size=512, stride=256, steps=18):
|
54 |
+
|
55 |
+
input_size = (1024, 1024)
|
56 |
+
|
57 |
+
# image = face_recognition.load_image_file(filename)
|
58 |
+
image = np.array(image)
|
59 |
+
|
60 |
+
fl = face_recognition.face_locations(image)[0]
|
61 |
+
|
62 |
+
# calculate margins
|
63 |
+
margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
|
64 |
+
margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
|
65 |
+
margin_x = int((fl[1] - fl[3]) // (2 / .85))
|
66 |
+
margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
|
67 |
+
|
68 |
+
l_y = max([fl[0] - margin_y_t, 0])
|
69 |
+
r_y = min([fl[2] + margin_y_b, image.shape[0]])
|
70 |
+
l_x = max([fl[3] - margin_x, 0])
|
71 |
+
r_x = min([fl[1] + margin_x, image.shape[1]])
|
72 |
+
|
73 |
+
# crop image
|
74 |
+
cropped_image = image[l_y:r_y, l_x:r_x, :]
|
75 |
+
|
76 |
+
# Resizing
|
77 |
+
orig_size = cropped_image.shape[:2]
|
78 |
+
|
79 |
+
cropped_image = transforms.ToTensor()(cropped_image)
|
80 |
+
|
81 |
+
cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
|
82 |
+
|
83 |
+
source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
|
84 |
+
target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
|
85 |
+
input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
|
86 |
+
|
87 |
+
image = transforms.ToTensor()(image)
|
88 |
+
|
89 |
+
# performing actions on image
|
90 |
+
aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
|
91 |
+
|
92 |
+
# resize back to original size
|
93 |
+
aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
|
94 |
+
aged_cropped_image)
|
95 |
+
|
96 |
+
# re-apply
|
97 |
+
image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
|
98 |
+
image = torch.clamp(image, 0, 1)
|
99 |
+
|
100 |
+
return transforms.functional.to_pil_image(image)
|