# Author: @huzpsb 2024/Dec.17 # Licensed under the MIT License. import gradio as gr import numpy as np import torch from PIL import Image from scipy.ndimage import binary_dilation, binary_erosion torch.serialization.add_safe_globals([object]) # FAQ: # Why is the detection model needed since the meaningful_diff can already detect watermark masks? # The detection model is trained on a larger, more diverse dataset, and is more robust to different watermark types. # As you may have noticed, the detection model is even larger than the fix model. if torch.cuda.is_available(): device = torch.device('cuda') # Cuda envs have cpu so we don't have to force it detect = torch.load('./det_vgg.pth', weights_only=False).to(device) fix = torch.load('./fix_vgg2.pth', weights_only=False).to(device) else: device = torch.device('cpu') # Cpu envs don't have cuda so we have to force it detect = torch.load('./det_vgg.pth', map_location=device, weights_only=False) fix = torch.load('./fix_vgg2.pth', map_location=device, weights_only=False) ''' :param img: PIL Image :param threshold: float, 0-1. The threshold for the mask. A value lower than 0 means to repaint the whole image. A value of 0 means to repaint wherever the model is uncertain if it's a watermark. It can also be used to 'detoxify' the image. A value of 0.1 means to repaint wherever the model is certain it's a watermark. :param erosion: int, 1-inf. The size of the erosion structuring element. :param dilation: int, 1-inf. The size of the dilation structuring element. :param mask_only: bool. If True, only the mask is returned. :param meaningful_diff: int, 0-255. The threshold for the meaningful difference. :param meaningful_ed: int, 1-inf. The size of the erosion/dilation structuring element for the meaningful difference. :return: PIL Image ''' def fix_img(img0, threshold=0.5, erosion=2, dilation=6, mask_only=False, meaningful_diff=150, meaningful_ed=2): sz_max = 2048 scale = min(sz_max / img0.size[0], sz_max / img0.size[1]) if scale < 0.99: img = img0.convert('RGB').resize((int(img0.size[0] * scale), int(img0.size[1] * scale))) else: img = img0.convert('RGB') npa = np.array(img) blue = npa[:, :, 2] blue_torch = torch.tensor(blue).float().to(device) mask = detect(blue_torch.unsqueeze(0)).cpu().detach().squeeze().numpy() mask = np.where(mask < threshold, 0, 1) if erosion > 1: struct_elem = np.ones((erosion, erosion), dtype=bool) mask = binary_erosion(mask, structure=struct_elem, iterations=1) if dilation > 1: struct_elem = np.ones((dilation, dilation), dtype=bool) mask = binary_dilation(mask, structure=struct_elem, iterations=1) if mask_only: mask = 1 - mask return Image.fromarray((mask * 255).astype(np.uint8), 'L') r_torch = torch.tensor(npa[:, :, 0]).float().to(device) g_torch = torch.tensor(npa[:, :, 1]).float().to(device) b_torch = torch.tensor(npa[:, :, 2]).float().to(device) fixed_r = fix(r_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy() fixed_g = fix(g_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy() fixed_b = fix(b_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy() fixed_r = np.where(mask, fixed_r, npa[:, :, 0]) fixed_g = np.where(mask, fixed_g, npa[:, :, 1]) fixed_b = np.where(mask, fixed_b, npa[:, :, 2]) if meaningful_diff > 0: diff_r = np.abs(fixed_r - npa[:, :, 0]) diff_g = np.abs(fixed_g - npa[:, :, 1]) diff_b = np.abs(fixed_b - npa[:, :, 2]) diff = 3 / (1.0 / (diff_r + 1) + 1.0 / (diff_g + 1) + 1.0 / (diff_b + 1)) meaningful = diff > meaningful_diff if meaningful_ed > 1: struct_elem = np.ones((meaningful_ed, meaningful_ed), dtype=bool) meaningful = binary_erosion(meaningful, structure=struct_elem, iterations=1) meaningful = binary_dilation(meaningful, structure=struct_elem, iterations=1) fixed_r = np.where(meaningful, fixed_r, npa[:, :, 0]) fixed_g = np.where(meaningful, fixed_g, npa[:, :, 1]) fixed_b = np.where(meaningful, fixed_b, npa[:, :, 2]) fixed = np.stack([fixed_r, fixed_g, fixed_b], axis=2) fixed = np.clip(fixed, 0, 255) return Image.fromarray(fixed.astype(np.uint8), 'RGB') # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Watermark Remover") gr.Markdown("Author: huzpsb | Just a fun project ;3") # Inputs Section with gr.Row(): image_input = gr.Image(type="pil", label="Image to test") threshold = gr.Slider(minimum=-0.001, maximum=0.5, value=0.001, label="Threshold", step=0.001) erosion = gr.Slider(minimum=1, maximum=10, value=2, label="Erosion", step=1) dilation = gr.Slider(minimum=1, maximum=10, value=10, label="Dilation", step=1) mask_only = gr.Checkbox(value=False, label="Mask Only") meaningful_diff = gr.Slider(minimum=0, maximum=128, value=50, label="Meaningful Diff", step=1) meaningful_ed = gr.Slider(minimum=1, maximum=10, value=2, label="Meaningful ED", step=1) # Outputs Section with gr.Row(): image_output = gr.Image(type="pil", label="Output Image") process_button = gr.Button("Process") # Define the process function process_button.click(fix_img, inputs=[image_input, threshold, erosion, dilation, mask_only, meaningful_diff, meaningful_ed], outputs=[image_output]) gr.Examples(examples=["./sample.png"], inputs=[image_input], label="Examples (Click to populate)") demo.launch()