File size: 5,708 Bytes
bb25a6e
 
 
 
 
 
 
 
727fe2d
bb25a6e
 
 
 
 
 
 
 
727fe2d
 
bb25a6e
 
 
727fe2d
 
bb25a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0335a1
 
 
 
 
 
bb25a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0335a1
bb25a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0335a1
bb25a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Author: @huzpsb 2024/Dec.17
# Licensed under the MIT License.

import gradio as gr
import numpy as np
import torch
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion
torch.serialization.add_safe_globals([object])
# FAQ:
# Why is the detection model needed since the meaningful_diff can already detect watermark masks?
# The detection model is trained on a larger, more diverse dataset, and is more robust to different watermark types.
# As you may have noticed, the detection model is even larger than the fix model.

if torch.cuda.is_available():
    device = torch.device('cuda')
    # Cuda envs have cpu so we don't have to force it
    detect = torch.load('./det_vgg.pth', weights_only=False).to(device)
    fix = torch.load('./fix_vgg2.pth', weights_only=False).to(device)
else:
    device = torch.device('cpu')
    # Cpu envs don't have cuda so we have to force it
    detect = torch.load('./det_vgg.pth', map_location=device, weights_only=False)
    fix = torch.load('./fix_vgg2.pth', map_location=device, weights_only=False)

'''
:param img: PIL Image
:param threshold: float, 0-1. The threshold for the mask.
A value lower than 0 means to repaint the whole image.
A value of 0 means to repaint wherever the model is uncertain if it's a watermark.
It can also be used to 'detoxify' the image.
A value of 0.1 means to repaint wherever the model is certain it's a watermark.
:param erosion: int, 1-inf. The size of the erosion structuring element.
:param dilation: int, 1-inf. The size of the dilation structuring element.
:param mask_only: bool. If True, only the mask is returned.
:param meaningful_diff: int, 0-255. The threshold for the meaningful difference.
:param meaningful_ed: int, 1-inf. The size of the erosion/dilation structuring element for the meaningful difference.
:return: PIL Image
'''


def fix_img(img0, threshold=0.5, erosion=2, dilation=6, mask_only=False,
            meaningful_diff=150, meaningful_ed=2):
    sz_max = 2048
    scale = min(sz_max / img0.size[0], sz_max / img0.size[1])
    if scale < 0.99:
        img = img0.convert('RGB').resize((int(img0.size[0] * scale), int(img0.size[1] * scale)))
    else:
        img = img0.convert('RGB')
    npa = np.array(img)
    blue = npa[:, :, 2]
    blue_torch = torch.tensor(blue).float().to(device)

    mask = detect(blue_torch.unsqueeze(0)).cpu().detach().squeeze().numpy()
    mask = np.where(mask < threshold, 0, 1)

    if erosion > 1:
        struct_elem = np.ones((erosion, erosion), dtype=bool)
        mask = binary_erosion(mask, structure=struct_elem, iterations=1)

    if dilation > 1:
        struct_elem = np.ones((dilation, dilation), dtype=bool)
        mask = binary_dilation(mask, structure=struct_elem, iterations=1)

    if mask_only:
        mask = 1 - mask
        return Image.fromarray((mask * 255).astype(np.uint8), 'L')

    r_torch = torch.tensor(npa[:, :, 0]).float().to(device)
    g_torch = torch.tensor(npa[:, :, 1]).float().to(device)
    b_torch = torch.tensor(npa[:, :, 2]).float().to(device)
    fixed_r = fix(r_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy()
    fixed_g = fix(g_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy()
    fixed_b = fix(b_torch.unsqueeze(0).unsqueeze(0)).cpu().detach().squeeze().numpy()
    fixed_r = np.where(mask, fixed_r, npa[:, :, 0])
    fixed_g = np.where(mask, fixed_g, npa[:, :, 1])
    fixed_b = np.where(mask, fixed_b, npa[:, :, 2])

    if meaningful_diff > 0:
        diff_r = np.abs(fixed_r - npa[:, :, 0])
        diff_g = np.abs(fixed_g - npa[:, :, 1])
        diff_b = np.abs(fixed_b - npa[:, :, 2])
        diff = 3 / (1.0 / (diff_r + 1) + 1.0 / (diff_g + 1) + 1.0 / (diff_b + 1))
        meaningful = diff > meaningful_diff

        if meaningful_ed > 1:
            struct_elem = np.ones((meaningful_ed, meaningful_ed), dtype=bool)
            meaningful = binary_erosion(meaningful, structure=struct_elem, iterations=1)
            meaningful = binary_dilation(meaningful, structure=struct_elem, iterations=1)
        fixed_r = np.where(meaningful, fixed_r, npa[:, :, 0])
        fixed_g = np.where(meaningful, fixed_g, npa[:, :, 1])
        fixed_b = np.where(meaningful, fixed_b, npa[:, :, 2])

    fixed = np.stack([fixed_r, fixed_g, fixed_b], axis=2)
    fixed = np.clip(fixed, 0, 255)
    return Image.fromarray(fixed.astype(np.uint8), 'RGB')


# Define the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Watermark Remover")
    gr.Markdown("Author: huzpsb | Just a fun project ;3")

    # Inputs Section
    with gr.Row():
        image_input = gr.Image(type="pil", label="Image to test")
        threshold = gr.Slider(minimum=-0.001, maximum=0.5, value=0.001, label="Threshold", step=0.001)
        erosion = gr.Slider(minimum=1, maximum=10, value=2, label="Erosion", step=1)
        dilation = gr.Slider(minimum=1, maximum=10, value=10, label="Dilation", step=1)
        mask_only = gr.Checkbox(value=False, label="Mask Only")
        meaningful_diff = gr.Slider(minimum=0, maximum=128, value=50, label="Meaningful Diff", step=1)
        meaningful_ed = gr.Slider(minimum=1, maximum=10, value=2, label="Meaningful ED", step=1)

    # Outputs Section
    with gr.Row():
        image_output = gr.Image(type="pil", label="Output Image")
        process_button = gr.Button("Process")

    # Define the process function
    process_button.click(fix_img,
                         inputs=[image_input, threshold, erosion, dilation, mask_only, meaningful_diff, meaningful_ed],
                         outputs=[image_output])

    gr.Examples(examples=["./sample.png"],
                inputs=[image_input], label="Examples (Click to populate)")
demo.launch()