Spaces:
Running
Running
File size: 4,055 Bytes
6ee3369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import functools
import cv2
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from einops import rearrange
from sketch_models import SimpleGenerator, UnetGenerator
from utils import common_input_validate, resize_image_with_pad, HWC3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_sketch(input_image, mode='anime', detect_resolution=512, output_type="pil", upscale_method="INTER_LANCZOS4", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
H, W, C = input_image.shape
Hn = 256 * int(np.ceil(float(H) / 256.0))
Wn = 256 * int(np.ceil(float(W) / 256.0))
assert detected_map.ndim == 3
if mode == 'realistic':
model = SimpleGenerator(3,1,3).to(device)
model.load_state_dict(torch.load("models/sk_model.pth", map_location=device))
model.eval()
with torch.no_grad():
image = torch.from_numpy(detected_map).float().to(device)
image = image / 255.0
image = rearrange(image, 'h w c -> 1 c h w')
line = model(image)[0][0]
line = line.cpu().numpy()
line = (line * 255.0).clip(0, 255).astype(np.uint8)
detected_map = HWC3(line)
detected_map = remove_pad(detected_map)
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LANCZOS4)
elif mode == 'anime':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False).to(device)
ckpt = torch.load("models/netG.pth", map_location=device)
for key in list(ckpt.keys()):
if 'module.' in key:
ckpt[key.replace('module.', '')] = ckpt[key]
del ckpt[key]
model.load_state_dict(ckpt)
model.eval()
input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_LANCZOS4)
with torch.no_grad():
image_feed = torch.from_numpy(input_image).float().to(device)
image_feed = image_feed / 127.5 - 1.0
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
line = model(image_feed)[0, 0] * 127.5 + 127.5
line = line.cpu().numpy()
line = line.clip(0, 255).astype(np.uint8)
#A1111 uses INTER AREA for downscaling so ig that is the best choice
detected_map = remove_pad(255 - detected_map)
detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_LANCZOS4)
else: # standard
guassian_sigma=6.0
intensity_threshold=8
x = detected_map.astype(np.float32)
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
intensity = np.min(g - x, axis=2).clip(0, 255)
intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
intensity *= 127
detected_map = intensity.clip(0, 255).astype(np.uint8)
detected_map = remove_pad(255 - detected_map)
detected_map = cv2.resize(HWC3(detected_map), (W, H), interpolation=cv2.INTER_LANCZOS4)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map
iface = gr.Interface(
fn=get_sketch,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Radio(["anime", "realistic", "standard"], label="Mode", info="Process methods"),
],
outputs=gr.Image(type="numpy", label="Sketch Output"),
title="Get a Sketch",
description="Upload an image and get a simplified sketch"
)
if __name__ == "__main__":
iface.launch()
|