|
|
import torch |
|
|
import torch.nn as nn |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
class ConvBNRelu(nn.Module): |
|
|
""" |
|
|
Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation |
|
|
""" |
|
|
|
|
|
def __init__(self, channels_in, channels_out): |
|
|
super(ConvBNRelu, self).__init__() |
|
|
|
|
|
self.layers = nn.Sequential( |
|
|
nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1), |
|
|
nn.BatchNorm2d(channels_out, eps=1e-3), |
|
|
nn.GELU() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class HiddenDecoder(nn.Module): |
|
|
""" |
|
|
Decoder module. Receives a watermarked image and extracts the watermark. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_blocks, num_bits, channels, redundancy=1): |
|
|
super(HiddenDecoder, self).__init__() |
|
|
|
|
|
layers = [ConvBNRelu(3, channels)] |
|
|
for _ in range(num_blocks - 1): |
|
|
layers.append(ConvBNRelu(channels, channels)) |
|
|
|
|
|
layers.append(ConvBNRelu(channels, num_bits * redundancy)) |
|
|
layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1))) |
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy) |
|
|
|
|
|
self.num_bits = num_bits |
|
|
self.redundancy = redundancy |
|
|
|
|
|
def forward(self, img_w): |
|
|
x = self.layers(img_w) |
|
|
x = x.squeeze(-1).squeeze(-1) |
|
|
x = self.linear(x) |
|
|
|
|
|
x = x.view(-1, self.num_bits, self.redundancy) |
|
|
x = torch.sum(x, dim=-1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MsgExtractor(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__(self, hidden_decoder: nn.Module, in_features: int, out_features: int): |
|
|
super().__init__() |
|
|
self.hidden_decoder = hidden_decoder |
|
|
self.head = nn.Linear(in_features, out_features) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.hidden_decoder(x) |
|
|
x = self.head(x) |
|
|
return x |
|
|
|