StableSignatureDecoder / modeling_message_extractor.py
ESmike's picture
chore: add message extractor classes.
36ccd8b verified
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class ConvBNRelu(nn.Module):
"""
Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation
"""
def __init__(self, channels_in, channels_out):
super(ConvBNRelu, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1),
nn.BatchNorm2d(channels_out, eps=1e-3),
nn.GELU()
)
def forward(self, x):
return self.layers(x)
class HiddenDecoder(nn.Module):
"""
Decoder module. Receives a watermarked image and extracts the watermark.
"""
def __init__(self, num_blocks, num_bits, channels, redundancy=1):
super(HiddenDecoder, self).__init__()
layers = [ConvBNRelu(3, channels)]
for _ in range(num_blocks - 1):
layers.append(ConvBNRelu(channels, channels))
layers.append(ConvBNRelu(channels, num_bits * redundancy))
layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
self.layers = nn.Sequential(*layers)
self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy)
self.num_bits = num_bits
self.redundancy = redundancy
def forward(self, img_w):
x = self.layers(img_w) # b d 1 1
x = x.squeeze(-1).squeeze(-1) # b d
x = self.linear(x)
x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r
x = torch.sum(x, dim=-1) # b k r -> b k
return x
class MsgExtractor(nn.Module, PyTorchModelHubMixin):
def __init__(self, hidden_decoder: nn.Module, in_features: int, out_features: int):
super().__init__()
self.hidden_decoder = hidden_decoder
self.head = nn.Linear(in_features, out_features)
def forward(self, x):
x = self.hidden_decoder(x)
x = self.head(x)
return x