|
|
import argparse |
|
|
from typing import Any, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from timm.models.layers import trunc_normal_ |
|
|
from copy import deepcopy |
|
|
import os |
|
|
import torch.backends.cudnn as cudnn |
|
|
|
|
|
import models.vision_transformer as vits |
|
|
|
|
|
class vit(nn.Module): |
|
|
|
|
|
def __init__(self, model_size="base", freeze_transformer=True, pretrained_weights=None): |
|
|
super(ibotvit, self).__init__() |
|
|
self.model_size = model_size |
|
|
self.freeze_transformer = freeze_transformer |
|
|
self.pretrained_weights = pretrained_weights |
|
|
|
|
|
|
|
|
n_register_tokens = 4 |
|
|
|
|
|
if model_size == "vit_small": |
|
|
self.embedding_size = 384 |
|
|
|
|
|
elif model_size == "vit_base": |
|
|
self.embedding_size = 768 |
|
|
|
|
|
elif model_size == "vit_large": |
|
|
self.embedding_size = 1024 |
|
|
|
|
|
elif model_size == "giant": |
|
|
self.embedding_size = 1536 |
|
|
|
|
|
|
|
|
model = vits.__dict__[model_size](patch_size=16) |
|
|
self.transformer = deepcopy(model) |
|
|
|
|
|
|
|
|
if self.freeze_transformer: |
|
|
for param in self.transformer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
if self.pretrained_weights and os.path.isfile(self.pretrained_weights): |
|
|
state_dict = torch.load(self.pretrained_weights, map_location="cpu") |
|
|
if 'teacher' in state_dict: |
|
|
state_dict = state_dict['teacher'] |
|
|
elif 'model' in state_dict: |
|
|
state_dict = state_dict['model'] |
|
|
|
|
|
|
|
|
state_dict = { |
|
|
(k[len("teacher."):] if k.startswith("teacher.") else k): v |
|
|
for k, v in state_dict.items() |
|
|
} |
|
|
state_dict = { |
|
|
(k[len("backbone."):] if k.startswith("backbone.") else k): v |
|
|
for k, v in state_dict.items() |
|
|
} |
|
|
msg = self.transformer.load_state_dict(state_dict, strict=False) |
|
|
print(model_size, msg) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x = self.transformer(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def build_model(args): |
|
|
|
|
|
net = vit("vit_base", freeze_transformer=True, pretrained_weights=args.pretrained_weights) |
|
|
net.cuda() |
|
|
|
|
|
|
|
|
|
|
|
return net |
|
|
|