File size: 4,055 Bytes
6ee3369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

import functools
import cv2
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from einops import rearrange
from sketch_models import SimpleGenerator, UnetGenerator
from utils import common_input_validate, resize_image_with_pad, HWC3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_sketch(input_image, mode='anime', detect_resolution=512, output_type="pil", upscale_method="INTER_LANCZOS4", **kwargs):
        
        input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
        detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)

        H, W, C = input_image.shape
        Hn = 256 * int(np.ceil(float(H) / 256.0))
        Wn = 256 * int(np.ceil(float(W) / 256.0))

        assert detected_map.ndim == 3

        if mode == 'realistic':
            model = SimpleGenerator(3,1,3).to(device)
            model.load_state_dict(torch.load("models/sk_model.pth", map_location=device))
            model.eval()

            with torch.no_grad():
                image = torch.from_numpy(detected_map).float().to(device)
                image = image / 255.0
                image = rearrange(image, 'h w c -> 1 c h w')

                line = model(image)[0][0]
                line = line.cpu().numpy()
                line = (line * 255.0).clip(0, 255).astype(np.uint8)

            detected_map = HWC3(line)

            detected_map = remove_pad(detected_map)
            detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LANCZOS4)

        elif mode == 'anime':
            norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
            model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False).to(device)
            ckpt = torch.load("models/netG.pth", map_location=device)
            for key in list(ckpt.keys()):
                if 'module.' in key:
                    ckpt[key.replace('module.', '')] = ckpt[key]
                    del ckpt[key]
            model.load_state_dict(ckpt)
            model.eval()

            input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_LANCZOS4)

            with torch.no_grad():
                image_feed = torch.from_numpy(input_image).float().to(device)
                image_feed = image_feed / 127.5 - 1.0
                image_feed = rearrange(image_feed, 'h w c -> 1 c h w')

                line = model(image_feed)[0, 0] * 127.5 + 127.5
                line = line.cpu().numpy()
                line = line.clip(0, 255).astype(np.uint8)
            
            #A1111 uses INTER AREA for downscaling so ig that is the best choice
            detected_map = remove_pad(255 - detected_map)
            detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_LANCZOS4)
        
        else: # standard
            guassian_sigma=6.0
            intensity_threshold=8
            x = detected_map.astype(np.float32)
            g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
            intensity = np.min(g - x, axis=2).clip(0, 255)
            intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
            intensity *= 127
            detected_map = intensity.clip(0, 255).astype(np.uint8)
            
            detected_map = remove_pad(255 - detected_map)
            detected_map = cv2.resize(HWC3(detected_map), (W, H), interpolation=cv2.INTER_LANCZOS4)

        if output_type == "pil":
            detected_map = Image.fromarray(detected_map)
        
        return detected_map

iface = gr.Interface(
    fn=get_sketch,
    inputs=[
         gr.Image(type="numpy", label="Upload Image"),
         gr.Radio(["anime", "realistic", "standard"], label="Mode", info="Process methods"),
    ],
    outputs=gr.Image(type="numpy", label="Sketch Output"),
    title="Get a Sketch",
    description="Upload an image and get a simplified sketch"
)

if __name__ == "__main__":
    iface.launch()