ZJU-AI4H commited on
Commit
ee9327c
·
verified ·
1 Parent(s): 02a9fd2

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HulumedVisionEncoderModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_hulumed_encoder.HulumedVisionEncoderConfig",
8
+ "AutoModel": "modeling_hulumed_encoder.HulumedVisionEncoderModel"
9
+ },
10
+ "hidden_act": "gelu_pytorch_tanh",
11
+ "hidden_size": 1152,
12
+ "intermediate_size": 4304,
13
+ "layer_norm_eps": 1e-06,
14
+ "model_type": "hulumed_vision_encoder",
15
+ "num_attention_heads": 16,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 27,
18
+ "patch_size": 14,
19
+ "torch_dtype": "bfloat16",
20
+ "transformers_version": "4.46.3"
21
+ }
22
+
configuration_hulumed_encoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """HuluMed vision encoder model configuration."""
18
+
19
+ from transformers import PretrainedConfig
20
+
21
+
22
+ class HulumedVisionEncoderConfig(PretrainedConfig):
23
+
24
+ model_type = "hulumed_vision_encoder"
25
+
26
+ def __init__(
27
+ self,
28
+ hidden_size=768,
29
+ intermediate_size=3072,
30
+ num_hidden_layers=12,
31
+ num_attention_heads=12,
32
+ num_channels=3,
33
+ patch_size=16,
34
+ hidden_act="gelu_pytorch_tanh",
35
+ layer_norm_eps=1e-6,
36
+ attention_dropout=0.0,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.num_channels = num_channels
46
+ self.patch_size = patch_size
47
+ self.attention_dropout = attention_dropout
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.hidden_act = hidden_act
image_processing_hulumed.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """Image processor class for HuluMed."""
22
+
23
+ import math
24
+ from typing import Dict, List, Optional, Union
25
+
26
+ import numpy as np
27
+
28
+ import torch
29
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
30
+ from transformers.image_utils import ImageInput
31
+ from transformers.image_transforms import (
32
+ convert_to_rgb,
33
+ resize,
34
+ to_channel_dimension_format,
35
+ )
36
+ from transformers.image_utils import (
37
+ OPENAI_CLIP_MEAN,
38
+ OPENAI_CLIP_STD,
39
+ ChannelDimension,
40
+ ImageInput,
41
+ PILImageResampling,
42
+ VideoInput,
43
+ get_image_size,
44
+ infer_channel_dimension_format,
45
+ is_scaled_image,
46
+ is_valid_image,
47
+ make_list_of_images,
48
+ to_numpy_array,
49
+ )
50
+ from transformers.utils import TensorType, is_vision_available, logging
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ if is_vision_available():
57
+ from PIL import Image
58
+
59
+
60
+ def is_valid_video(video) -> bool:
61
+ if isinstance(video, (list, tuple)):
62
+ return all(is_valid_image(frame) for frame in video)
63
+ elif isinstance(video, np.ndarray):
64
+ return video.ndim == 4
65
+ elif isinstance(video, torch.Tensor):
66
+ return video.ndim == 4
67
+ return False
68
+
69
+
70
+ def make_batched_images(images) -> List[List[ImageInput]]:
71
+ """
72
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
73
+
74
+ Args:
75
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
76
+ The input image.
77
+
78
+ Returns:
79
+ list: A list of images.
80
+ """
81
+ if isinstance(images, (list, tuple)):
82
+ # list of images/videos
83
+ if not all(is_valid_video(image) or is_valid_image(image) for image in images):
84
+ raise ValueError(f"Could not make batched images from {images}")
85
+ return images
86
+ elif is_valid_video(images) or is_valid_image(images):
87
+ # single image/video
88
+ return [images]
89
+
90
+ raise ValueError(f"Could not make batched images from {images}")
91
+
92
+
93
+ def simple_batched_resize(
94
+ images, factor: int = 28, min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
95
+ ):
96
+ min_pixels = min_tokens * factor * factor
97
+ max_pixels = max_tokens * factor * factor
98
+
99
+ num_images = 0
100
+ for image in images:
101
+ if is_valid_video(image):
102
+ num_images += len(image)
103
+ else:
104
+ num_images += 1
105
+
106
+ image_sizes = []
107
+ for image in images:
108
+ if is_valid_video(image):
109
+ image = image[0]
110
+ if isinstance(image, Image.Image):
111
+ height, width = image.size
112
+ else:
113
+ height, width = get_image_size(image, channel_dim=input_data_format)
114
+ image_sizes.append([height, width])
115
+
116
+ tmp_image_sizes = []
117
+ for height, width in image_sizes:
118
+ h_bar = round(height / factor) * factor
119
+ w_bar = round(width / factor) * factor
120
+ if h_bar * w_bar > (max_pixels // num_images):
121
+ beta = math.sqrt((height * width) / (max_pixels // num_images))
122
+ h_bar = math.floor(height / beta / factor) * factor
123
+ w_bar = math.floor(width / beta / factor) * factor
124
+ # per image min_pixels
125
+ if h_bar * w_bar < min_pixels:
126
+ beta = math.sqrt(min_pixels / (height * width))
127
+ h_bar = math.ceil(height * beta / factor) * factor
128
+ w_bar = math.ceil(width * beta / factor) * factor
129
+ tmp_image_sizes.append((h_bar, w_bar))
130
+ image_sizes = tmp_image_sizes
131
+ return image_sizes
132
+
133
+
134
+ def batched_resize(
135
+ images, factors: List[int], min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
136
+ ):
137
+ image_sizes = []
138
+ for image in images:
139
+ if is_valid_video(image):
140
+ num_frame = len(image)
141
+ image = image[0]
142
+ else:
143
+ num_frame = 1
144
+ if isinstance(image, Image.Image):
145
+ height, width = image.size
146
+ else:
147
+ height, width = get_image_size(image, channel_dim=input_data_format)
148
+ image_sizes.append([num_frame, height, width])
149
+
150
+ # global max_pixels
151
+ smart_scale_factors = 1.0
152
+ total_tokens = 0
153
+ for (num_frame, height, width), factor in zip(image_sizes, factors):
154
+ total_tokens += num_frame * math.ceil(height / factor) * math.ceil(width / factor)
155
+
156
+ # TODO: add min_pixels
157
+ if total_tokens > max_tokens:
158
+ beta = math.sqrt(total_tokens / max_tokens)
159
+ tmp_image_sizes = []
160
+ for (_, height, width), factor in zip(image_sizes, factors):
161
+ h_bar = math.floor(height / beta / factor) * factor
162
+ w_bar = math.floor(width / beta / factor) * factor
163
+ tmp_image_sizes.append((h_bar, w_bar))
164
+ image_sizes = tmp_image_sizes
165
+ else:
166
+ tmp_image_sizes = []
167
+ for (_, height, width), factor in zip(image_sizes, factors):
168
+ height = round(height / factor) * factor
169
+ width = round(width / factor) * factor
170
+ tmp_image_sizes.append((height, width))
171
+ image_sizes = tmp_image_sizes
172
+
173
+ return image_sizes
174
+
175
+
176
+ class HulumedImageProcessor(BaseImageProcessor):
177
+ r"""
178
+ Constructs a HuluMed image processor that dynamically resizes images based on the original images.
179
+
180
+ Args:
181
+ do_resize (`bool`, *optional*, defaults to `True`):
182
+ Whether to resize the image's (height, width) dimensions.
183
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
184
+ Resampling filter to use when resizing the image.
185
+ do_rescale (`bool`, *optional*, defaults to `True`):
186
+ Whether to rescale the image by the specified scale `rescale_factor`.
187
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
188
+ Scale factor to use if rescaling the image.
189
+ do_normalize (`bool`, *optional*, defaults to `True`):
190
+ Whether to normalize the image.
191
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
192
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
193
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
194
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
195
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
196
+ Whether to convert the image to RGB.
197
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
198
+ The min pixels of the image to resize the image.
199
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
200
+ The max pixels of the image to resize the image.
201
+ patch_size (`int`, *optional*, defaults to 14):
202
+ The spacial patch size of the vision encoder.
203
+ """
204
+
205
+ model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
206
+
207
+ def __init__(
208
+ self,
209
+ do_resize: bool = True,
210
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
211
+ do_rescale: bool = True,
212
+ rescale_factor: Union[int, float] = 1 / 255,
213
+ do_normalize: bool = True,
214
+ image_mean: Optional[Union[float, List[float]]] = None,
215
+ image_std: Optional[Union[float, List[float]]] = None,
216
+ do_convert_rgb: bool = True,
217
+ min_tokens: int = 4 * 4,
218
+ max_tokens: int = 16384,
219
+ patch_size: int = 14,
220
+ **kwargs,
221
+ ) -> None:
222
+ super().__init__(**kwargs)
223
+ self.do_resize = do_resize
224
+ self.resample = resample
225
+ self.do_rescale = do_rescale
226
+ self.rescale_factor = rescale_factor
227
+ self.do_normalize = do_normalize
228
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
229
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
230
+ self.min_tokens = min_tokens
231
+ self.max_tokens = max_tokens
232
+ self.patch_size = patch_size
233
+ self.do_convert_rgb = do_convert_rgb
234
+
235
+ def _preprocess(
236
+ self,
237
+ images: Union[ImageInput, VideoInput],
238
+ target_size: List[int],
239
+ merge_size: int = 1,
240
+ do_resize: bool = None,
241
+ resample: PILImageResampling = None,
242
+ do_rescale: bool = None,
243
+ rescale_factor: float = None,
244
+ do_normalize: bool = None,
245
+ image_mean: Optional[Union[float, List[float]]] = None,
246
+ image_std: Optional[Union[float, List[float]]] = None,
247
+ do_convert_rgb: bool = None,
248
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
249
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
250
+ ):
251
+ """
252
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
253
+
254
+ Args:
255
+ images (`ImageInput`):
256
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
257
+ target_size (`List[int]`):
258
+ The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
259
+ merge_size (`int`, *optional*, defaults to `1`):
260
+ The merge size after the vision encoder.
261
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
262
+ Whether to resize the image.
263
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
264
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
265
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
266
+ Whether to rescale the image.
267
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
268
+ Scale factor to use if rescaling the image.
269
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
270
+ Whether to normalize the image.
271
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
272
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
273
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
274
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
275
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
276
+ Whether to convert the image to RGB.
277
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
278
+ The channel dimension format for the output image. Can be one of:
279
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
280
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
281
+ - Unset: Use the channel dimension format of the input image.
282
+ input_data_format (`ChannelDimension` or `str`, *optional*):
283
+ The channel dimension format for the input image. Can be one of:
284
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
285
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
286
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
287
+ """
288
+ images = make_list_of_images(images)
289
+
290
+ if do_convert_rgb:
291
+ images = [convert_to_rgb(image) for image in images]
292
+
293
+ # All transformations expect numpy arrays.
294
+ images = [to_numpy_array(image) for image in images]
295
+
296
+ if is_scaled_image(images[0]) and do_rescale:
297
+ logger.warning_once(
298
+ "It looks like you are trying to rescale already rescaled images. If the input"
299
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
300
+ )
301
+ if input_data_format is None:
302
+ # We assume that all images have the same channel dimension format.
303
+ input_data_format = infer_channel_dimension_format(images[0])
304
+
305
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
306
+ resized_height, resized_width = height, width
307
+ processed_images = []
308
+ for image in images:
309
+ if do_resize:
310
+ resized_height, resized_width = target_size
311
+ image = resize(
312
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
313
+ )
314
+
315
+ if do_rescale:
316
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
317
+
318
+ if do_normalize:
319
+ image = self.normalize(
320
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
321
+ )
322
+
323
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
324
+ processed_images.append(image)
325
+
326
+ patches = np.array(processed_images)
327
+ if data_format == ChannelDimension.LAST:
328
+ patches = patches.transpose(0, 3, 1, 2)
329
+ t = patches.shape[0]
330
+ channel = patches.shape[1]
331
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
332
+ patches = patches.reshape(
333
+ t,
334
+ channel,
335
+ grid_h // merge_size,
336
+ merge_size,
337
+ self.patch_size,
338
+ grid_w // merge_size,
339
+ merge_size,
340
+ self.patch_size,
341
+ )
342
+ patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
343
+ flatten_patches = patches.reshape(
344
+ t * grid_h * grid_w, channel * self.patch_size * self.patch_size
345
+ )
346
+
347
+ return flatten_patches, (t, grid_h, grid_w)
348
+
349
+ def preprocess(
350
+ self,
351
+ images: ImageInput,
352
+ do_resize: bool = None,
353
+ resample: PILImageResampling = None,
354
+ do_rescale: bool = None,
355
+ rescale_factor: float = None,
356
+ do_normalize: bool = None,
357
+ image_mean: Optional[Union[float, List[float]]] = None,
358
+ image_std: Optional[Union[float, List[float]]] = None,
359
+ do_convert_rgb: bool = None,
360
+ merge_size: Optional[Union[int, List[int]]] = None,
361
+ return_tensors: Optional[Union[str, TensorType]] = None,
362
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
363
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
364
+ ):
365
+ """
366
+ Args:
367
+ images (`ImageInput`):
368
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
369
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
370
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
371
+ Whether to resize the image.
372
+ resample (`int`, *optional*, defaults to `self.resample`):
373
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
374
+ has an effect if `do_resize` is set to `True`.
375
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
376
+ Whether to rescale the image.
377
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
378
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
379
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
380
+ Whether to normalize the image.
381
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
382
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
383
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
384
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
385
+ `True`.
386
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
387
+ Whether to convert the image to RGB.
388
+ return_tensors (`str` or `TensorType`, *optional*):
389
+ The type of tensors to return. Can be one of:
390
+ - Unset: Return a list of `np.ndarray`.
391
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
392
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
393
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
394
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
395
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
396
+ The channel dimension format for the output image. Can be one of:
397
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
398
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
399
+ - Unset: Use the channel dimension format of the input image.
400
+ input_data_format (`ChannelDimension` or `str`, *optional*):
401
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
402
+ from the input image. Can be one of:
403
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
404
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
405
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
406
+
407
+ """
408
+ do_resize = do_resize if do_resize is not None else self.do_resize
409
+ resample = resample if resample is not None else self.resample
410
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
411
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
412
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
413
+ image_mean = image_mean if image_mean is not None else self.image_mean
414
+ image_std = image_std if image_std is not None else self.image_std
415
+ merge_size = merge_size if merge_size is not None else self.merge_size
416
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
417
+
418
+ images = make_batched_images(images)
419
+
420
+ if isinstance(merge_size, (list, tuple)):
421
+ assert len(merge_size) == len(images), "Merge size must be the same length as images."
422
+ merge_sizes = merge_size
423
+ else:
424
+ merge_sizes = [merge_size for _ in images]
425
+
426
+ if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
427
+ target_sizes = simple_batched_resize(
428
+ images,
429
+ factor=self.patch_size * merge_sizes[0],
430
+ min_tokens=self.min_tokens,
431
+ max_tokens=self.max_tokens,
432
+ input_data_format=input_data_format,
433
+ )
434
+ else:
435
+ target_sizes = batched_resize(
436
+ images,
437
+ factors=[self.patch_size * merge_size for merge_size in merge_sizes],
438
+ min_tokens=self.min_tokens,
439
+ max_tokens=self.max_tokens,
440
+ input_data_format=input_data_format,
441
+ )
442
+
443
+ pixel_values, grid_sizes = [], []
444
+ for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
445
+ patches, grid_size = self._preprocess(
446
+ image,
447
+ target_size=target_size,
448
+ merge_size=merge_size,
449
+ do_resize=do_resize,
450
+ resample=resample,
451
+ do_rescale=do_rescale,
452
+ rescale_factor=rescale_factor,
453
+ do_normalize=do_normalize,
454
+ image_mean=image_mean,
455
+ image_std=image_std,
456
+ data_format=data_format,
457
+ do_convert_rgb=do_convert_rgb,
458
+ input_data_format=input_data_format,
459
+ )
460
+ pixel_values.append(patches)
461
+ grid_sizes.append(grid_size)
462
+
463
+ pixel_values = np.concatenate(pixel_values, axis=0)
464
+ grid_sizes = np.array(grid_sizes)
465
+ merge_sizes = np.array(merge_sizes)
466
+
467
+ data = {
468
+ "pixel_values": pixel_values,
469
+ "grid_sizes": grid_sizes,
470
+ "merge_sizes": merge_sizes,
471
+ }
472
+
473
+ return BatchFeature(data=data, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:758ae92931ff54c6d278664af3fed5a452f83a2e89f534ab2e3f4ac0c6e9c061
3
+ size 824342816
modeling_hulumed_encoder.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """PyTorch HuluMed vision encoder model."""
22
+
23
+ import importlib.util
24
+ import os.path as osp
25
+ import math
26
+ import warnings
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ from torch.nn.init import _calculate_fan_in_and_fan_out
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import is_flash_attn_2_available
37
+
38
+ if is_flash_attn_2_available():
39
+ from flash_attn import flash_attn_varlen_func
40
+ else:
41
+ flash_attn_varlen_func = None
42
+
43
+ try:
44
+ from .configuration_hulumed_encoder import HulumedVisionEncoderConfig
45
+ except ImportError:
46
+ spec = importlib.util.spec_from_file_location(
47
+ "configuration_hulumed_encoder",
48
+ osp.join(osp.dirname(__file__), "configuration_hulumed_encoder.py"),
49
+ )
50
+ configuration_hulumed_encoder = importlib.util.module_from_spec(spec)
51
+ spec.loader.exec_module(configuration_hulumed_encoder)
52
+ HulumedVisionEncoderConfig = getattr(
53
+ configuration_hulumed_encoder,
54
+ "HulumedVisionEncoderConfig",
55
+ )
56
+
57
+
58
+ def _trunc_normal_(tensor, mean, std, a, b):
59
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
60
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
61
+ def norm_cdf(x):
62
+ # Computes standard normal cumulative distribution function
63
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
64
+
65
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
66
+ warnings.warn(
67
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
68
+ "The distribution of values may be incorrect.",
69
+ stacklevel=2,
70
+ )
71
+
72
+ # Values are generated by using a truncated uniform distribution and
73
+ # then using the inverse CDF for the normal distribution.
74
+ # Get upper and lower cdf values
75
+ l = norm_cdf((a - mean) / std)
76
+ u = norm_cdf((b - mean) / std)
77
+
78
+ # Uniformly fill tensor with values from [l, u], then translate to
79
+ # [2l-1, 2u-1].
80
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
81
+
82
+ # Use inverse cdf transform for normal distribution to get truncated
83
+ # standard normal
84
+ tensor.erfinv_()
85
+
86
+ # Transform to proper mean, std
87
+ tensor.mul_(std * math.sqrt(2.0))
88
+ tensor.add_(mean)
89
+
90
+ # Clamp to ensure it's in the proper range
91
+ tensor.clamp_(min=a, max=b)
92
+
93
+
94
+ def trunc_normal_tf_(
95
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
96
+ ) -> torch.Tensor:
97
+ """Fills the input Tensor with values drawn from a truncated
98
+ normal distribution. The values are effectively drawn from the
99
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
100
+ with values outside :math:`[a, b]` redrawn until they are within
101
+ the bounds. The method used for generating the random values works
102
+ best when :math:`a \\leq \text{mean} \\leq b`.
103
+
104
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
105
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
106
+ and the result is subsequently scaled and shifted by the mean and std args.
107
+
108
+ Args:
109
+ tensor: an n-dimensional `torch.Tensor`
110
+ mean: the mean of the normal distribution
111
+ std: the standard deviation of the normal distribution
112
+ a: the minimum cutoff value
113
+ b: the maximum cutoff value
114
+ """
115
+ with torch.no_grad():
116
+ _trunc_normal_(tensor, 0, 1.0, a, b)
117
+ tensor.mul_(std).add_(mean)
118
+
119
+
120
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
121
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
122
+ if mode == "fan_in":
123
+ denom = fan_in
124
+ elif mode == "fan_out":
125
+ denom = fan_out
126
+ elif mode == "fan_avg":
127
+ denom = (fan_in + fan_out) / 2
128
+
129
+ variance = scale / denom
130
+
131
+ if distribution == "truncated_normal":
132
+ # constant is stddev of standard normal truncated to (-2, 2)
133
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
134
+ elif distribution == "normal":
135
+ with torch.no_grad():
136
+ tensor.normal_(std=math.sqrt(variance))
137
+ elif distribution == "uniform":
138
+ bound = math.sqrt(3 * variance)
139
+ with torch.no_grad():
140
+ tensor.uniform_(-bound, bound)
141
+ else:
142
+ raise ValueError(f"invalid distribution {distribution}")
143
+
144
+
145
+ def lecun_normal_(tensor):
146
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
147
+
148
+
149
+ def default_flax_embed_init(tensor):
150
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
162
+ orig_dtype = tensor.dtype
163
+ tensor = tensor.float()
164
+ cos = freqs.cos()
165
+ sin = freqs.sin()
166
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
167
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
168
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
169
+ output = output.to(orig_dtype)
170
+ return output
171
+
172
+
173
+ class VisionRotaryEmbedding(nn.Module):
174
+
175
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
176
+ super().__init__()
177
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
178
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
179
+
180
+ def forward(self, seqlen: int) -> torch.Tensor:
181
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
182
+ freqs = torch.outer(seq, self.inv_freq)
183
+ return freqs
184
+
185
+
186
+ class HulumedVisionEmbeddings(nn.Module):
187
+
188
+ def __init__(self, config: HulumedVisionEncoderConfig):
189
+ super().__init__()
190
+ self.config = config
191
+ self.embed_dim = config.hidden_size
192
+ self.patch_size = config.patch_size
193
+
194
+ self.patch_embedding = nn.Conv2d(
195
+ in_channels=config.num_channels,
196
+ out_channels=self.embed_dim,
197
+ kernel_size=self.patch_size,
198
+ stride=self.patch_size,
199
+ padding="valid",
200
+ )
201
+
202
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
203
+ hidden_states = hidden_states.view(
204
+ -1, self.config.num_channels, self.patch_size, self.patch_size
205
+ )
206
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
207
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
208
+ embeddings = patch_embeds.view(-1, self.embed_dim)
209
+
210
+ return embeddings
211
+
212
+
213
+ class VisionAttention(nn.Module):
214
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
215
+
216
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
217
+ def __init__(self, config):
218
+ super().__init__()
219
+ self.config = config
220
+ self.embed_dim = config.hidden_size
221
+ self.num_heads = config.num_attention_heads
222
+ self.head_dim = self.embed_dim // self.num_heads
223
+ if self.head_dim * self.num_heads != self.embed_dim:
224
+ raise ValueError(
225
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
226
+ f" {self.num_heads})."
227
+ )
228
+ self.scale = self.head_dim**-0.5
229
+ self.dropout = config.attention_dropout
230
+
231
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
232
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
233
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
234
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ cu_seqlens: torch.Tensor,
240
+ rotary_pos_emb: torch.Tensor = None,
241
+ ) -> torch.Tensor:
242
+ """Input shape: Time x Channel"""
243
+
244
+ q_len, _ = hidden_states.size()
245
+
246
+ query_states = self.q_proj(hidden_states)
247
+ key_states = self.k_proj(hidden_states)
248
+ value_states = self.v_proj(hidden_states)
249
+
250
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
251
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
252
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
253
+
254
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
255
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
256
+
257
+ attention_mask = torch.zeros([1, q_len, q_len], device=query_states.device, dtype=torch.bool)
258
+ for i in range(1, len(cu_seqlens)):
259
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
260
+
261
+ query_states = query_states.transpose(0, 1)
262
+ key_states = key_states.transpose(0, 1)
263
+ value_states = value_states.transpose(0, 1)
264
+
265
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
266
+ attn_weights = attn_weights + attention_mask
267
+
268
+ # upcast attention to fp32
269
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
270
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
271
+ attn_output = torch.matmul(attn_weights, value_states)
272
+
273
+ attn_output = attn_output.transpose(0, 1)
274
+ attn_output = attn_output.reshape(q_len, -1)
275
+ attn_output = self.out_proj(attn_output)
276
+
277
+ return attn_output
278
+
279
+
280
+ class VisionFlashAttention2(VisionAttention):
281
+
282
+ def __init__(self, *args, **kwargs):
283
+ super().__init__(*args, **kwargs)
284
+
285
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
286
+ def forward(
287
+ self,
288
+ hidden_states: torch.Tensor,
289
+ cu_seqlens: torch.Tensor,
290
+ rotary_pos_emb: torch.Tensor = None,
291
+ ) -> torch.Tensor:
292
+ q_len, _ = hidden_states.size()
293
+
294
+ query_states = self.q_proj(hidden_states)
295
+ key_states = self.k_proj(hidden_states)
296
+ value_states = self.v_proj(hidden_states)
297
+
298
+ # Flash attention requires the input to have the shape
299
+ # batch_size x seq_length x head_dim x hidden_dim
300
+ # therefore we just need to keep the original shape
301
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
302
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
303
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
304
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
305
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
306
+
307
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
308
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
309
+ q_len, -1
310
+ )
311
+ attn_output = self.out_proj(attn_output)
312
+
313
+ return attn_output
314
+
315
+
316
+ class VisionSdpaAttention(VisionAttention):
317
+
318
+ def forward(
319
+ self,
320
+ hidden_states: torch.Tensor,
321
+ cu_seqlens: torch.Tensor,
322
+ rotary_pos_emb: torch.Tensor = None,
323
+ ) -> torch.Tensor:
324
+ seq_length = hidden_states.shape[0]
325
+ query_states = self.q_proj(hidden_states)
326
+ key_states = self.k_proj(hidden_states)
327
+ value_states = self.v_proj(hidden_states)
328
+
329
+ query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
330
+ key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
331
+ value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
332
+
333
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
334
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
335
+
336
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
337
+ for i in range(1, len(cu_seqlens)):
338
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
339
+
340
+ query_states = query_states.transpose(0, 1)
341
+ key_states = key_states.transpose(0, 1)
342
+ value_states = value_states.transpose(0, 1)
343
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
344
+ attn_output = attn_output.transpose(0, 1)
345
+ attn_output = attn_output.reshape(seq_length, -1)
346
+ attn_output = self.out_proj(attn_output)
347
+ return attn_output
348
+
349
+
350
+ VISION_ATTENTION_CLASSES = {
351
+ "eager": VisionAttention,
352
+ "flash_attention_2": VisionFlashAttention2,
353
+ "sdpa": VisionSdpaAttention,
354
+ }
355
+
356
+
357
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Hulumed
358
+ class HulumedVisionMLP(nn.Module):
359
+
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.config = config
363
+ self.activation_fn = ACT2FN[config.hidden_act]
364
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
365
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
366
+
367
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
368
+ hidden_states = self.fc1(hidden_states)
369
+ hidden_states = self.activation_fn(hidden_states)
370
+ hidden_states = self.fc2(hidden_states)
371
+ return hidden_states
372
+
373
+
374
+ class HulumedVisionEncoderLayer(nn.Module):
375
+
376
+ def __init__(self, config: HulumedVisionEncoderConfig):
377
+ super().__init__()
378
+ self.embed_dim = config.hidden_size
379
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
380
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
381
+ self.mlp = HulumedVisionMLP(config)
382
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
383
+
384
+ # Ignore copy
385
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
386
+ hidden_states = hidden_states + self.self_attn(
387
+ self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
388
+ )
389
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
390
+ return hidden_states
391
+
392
+
393
+ class HulumedVisionTransformerEncoder(nn.Module):
394
+
395
+ def __init__(self, config: HulumedVisionEncoderConfig):
396
+ super().__init__()
397
+ self.config = config
398
+ head_dim = config.hidden_size // config.num_attention_heads
399
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
400
+ self.layers = nn.ModuleList([HulumedVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
401
+ self.gradient_checkpointing = False
402
+
403
+ def rot_pos_emb(self, grid_sizes, merge_sizes):
404
+ pos_ids = []
405
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
406
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
407
+ hpos_ids = hpos_ids.reshape(
408
+ h // merge_size,
409
+ merge_size,
410
+ w // merge_size,
411
+ merge_size,
412
+ )
413
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
414
+ hpos_ids = hpos_ids.flatten()
415
+
416
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
417
+ wpos_ids = wpos_ids.reshape(
418
+ h // merge_size,
419
+ merge_size,
420
+ w // merge_size,
421
+ merge_size,
422
+ )
423
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
424
+ wpos_ids = wpos_ids.flatten()
425
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
426
+
427
+ pos_ids = torch.cat(pos_ids, dim=0)
428
+ max_grid_size = grid_sizes[:, 1:].max()
429
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
430
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
431
+
432
+ return rotary_pos_emb
433
+
434
+ def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
435
+ rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
436
+
437
+ cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
438
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
439
+
440
+ for blk in self.layers:
441
+ if self.gradient_checkpointing and self.training:
442
+ hidden_states = self._gradient_checkpointing_func(
443
+ blk.__call__,
444
+ hidden_states,
445
+ cu_seqlens,
446
+ rotary_pos_emb
447
+ )
448
+ else:
449
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
450
+
451
+ return hidden_states
452
+
453
+
454
+ class HulumedVisionEncoderModel(PreTrainedModel):
455
+
456
+ config_class = HulumedVisionEncoderConfig
457
+ base_model_prefix = "hulumed"
458
+ main_input_name = "pixel_values"
459
+ supports_gradient_checkpointing = True
460
+ _no_split_modules = [
461
+ "HulumedVisionEncoderLayer",
462
+ "HulumedVisionEmbeddings",
463
+ ]
464
+ _supports_flash_attn_2 = True
465
+ _supports_sdpa = True
466
+
467
+ def __init__(self, config: HulumedVisionEncoderConfig):
468
+ super().__init__(config=config)
469
+ embed_dim = config.hidden_size
470
+
471
+ self.embeddings = HulumedVisionEmbeddings(config)
472
+ self.encoder = HulumedVisionTransformerEncoder(config)
473
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
474
+
475
+ self.post_init()
476
+
477
+ def forward(self, pixel_values, grid_sizes, merge_sizes=None) -> torch.Tensor:
478
+ hidden_states = self.embeddings(pixel_values)
479
+ hidden_states = self.encoder(hidden_states, grid_sizes, merge_sizes)
480
+ hidden_states = self.post_layernorm(hidden_states)
481
+
482
+ hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
483
+ outputs = []
484
+
485
+ for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
486
+ # NOTE: previous implementation, which supports downsampling with any factor
487
+ c = hidden_states.shape[-1]
488
+ hidden_states = hidden_states.view(
489
+ grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
490
+ ).permute(0, 1, 3, 2, 4, 5)
491
+ hidden_states = hidden_states.reshape(
492
+ grid_size[0], grid_size[1], grid_size[2], c
493
+ ).permute(0, 3, 1, 2)
494
+ hidden_states = torch.nn.functional.interpolate(
495
+ hidden_states,
496
+ size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
497
+ mode='bilinear'
498
+ )
499
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
500
+
501
+ # NOTE: simplified implementation, which only supports downsampling with integer factor
502
+ # NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
503
+ # hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
504
+ # hidden_states = hidden_states.mean(dim=1)
505
+
506
+ outputs.append(hidden_states)
507
+
508
+ return torch.cat(outputs, dim=0)
509
+
510
+ def _init_weights(self, module):
511
+ """Initialize the weights"""
512
+ if isinstance(module, nn.Embedding):
513
+ default_flax_embed_init(module.weight)
514
+ elif isinstance(module, VisionAttention):
515
+ nn.init.xavier_uniform_(module.q_proj.weight)
516
+ nn.init.xavier_uniform_(module.k_proj.weight)
517
+ nn.init.xavier_uniform_(module.v_proj.weight)
518
+ nn.init.xavier_uniform_(module.out_proj.weight)
519
+ nn.init.zeros_(module.q_proj.bias)
520
+ nn.init.zeros_(module.k_proj.bias)
521
+ nn.init.zeros_(module.v_proj.bias)
522
+ nn.init.zeros_(module.out_proj.bias)
523
+ elif isinstance(module, HulumedVisionMLP):
524
+ nn.init.xavier_uniform_(module.fc1.weight)
525
+ nn.init.xavier_uniform_(module.fc2.weight)
526
+ nn.init.normal_(module.fc1.bias, std=1e-6)
527
+ nn.init.normal_(module.fc2.bias, std=1e-6)
528
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
529
+ lecun_normal_(module.weight)
530
+ if module.bias is not None:
531
+ nn.init.zeros_(module.bias)
532
+ elif isinstance(module, nn.LayerNorm):
533
+ module.bias.data.zero_()
534
+ module.weight.data.fill_(1.0)
preprocessor_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_hulumed.HulumedImageProcessor"
4
+ },
5
+ "do_convert_rgb": true,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_processor_type": "HulumedImageProcessor",
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "max_tokens": 16384,
21
+ "min_tokens": 16,
22
+ "patch_size": 14,
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098
25
+ }