img2sketch / app.py
kairusann
first working version
6ee3369
import functools
import cv2
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from einops import rearrange
from sketch_models import SimpleGenerator, UnetGenerator
from utils import common_input_validate, resize_image_with_pad, HWC3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_sketch(input_image, mode='anime', detect_resolution=512, output_type="pil", upscale_method="INTER_LANCZOS4", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
H, W, C = input_image.shape
Hn = 256 * int(np.ceil(float(H) / 256.0))
Wn = 256 * int(np.ceil(float(W) / 256.0))
assert detected_map.ndim == 3
if mode == 'realistic':
model = SimpleGenerator(3,1,3).to(device)
model.load_state_dict(torch.load("models/sk_model.pth", map_location=device))
model.eval()
with torch.no_grad():
image = torch.from_numpy(detected_map).float().to(device)
image = image / 255.0
image = rearrange(image, 'h w c -> 1 c h w')
line = model(image)[0][0]
line = line.cpu().numpy()
line = (line * 255.0).clip(0, 255).astype(np.uint8)
detected_map = HWC3(line)
detected_map = remove_pad(detected_map)
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LANCZOS4)
elif mode == 'anime':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False).to(device)
ckpt = torch.load("models/netG.pth", map_location=device)
for key in list(ckpt.keys()):
if 'module.' in key:
ckpt[key.replace('module.', '')] = ckpt[key]
del ckpt[key]
model.load_state_dict(ckpt)
model.eval()
input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_LANCZOS4)
with torch.no_grad():
image_feed = torch.from_numpy(input_image).float().to(device)
image_feed = image_feed / 127.5 - 1.0
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
line = model(image_feed)[0, 0] * 127.5 + 127.5
line = line.cpu().numpy()
line = line.clip(0, 255).astype(np.uint8)
#A1111 uses INTER AREA for downscaling so ig that is the best choice
detected_map = remove_pad(255 - detected_map)
detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_LANCZOS4)
else: # standard
guassian_sigma=6.0
intensity_threshold=8
x = detected_map.astype(np.float32)
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
intensity = np.min(g - x, axis=2).clip(0, 255)
intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
intensity *= 127
detected_map = intensity.clip(0, 255).astype(np.uint8)
detected_map = remove_pad(255 - detected_map)
detected_map = cv2.resize(HWC3(detected_map), (W, H), interpolation=cv2.INTER_LANCZOS4)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map
iface = gr.Interface(
fn=get_sketch,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Radio(["anime", "realistic", "standard"], label="Mode", info="Process methods"),
],
outputs=gr.Image(type="numpy", label="Sketch Output"),
title="Get a Sketch",
description="Upload an image and get a simplified sketch"
)
if __name__ == "__main__":
iface.launch()