File size: 1,967 Bytes
36ccd8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class ConvBNRelu(nn.Module):
"""
Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation
"""
def __init__(self, channels_in, channels_out):
super(ConvBNRelu, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1),
nn.BatchNorm2d(channels_out, eps=1e-3),
nn.GELU()
)
def forward(self, x):
return self.layers(x)
class HiddenDecoder(nn.Module):
"""
Decoder module. Receives a watermarked image and extracts the watermark.
"""
def __init__(self, num_blocks, num_bits, channels, redundancy=1):
super(HiddenDecoder, self).__init__()
layers = [ConvBNRelu(3, channels)]
for _ in range(num_blocks - 1):
layers.append(ConvBNRelu(channels, channels))
layers.append(ConvBNRelu(channels, num_bits * redundancy))
layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
self.layers = nn.Sequential(*layers)
self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy)
self.num_bits = num_bits
self.redundancy = redundancy
def forward(self, img_w):
x = self.layers(img_w) # b d 1 1
x = x.squeeze(-1).squeeze(-1) # b d
x = self.linear(x)
x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r
x = torch.sum(x, dim=-1) # b k r -> b k
return x
class MsgExtractor(nn.Module, PyTorchModelHubMixin):
def __init__(self, hidden_decoder: nn.Module, in_features: int, out_features: int):
super().__init__()
self.hidden_decoder = hidden_decoder
self.head = nn.Linear(in_features, out_features)
def forward(self, x):
x = self.hidden_decoder(x)
x = self.head(x)
return x
|