PixNerd / src /models /encoder.py
wangshuai6
init
56238f0
import copy
import torch
import torch.nn as nn
import timm
from torchvision.transforms import Normalize
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
import os
class IndentityMapping(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, resize=True):
b, c, h, w = x.shape
x = x.reshape(b, c, h*w).transpose(1, 2)
return x
class DINOv2(nn.Module):
def __init__(self, weight_path:str, base_patch_size=16):
super(DINOv2, self).__init__()
directory = os.path.dirname(weight_path)
weight_path = os.path.basename(weight_path)
self.encoder = torch.hub.load(
directory,
weight_path,
source="local",
skip_validation=True
)
self.encoder = self.encoder.to(torch.bfloat16)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
self.base_patch_size = base_patch_size
self.encoder.compile()
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
def forward(self, x, resize=True):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
if resize:
x = torch.nn.functional.interpolate(x, (int(14*h/self.base_patch_size), int(14*w/self.base_patch_size)), mode='bicubic')
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
feature = feature.to(torch.bfloat16)
return feature
from transformers import CLIPModel, CLIPTokenizer
class CLIP(nn.Module):
def __init__(self, weight_path:str):
super(CLIP, self).__init__()
self.model = CLIPModel.from_pretrained(weight_path).to(torch.bfloat16)
self.tokenizer = CLIPTokenizer.from_pretrained(weight_path)
self.height = self.model.config.vision_config.image_size
self.width = self.model.config.vision_config.image_size
self.model.vision_model.compile()
self.model.text_model.compile()
def forward(self, x, text, resize=True):
tokens = self.tokenizer(text, truncation=True, return_tensors='pt', padding="max_length", max_length=self.tokenizer.model_max_length).input_ids.cuda()
text_output = self.model.text_model(input_ids=tokens).last_hidden_state
text_output = self.model.text_projection(text_output)
text_output = torch.nn.functional.normalize(text_output, dim=-1, p=2)
if resize:
x = torch.nn.functional.interpolate(x, (self.height, self.width), mode='bicubic')
x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
vision_output = self.model.vision_model(x).last_hidden_state[:, 1:]
vision_output = self.model.visual_projection(vision_output)
vision_output = torch.nn.functional.normalize(vision_output, dim=-1, p=2)
output = torch.bmm(vision_output, text_output.transpose(1, 2))
return output
from transformers import SiglipModel, GemmaTokenizer, SiglipTokenizer
class SigLIP(nn.Module):
def __init__(self, weight_path:str):
super(SigLIP, self).__init__()
if "siglip2" in weight_path:
self.tokenizer = GemmaTokenizer.from_pretrained(weight_path)
else:
self.tokenizer = SiglipTokenizer.from_pretrained(weight_path)
self.model = SiglipModel.from_pretrained(weight_path).to(torch.bfloat16)
self.mean = 0.5
self.std = 0.5
self.model.vision_model.compile()
self.model.text_model.compile()
def forward(self, x, text, resize=True):
tokens = self.tokenizer(text, truncation=True, return_tensors='pt', padding="max_length", max_length=64).input_ids.cuda()
text_output = self.model.text_model(input_ids=tokens).last_hidden_state
text_output = torch.nn.functional.normalize(text_output, dim=-1, p=2)
if resize:
x = torch.nn.functional.interpolate(x, (self.height, self.width), mode='bicubic')
x = (x - self.mean)/self.std
vision_output = self.model.vision_model(x).last_hidden_state
vision_output = torch.nn.functional.normalize(vision_output, dim=-1, p=2)
output = torch.bmm(vision_output, text_output.transpose(1, 2))
return output
from transformers import SiglipVisionModel
class SigLIPVision(nn.Module):
def __init__(self, weight_path:str, base_patch_size=16):
super(SigLIPVision, self).__init__()
self.model = SiglipVisionModel.from_pretrained(weight_path).to(torch.bfloat16)
self.height = self.model.config.image_size
self.width = self.model.config.image_size
self.patch_size = self.model.config.patch_size
self.base_patch_size = base_patch_size
self.model.compile()
self.mean = 0.5
self.std = 0.5
def forward(self, x, resize=True):
if resize:
h, w = x.shape[-2:]
new_h = int(self.patch_size * h / self.base_patch_size)
new_w = int(self.patch_size * w / self.base_patch_size)
x = torch.nn.functional.interpolate(x, (new_h, new_w), mode='bicubic')
x = (x - self.mean)/self.std
vision_output = self.model.vision_model(x).last_hidden_state
return vision_output