Spaces:
Sleeping
Sleeping
import functools | |
import cv2 | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
from sketch_models import SimpleGenerator, UnetGenerator | |
from utils import common_input_validate, resize_image_with_pad, HWC3 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_sketch(input_image, mode='anime', detect_resolution=512, output_type="pil", upscale_method="INTER_LANCZOS4", **kwargs): | |
input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
H, W, C = input_image.shape | |
Hn = 256 * int(np.ceil(float(H) / 256.0)) | |
Wn = 256 * int(np.ceil(float(W) / 256.0)) | |
assert detected_map.ndim == 3 | |
if mode == 'realistic': | |
model = SimpleGenerator(3,1,3).to(device) | |
model.load_state_dict(torch.load("models/sk_model.pth", map_location=device)) | |
model.eval() | |
with torch.no_grad(): | |
image = torch.from_numpy(detected_map).float().to(device) | |
image = image / 255.0 | |
image = rearrange(image, 'h w c -> 1 c h w') | |
line = model(image)[0][0] | |
line = line.cpu().numpy() | |
line = (line * 255.0).clip(0, 255).astype(np.uint8) | |
detected_map = HWC3(line) | |
detected_map = remove_pad(detected_map) | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LANCZOS4) | |
elif mode == 'anime': | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) | |
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False).to(device) | |
ckpt = torch.load("models/netG.pth", map_location=device) | |
for key in list(ckpt.keys()): | |
if 'module.' in key: | |
ckpt[key.replace('module.', '')] = ckpt[key] | |
del ckpt[key] | |
model.load_state_dict(ckpt) | |
model.eval() | |
input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_LANCZOS4) | |
with torch.no_grad(): | |
image_feed = torch.from_numpy(input_image).float().to(device) | |
image_feed = image_feed / 127.5 - 1.0 | |
image_feed = rearrange(image_feed, 'h w c -> 1 c h w') | |
line = model(image_feed)[0, 0] * 127.5 + 127.5 | |
line = line.cpu().numpy() | |
line = line.clip(0, 255).astype(np.uint8) | |
#A1111 uses INTER AREA for downscaling so ig that is the best choice | |
detected_map = remove_pad(255 - detected_map) | |
detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_LANCZOS4) | |
else: # standard | |
guassian_sigma=6.0 | |
intensity_threshold=8 | |
x = detected_map.astype(np.float32) | |
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) | |
intensity = np.min(g - x, axis=2).clip(0, 255) | |
intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) | |
intensity *= 127 | |
detected_map = intensity.clip(0, 255).astype(np.uint8) | |
detected_map = remove_pad(255 - detected_map) | |
detected_map = cv2.resize(HWC3(detected_map), (W, H), interpolation=cv2.INTER_LANCZOS4) | |
if output_type == "pil": | |
detected_map = Image.fromarray(detected_map) | |
return detected_map | |
iface = gr.Interface( | |
fn=get_sketch, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload Image"), | |
gr.Radio(["anime", "realistic", "standard"], label="Mode", info="Process methods"), | |
], | |
outputs=gr.Image(type="numpy", label="Sketch Output"), | |
title="Get a Sketch", | |
description="Upload an image and get a simplified sketch" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |