Spaces:
Configuration error
Configuration error
File size: 4,033 Bytes
6858cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import torch
from torch import nn
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
from blip3o.utils import rank0_print
from tok.ta_tok import TextAlignedTokenizer
from tok.utils import ScalingLayer
class TATokVisionTower(nn.Module):
def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
super().__init__()
self.is_loaded = False
self.config = None
self.image_processor = SiglipImageProcessor()
self.vision_tower_name = vision_tower
if not delay_load:
rank0_print(f"Loading vision tower: {vision_tower}")
self.load_model()
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
else:
self.cfg_only = self.config
def load_model(self, device_map=None):
if self.is_loaded:
rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
return
self.vision_tower = TextAlignedTokenizer.from_checkpoint(self.vision_tower_name, load_teacher=False).to(device_map)
self.vision_tower.bottleneck.regularizer.set_eval_deterministic(deterministic=True)
self.vision_tower.input_type = 'rec'
self.vision_tower.scale_layer = ScalingLayer(mean=[0., 0., 0.], std=[1., 1., 1.])
self.vision_tower.requires_grad_(False)
self.vision_tower.eval()
self.pool_scales = [1, 1, 2, 3]
input_size = self.vision_tower.input_size
self.image_processor.size = (input_size, input_size)
self.image_processor.crop_size = {'height': input_size, 'width': input_size}
self.image_tokens = self.vision_tower.bottleneck_token_num
self.bottleneck_dim = self.vision_tower.bottleneck_dim
self.num_patches = self.image_tokens
self.num_patches_per_side = int(self.num_patches ** 0.5)
self.hidden_size = self.vision_tower.encoder_hidden_dim
self.image_size = input_size
self.is_loaded = True
def get_embedding(self):
return self.vision_tower.bottleneck.regularizer.get_emb()
def forward(self, images, pool_scale=1):
# load from ENV
# pool_scale from ENV has the highest priority
pool_scale = int(os.environ.get('POOL_SCALE', pool_scale))
if pool_scale is None: pool_scale = 1
if type(images) is list:
image_features, tokens = [], []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), pool_scale=pool_scale)
image_feature, token = image_forward_out['vq_feats'].to(image.dtype), image_forward_out['bottleneck_rep']
image_features.append(image_feature)
tokens.append(token)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), pool_scale=pool_scale)
image_features, tokens = image_forward_outs['vq_feats'].to(images.dtype), image_forward_outs['bottleneck_rep']
return {"image_features": image_features, "tokens": tokens, 'pool_scale': pool_scale}
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device |