PixNerd / src /models /layers /adv_head.py
wangshuai6
init
56238f0
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ConvHead(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature, text_embedding=None):
# assume sqrt image size
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class ConvLinearMMHead(nn.Module):
def __init__(self, im_channels, mm_channels, hidden_size):
super().__init__()
self.conv_head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=im_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
)
self.linear_head = nn.Sequential(
nn.Linear(mm_channels, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
)
self.out = nn.Linear(hidden_size*2, 1)
def forward(self, im_feature, mm_feature=None):
# assume sqrt image size
B, L, C = im_feature.shape
H = W = int(math.sqrt(L))
im_feature = im_feature.permute(0, 2, 1)
im_feature = im_feature.view(B, C, H, W)
im_out = self.conv_head(im_feature).view(B, -1)
mm_out = self.linear_head(mm_feature).view(B, -1)
out = self.out(torch.cat([im_out, mm_out], dim=-1)).sigmoid().clamp(0.01, 0.99)
return out
class ConvMMHead(nn.Module):
def __init__(self, im_channels, mm_channels, hidden_size):
super().__init__()
self.conv1_head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=im_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
)
self.conv2_head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=mm_channels, out_channels=hidden_size, stride=2, padding=1),
# 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),
# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),
# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
)
self.out = nn.Linear(hidden_size*2, 1)
def forward(self, im_feature, mm_feature=None):
# assume sqrt image size
B, L, C = im_feature.shape
H = W = int(math.sqrt(L))
im_feature = im_feature.permute(0, 2, 1)
im_feature = im_feature.view(B, C, H, W)
B, Lmm, Cmm = mm_feature.shape
Hmm = Wmm = int(math.sqrt(Lmm))
mm_feature = mm_feature.permute(0, 2, 1)
mm_feature = mm_feature.view(B, Cmm, Hmm, Wmm)
im_out = self.conv1_head(im_feature).view(B, -1)
mm_out = self.conv2_head(mm_feature).view(B, -1)
out = self.out(torch.cat([im_out, mm_out], dim=-1)).sigmoid().clamp(0.01, 0.99)
return out
# class ConvTextHead(nn.Module):
# def __init__(self, in_channels, text_channels, hidden_size):
# super().__init__()
# self.head = nn.Sequential(
# nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
# nn.GroupNorm(num_groups=32, num_channels=hidden_size),
# nn.SiLU(),
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
# nn.GroupNorm(num_groups=32, num_channels=hidden_size),
# nn.SiLU(),
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
# nn.GroupNorm(num_groups=32, num_channels=hidden_size),
# nn.SiLU(),
# nn.AdaptiveAvgPool2d(1),
# nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=hidden_size, stride=1, padding=0), # 1x1 -> 1x1
# )
# self.text_head = nn.Sequential(
# nn.Linear(text_channels, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size),
# )
#
# def forward(self, feature, text_embedding=None):
# # assume sqrt image size
# B, L, C = feature.shape
# H = W = int(math.sqrt(L))
# feature = feature.permute(0, 2, 1)
# feature = feature.view(B, C, H, W)
# feature = self.head(feature).view(B, -1)
# text_embedding = torch.mean(text_embedding, dim=1, keepdim=False)
# text_embedding = self.text_head(text_embedding)
# logits = torch.sum(feature * text_embedding, dim=1, keepdim=False)
# score = logits.sigmoid().clamp(0.01, 0.99)
# return score
#
# class LinearHead(nn.Module):
# def __init__(self, in_channels, hidden_size):
# super().__init__()
# self.head = nn.Sequential(
# nn.Linear(in_channels, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, 1),
# )
# def forward(self, feature, text_embedding=None):
# out = self.head(feature).sigmoid().clamp(0.01, 0.99)
# return out
# class ConvMultiModalHead(nn.Module):
# def __init__(self, in_channels, mm_channels, hidden_size):
# super().__init__()
# self.image_head = nn.Sequential(
# nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
# nn.GroupNorm(num_groups=32, num_channels=hidden_size),
# nn.SiLU(),
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
# nn.GroupNorm(num_groups=32, num_channels=hidden_size),
# nn.SiLU(),
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
# nn.GroupNorm(num_groups=32, num_channels=hidden_size),
# nn.SiLU(),
# nn.AdaptiveAvgPool2d(1),
# nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
# )
# self.mm_head = nn.Sequential(
# nn.Linear(mm_channels, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size),
# )
#
# def forward(self, feature, text_embedding=None):
# # assume sqrt image size
# B, L, C = feature.shape
# H = W = int(math.sqrt(L))
# feature = feature.permute(0, 2, 1)
# feature = feature.view(B, C, H, W)
# feature = self.head(feature).view(B, -1)
# text_embedding = torch.mean(text_embedding, dim=1, keepdim=False)
# text_embedding = self.text_head(text_embedding)
# logits = torch.sum(feature * text_embedding, dim=1, keepdim=False)
# score = logits.sigmoid().clamp(0.01, 0.99)
# return score
# class TransformerTextHead(nn.Module):
# def __init__(self, in_channels, text_channels, hidden_size):
# super().__init__()
#
# self.transformer = nn.Sequential(
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True),
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True),
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True),
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True),
# )
# self.text_head = nn.Sequential(
# nn.Linear(text_channels, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size),
# )
# self.feature_head = nn.Sequential(
# nn.Linear(in_channels, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, hidden_size),
# )
# self.cls_head = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
# nn.SiLU(),
# nn.Linear(hidden_size, 1),
# )
#
# def forward(self, feature, text_embedding=None):
# # assume sqrt image size
# feature = self.feature_head(feature)
# text_embedding = self.text_head(text_embedding)
# tokens = torch.cat([feature, text_embedding], dim=1)
# tokens = self.transformer(tokens)
# cls_token = tokens
# logits = self.cls_head(cls_token)
# logits = torch.mean(logits, dim=1, keepdim=False)
# score = logits.sigmoid().clamp(0.01, 0.99)
# return score