Robys01 commited on
Commit
0fd4d4e
·
1 Parent(s): bef7a72

Added face aging model and interface with Gradio

Browse files
Files changed (7) hide show
  1. .gitignore +5 -0
  2. app.py +36 -4
  3. assets/mask1024.jpg +0 -0
  4. assets/mask512.jpg +0 -0
  5. models.py +98 -0
  6. requirements.txt +66 -0
  7. test_functions.py +100 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Virtual Environment
2
+ venv
3
+
4
+ __pycache__
5
+
app.py CHANGED
@@ -1,7 +1,39 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from models import UNet
4
+ from test_functions import process_image
5
+ from PIL import Image
6
  import gradio as gr
7
 
8
+ from huggingface_hub import hf_hub_download
 
9
 
10
+ MODEL_PATH = hf_hub_download(repo_id="Robys01/face-aging", filename="best_unet_model.pth")
11
+ print(f"Model downloaded to {MODEL_PATH}")
12
+
13
+ model = UNet()
14
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"), weights_only=False))
15
+ model.eval()
16
+
17
+ def age_image(image: Image.Image, source_age: int, target_age: int) -> Image.Image:
18
+ # Ensure the image is in RGB or grayscale; if not, convert it.
19
+ if image.mode not in ["RGB", "L"]:
20
+ print(f"Converting image from {image.mode} to RGB")
21
+ image = image.convert("RGB")
22
+
23
+ processed_image = process_image(model, image, source_age, target_age)
24
+ return processed_image
25
+
26
+ iface = gr.Interface(
27
+ fn=age_image,
28
+ inputs=[
29
+ gr.Image(type="pil", label="Input Image"),
30
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose the current age"),
31
+ gr.Slider(10, 90, value=70, step=1, label="Target age", info="Choose the desired age")
32
+ ],
33
+ outputs=gr.Image(type="pil", label="Aged Image"),
34
+ title="Face Aging Demo",
35
+ description="Upload an image along with a source age approximation and a target age to generate an aged version of the face."
36
+ )
37
+
38
+ if __name__ == "__main__":
39
+ iface.launch()
assets/mask1024.jpg ADDED
assets/mask512.jpg ADDED
models.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import antialiased_cnns
4
+
5
+ class DownLayer(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(DownLayer, self).__init__()
8
+ self.layer = nn.Sequential(
9
+ nn.MaxPool2d(kernel_size=2, stride=1),
10
+ antialiased_cnns.BlurPool(in_channels, stride=2),
11
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
12
+ nn.LeakyReLU(inplace=True),
13
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
14
+ nn.LeakyReLU(inplace=True)
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.layer(x)
19
+
20
+
21
+ class UpLayer(nn.Module):
22
+ def __init__(self, in_channels, out_channels):
23
+ super(UpLayer, self).__init__()
24
+ # Conv transpose upsampling
25
+
26
+ self.blur_upsample = nn.Sequential(
27
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
28
+ antialiased_cnns.BlurPool(out_channels, stride=1)
29
+ )
30
+
31
+ self.layer = nn.Sequential(
32
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
33
+ nn.LeakyReLU(inplace=True),
34
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
35
+ nn.LeakyReLU(inplace=True)
36
+ )
37
+
38
+ def forward(self, x, skip):
39
+ x = self.blur_upsample(x)
40
+ x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
41
+ return self.layer(x)
42
+
43
+
44
+ class UNet(nn.Module):
45
+ def __init__(self):
46
+ super(UNet, self).__init__()
47
+ self.init_conv = nn.Sequential(
48
+ nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
49
+ nn.LeakyReLU(inplace=True),
50
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
51
+ nn.LeakyReLU(inplace=True)
52
+ )
53
+
54
+ self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
55
+ self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
56
+ self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
57
+ self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
58
+ self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
59
+ self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
60
+ self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
61
+ self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
62
+ self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
63
+
64
+ def forward(self, x):
65
+ x0 = self.init_conv(x)
66
+ x1 = self.down1(x0)
67
+ x2 = self.down2(x1)
68
+ x3 = self.down3(x2)
69
+ x4 = self.down4(x3)
70
+ x = self.up1(x4, x3)
71
+ x = self.up2(x, x2)
72
+ x = self.up3(x, x1)
73
+ x = self.up4(x, x0)
74
+ x = self.final_conv(x)
75
+ return x
76
+
77
+
78
+ class PatchGANDiscriminator(nn.Module):
79
+ def __init__(self, input_channels=3):
80
+ super(PatchGANDiscriminator, self).__init__()
81
+ self.model = nn.Sequential(
82
+ nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+
85
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
86
+ nn.BatchNorm2d(128),
87
+ nn.LeakyReLU(0.2, inplace=True),
88
+
89
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
90
+ nn.BatchNorm2d(256),
91
+ nn.LeakyReLU(0.2, inplace=True),
92
+
93
+ nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
94
+ # Output layer with 1 channel for binary classification
95
+ )
96
+
97
+ def forward(self, x):
98
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/cpu/torch_stable.html
2
+ torch==2.3.1+cpu
3
+ torchvision==0.18.1+cpu
4
+ aiofiles==23.2.1
5
+ annotated-types==0.7.0
6
+ antialiased-cnns==0.3
7
+ anyio==4.8.0
8
+ certifi==2025.1.31
9
+ charset-normalizer==3.4.1
10
+ click==8.1.8
11
+ dlib==19.24.6
12
+ face-recognition==1.3.0
13
+ face_recognition_models==0.3.0
14
+ fastapi==0.115.8
15
+ ffmpy==0.5.0
16
+ filelock==3.17.0
17
+ fsspec==2025.2.0
18
+ gradio==5.15.0
19
+ gradio_client==1.7.0
20
+ h11==0.14.0
21
+ httpcore==1.0.7
22
+ httptools==0.6.4
23
+ httpx==0.28.1
24
+ huggingface-hub==0.28.1
25
+ idna==3.10
26
+ Jinja2==3.1.5
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==2.1.5
29
+ mdurl==0.1.2
30
+ mpmath==1.3.0
31
+ networkx==3.4.2
32
+ numpy==2.2.2
33
+ orjson==3.10.15
34
+ packaging==24.2
35
+ pandas==2.2.3
36
+ pillow==11.1.0
37
+ pydantic==2.10.6
38
+ pydantic_core==2.27.2
39
+ pydub==0.25.1
40
+ Pygments==2.19.1
41
+ python-dateutil==2.9.0.post0
42
+ python-dotenv==1.0.1
43
+ python-multipart==0.0.20
44
+ pytz==2025.1
45
+ PyYAML==6.0.2
46
+ requests==2.32.3
47
+ rich==13.9.4
48
+ ruff==0.9.6
49
+ safehttpx==0.1.6
50
+ semantic-version==2.10.0
51
+ setuptools==75.8.0
52
+ shellingham==1.5.4
53
+ six==1.17.0
54
+ sniffio==1.3.1
55
+ starlette==0.45.3
56
+ sympy==1.13.1
57
+ tomlkit==0.13.2
58
+ tqdm==4.67.1
59
+ typer==0.15.1
60
+ typing_extensions==4.12.2
61
+ tzdata==2025.1
62
+ urllib3==2.3.0
63
+ uvicorn==0.34.0
64
+ uvloop==0.21.0
65
+ watchfiles==1.0.4
66
+ websockets==14.2
test_functions.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import face_recognition
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+
8
+ mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
9
+ small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
10
+
11
+ def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
12
+ """
13
+ Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
14
+ """
15
+
16
+ input_tensor = input_tensor.to(next(your_model.parameters()).device)
17
+ mask = mask.to(next(your_model.parameters()).device)
18
+ small_mask = small_mask.to(next(your_model.parameters()).device)
19
+
20
+ n, c, h, w = input_tensor.size()
21
+ output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
22
+
23
+ count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
24
+
25
+ add = 2 if window_size % stride != 0 else 1
26
+
27
+ for y in range(0, h - window_size + add, stride):
28
+ for x in range(0, w - window_size + add, stride):
29
+ window = input_tensor[:, :, y:y + window_size, x:x + window_size]
30
+
31
+ # Apply the same preprocessing as during training
32
+ input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
33
+
34
+ # Forward pass
35
+ with torch.no_grad():
36
+ output = your_model(input_variable)
37
+
38
+ output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
39
+ count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
40
+
41
+ count_tensor = torch.clamp(count_tensor, min=1.0)
42
+
43
+ # Average the overlapping regions
44
+ output_tensor /= count_tensor
45
+
46
+ # Apply mask
47
+ output_tensor *= mask
48
+
49
+ return output_tensor.cpu()
50
+
51
+
52
+ def process_image(your_model, image, source_age, target_age=0,
53
+ window_size=512, stride=256, steps=18):
54
+
55
+ input_size = (1024, 1024)
56
+
57
+ # image = face_recognition.load_image_file(filename)
58
+ image = np.array(image)
59
+
60
+ fl = face_recognition.face_locations(image)[0]
61
+
62
+ # calculate margins
63
+ margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
64
+ margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
65
+ margin_x = int((fl[1] - fl[3]) // (2 / .85))
66
+ margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
67
+
68
+ l_y = max([fl[0] - margin_y_t, 0])
69
+ r_y = min([fl[2] + margin_y_b, image.shape[0]])
70
+ l_x = max([fl[3] - margin_x, 0])
71
+ r_x = min([fl[1] + margin_x, image.shape[1]])
72
+
73
+ # crop image
74
+ cropped_image = image[l_y:r_y, l_x:r_x, :]
75
+
76
+ # Resizing
77
+ orig_size = cropped_image.shape[:2]
78
+
79
+ cropped_image = transforms.ToTensor()(cropped_image)
80
+
81
+ cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
82
+
83
+ source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
84
+ target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
85
+ input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
86
+
87
+ image = transforms.ToTensor()(image)
88
+
89
+ # performing actions on image
90
+ aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
91
+
92
+ # resize back to original size
93
+ aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
94
+ aged_cropped_image)
95
+
96
+ # re-apply
97
+ image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
98
+ image = torch.clamp(image, 0, 1)
99
+
100
+ return transforms.functional.to_pil_image(image)