Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
class ConvHead(nn.Module): | |
def __init__(self, in_channels, hidden_size): | |
super().__init__() | |
self.head = nn.Sequential( | |
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 | |
) | |
def forward(self, feature, text_embedding=None): | |
# assume sqrt image size | |
B, L, C = feature.shape | |
H = W = int(math.sqrt(L)) | |
feature = feature.permute(0, 2, 1) | |
feature = feature.view(B, C, H, W) | |
out = self.head(feature).sigmoid().clamp(0.01, 0.99) | |
return out | |
class ConvLinearMMHead(nn.Module): | |
def __init__(self, im_channels, mm_channels, hidden_size): | |
super().__init__() | |
self.conv_head = nn.Sequential( | |
nn.Conv2d(kernel_size=4, in_channels=im_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.AdaptiveAvgPool2d(1), | |
) | |
self.linear_head = nn.Sequential( | |
nn.Linear(mm_channels, hidden_size), | |
nn.SiLU(), | |
nn.Linear(hidden_size, hidden_size), | |
nn.SiLU(), | |
) | |
self.out = nn.Linear(hidden_size*2, 1) | |
def forward(self, im_feature, mm_feature=None): | |
# assume sqrt image size | |
B, L, C = im_feature.shape | |
H = W = int(math.sqrt(L)) | |
im_feature = im_feature.permute(0, 2, 1) | |
im_feature = im_feature.view(B, C, H, W) | |
im_out = self.conv_head(im_feature).view(B, -1) | |
mm_out = self.linear_head(mm_feature).view(B, -1) | |
out = self.out(torch.cat([im_out, mm_out], dim=-1)).sigmoid().clamp(0.01, 0.99) | |
return out | |
class ConvMMHead(nn.Module): | |
def __init__(self, im_channels, mm_channels, hidden_size): | |
super().__init__() | |
self.conv1_head = nn.Sequential( | |
nn.Conv2d(kernel_size=4, in_channels=im_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.AdaptiveAvgPool2d(1), | |
) | |
self.conv2_head = nn.Sequential( | |
nn.Conv2d(kernel_size=4, in_channels=mm_channels, out_channels=hidden_size, stride=2, padding=1), | |
# 16x16 -> 8x8 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), | |
# 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), | |
# 8x8 -> 4x4 | |
nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
nn.SiLU(), | |
nn.AdaptiveAvgPool2d(1), | |
) | |
self.out = nn.Linear(hidden_size*2, 1) | |
def forward(self, im_feature, mm_feature=None): | |
# assume sqrt image size | |
B, L, C = im_feature.shape | |
H = W = int(math.sqrt(L)) | |
im_feature = im_feature.permute(0, 2, 1) | |
im_feature = im_feature.view(B, C, H, W) | |
B, Lmm, Cmm = mm_feature.shape | |
Hmm = Wmm = int(math.sqrt(Lmm)) | |
mm_feature = mm_feature.permute(0, 2, 1) | |
mm_feature = mm_feature.view(B, Cmm, Hmm, Wmm) | |
im_out = self.conv1_head(im_feature).view(B, -1) | |
mm_out = self.conv2_head(mm_feature).view(B, -1) | |
out = self.out(torch.cat([im_out, mm_out], dim=-1)).sigmoid().clamp(0.01, 0.99) | |
return out | |
# class ConvTextHead(nn.Module): | |
# def __init__(self, in_channels, text_channels, hidden_size): | |
# super().__init__() | |
# self.head = nn.Sequential( | |
# nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 | |
# nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
# nn.SiLU(), | |
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 | |
# nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
# nn.SiLU(), | |
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 | |
# nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
# nn.SiLU(), | |
# nn.AdaptiveAvgPool2d(1), | |
# nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=hidden_size, stride=1, padding=0), # 1x1 -> 1x1 | |
# ) | |
# self.text_head = nn.Sequential( | |
# nn.Linear(text_channels, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, hidden_size), | |
# ) | |
# | |
# def forward(self, feature, text_embedding=None): | |
# # assume sqrt image size | |
# B, L, C = feature.shape | |
# H = W = int(math.sqrt(L)) | |
# feature = feature.permute(0, 2, 1) | |
# feature = feature.view(B, C, H, W) | |
# feature = self.head(feature).view(B, -1) | |
# text_embedding = torch.mean(text_embedding, dim=1, keepdim=False) | |
# text_embedding = self.text_head(text_embedding) | |
# logits = torch.sum(feature * text_embedding, dim=1, keepdim=False) | |
# score = logits.sigmoid().clamp(0.01, 0.99) | |
# return score | |
# | |
# class LinearHead(nn.Module): | |
# def __init__(self, in_channels, hidden_size): | |
# super().__init__() | |
# self.head = nn.Sequential( | |
# nn.Linear(in_channels, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, 1), | |
# ) | |
# def forward(self, feature, text_embedding=None): | |
# out = self.head(feature).sigmoid().clamp(0.01, 0.99) | |
# return out | |
# class ConvMultiModalHead(nn.Module): | |
# def __init__(self, in_channels, mm_channels, hidden_size): | |
# super().__init__() | |
# self.image_head = nn.Sequential( | |
# nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 | |
# nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
# nn.SiLU(), | |
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 | |
# nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
# nn.SiLU(), | |
# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 | |
# nn.GroupNorm(num_groups=32, num_channels=hidden_size), | |
# nn.SiLU(), | |
# nn.AdaptiveAvgPool2d(1), | |
# nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 | |
# ) | |
# self.mm_head = nn.Sequential( | |
# nn.Linear(mm_channels, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, hidden_size), | |
# ) | |
# | |
# def forward(self, feature, text_embedding=None): | |
# # assume sqrt image size | |
# B, L, C = feature.shape | |
# H = W = int(math.sqrt(L)) | |
# feature = feature.permute(0, 2, 1) | |
# feature = feature.view(B, C, H, W) | |
# feature = self.head(feature).view(B, -1) | |
# text_embedding = torch.mean(text_embedding, dim=1, keepdim=False) | |
# text_embedding = self.text_head(text_embedding) | |
# logits = torch.sum(feature * text_embedding, dim=1, keepdim=False) | |
# score = logits.sigmoid().clamp(0.01, 0.99) | |
# return score | |
# class TransformerTextHead(nn.Module): | |
# def __init__(self, in_channels, text_channels, hidden_size): | |
# super().__init__() | |
# | |
# self.transformer = nn.Sequential( | |
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), | |
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), | |
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), | |
# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), | |
# ) | |
# self.text_head = nn.Sequential( | |
# nn.Linear(text_channels, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, hidden_size), | |
# ) | |
# self.feature_head = nn.Sequential( | |
# nn.Linear(in_channels, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, hidden_size), | |
# ) | |
# self.cls_head = nn.Sequential( | |
# nn.Linear(hidden_size, hidden_size), | |
# nn.SiLU(), | |
# nn.Linear(hidden_size, 1), | |
# ) | |
# | |
# def forward(self, feature, text_embedding=None): | |
# # assume sqrt image size | |
# feature = self.feature_head(feature) | |
# text_embedding = self.text_head(text_embedding) | |
# tokens = torch.cat([feature, text_embedding], dim=1) | |
# tokens = self.transformer(tokens) | |
# cls_token = tokens | |
# logits = self.cls_head(cls_token) | |
# logits = torch.mean(logits, dim=1, keepdim=False) | |
# score = logits.sigmoid().clamp(0.01, 0.99) | |
# return score | |