import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin class ConvBNRelu(nn.Module): """ Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation """ def __init__(self, channels_in, channels_out): super(ConvBNRelu, self).__init__() self.layers = nn.Sequential( nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1), nn.BatchNorm2d(channels_out, eps=1e-3), nn.GELU() ) def forward(self, x): return self.layers(x) class HiddenDecoder(nn.Module): """ Decoder module. Receives a watermarked image and extracts the watermark. """ def __init__(self, num_blocks, num_bits, channels, redundancy=1): super(HiddenDecoder, self).__init__() layers = [ConvBNRelu(3, channels)] for _ in range(num_blocks - 1): layers.append(ConvBNRelu(channels, channels)) layers.append(ConvBNRelu(channels, num_bits * redundancy)) layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1))) self.layers = nn.Sequential(*layers) self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy) self.num_bits = num_bits self.redundancy = redundancy def forward(self, img_w): x = self.layers(img_w) # b d 1 1 x = x.squeeze(-1).squeeze(-1) # b d x = self.linear(x) x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r x = torch.sum(x, dim=-1) # b k r -> b k return x class MsgExtractor(nn.Module, PyTorchModelHubMixin): def __init__(self, hidden_decoder: nn.Module, in_features: int, out_features: int): super().__init__() self.hidden_decoder = hidden_decoder self.head = nn.Linear(in_features, out_features) def forward(self, x): x = self.hidden_decoder(x) x = self.head(x) return x