import functools import cv2 import gradio as gr import torch import torch.nn as nn import numpy as np from PIL import Image from einops import rearrange from sketch_models import SimpleGenerator, UnetGenerator from utils import common_input_validate, resize_image_with_pad, HWC3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_sketch(input_image, mode='anime', detect_resolution=512, output_type="pil", upscale_method="INTER_LANCZOS4", **kwargs): input_image, output_type = common_input_validate(input_image, output_type, **kwargs) detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) H, W, C = input_image.shape Hn = 256 * int(np.ceil(float(H) / 256.0)) Wn = 256 * int(np.ceil(float(W) / 256.0)) assert detected_map.ndim == 3 if mode == 'realistic': model = SimpleGenerator(3,1,3).to(device) model.load_state_dict(torch.load("models/sk_model.pth", map_location=device)) model.eval() with torch.no_grad(): image = torch.from_numpy(detected_map).float().to(device) image = image / 255.0 image = rearrange(image, 'h w c -> 1 c h w') line = model(image)[0][0] line = line.cpu().numpy() line = (line * 255.0).clip(0, 255).astype(np.uint8) detected_map = HWC3(line) detected_map = remove_pad(detected_map) detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LANCZOS4) elif mode == 'anime': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False).to(device) ckpt = torch.load("models/netG.pth", map_location=device) for key in list(ckpt.keys()): if 'module.' in key: ckpt[key.replace('module.', '')] = ckpt[key] del ckpt[key] model.load_state_dict(ckpt) model.eval() input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_LANCZOS4) with torch.no_grad(): image_feed = torch.from_numpy(input_image).float().to(device) image_feed = image_feed / 127.5 - 1.0 image_feed = rearrange(image_feed, 'h w c -> 1 c h w') line = model(image_feed)[0, 0] * 127.5 + 127.5 line = line.cpu().numpy() line = line.clip(0, 255).astype(np.uint8) #A1111 uses INTER AREA for downscaling so ig that is the best choice detected_map = remove_pad(255 - detected_map) detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_LANCZOS4) else: # standard guassian_sigma=6.0 intensity_threshold=8 x = detected_map.astype(np.float32) g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) intensity = np.min(g - x, axis=2).clip(0, 255) intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) intensity *= 127 detected_map = intensity.clip(0, 255).astype(np.uint8) detected_map = remove_pad(255 - detected_map) detected_map = cv2.resize(HWC3(detected_map), (W, H), interpolation=cv2.INTER_LANCZOS4) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map iface = gr.Interface( fn=get_sketch, inputs=[ gr.Image(type="numpy", label="Upload Image"), gr.Radio(["anime", "realistic", "standard"], label="Mode", info="Process methods"), ], outputs=gr.Image(type="numpy", label="Sketch Output"), title="Get a Sketch", description="Upload an image and get a simplified sketch" ) if __name__ == "__main__": iface.launch()