Spaces:
Running
Running
# Author: @huzpsb 2024/Dec.17 | |
# Licensed under the MIT License. | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from scipy.ndimage import binary_dilation, binary_erosion | |
torch.serialization.add_safe_globals([object]) | |
# FAQ: | |
# Why is the detection model needed since the meaningful_diff can already detect watermark masks? | |
# The detection model is trained on a larger, more diverse dataset, and is more robust to different watermark types. | |
# As you may have noticed, the detection model is even larger than the fix model. | |
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
# Cuda envs have cpu so we don't have to force it | |
detect = torch.load('./det_vgg.pth', weights_only=False).to(device) | |
fix = torch.load('./fix_vgg2.pth', weights_only=False).to(device) | |
else: | |
device = torch.device('cpu') | |
# Cpu envs don't have cuda so we have to force it | |
detect = torch.load('./det_vgg.pth', map_location=device, weights_only=False) | |
fix = torch.load('./fix_vgg2.pth', map_location=device, weights_only=False) | |
''' | |
:param img: PIL Image | |
:param threshold: float, 0-1. The threshold for the mask. | |
A value lower than 0 means to repaint the whole image. | |
A value of 0 means to repaint wherever the model is uncertain if it's a watermark. | |
It can also be used to 'detoxify' the image. | |
A value of 0.1 means to repaint wherever the model is certain it's a watermark. | |
:param erosion: int, 1-inf. The size of the erosion structuring element. | |
:param dilation: int, 1-inf. The size of the dilation structuring element. | |
:param mask_only: bool. If True, only the mask is returned. | |
:param meaningful_diff: int, 0-255. The threshold for the meaningful difference. | |
:param meaningful_ed: int, 1-inf. The size of the erosion/dilation structuring element for the meaningful difference. | |
:return: PIL Image | |
''' | |
def fix_img(img0, threshold=0.5, erosion=2, dilation=6, mask_only=False, | |
meaningful_diff=150, meaningful_ed=2): | |
sz_max = 2048 | |
scale = min(sz_max / img0.size[0], sz_max / img0.size[1]) | |
if scale < 0.99: | |
img = img0.convert('RGB').resize((int(img0.size[0] * scale), int(img0.size[1] * scale))) | |
else: | |
img = img0.convert('RGB') | |
npa = np.array(img) | |
blue = npa[:, :, 2] | |
blue_torch = torch.tensor(blue).float().to(device) | |
mask = detect(blue_torch.unsqueeze(0)).cpu().detach().squeeze().numpy() | |
mask = np.where(mask < threshold, 0, 1) | |
if erosion > 1: | |
struct_elem = np.ones((erosion, erosion), dtype=bool) | |
mask = binary_erosion(mask, structure=struct_elem, iterations=1) | |
if dilation > 1: | |
struct_elem = np.ones((dilation, dilation), dtype=bool) | |
mask = binary_dilation(mask, structure=struct_elem, iterations=1) | |
if mask_only: | |
mask = 1 - mask | |
return Image.fromarray((mask * 255).astype(np.uint8), 'L') | |
r_torch = torch.tensor(npa[:, :, 0]).float().to(device) | |
g_torch = torch.tensor(npa[:, :, 1]).float().to(device) | |
b_torch = torch.tensor(npa[:, :, 2]).float().to(device) | |
fixed_r = fix(r_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy() | |
fixed_g = fix(g_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy() | |
fixed_b = fix(b_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy() | |
fixed_r = np.where(mask, fixed_r, npa[:, :, 0]) | |
fixed_g = np.where(mask, fixed_g, npa[:, :, 1]) | |
fixed_b = np.where(mask, fixed_b, npa[:, :, 2]) | |
if meaningful_diff > 0: | |
diff_r = np.abs(fixed_r - npa[:, :, 0]) | |
diff_g = np.abs(fixed_g - npa[:, :, 1]) | |
diff_b = np.abs(fixed_b - npa[:, :, 2]) | |
diff = 3 / (1.0 / (diff_r + 1) + 1.0 / (diff_g + 1) + 1.0 / (diff_b + 1)) | |
meaningful = diff > meaningful_diff | |
if meaningful_ed > 1: | |
struct_elem = np.ones((meaningful_ed, meaningful_ed), dtype=bool) | |
meaningful = binary_erosion(meaningful, structure=struct_elem, iterations=1) | |
meaningful = binary_dilation(meaningful, structure=struct_elem, iterations=1) | |
fixed_r = np.where(meaningful, fixed_r, npa[:, :, 0]) | |
fixed_g = np.where(meaningful, fixed_g, npa[:, :, 1]) | |
fixed_b = np.where(meaningful, fixed_b, npa[:, :, 2]) | |
fixed = np.stack([fixed_r, fixed_g, fixed_b], axis=2) | |
fixed = np.clip(fixed, 0, 255) | |
return Image.fromarray(fixed.astype(np.uint8), 'RGB') | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Watermark Remover") | |
gr.Markdown("Author: huzpsb | Just a fun project ;3") | |
# Inputs Section | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Image to test") | |
threshold = gr.Slider(minimum=-0.001, maximum=0.5, value=0.001, label="Threshold", step=0.001) | |
erosion = gr.Slider(minimum=1, maximum=10, value=2, label="Erosion", step=1) | |
dilation = gr.Slider(minimum=1, maximum=10, value=10, label="Dilation", step=1) | |
mask_only = gr.Checkbox(value=False, label="Mask Only") | |
meaningful_diff = gr.Slider(minimum=0, maximum=128, value=50, label="Meaningful Diff", step=1) | |
meaningful_ed = gr.Slider(minimum=1, maximum=10, value=2, label="Meaningful ED", step=1) | |
# Outputs Section | |
with gr.Row(): | |
image_output = gr.Image(type="pil", label="Output Image") | |
process_button = gr.Button("Process") | |
# Define the process function | |
process_button.click(fix_img, | |
inputs=[image_input, threshold, erosion, dilation, mask_only, meaningful_diff, meaningful_ed], | |
outputs=[image_output]) | |
gr.Examples(examples=["./sample.png"], | |
inputs=[image_input], label="Examples (Click to populate)") | |
demo.launch() | |