shai commited on
Commit
4e6fc5a
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
+ license: apache-2.0
6
+ tags:
7
+ - gpt
8
+ - llm
9
+ - multimodal large language model
10
+ thumbnail: >-
11
+ https://h2o.ai/etc.clientlibs/h2o/clientlibs/clientlib-site/resources/images/favicon.ico
12
+ pipeline_tag: text-generation
13
+ ---
14
+ # Model Card
15
+ The H2OVL-Mississippi-2B is a high-performing, general-purpose vision-language model developed by H2O.ai to handle a wide range of multimodal tasks. This model, with 2 billion parameters, excels in tasks such as image captioning, visual question answering (VQA), and document understanding, while maintaining efficiency for real-world applications.
16
+
17
+ The Mississippi-2B model builds on the strong foundations of our H2O-Danube language models, now extended to integrate vision and language tasks. It competes with larger models across various benchmarks, offering a versatile and scalable solution for document AI, OCR, and multimodal reasoning.
18
+
19
+
20
+ <div align="center">
21
+ <img src="./assets/Mississippi-2B_benchmarks.png" alt="Mississippi-2B Benchmarks" width="600"/>
22
+ </div>
23
+
24
+
25
+
26
+ ## Key Features:
27
+
28
+ - 2 Billion Parameters: Balance between performance and efficiency, making it suitable for document processing, OCR, VQA, and more.
29
+ - Optimized for Vision-Language Tasks: Achieves high performance across a wide range of applications, including document AI, OCR, and multimodal reasoning.
30
+ - Comprehensive Dataset: Trained on 17M image-text pairs, ensuring broad coverage and strong task generalization.
31
+
32
+ ## Usage
33
+
34
+ ### Install dependencies:
35
+ ```bash
36
+ pip install transformers torch torchvision einops timm peft sentencepiece
37
+ ```
38
+
39
+ If you have ampere GPUs, install flash-attention to speed up inference:
40
+ ```bash
41
+ pip install flash_attn
42
+ ```
43
+
44
+ ### Sample demo:
45
+
46
+ ```python
47
+ import torch
48
+ from transformers import AutoModel, AutoTokenizer
49
+
50
+
51
+ # Set up the model and tokenizer
52
+ model_path = 'h2oai/h2o-mississippi-2b'
53
+ model = AutoModel.from_pretrained(
54
+ model_path,
55
+ torch_dtype=torch.bfloat16,
56
+ low_cpu_mem_usage=True,
57
+ trust_remote_code=True).eval().cuda()
58
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
59
+ generation_config = dict(max_new_tokens=1024, do_sample=True)
60
+
61
+
62
+ # pure-text conversation
63
+ question = 'Hello, who are you?'
64
+ response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
65
+ print(f'User: {question}\nAssistant: {response}')
66
+
67
+
68
+ # Example for single image
69
+ image_file = './examples/image1.jpg'
70
+ question = '<image>\nPlease describe the image in detail.'
71
+ response, history = model.chat(tokenizer, image_file, question, generation_config, history=None, return_history=True)
72
+ print(f'User: {question}\nAssistant: {response}')
73
+
74
+
75
+ # Example for multiple images - multiround conversation
76
+ image_files = ['./examples/image1.jpg', './examples/image2.jpg']
77
+ question = 'Image-1: <image>\nImage-2: <image>\nDescribe the Image-1 and Image-2 in detail.'
78
+ response, history = model.chat(tokenizer, image_files, question, generation_config, history=None, return_history=True)
79
+ print(f'User: {question}\nAssistant: {response}')
80
+
81
+ question = 'What are the similarities and differences between these two images.'
82
+ response, history = model.chat(tokenizer, image_files, question, generation_config=generation_config, history=history, return_history=True)
83
+ print(f'User: {question}\nAssistant: {response}')
84
+
85
+
86
+ ```
87
+
88
+ ## Acknowledgments
89
+
90
+ We would like to express our gratitude to the [InternVL team at OpenGVLab](https://github.com/OpenGVLab/InternVL) for their research and codebases, upon which we have built and expanded. We also acknowledge the work of the [LLaVA team](https://github.com/haotian-liu/LLaVA) and the [Monkey team](https://github.com/Yuliang-Liu/Monkey/tree/main/project/mini_monkey) for their insights and techniques used in improving multimodal models.
91
+
92
+ ## Disclaimer
93
+
94
+ Please read this disclaimer carefully before using the large language model provided in this repository. Your use of the model signifies your agreement to the following terms and conditions.
95
+
96
+ - Biases and Offensiveness: The large language model is trained on a diverse range of internet text data, which may contain biased, racist, offensive, or otherwise inappropriate content. By using this model, you acknowledge and accept that the generated content may sometimes exhibit biases or produce content that is offensive or inappropriate. The developers of this repository do not endorse, support, or promote any such content or viewpoints.
97
+ - Limitations: The large language model is an AI-based tool and not a human. It may produce incorrect, nonsensical, or irrelevant responses. It is the user's responsibility to critically evaluate the generated content and use it at their discretion.
98
+ - Use at Your Own Risk: Users of this large language model must assume full responsibility for any consequences that may arise from their use of the tool. The developers and contributors of this repository shall not be held liable for any damages, losses, or harm resulting from the use or misuse of the provided model.
99
+ - Ethical Considerations: Users are encouraged to use the large language model responsibly and ethically. By using this model, you agree not to use it for purposes that promote hate speech, discrimination, harassment, or any form of illegal or harmful activities.
100
+ - Reporting Issues: If you encounter any biased, offensive, or otherwise inappropriate content generated by the large language model, please report it to the repository maintainers through the provided channels. Your feedback will help improve the model and mitigate potential issues.
101
+ - Changes to this Disclaimer: The developers of this repository reserve the right to modify or update this disclaimer at any time without prior notice. It is the user's responsibility to periodically review the disclaimer to stay informed about any changes.
102
+
103
+ By using the large language model provided in this repository, you agree to accept and comply with the terms and conditions outlined in this disclaimer. If you do not agree with any part of this disclaimer, you should refrain from using the model and any content generated by it.
added_tokens.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</box>": 32008,
3
+ "</img>": 32001,
4
+ "</quad>": 32004,
5
+ "</ref>": 32006,
6
+ "<IMG_CONTEXT>": 32002,
7
+ "<box>": 32007,
8
+ "<img>": 32000,
9
+ "<quad>": 32003,
10
+ "<ref>": 32005,
11
+ "<|end|>": 32009
12
+ }
assets/Mississippi-2B_benchmarks.png ADDED
config.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "H2OVLChatModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_h2ovl_chat.H2OVLChatConfig",
7
+ "AutoModel": "modeling_h2ovl_chat.H2OVLChatModel",
8
+ "AutoModelForCausalLM": "modeling_h2ovl_chat.H2OVLChatModel"
9
+ },
10
+
11
+ "downsample_ratio": 0.5,
12
+ "dynamic_image_size": true,
13
+ "force_image_size": 448,
14
+ "max_dynamic_patch": 6,
15
+ "min_dynamic_patch": 1,
16
+ "model_type": "h2ovl_chat",
17
+ "pad2square": false,
18
+ "ps_version": "v2",
19
+ "select_layer": -1,
20
+ "template": "h2ogpt2",
21
+ "torch_dtype": "bfloat16",
22
+ "use_backbone_lora": 0,
23
+ "use_llm_lora": 0,
24
+ "use_thumbnail": true,
25
+ "use_msac": true,
26
+
27
+ "llm_config": {
28
+ "_name_or_path": "h2oai/h2o-danube2-1.8b-chat",
29
+ "model_type": "mistral",
30
+ "architectures": [
31
+ "MistralForCausalLM"
32
+ ],
33
+ "attention_dropout": 0.0,
34
+ "torch_dtype": "bfloat16",
35
+ "use_bfloat16": true,
36
+ "hidden_size": 2560,
37
+ "intermediate_size": 6912,
38
+ "num_hidden_layers": 24,
39
+ "num_attention_heads": 32,
40
+ "num_key_value_heads": 8,
41
+ "rms_norm_eps": 1e-05,
42
+ "hidden_act": "silu",
43
+ "bos_token_id": 1,
44
+ "eos_token_id": 2,
45
+ "pad_token_id": 0,
46
+ "vocab_size": 32010,
47
+ "add_cross_attention": false,
48
+ "return_dict": true,
49
+ "output_attentions": false,
50
+ "output_hidden_states": false,
51
+ "output_scores": false,
52
+ "prefix": null,
53
+ "rope_theta": 10000,
54
+ "sep_token_id": null,
55
+ "sliding_window": null,
56
+ "tie_word_embeddings": false,
57
+ "tie_encoder_decoder": false,
58
+ "torchscript": false,
59
+ "use_cache": true,
60
+ "transformers_version": "4.44.0"
61
+ },
62
+
63
+ "vision_config": {
64
+ "architectures": [
65
+ "InternVisionModel"
66
+ ],
67
+ "hidden_size": 1024,
68
+ "image_size": 448,
69
+ "intermediate_size": 4096,
70
+ "model_type": "intern_vit_6b",
71
+ "norm_type": "layer_norm",
72
+ "num_attention_heads": 16,
73
+ "num_channels": 3,
74
+ "num_hidden_layers": 24,
75
+ "patch_size": 14,
76
+ "qk_normalization": false,
77
+ "qkv_bias": true,
78
+ "return_dict": true,
79
+ "torch_dtype": "bfloat16",
80
+ "use_bfloat16": true,
81
+ "use_flash_attn": true
82
+ }
83
+ }
84
+
configuration_h2ovl_chat.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import copy
8
+
9
+ from transformers import AutoConfig, MistralConfig
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+
13
+ from .configuration_intern_vit import InternVisionConfig
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class H2OVLChatConfig(PretrainedConfig):
19
+ model_type = 'h2ovl_chat'
20
+ is_composition = True
21
+
22
+ def __init__(
23
+ self,
24
+ vision_config=None,
25
+ llm_config=None,
26
+ use_backbone_lora=0,
27
+ use_llm_lora=0,
28
+ pad2square=False,
29
+ select_layer=-1,
30
+ force_image_size=None,
31
+ downsample_ratio=0.5,
32
+ template=None,
33
+ dynamic_image_size=False,
34
+ use_thumbnail=False,
35
+ ps_version='v1',
36
+ min_dynamic_patch=1,
37
+ max_dynamic_patch=6,
38
+ use_msac=False,
39
+ **kwargs):
40
+ super().__init__(**kwargs)
41
+
42
+ if vision_config is None:
43
+ vision_config = {}
44
+ logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
45
+
46
+ if llm_config is None:
47
+ llm_config = {}
48
+ logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
49
+
50
+ self.vision_config = InternVisionConfig(**vision_config)
51
+ if llm_config['architectures'][0] == 'MistralForCausalLM':
52
+ self.llm_config = MistralConfig(**llm_config)
53
+ else:
54
+ raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
55
+ self.use_backbone_lora = use_backbone_lora
56
+ self.use_llm_lora = use_llm_lora
57
+ self.pad2square = pad2square
58
+ self.select_layer = select_layer
59
+ self.force_image_size = force_image_size
60
+ self.downsample_ratio = downsample_ratio
61
+ self.template = template
62
+ self.dynamic_image_size = dynamic_image_size
63
+ self.use_thumbnail = use_thumbnail
64
+ self.ps_version = ps_version # pixel shuffle version
65
+ self.min_dynamic_patch = min_dynamic_patch
66
+ self.max_dynamic_patch = max_dynamic_patch
67
+ self.use_msac = use_msac
68
+
69
+ logger.info(f'vision_select_layer: {self.select_layer}')
70
+ logger.info(f'ps_version: {self.ps_version}')
71
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
72
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
73
+
74
+ def to_dict(self):
75
+ """
76
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
77
+
78
+ Returns:
79
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
80
+ """
81
+ output = copy.deepcopy(self.__dict__)
82
+ output['vision_config'] = self.vision_config.to_dict()
83
+ output['llm_config'] = self.llm_config.to_dict()
84
+ output['model_type'] = self.__class__.model_type
85
+ output['use_backbone_lora'] = self.use_backbone_lora
86
+ output['use_llm_lora'] = self.use_llm_lora
87
+ output['pad2square'] = self.pad2square
88
+ output['select_layer'] = self.select_layer
89
+ output['force_image_size'] = self.force_image_size
90
+ output['downsample_ratio'] = self.downsample_ratio
91
+ output['template'] = self.template
92
+ output['dynamic_image_size'] = self.dynamic_image_size
93
+ output['use_thumbnail'] = self.use_thumbnail
94
+ output['ps_version'] = self.ps_version
95
+ output['min_dynamic_patch'] = self.min_dynamic_patch
96
+ output['max_dynamic_patch'] = self.max_dynamic_patch
97
+ output['use_msac'] = self.use_msac
98
+
99
+ return output
configuration_intern_vit.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = 'intern_vit_6b'
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act='gelu',
76
+ norm_type='rms_norm',
77
+ layer_norm_eps=1e-6,
78
+ dropout=0.0,
79
+ drop_path_rate=0.0,
80
+ attention_dropout=0.0,
81
+ initializer_range=0.02,
82
+ initializer_factor=0.1,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+
87
+ self.hidden_size = hidden_size
88
+ self.intermediate_size = intermediate_size
89
+ self.dropout = dropout
90
+ self.drop_path_rate = drop_path_rate
91
+ self.num_hidden_layers = num_hidden_layers
92
+ self.num_attention_heads = num_attention_heads
93
+ self.num_channels = num_channels
94
+ self.patch_size = patch_size
95
+ self.image_size = image_size
96
+ self.initializer_range = initializer_range
97
+ self.initializer_factor = initializer_factor
98
+ self.attention_dropout = attention_dropout
99
+ self.layer_norm_eps = layer_norm_eps
100
+ self.hidden_act = hidden_act
101
+ self.norm_type = norm_type
102
+ self.qkv_bias = qkv_bias
103
+ self.qk_normalization = qk_normalization
104
+ self.use_flash_attn = use_flash_attn
105
+
106
+ @classmethod
107
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
108
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
109
+
110
+ if 'vision_config' in config_dict:
111
+ config_dict = config_dict['vision_config']
112
+
113
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
114
+ logger.warning(
115
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
116
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
117
+ )
118
+
119
+ return cls.from_dict(config_dict, **kwargs)
conversation.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import IntEnum, auto
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ NO_COLON_SINGLE = auto()
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class Conversation:
22
+ """A class that manages prompt templates and keeps all conversation history."""
23
+
24
+ # The name of this template
25
+ name: str
26
+ # The template of the system prompt
27
+ system_template: str = '{system_message}'
28
+ # The system message
29
+ system_message: str = ''
30
+ # The names of two roles
31
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
32
+ # All messages. Each item is (role, message).
33
+ messages: List[List[str]] = ()
34
+ # The number of few shot examples
35
+ offset: int = 0
36
+ # The separator style and configurations
37
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
38
+ sep: str = '\n'
39
+ sep2: str = None
40
+ # Stop criteria (the default one is EOS token)
41
+ stop_str: Union[str, List[str]] = None
42
+ # Stops generation if meeting any token in this list
43
+ stop_token_ids: List[int] = None
44
+
45
+ def get_prompt(self) -> str:
46
+ """Get the prompt for generation."""
47
+ system_prompt = self.system_template.format(system_message=self.system_message)
48
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
49
+ ret = system_prompt + self.sep
50
+ for role, message in self.messages:
51
+ if message:
52
+ ret += role + ': ' + message + self.sep
53
+ else:
54
+ ret += role + ':'
55
+ return ret
56
+ if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
57
+ ret = system_prompt
58
+ for role, message in self.messages:
59
+ if message:
60
+ ret += role + message + self.sep
61
+ else:
62
+ ret += role
63
+ return ret
64
+ else:
65
+ raise ValueError(f'Invalid style: {self.sep_style}')
66
+
67
+ def set_system_message(self, system_message: str):
68
+ """Set the system message."""
69
+ self.system_message = system_message
70
+
71
+ def append_message(self, role: str, message: str):
72
+ """Append a new message."""
73
+ self.messages.append([role, message])
74
+
75
+ def update_last_message(self, message: str):
76
+ """Update the last output.
77
+
78
+ The last message is typically set to be None when constructing the prompt,
79
+ so we need to update it in-place after getting the response from a model.
80
+ """
81
+ self.messages[-1][1] = message
82
+
83
+ def to_gradio_chatbot(self):
84
+ """Convert the conversation to gradio chatbot format."""
85
+ ret = []
86
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
87
+ if i % 2 == 0:
88
+ ret.append([msg, None])
89
+ else:
90
+ ret[-1][-1] = msg
91
+ return ret
92
+
93
+ def to_openai_api_messages(self):
94
+ """Convert the conversation to OpenAI chat completion format."""
95
+ ret = [{'role': 'system', 'content': self.system_message}]
96
+
97
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
98
+ if i % 2 == 0:
99
+ ret.append({'role': 'user', 'content': msg})
100
+ else:
101
+ if msg is not None:
102
+ ret.append({'role': 'assistant', 'content': msg})
103
+ return ret
104
+
105
+ def copy(self):
106
+ return Conversation(
107
+ name=self.name,
108
+ system_template=self.system_template,
109
+ system_message=self.system_message,
110
+ roles=self.roles,
111
+ messages=[[x, y] for x, y in self.messages],
112
+ offset=self.offset,
113
+ sep_style=self.sep_style,
114
+ sep=self.sep,
115
+ sep2=self.sep2,
116
+ stop_str=self.stop_str,
117
+ stop_token_ids=self.stop_token_ids,
118
+ )
119
+
120
+ def dict(self):
121
+ return {
122
+ 'template_name': self.name,
123
+ 'system_message': self.system_message,
124
+ 'roles': self.roles,
125
+ 'messages': self.messages,
126
+ 'offset': self.offset,
127
+ }
128
+
129
+
130
+ # A global registry for all conversation templates
131
+ conv_templates: Dict[str, Conversation] = {}
132
+
133
+
134
+ def register_conv_template(template: Conversation, override: bool = False):
135
+ """Register a new conversation template."""
136
+ if not override:
137
+ assert (
138
+ template.name not in conv_templates
139
+ ), f'{template.name} has been registered.'
140
+
141
+ conv_templates[template.name] = template
142
+
143
+
144
+ def get_conv_template(name: str) -> Conversation:
145
+ """Get a conversation template."""
146
+ return conv_templates[name].copy()
147
+
148
+
149
+
150
+ register_conv_template(
151
+ Conversation(
152
+ name='h2ogpt2',
153
+ roles=('<|prompt|>', '<|answer|>'),
154
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
155
+ sep='<|end|>',
156
+ stop_token_ids=[
157
+ 2,
158
+ 32009
159
+ ]
160
+ )
161
+ )
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "repetition_penalty": 1.0,
4
+ "temperature": 1.0,
5
+ "max_length": 1024,
6
+ "eos_token_id": [
7
+ 2,
8
+ 32009
9
+ ],
10
+ "transformers_version": "4.44.0"
11
+ }
image_process.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from PIL import Image
4
+ from torchvision.transforms.functional import InterpolationMode
5
+
6
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
7
+ IMAGENET_STD = (0.229, 0.224, 0.225)
8
+
9
+
10
+ def build_transform(input_size):
11
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
12
+ transform = T.Compose([
13
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
14
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
15
+ T.ToTensor(),
16
+ T.Normalize(mean=MEAN, std=STD)
17
+ ])
18
+ return transform
19
+
20
+
21
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
22
+ best_ratio_diff = float('inf')
23
+ best_ratio = (1, 1)
24
+ area = width * height
25
+ for ratio in target_ratios:
26
+ target_aspect_ratio = ratio[0] / ratio[1]
27
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
28
+ if ratio_diff < best_ratio_diff:
29
+ best_ratio_diff = ratio_diff
30
+ best_ratio = ratio
31
+ elif ratio_diff == best_ratio_diff:
32
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
33
+ best_ratio = ratio
34
+ return best_ratio
35
+
36
+
37
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
38
+ orig_width, orig_height = image.size
39
+ aspect_ratio = orig_width / orig_height
40
+
41
+ # calculate the existing image aspect ratio
42
+ target_ratios = set(
43
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
44
+ i * j <= max_num and i * j >= min_num)
45
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
46
+
47
+ # find the closest aspect ratio to the target
48
+ target_aspect_ratio = find_closest_aspect_ratio(
49
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
50
+
51
+ # calculate the target width and height
52
+ target_width = image_size * target_aspect_ratio[0]
53
+ target_height = image_size * target_aspect_ratio[1]
54
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
55
+
56
+ # resize the image
57
+ resized_img = image.resize((target_width, target_height))
58
+ processed_images = []
59
+ for i in range(blocks):
60
+ box = (
61
+ (i % (target_width // image_size)) * image_size,
62
+ (i // (target_width // image_size)) * image_size,
63
+ ((i % (target_width // image_size)) + 1) * image_size,
64
+ ((i // (target_width // image_size)) + 1) * image_size
65
+ )
66
+ # split the image
67
+ split_img = resized_img.crop(box)
68
+ processed_images.append(split_img)
69
+ assert len(processed_images) == blocks
70
+ if use_thumbnail and len(processed_images) != 1:
71
+ thumbnail_img = image.resize((image_size, image_size))
72
+ processed_images.append(thumbnail_img)
73
+ return processed_images, target_aspect_ratio
74
+
75
+
76
+ def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
77
+ orig_width, orig_height = image.size
78
+ aspect_ratio = orig_width / orig_height
79
+
80
+ # calculate the existing image aspect ratio
81
+ target_ratios = set(
82
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
83
+ i * j <= max_num and i * j >= min_num)
84
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
85
+
86
+ new_target_ratios = []
87
+ if prior_aspect_ratio is not None:
88
+ for i in target_ratios:
89
+ if prior_aspect_ratio[0]%i[0] != 0 and prior_aspect_ratio[1]%i[1] != 0:
90
+ new_target_ratios.append(i)
91
+ else:
92
+ continue
93
+
94
+ # find the closest aspect ratio to the target
95
+ target_aspect_ratio = find_closest_aspect_ratio(
96
+ aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
97
+
98
+ # calculate the target width and height
99
+ target_width = image_size * target_aspect_ratio[0]
100
+ target_height = image_size * target_aspect_ratio[1]
101
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
102
+
103
+ # resize the image
104
+ resized_img = image.resize((target_width, target_height))
105
+ processed_images = []
106
+ for i in range(blocks):
107
+ box = (
108
+ (i % (target_width // image_size)) * image_size,
109
+ (i // (target_width // image_size)) * image_size,
110
+ ((i % (target_width // image_size)) + 1) * image_size,
111
+ ((i // (target_width // image_size)) + 1) * image_size
112
+ )
113
+ # split the image
114
+ split_img = resized_img.crop(box)
115
+ processed_images.append(split_img)
116
+ assert len(processed_images) == blocks
117
+ if use_thumbnail and len(processed_images) != 1:
118
+ thumbnail_img = image.resize((image_size, image_size))
119
+ processed_images.append(thumbnail_img)
120
+ return processed_images
121
+
122
+ def load_image1(image_file, input_size=448, min_num=1, max_num=12):
123
+ image = Image.open(image_file).convert('RGB')
124
+ transform = build_transform(input_size=input_size)
125
+ images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
126
+ pixel_values = [transform(image) for image in images]
127
+ pixel_values = torch.stack(pixel_values)
128
+ return pixel_values, target_aspect_ratio
129
+
130
+ def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
131
+ image = Image.open(image_file).convert('RGB')
132
+ transform = build_transform(input_size=input_size)
133
+ images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
134
+ pixel_values = [transform(image) for image in images]
135
+ pixel_values = torch.stack(pixel_values)
136
+ return pixel_values
137
+
138
+ def load_single_image(file_name, max_num=6, msac=False):
139
+ pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=max_num)
140
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
141
+ if not msac:
142
+ num_patches_list = [pixel_values.size(0)]
143
+ return pixel_values, num_patches_list
144
+
145
+ pixel_values2 = load_image2(file_name, min_num=3, max_num=max_num, target_aspect_ratio=target_aspect_ratio)
146
+ pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
147
+ pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], dim=0).to(torch.bfloat16).cuda()
148
+ num_patches_list = [pixel_values.size(0)] # The number of patches after MSAC
149
+ return pixel_values, num_patches_list
150
+
151
+ def load_multi_images(image_files, max_num=6):
152
+ pixel_values_list = []
153
+ num_patches_list = []
154
+ for image_file in image_files:
155
+ pixel_values, _ = load_image1(image_file, max_num=max_num)
156
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
157
+ pixel_values_list.append(pixel_values)
158
+ num_patches_list.append(pixel_values.size(0))
159
+ pixel_values = torch.cat(pixel_values_list, dim=0)
160
+
161
+ return pixel_values, num_patches_list
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:682177bfbfb6516c8df626f36a3ceea423b15a290d636537b71641752823ccf8
3
+ size 4304703888
modeling_h2ovl_chat.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import warnings
7
+ from typing import Any, List, Optional, Tuple, Union
8
+
9
+ import torch.utils.checkpoint
10
+ import transformers
11
+ from peft import LoraConfig, get_peft_model
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers import (AutoModel, GenerationConfig, MistralForCausalLM)
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import ModelOutput, logging
18
+
19
+ from .conversation import get_conv_template
20
+ from .configuration_h2ovl_chat import H2OVLChatConfig
21
+ from .modeling_intern_vit import InternVisionModel
22
+ from .image_process import load_single_image, load_multi_images
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def version_cmp(v1, v2, op='eq'):
28
+ import operator
29
+
30
+ from packaging import version
31
+ op_func = getattr(operator, op)
32
+ return op_func(version.parse(v1), version.parse(v2))
33
+
34
+
35
+ class H2OVLChatModel(PreTrainedModel):
36
+ config_class = H2OVLChatConfig
37
+ main_input_name = 'pixel_values'
38
+ _no_split_modules = ['InternVisionModel', 'MistralDecoderLayer']
39
+ _supports_flash_attn_2 = True
40
+
41
+ def __init__(self, config: H2OVLChatConfig, vision_model=None, language_model=None):
42
+ super().__init__(config)
43
+
44
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
45
+ image_size = config.force_image_size or config.vision_config.image_size
46
+ patch_size = config.vision_config.patch_size
47
+ self.patch_size = patch_size
48
+ self.select_layer = config.select_layer
49
+ self.template = config.template
50
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
51
+ self.downsample_ratio = config.downsample_ratio
52
+ self.ps_version = config.ps_version
53
+ self.llm_arch_name = config.llm_config.architectures[0]
54
+ self.use_msac = config.use_msac
55
+
56
+ logger.info(f'num_image_token: {self.num_image_token}')
57
+ logger.info(f'ps_version: {self.ps_version}')
58
+ if vision_model is not None:
59
+ self.vision_model = vision_model
60
+ else:
61
+ self.vision_model = InternVisionModel(config.vision_config)
62
+ if language_model is not None:
63
+ self.language_model = language_model
64
+ else:
65
+ if config.llm_config.architectures[0] == 'MistralForCausalLM':
66
+ self.language_model = MistralForCausalLM(config.llm_config)
67
+ else:
68
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
69
+
70
+ vit_hidden_size = config.vision_config.hidden_size
71
+ llm_hidden_size = config.llm_config.hidden_size
72
+
73
+ self.mlp1 = nn.Sequential(
74
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
75
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
76
+ nn.GELU(),
77
+ nn.Linear(llm_hidden_size, llm_hidden_size)
78
+ )
79
+
80
+ self.img_context_token_id = None
81
+ self.conv_template = get_conv_template(self.template)
82
+ if hasattr(config, 'system_message'):
83
+ self.system_message = config.system_message
84
+ else:
85
+ self.system_message = self.conv_template.system_message
86
+ self.num_samples = 0
87
+
88
+ if config.use_backbone_lora:
89
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
90
+
91
+ if config.use_llm_lora:
92
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
93
+
94
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
95
+ lora_config = LoraConfig(
96
+ r=r,
97
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
98
+ lora_alpha=lora_alpha,
99
+ lora_dropout=lora_dropout,
100
+ )
101
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
102
+ self.vision_model.print_trainable_parameters()
103
+
104
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
105
+ # Determine the target modules based on the architecture of the language model
106
+ if self.llm_arch_name in ['MistralForCausalLM']:
107
+ target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
108
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
109
+ else:
110
+ raise NotImplemented
111
+ lora_config = LoraConfig(
112
+ r=r,
113
+ target_modules=target_modules,
114
+ lora_alpha=lora_alpha,
115
+ lora_dropout=lora_dropout,
116
+ task_type='CAUSAL_LM'
117
+ )
118
+ self.language_model = get_peft_model(self.language_model, lora_config)
119
+ self.language_model.enable_input_require_grads()
120
+ self.language_model.print_trainable_parameters()
121
+
122
+ def forward(
123
+ self,
124
+ pixel_values: torch.FloatTensor,
125
+ input_ids: torch.LongTensor = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ position_ids: Optional[torch.LongTensor] = None,
128
+ image_flags: Optional[torch.LongTensor] = None,
129
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
130
+ labels: Optional[torch.LongTensor] = None,
131
+ use_cache: Optional[bool] = None,
132
+ output_attentions: Optional[bool] = None,
133
+ output_hidden_states: Optional[bool] = None,
134
+ return_dict: Optional[bool] = None,
135
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
136
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
137
+
138
+ image_flags = image_flags.squeeze(-1)
139
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
140
+
141
+ vit_embeds = self.extract_feature(pixel_values)
142
+ vit_embeds = vit_embeds[image_flags == 1]
143
+ vit_batch_size = pixel_values.shape[0]
144
+
145
+ B, N, C = input_embeds.shape
146
+ input_embeds = input_embeds.reshape(B * N, C)
147
+
148
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
149
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
150
+
151
+ input_ids = input_ids.reshape(B * N)
152
+ selected = (input_ids == self.img_context_token_id)
153
+ try:
154
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
155
+ ignore_flag = False
156
+ except Exception as e:
157
+ vit_embeds = vit_embeds.reshape(-1, C)
158
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
159
+ f'vit_embeds.shape={vit_embeds.shape}')
160
+ n_token = selected.sum()
161
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
162
+ ignore_flag = True
163
+
164
+ input_embeds = input_embeds.reshape(B, N, C)
165
+
166
+ outputs = self.language_model(
167
+ inputs_embeds=input_embeds,
168
+ attention_mask=attention_mask,
169
+ position_ids=position_ids,
170
+ past_key_values=past_key_values,
171
+ use_cache=use_cache,
172
+ output_attentions=output_attentions,
173
+ output_hidden_states=output_hidden_states,
174
+ return_dict=return_dict,
175
+ )
176
+ logits = outputs.logits
177
+
178
+ loss = None
179
+ if labels is not None:
180
+ # Shift so that tokens < n predict n
181
+ shift_logits = logits[..., :-1, :].contiguous()
182
+ shift_labels = labels[..., 1:].contiguous()
183
+ # Flatten the tokens
184
+ loss_fct = CrossEntropyLoss()
185
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
186
+ shift_labels = shift_labels.view(-1)
187
+ # Enable model parallelism
188
+ shift_labels = shift_labels.to(shift_logits.device)
189
+ loss = loss_fct(shift_logits, shift_labels)
190
+ if ignore_flag:
191
+ loss = loss * 0.0
192
+
193
+ if not return_dict:
194
+ output = (logits,) + outputs[1:]
195
+ return (loss,) + output if loss is not None else output
196
+
197
+ return CausalLMOutputWithPast(
198
+ loss=loss,
199
+ logits=logits,
200
+ past_key_values=outputs.past_key_values,
201
+ hidden_states=outputs.hidden_states,
202
+ attentions=outputs.attentions,
203
+ )
204
+
205
+ def pixel_shuffle(self, x, scale_factor=0.5):
206
+ n, w, h, c = x.size()
207
+ # N, W, H, C --> N, W, H * scale, C // scale
208
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
209
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
210
+ x = x.permute(0, 2, 1, 3).contiguous()
211
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
212
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
213
+ int(c / (scale_factor * scale_factor)))
214
+ if self.ps_version == 'v1':
215
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
216
+ 'which results in a transposed image.')
217
+ else:
218
+ x = x.permute(0, 2, 1, 3).contiguous()
219
+ return x
220
+
221
+ def extract_feature(self, pixel_values):
222
+ if self.select_layer == -1:
223
+ vit_embeds = self.vision_model(
224
+ pixel_values=pixel_values,
225
+ output_hidden_states=False,
226
+ return_dict=True).last_hidden_state
227
+ else:
228
+ vit_embeds = self.vision_model(
229
+ pixel_values=pixel_values,
230
+ output_hidden_states=True,
231
+ return_dict=True).hidden_states[self.select_layer]
232
+ vit_embeds = vit_embeds[:, 1:, :]
233
+
234
+ h = w = int(vit_embeds.shape[1] ** 0.5)
235
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
236
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
237
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
238
+ vit_embeds = self.mlp1(vit_embeds)
239
+ return vit_embeds
240
+
241
+
242
+ def chat(self, tokenizer, image_files, question, generation_config , max_tiles=6, history=None, return_history=False,
243
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
244
+ verbose=False):
245
+
246
+ if image_files:
247
+ if isinstance(image_files, list):
248
+ pixel_values, num_patches_list = load_multi_images(image_files, max_num=max_tiles) # Load multiple images
249
+ else:
250
+ pixel_values, num_patches_list = load_single_image(image_files, max_num=max_tiles, msac=self.use_msac) # Load single image
251
+ else:
252
+ pixel_values = None
253
+ num_patches_list = []
254
+
255
+
256
+ if history is None and pixel_values is not None and '<image>' not in question:
257
+ question = '<image>\n' + question
258
+
259
+ if num_patches_list is None:
260
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
261
+
262
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
263
+
264
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
265
+ self.img_context_token_id = img_context_token_id
266
+
267
+ template = get_conv_template(self.template)
268
+ template.system_message = self.system_message
269
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
270
+
271
+ history = [] if history is None else history
272
+ for (old_question, old_answer) in history:
273
+ template.append_message(template.roles[0], old_question)
274
+ template.append_message(template.roles[1], old_answer)
275
+ template.append_message(template.roles[0], question)
276
+ template.append_message(template.roles[1], None)
277
+ query = template.get_prompt()
278
+
279
+ if verbose and pixel_values is not None:
280
+ image_bs = pixel_values.shape[0]
281
+ print(f'dynamic ViT batch size: {image_bs}')
282
+
283
+ for num_patches in num_patches_list:
284
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
285
+ query = query.replace('<image>', image_tokens, 1)
286
+
287
+ model_inputs = tokenizer(query, return_tensors='pt')
288
+ input_ids = model_inputs['input_ids'].cuda()
289
+ attention_mask = model_inputs['attention_mask'].cuda()
290
+ generation_config['eos_token_id'] = eos_token_id
291
+ generation_output = self.generate(
292
+ pixel_values=pixel_values,
293
+ input_ids=input_ids,
294
+ attention_mask=attention_mask,
295
+ **generation_config
296
+ )
297
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
298
+ response = response.split(template.sep)[0].strip()
299
+ history.append((question, response))
300
+ if return_history:
301
+ return response, history
302
+ else:
303
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
304
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
305
+ if verbose:
306
+ print(query_to_print, response)
307
+ return response
308
+
309
+ @torch.no_grad()
310
+ def generate(
311
+ self,
312
+ pixel_values: Optional[torch.FloatTensor] = None,
313
+ input_ids: Optional[torch.FloatTensor] = None,
314
+ attention_mask: Optional[torch.LongTensor] = None,
315
+ visual_features: Optional[torch.FloatTensor] = None,
316
+ generation_config: Optional[GenerationConfig] = None,
317
+ output_hidden_states: Optional[bool] = None,
318
+ return_dict: Optional[bool] = None,
319
+ **generate_kwargs,
320
+ ) -> torch.LongTensor:
321
+
322
+ assert self.img_context_token_id is not None
323
+ if pixel_values is not None:
324
+ if visual_features is not None:
325
+ vit_embeds = visual_features
326
+ else:
327
+ vit_embeds = self.extract_feature(pixel_values)
328
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
329
+ B, N, C = input_embeds.shape
330
+ input_embeds = input_embeds.reshape(B * N, C)
331
+
332
+ input_ids = input_ids.reshape(B * N)
333
+ selected = (input_ids == self.img_context_token_id)
334
+ assert selected.sum() != 0
335
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
336
+
337
+ input_embeds = input_embeds.reshape(B, N, C)
338
+ else:
339
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
340
+
341
+ outputs = self.language_model.generate(
342
+ inputs_embeds=input_embeds,
343
+ attention_mask=attention_mask,
344
+ generation_config=generation_config,
345
+ output_hidden_states=output_hidden_states,
346
+ return_dict=return_dict,
347
+ use_cache=True,
348
+ **generate_kwargs,
349
+ )
350
+
351
+ return outputs
modeling_intern_vit.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+ try:
23
+ try: # v1
24
+ from flash_attn.flash_attn_interface import \
25
+ flash_attn_unpadded_qkvpacked_func
26
+ except: # v2
27
+ from flash_attn.flash_attn_interface import \
28
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
29
+
30
+ from flash_attn.bert_padding import pad_input, unpad_input
31
+
32
+ has_flash_attn = True
33
+ except:
34
+ print('FlashAttention is not installed.')
35
+ has_flash_attn = False
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class FlashAttention(nn.Module):
41
+ """Implement the scaled dot product attention with softmax.
42
+ Arguments
43
+ ---------
44
+ softmax_scale: The temperature to use for the softmax attention.
45
+ (default: 1/sqrt(d_keys) where d_keys is computed at
46
+ runtime)
47
+ attention_dropout: The dropout rate to apply to the attention
48
+ (default: 0.0)
49
+ """
50
+
51
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
52
+ super().__init__()
53
+ self.softmax_scale = softmax_scale
54
+ self.dropout_p = attention_dropout
55
+
56
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
57
+ max_s=None, need_weights=False):
58
+ """Implements the multihead softmax attention.
59
+ Arguments
60
+ ---------
61
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
62
+ if unpadded: (nnz, 3, h, d)
63
+ key_padding_mask: a bool tensor of shape (B, S)
64
+ """
65
+ assert not need_weights
66
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
67
+ assert qkv.is_cuda
68
+
69
+ if cu_seqlens is None:
70
+ batch_size = qkv.shape[0]
71
+ seqlen = qkv.shape[1]
72
+ if key_padding_mask is None:
73
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
74
+ max_s = seqlen
75
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
76
+ device=qkv.device)
77
+ output = flash_attn_unpadded_qkvpacked_func(
78
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
79
+ softmax_scale=self.softmax_scale, causal=causal
80
+ )
81
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
82
+ else:
83
+ nheads = qkv.shape[-2]
84
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
85
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
86
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
87
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
88
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
89
+ softmax_scale=self.softmax_scale, causal=causal
90
+ )
91
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
92
+ indices, batch_size, seqlen),
93
+ 'b s (h d) -> b s h d', h=nheads)
94
+ else:
95
+ assert max_s is not None
96
+ output = flash_attn_unpadded_qkvpacked_func(
97
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
98
+ softmax_scale=self.softmax_scale, causal=causal
99
+ )
100
+
101
+ return output, None
102
+
103
+
104
+ class InternRMSNorm(nn.Module):
105
+ def __init__(self, hidden_size, eps=1e-6):
106
+ super().__init__()
107
+ self.weight = nn.Parameter(torch.ones(hidden_size))
108
+ self.variance_epsilon = eps
109
+
110
+ def forward(self, hidden_states):
111
+ input_dtype = hidden_states.dtype
112
+ hidden_states = hidden_states.to(torch.float32)
113
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
114
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
115
+ return self.weight * hidden_states.to(input_dtype)
116
+
117
+
118
+ try:
119
+ from apex.normalization import FusedRMSNorm
120
+
121
+ InternRMSNorm = FusedRMSNorm # noqa
122
+
123
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
124
+ except ImportError:
125
+ # using the normal InternRMSNorm
126
+ pass
127
+ except Exception:
128
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
129
+ pass
130
+
131
+
132
+ NORM2FN = {
133
+ 'rms_norm': InternRMSNorm,
134
+ 'layer_norm': nn.LayerNorm,
135
+ }
136
+
137
+
138
+ class InternVisionEmbeddings(nn.Module):
139
+ def __init__(self, config: InternVisionConfig):
140
+ super().__init__()
141
+ self.config = config
142
+ self.embed_dim = config.hidden_size
143
+ self.image_size = config.image_size
144
+ self.patch_size = config.patch_size
145
+
146
+ self.class_embedding = nn.Parameter(
147
+ torch.randn(1, 1, self.embed_dim),
148
+ )
149
+
150
+ self.patch_embedding = nn.Conv2d(
151
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
152
+ )
153
+
154
+ self.num_patches = (self.image_size // self.patch_size) ** 2
155
+ self.num_positions = self.num_patches + 1
156
+
157
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
158
+
159
+ def _get_pos_embed(self, pos_embed, H, W):
160
+ target_dtype = pos_embed.dtype
161
+ pos_embed = pos_embed.float().reshape(
162
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
163
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
164
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
165
+ return pos_embed
166
+
167
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
168
+ target_dtype = self.patch_embedding.weight.dtype
169
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
170
+ batch_size, _, height, width = patch_embeds.shape
171
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
172
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
173
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
174
+ position_embedding = torch.cat([
175
+ self.position_embedding[:, :1, :],
176
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
177
+ ], dim=1)
178
+ embeddings = embeddings + position_embedding.to(target_dtype)
179
+ return embeddings
180
+
181
+
182
+ class InternAttention(nn.Module):
183
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
184
+
185
+ def __init__(self, config: InternVisionConfig):
186
+ super().__init__()
187
+ self.config = config
188
+ self.embed_dim = config.hidden_size
189
+ self.num_heads = config.num_attention_heads
190
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
191
+ if config.use_flash_attn and not has_flash_attn:
192
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
193
+ self.head_dim = self.embed_dim // self.num_heads
194
+ if self.head_dim * self.num_heads != self.embed_dim:
195
+ raise ValueError(
196
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
197
+ f' {self.num_heads}).'
198
+ )
199
+
200
+ self.scale = self.head_dim ** -0.5
201
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
202
+ self.attn_drop = nn.Dropout(config.attention_dropout)
203
+ self.proj_drop = nn.Dropout(config.dropout)
204
+
205
+ self.qk_normalization = config.qk_normalization
206
+
207
+ if self.qk_normalization:
208
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
209
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
210
+
211
+ if self.use_flash_attn:
212
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
213
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
214
+
215
+ def _naive_attn(self, x):
216
+ B, N, C = x.shape
217
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
218
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
219
+
220
+ if self.qk_normalization:
221
+ B_, H_, N_, D_ = q.shape
222
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
223
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
224
+
225
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
226
+ attn = attn.softmax(dim=-1)
227
+ attn = self.attn_drop(attn)
228
+
229
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
230
+ x = self.proj(x)
231
+ x = self.proj_drop(x)
232
+ return x
233
+
234
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
235
+ qkv = self.qkv(x)
236
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
237
+
238
+ if self.qk_normalization:
239
+ q, k, v = qkv.unbind(2)
240
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
241
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
242
+ qkv = torch.stack([q, k, v], dim=2)
243
+
244
+ context, _ = self.inner_attn(
245
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
246
+ )
247
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
248
+ outs = self.proj_drop(outs)
249
+ return outs
250
+
251
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
253
+ return x
254
+
255
+
256
+ class InternMLP(nn.Module):
257
+ def __init__(self, config: InternVisionConfig):
258
+ super().__init__()
259
+ self.config = config
260
+ self.act = ACT2FN[config.hidden_act]
261
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
262
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
263
+
264
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265
+ hidden_states = self.fc1(hidden_states)
266
+ hidden_states = self.act(hidden_states)
267
+ hidden_states = self.fc2(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class InternVisionEncoderLayer(nn.Module):
272
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
273
+ super().__init__()
274
+ self.embed_dim = config.hidden_size
275
+ self.intermediate_size = config.intermediate_size
276
+ self.norm_type = config.norm_type
277
+
278
+ self.attn = InternAttention(config)
279
+ self.mlp = InternMLP(config)
280
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
281
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
282
+
283
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
284
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
285
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
286
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
292
+ """
293
+ Args:
294
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
295
+ """
296
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
297
+
298
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
299
+
300
+ return hidden_states
301
+
302
+
303
+ class InternVisionEncoder(nn.Module):
304
+ """
305
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
306
+ [`InternEncoderLayer`].
307
+
308
+ Args:
309
+ config (`InternConfig`):
310
+ The corresponding vision configuration for the `InternEncoder`.
311
+ """
312
+
313
+ def __init__(self, config: InternVisionConfig):
314
+ super().__init__()
315
+ self.config = config
316
+ # stochastic depth decay rule
317
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
318
+ self.layers = nn.ModuleList([
319
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
320
+ self.gradient_checkpointing = True
321
+
322
+ def forward(
323
+ self,
324
+ inputs_embeds,
325
+ output_hidden_states: Optional[bool] = None,
326
+ return_dict: Optional[bool] = None,
327
+ ) -> Union[Tuple, BaseModelOutput]:
328
+ r"""
329
+ Args:
330
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
331
+ Embedded representation of the inputs. Should be float, not int tokens.
332
+ output_hidden_states (`bool`, *optional*):
333
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
334
+ for more detail.
335
+ return_dict (`bool`, *optional*):
336
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
337
+ """
338
+ output_hidden_states = (
339
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
340
+ )
341
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
342
+
343
+ encoder_states = () if output_hidden_states else None
344
+ hidden_states = inputs_embeds
345
+
346
+ for idx, encoder_layer in enumerate(self.layers):
347
+ if output_hidden_states:
348
+ encoder_states = encoder_states + (hidden_states,)
349
+ if self.gradient_checkpointing and self.training:
350
+ layer_outputs = torch.utils.checkpoint.checkpoint(
351
+ encoder_layer,
352
+ hidden_states)
353
+ else:
354
+ layer_outputs = encoder_layer(
355
+ hidden_states,
356
+ )
357
+ hidden_states = layer_outputs
358
+
359
+ if output_hidden_states:
360
+ encoder_states = encoder_states + (hidden_states,)
361
+
362
+ if not return_dict:
363
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
364
+ return BaseModelOutput(
365
+ last_hidden_state=hidden_states, hidden_states=encoder_states
366
+ )
367
+
368
+
369
+ class InternVisionModel(PreTrainedModel):
370
+ main_input_name = 'pixel_values'
371
+ _supports_flash_attn_2 = True
372
+ config_class = InternVisionConfig
373
+ _no_split_modules = ['InternVisionEncoderLayer']
374
+
375
+ def __init__(self, config: InternVisionConfig):
376
+ super().__init__(config)
377
+ self.config = config
378
+
379
+ self.embeddings = InternVisionEmbeddings(config)
380
+ self.encoder = InternVisionEncoder(config)
381
+
382
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
383
+ pos_emb = self.embeddings.position_embedding
384
+ _, num_positions, embed_dim = pos_emb.shape
385
+ cls_emb = pos_emb[:, :1, :]
386
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
387
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
388
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
389
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
390
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
391
+ self.embeddings.image_size = new_size
392
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
393
+
394
+ def get_input_embeddings(self):
395
+ return self.embeddings
396
+
397
+ def forward(
398
+ self,
399
+ pixel_values: Optional[torch.FloatTensor] = None,
400
+ output_hidden_states: Optional[bool] = None,
401
+ return_dict: Optional[bool] = None,
402
+ pixel_embeds: Optional[torch.FloatTensor] = None,
403
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
404
+ output_hidden_states = (
405
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
406
+ )
407
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
408
+
409
+ if pixel_values is None and pixel_embeds is None:
410
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
411
+
412
+ if pixel_embeds is not None:
413
+ hidden_states = pixel_embeds
414
+ else:
415
+ if len(pixel_values.shape) == 4:
416
+ hidden_states = self.embeddings(pixel_values)
417
+ else:
418
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
419
+ encoder_outputs = self.encoder(
420
+ inputs_embeds=hidden_states,
421
+ output_hidden_states=output_hidden_states,
422
+ return_dict=return_dict,
423
+ )
424
+ last_hidden_state = encoder_outputs.last_hidden_state
425
+ pooled_output = last_hidden_state[:, 0, :]
426
+
427
+ if not return_dict:
428
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
429
+
430
+ return BaseModelOutputWithPooling(
431
+ last_hidden_state=last_hidden_state,
432
+ pooler_output=pooled_output,
433
+ hidden_states=encoder_outputs.hidden_states,
434
+ attentions=encoder_outputs.attentions,
435
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<img>",
4
+ "</img>",
5
+ "<IMG_CONTEXT>",
6
+ "<quad>",
7
+ "</quad>",
8
+ "<ref>",
9
+ "</ref>",
10
+ "<box>",
11
+ "</box>",
12
+ "<|end|>"
13
+ ],
14
+ "bos_token": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "cls_token": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "eos_token": {
29
+ "content": "</s>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ },
35
+ "pad_token": {
36
+ "content": "<unk>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false
41
+ },
42
+ "sep_token": {
43
+ "content": "</s>",
44
+ "lstrip": false,
45
+ "normalized": false,
46
+ "rstrip": false,
47
+ "single_word": false
48
+ },
49
+ "unk_token": {
50
+ "content": "<unk>",
51
+ "lstrip": false,
52
+ "normalized": false,
53
+ "rstrip": false,
54
+ "single_word": false
55
+ }
56
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": true,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "32000": {
31
+ "content": "<img>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "32001": {
39
+ "content": "</img>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "32002": {
47
+ "content": "<IMG_CONTEXT>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "32003": {
55
+ "content": "<quad>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "32004": {
63
+ "content": "</quad>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "32005": {
71
+ "content": "<ref>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "32006": {
79
+ "content": "</ref>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "32007": {
87
+ "content": "<box>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "32008": {
95
+ "content": "</box>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "32009": {
103
+ "content": "<|end|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ }
110
+ },
111
+ "additional_special_tokens": [
112
+ "<img>",
113
+ "</img>",
114
+ "<IMG_CONTEXT>",
115
+ "<quad>",
116
+ "</quad>",
117
+ "<ref>",
118
+ "</ref>",
119
+ "<box>",
120
+ "</box>",
121
+ "<|end|>"
122
+ ],
123
+ "bos_token": "<s>",
124
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}",
125
+ "clean_up_tokenization_spaces": false,
126
+ "cls_token": "</s>",
127
+ "eos_token": "<|end|>",
128
+ "legacy": true,
129
+ "model_max_length": 8192,
130
+ "pad_token": "<unk>",
131
+ "sep_token": "</s>",
132
+ "sp_model_kwargs": {},
133
+ "spaces_between_special_tokens": false,
134
+ "tokenizer_class": "LlamaTokenizer",
135
+ "unk_token": "<unk>",
136
+ "use_default_system_prompt": false
137
+ }