Spaces:
Running
Running
File size: 5,708 Bytes
bb25a6e 727fe2d bb25a6e 727fe2d bb25a6e 727fe2d bb25a6e c0335a1 bb25a6e c0335a1 bb25a6e c0335a1 bb25a6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Author: @huzpsb 2024/Dec.17
# Licensed under the MIT License.
import gradio as gr
import numpy as np
import torch
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion
torch.serialization.add_safe_globals([object])
# FAQ:
# Why is the detection model needed since the meaningful_diff can already detect watermark masks?
# The detection model is trained on a larger, more diverse dataset, and is more robust to different watermark types.
# As you may have noticed, the detection model is even larger than the fix model.
if torch.cuda.is_available():
device = torch.device('cuda')
# Cuda envs have cpu so we don't have to force it
detect = torch.load('./det_vgg.pth', weights_only=False).to(device)
fix = torch.load('./fix_vgg2.pth', weights_only=False).to(device)
else:
device = torch.device('cpu')
# Cpu envs don't have cuda so we have to force it
detect = torch.load('./det_vgg.pth', map_location=device, weights_only=False)
fix = torch.load('./fix_vgg2.pth', map_location=device, weights_only=False)
'''
:param img: PIL Image
:param threshold: float, 0-1. The threshold for the mask.
A value lower than 0 means to repaint the whole image.
A value of 0 means to repaint wherever the model is uncertain if it's a watermark.
It can also be used to 'detoxify' the image.
A value of 0.1 means to repaint wherever the model is certain it's a watermark.
:param erosion: int, 1-inf. The size of the erosion structuring element.
:param dilation: int, 1-inf. The size of the dilation structuring element.
:param mask_only: bool. If True, only the mask is returned.
:param meaningful_diff: int, 0-255. The threshold for the meaningful difference.
:param meaningful_ed: int, 1-inf. The size of the erosion/dilation structuring element for the meaningful difference.
:return: PIL Image
'''
def fix_img(img0, threshold=0.5, erosion=2, dilation=6, mask_only=False,
meaningful_diff=150, meaningful_ed=2):
sz_max = 2048
scale = min(sz_max / img0.size[0], sz_max / img0.size[1])
if scale < 0.99:
img = img0.convert('RGB').resize((int(img0.size[0] * scale), int(img0.size[1] * scale)))
else:
img = img0.convert('RGB')
npa = np.array(img)
blue = npa[:, :, 2]
blue_torch = torch.tensor(blue).float().to(device)
mask = detect(blue_torch.unsqueeze(0)).cpu().detach().squeeze().numpy()
mask = np.where(mask < threshold, 0, 1)
if erosion > 1:
struct_elem = np.ones((erosion, erosion), dtype=bool)
mask = binary_erosion(mask, structure=struct_elem, iterations=1)
if dilation > 1:
struct_elem = np.ones((dilation, dilation), dtype=bool)
mask = binary_dilation(mask, structure=struct_elem, iterations=1)
if mask_only:
mask = 1 - mask
return Image.fromarray((mask * 255).astype(np.uint8), 'L')
r_torch = torch.tensor(npa[:, :, 0]).float().to(device)
g_torch = torch.tensor(npa[:, :, 1]).float().to(device)
b_torch = torch.tensor(npa[:, :, 2]).float().to(device)
fixed_r = fix(r_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy()
fixed_g = fix(g_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy()
fixed_b = fix(b_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy()
fixed_r = np.where(mask, fixed_r, npa[:, :, 0])
fixed_g = np.where(mask, fixed_g, npa[:, :, 1])
fixed_b = np.where(mask, fixed_b, npa[:, :, 2])
if meaningful_diff > 0:
diff_r = np.abs(fixed_r - npa[:, :, 0])
diff_g = np.abs(fixed_g - npa[:, :, 1])
diff_b = np.abs(fixed_b - npa[:, :, 2])
diff = 3 / (1.0 / (diff_r + 1) + 1.0 / (diff_g + 1) + 1.0 / (diff_b + 1))
meaningful = diff > meaningful_diff
if meaningful_ed > 1:
struct_elem = np.ones((meaningful_ed, meaningful_ed), dtype=bool)
meaningful = binary_erosion(meaningful, structure=struct_elem, iterations=1)
meaningful = binary_dilation(meaningful, structure=struct_elem, iterations=1)
fixed_r = np.where(meaningful, fixed_r, npa[:, :, 0])
fixed_g = np.where(meaningful, fixed_g, npa[:, :, 1])
fixed_b = np.where(meaningful, fixed_b, npa[:, :, 2])
fixed = np.stack([fixed_r, fixed_g, fixed_b], axis=2)
fixed = np.clip(fixed, 0, 255)
return Image.fromarray(fixed.astype(np.uint8), 'RGB')
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Watermark Remover")
gr.Markdown("Author: huzpsb | Just a fun project ;3")
# Inputs Section
with gr.Row():
image_input = gr.Image(type="pil", label="Image to test")
threshold = gr.Slider(minimum=-0.001, maximum=0.5, value=0.001, label="Threshold", step=0.001)
erosion = gr.Slider(minimum=1, maximum=10, value=2, label="Erosion", step=1)
dilation = gr.Slider(minimum=1, maximum=10, value=10, label="Dilation", step=1)
mask_only = gr.Checkbox(value=False, label="Mask Only")
meaningful_diff = gr.Slider(minimum=0, maximum=128, value=50, label="Meaningful Diff", step=1)
meaningful_ed = gr.Slider(minimum=1, maximum=10, value=2, label="Meaningful ED", step=1)
# Outputs Section
with gr.Row():
image_output = gr.Image(type="pil", label="Output Image")
process_button = gr.Button("Process")
# Define the process function
process_button.click(fix_img,
inputs=[image_input, threshold, erosion, dilation, mask_only, meaningful_diff, meaningful_ed],
outputs=[image_output])
gr.Examples(examples=["./sample.png"],
inputs=[image_input], label="Examples (Click to populate)")
demo.launch()
|