File size: 4,219 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import os
import sys
import tempfile
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from joblib import Parallel, delayed
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image
from tqdm.auto import tqdm

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from hair_swap import get_parser, HairFast
from scripts.pp_train import Trainer
from utils.train import seed_everything
from utils.image_utils import list_image_files


class ImageException(Exception):
    def __init__(self, image, message="Return image before PP"):
        self.image = image
        self.message = message
        super().__init__(self.message)


def hairfast_wo_pp(hair_fast):
    class RaiseDownsample(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, image):
            image = ((image[0] + 1) / 2).clip(0, 1)
            raise ImageException(image)

    def blend_images(func):
        def wrapper(*args, **kwargs):
            try:
                func(*args, **kwargs)
            except ImageException as e:
                return e.image

        return wrapper

    hair_fast.blend.downsample_256 = RaiseDownsample()
    hair_fast.blend.blend_images = blend_images(hair_fast.blend.blend_images)


def load_image(path):
    return T.functional.to_tensor(Image.open(path))


@torch.no_grad()
def load_dataset_images(imgs, dataset_path):
    net_trainer = Trainer()
    source_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(args.FFHQ, img[0])) for img in tqdm(imgs))
    target_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(dataset_path, img[1])) for img in tqdm(imgs))
    tensors_dataloader = DataLoader(list(zip(source_images, target_images)), batch_size=64, pin_memory=False,
                                    shuffle=False, drop_last=False)
    source_files = [os.path.join(args.FFHQ, img[0]) for img in imgs]

    data = []
    total = 0
    for batch in tqdm(tensors_dataloader):
        source, target = [elem.to('cuda') for elem in batch]

        HS_D, _ = net_trainer.generate_mask(source)
        HT_D, HT_E = net_trainer.generate_mask(target)
        target_mask = (1 - HS_D) * (1 - HT_D)

        data.extend(list(zip(
            [source_files[total + i] for i in range(len(source))],
            net_trainer.downsample_256(target).clip(0, 1).cpu(),
            target_mask.cpu(),
            HT_E.cpu(),
        )))
        total += len(source)

    return data


def main(args):
    seed_everything(args.seed)

    # init HairFast
    model_parser = get_parser()
    model_args = model_parser.parse_args([])
    hair_fast = HairFast(model_args)
    hairfast_wo_pp(hair_fast)

    # generate dataset
    os.makedirs(args.output, exist_ok=True)
    images = list_image_files(args.FFHQ)
    face, shape, color = np.array_split(np.random.choice(images, size=3 * args.size), 3)

    exps = []
    for exp in zip(face, shape, color):
        imgs = map(lambda im: im.split('.')[0], exp)
        exps.append([exp[0], f"{'_'.join(imgs)}.png", exp])

    batch = 5_000
    left, right, idx = 0, min(len(exps), batch), 1
    while left < len(exps):
        with tempfile.TemporaryDirectory() as temp_dir:
            for exp in tqdm(exps[left:right]):
                im1, im2, im3 = exp[-1]
                image = hair_fast(args.FFHQ / im1, args.FFHQ / im2, args.FFHQ / im3)
                save_image(image, os.path.join(temp_dir, exp[1]))

            batch_data = load_dataset_images(exps[left:right], temp_dir)
            torch.save(batch_data, os.path.join(args.output, f'pp_part_{idx}.dataset'))
            left = right
            right = min(len(exps), right + batch)
            idx += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Blending dataset')
    parser.add_argument('--FFHQ', type=Path)
    parser.add_argument('--seed', type=int, default=3407)
    parser.add_argument('--size', type=int, default=10_000)
    parser.add_argument('--output', type=Path, default='input/pp_dataset')
    args = parser.parse_args()

    main(args)