File size: 8,658 Bytes
5644dea e3bf9ba 5644dea e3bf9ba 5644dea e3bf9ba 5644dea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
from typing import Optional, Tuple, Union, Dict
from PIL import Image
from functools import partial, reduce
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
import torch.distributed as dist
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.image_processing_utils import BatchFeature, get_size_dict
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
def rank0_print(*args):
if dist.is_initialized():
if dist.get_rank() == 0:
print(f"Rank {dist.get_rank()}: ", *args)
else:
print(*args)
class BaseVisionTower(nn.Module):
def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower_name
self.delay_load = delay_load
@abstractmethod
def load_model(self, device_map=None):
raise NotImplementedError("Subclasses must implement load_model")
@abstractmethod
def _forward(self, images):
raise NotImplementedError("Subclasses must implement forward")
def forward(self, images):
if type(images) is list:
image_features = [self._forward(image.unsqueeze(0)) for image in images]
else:
image_features = self._forward(images)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
# Dynamically infer the dtype from the first parameter, if not explicitly specified
if hasattr(self.vision_tower, "dtype"):
return self.vision_tower.dtype
else:
params = list(self.vision_tower.parameters())
return (
params[0].dtype if len(params) > 0 else torch.float32
) # Default to torch.float32 if no parameters
@property
def device(self):
# Dynamically infer the device from the first parameter, if not explicitly specified
if hasattr(self.vision_tower, "device"):
return self.vision_tower.device
else:
params = list(self.vision_tower.parameters())
return (
params[0].device if len(params) > 0 else torch.device("cpu")
) # Default to CPU if no parameters
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
try:
return self.config.hidden_size
except:
return self._hidden_size
class SigLipImageProcessor:
def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
class SigLipVisionTower(BaseVisionTower):
def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
# model_path = "google/siglip-so400m-patch14-384"
# base_model_name, res, interp = model_path, 384, 576
# self.vision_tower_name = base_model_name
self.vision_tower_name, res, interp = vision_tower_name, 384, 576
self._image_size = res if res is not None else 512
self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
if not delay_load:
rank0_print(f"Loading vision tower: {vision_tower_name}")
self.load_model()
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
else:
self.cfg_only = self.config
def load_model(self, device_map=None):
self.vision_model = "siglip"
# clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
print(self.vision_tower_name)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
# self.vision_tower = clip_model.visual.trunk
self.vision_tower.output_tokens = True
self._hidden_size = self.vision_tower.config.hidden_size
self.image_processor = SigLipImageProcessor()
del self.vision_tower.vision_model.encoder.layers[-1:]
self.vision_tower.vision_model.head = nn.Identity()
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
self.is_loaded = True
def _forward(self, images):
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
image_features = self.vision_tower.forward(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
).hidden_states[-1]
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (336 // 14) ** 2
@property
def num_patches_per_side(self):
#return self.config.image_size // self.config.patch_size
return 336//14
#return 27
# return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
@property
def image_size(self):
return 384
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
is_absolute_path_exists = os.path.exists(vision_tower)
use_s2 = getattr(vision_tower_cfg, "s2", False)
#print(getattr(vision_tower_cfg, "vision_tower", None))
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
#print('*************\n')
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
raise ValueError(f"Unknown vision tower: {vision_tower}")
|