File size: 13,197 Bytes
9b14d45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
from torch import nn
import torch.utils.checkpoint
from transformers import Qwen3ForCausalLM
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_andesvl import AndesVLConfig
from .modeling_aimv2_navit_rope import Aimv2VisionModel
logger = logging.get_logger(__name__)
class AndesVLForConditionalGeneration(PreTrainedModel):
config_class = AndesVLConfig
main_input_name = 'pixel_values'
_supports_flash_attn_2 = True
_no_split_modules = ['Aimv2VisionModel','Qwen3DecoderLayer']
def __init__(self, config: AndesVLConfig):
super().__init__(config)
self.config = config
self.vision_encoder = Aimv2VisionModel(config.vision_config)
self.language_model = Qwen3ForCausalLM(config.text_config)
vit_hidden_size = self.vision_encoder.config.hidden_size
llm_hidden_size = self.language_model.config.hidden_size
self.patch_size = self.vision_encoder.config.patch_size
self.mlp = nn.Sequential(
nn.Linear(vit_hidden_size * 4, vit_hidden_size * 4),
nn.GELU(),
nn.Linear(vit_hidden_size * 4, llm_hidden_size),
)
def get_input_embeddings(self):
return self.language_model.model.embed_tokens
def set_input_embeddings(self, value):
self.language_model.model.embed_tokens = value
def get_output_embeddings(self):
return self.language_model.lm_head
def set_output_embeddings(self, new_embeddings):
self.language_model.lm_head = new_embeddings
def get_flated_pixel_values(self, pixel_values):
flated_pixel_values = []
image_grid_hw = []
for pv in pixel_values:
c, h, w = pv.shape
assert c==3 and h%self.patch_size==0 and w%self.patch_size==0, f"{c}, {w}, {h}, {self.patch_size}"
image_grid_hw.append((h//self.patch_size, w//self.patch_size))
fpv = pv.reshape(c, h//(2*self.patch_size), 2, self.patch_size, w//(2*self.patch_size), 2, self.patch_size)
flated_pixel_values.append(fpv.permute(1, 4, 2, 5, 0, 3, 6).reshape(-1, c*self.patch_size*self.patch_size))
flated_pixel_values = torch.cat(flated_pixel_values, dim=0) # (Len_img, C, H, W)
image_grid_hw = torch.tensor(image_grid_hw, device=flated_pixel_values.device) # (N_img, 2)
return flated_pixel_values, image_grid_hw
def get_vit_embeds_and_merge(self, pixel_values, image_grid_hw, input_embeds, image_flags):
"""
Args:
pixel_values: (Len_img, H_vit0), 拉平后的初始patch特征,按照序列维度拼接在一起
image_grid_hw: (N_img, 2), 每个图片的宽高
input_embeds: (Bt, Lt, Ht), 每个token的embedding
image_flags: (Bt, Lt), 每个token是否是图片
"""
vit_embeds = self.vision_encoder(pixel_values, image_grid_hw) # (Len_img, H_vit)
vit_embeds = vit_embeds.view(-1, vit_embeds.shape[-1]*4) # (Len_img//4, H_vit*4)
vit_embeds = self.mlp(vit_embeds) # (Len_img//4, H_llm)
vit_embeds = vit_embeds[:image_flags.sum()]
Bt, Lt, Ht = input_embeds.shape
input_embeds = input_embeds.reshape(-1, Ht)
image_flags = image_flags.view(-1)
input_embeds[image_flags == 1] = vit_embeds
input_embeds = input_embeds.view(Bt, Lt, Ht)
return input_embeds
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def generate(
self,
pixel_values=None,
input_ids=None,
attention_mask=None,
image_flags=None, # (Bt, Lt)
generation_config=None,
**generate_kwargs,
) -> torch.LongTensor:
input_embeds = self.language_model.get_input_embeddings()(input_ids) # (Bt, Lt, Ht)
if image_flags != None and (image_flags == 1).sum() > 0:
flated_pixel_values, image_grid_hw = self.get_flated_pixel_values(pixel_values)
input_embeds = self.get_vit_embeds_and_merge(flated_pixel_values, image_grid_hw, input_embeds, image_flags)
outputs = self.language_model.generate(
input_ids=input_ids,
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
use_cache=True,
**generate_kwargs,
)
return outputs
#NOTE: completion和chat接口暂不支持batch推理,需要手动构建self.generate函数的输入来实现。
def completion(self, prompt, images, tokenizer, image_processor, **kwargs):
"""输入一段文字和一组图片(其中文字中的图片用占位符标记为<image>),输出补全的文本"""
assert prompt.count("<image>") == len(images), "图片数量和占位符数量不匹配"
def replacement(m):
token_count = image_tokens.pop(0)
return f"<img>{'<|vision_pad|>' * token_count}</img>"
#首先对所有的图像进行处理,获取对应的size
max_size = kwargs.get("max_size", 733) # max_size**2为支持的最大的面积
base = self.patch_size*2
image_token_id = tokenizer.vocab['<|vision_pad|>'] # 图像token的占位符
background_color = tuple(int(x*255) for x in image_processor.image_mean)
transform = T.Compose([T.ToTensor(),T.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)])
pixel_values = []
image_tokens = []
for image in images:
if isinstance(image, (tuple, list)):
image, detail = image
else:
detail = "low"
image = load_image(image)
if detail=="low":
image = native_preprocess(image, max_size, base, background_color, min_tokens=4)
pixel_values.append(transform(image))
image_tokens.append(image.size[0]*image.size[1]//(base*base))
else:
raise NotImplementedError("暂未实现")
new_prompt = re.sub(r"<image>", replacement, prompt)
input_ids = tokenizer(new_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device)
image_flags = (input_ids == image_token_id).int()
input_ids = input_ids.to(self.vision_encoder.device)
pixel_values = [pv.to(self.vision_encoder.device) for pv in pixel_values]
image_flags = image_flags.to(self.vision_encoder.device)
output_ids = self.generate(pixel_values=pixel_values, input_ids=input_ids, image_flags=image_flags, **kwargs)[0][input_ids.shape[1]:]
return tokenizer.decode(output_ids, skip_special_tokens=True)
def chat(self, messages, tokenizer, image_processor, **kwargs):
"""输入是一组对话信息(openai格式),输出是回复"""
prompt = ""
images = []
for message in messages:
role = message["role"]
assert role in ["user", "assistant", "system"], f"非法的角色{role}"
content = message['content']
if isinstance(content, str):
prompt += f"<|im_start|>{role}\n{content}{tokenizer.eos_token}\n"
elif isinstance(content, list):
temp = ""
for sub_content in content:
if sub_content['type']=='text':
temp += f"{sub_content['text']}"
elif sub_content['type']=='image_url':
temp += "<image>"
images.append([load_image(sub_content['image_url']['url']), sub_content['image_url'].get("detail",'low')])
prompt += f"<|im_start|>{role}\n{temp}{tokenizer.eos_token}\n"
else:
raise ValueError(f"非法的内容{content}")
prompt += f"<|im_start|>assistant\n"
thinking = 'thinking' in kwargs and kwargs['thinking']
if 'thinking' in kwargs:
kwargs.pop('thinking')
prompt += f"<|im_start|>assistant\n" + ('<think>' if thinking else '')
return ('<think>' if thinking else '') + self.completion(prompt, images, tokenizer, image_processor, **kwargs)
# return self.completion(prompt, images, tokenizer, image_processor, **kwargs)
########################
###下面是图像处理的代码###
########################
import os
import math
import re
from typing import Union
import requests
import base64
from io import BytesIO
from PIL import Image
import torchvision.transforms as T
def load_image(source: Union[str, Image.Image]) -> Image.Image:
"""加载图像"""
if isinstance(source, Image.Image):
img = source
elif isinstance(source, str):
if source.startswith('http'):
response = requests.get(source)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
elif os.path.exists(source):
img = Image.open(source)
elif source.startswith('data:image'):
img = Image.open(BytesIO(base64.b64decode(source.split(',')[1])))
else:
raise ValueError("Unsupported image source")
else:
raise ValueError("Unsupported image source")
return img.convert('RGB')
def get_scaled_img_size(image_size, max_area, base, max_resolution=4172, upper=True):
"""计算缩放后的图片大小和包裹矩形的大小"""
# 计算原始图片的宽高比
aspect_ratio = image_size[0] / image_size[1]
# 计算包裹矩形的最大可能宽度和高度
max_width = math.floor(math.sqrt(max_area * aspect_ratio))
max_height = math.floor(math.sqrt(max_area / aspect_ratio))
max_width, max_height = min(max_width, max_resolution), min(
max_height, max_resolution
)
max_width, max_height = max(max_width, base), max(max_height, base)
# 确保包裹矩形的宽度和高度都是base的整数倍
if not upper:
# 向下取整, 保证面积不会超过max_area
max_width = max_width - max_width % base
max_height = max_height - max_height % base
else:
# 向上取整,同时不超过max_resolution(单边最大长度)
max_width = min(max_width + (base - max_width % base), max_resolution)
max_height = min(max_height + (base - max_height % base), max_resolution)
# 计算缩放因子
scale_factor = min(max_width / image_size[0], max_height / image_size[1])
# 计算缩放后的图片大小
new_image_size = (
round(image_size[0] * scale_factor),
round(image_size[1] * scale_factor),
)
# 计算包裹矩形的大小
bounding_box_size = (max_width, max_height)
return new_image_size, bounding_box_size
def max_preprocess(
img, max_size, base, background_color, max_resolution=4172, upper=True, force_resize=False
):
"""对图片进行预处理,使其面积接近max_size**2"""
# 首先把图片resize到长度和宽度都低于max_resolution
w, h = img.size
if max(w, h) > max_resolution:
scale = max_resolution / max(w, h)
w, h = int(w * scale), int(h * scale)
# 获取缩放后的图片大小和包裹矩形的大小
new_image_size, bounding_box_size = get_scaled_img_size(
(w, h), max_size**2, base, max_resolution, upper
)
if force_resize:
return img.resize(bounding_box_size)
# 创建一个新的画布
canvas = Image.new("RGB", bounding_box_size, background_color)
# 计算将图像粘贴到画布上的位置
paste_width = (bounding_box_size[0] - new_image_size[0]) // 2
paste_height = (bounding_box_size[1] - new_image_size[1]) // 2
# 将图像粘贴到画布上
canvas.paste(img.resize(new_image_size), (paste_width, paste_height))
return canvas
def native_preprocess(
img, max_size, base, background_color, max_resolution=4172, min_tokens=64
):
# 对图片进行处理,使其宽度和高度都是base的整数倍
# 如果图片的最长边超过max_resolution,就把图片resize到max_resolution以内
w, h = img.size
# 首先保证图片的最长边不超过max_resolution(ViT在极限长度)
if max(w, h) > max_resolution:
scale = max_resolution / max(w, h)
w, h = int(w * scale), int(h * scale)
img = img.resize((w, h))
if w * h > max_size**2:
return max_preprocess(img, max_size, base, background_color, max_resolution)
if w * h < (base * base * min_tokens):
return max_preprocess(
img,
int(base * (min_tokens**0.5)),
base,
background_color,
max_resolution,
)
w1, h1 = w + base - w % base, h + base - h % base
if w1 == w and h1 == h:
return img
else:
# 创建一个新的(w1, h1)的画布,并把图片resize保证只有一侧存在白边的情况
scale = min(w1 / w, h1 / h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h))
canvas = Image.new("RGB", (w1, h1), background_color)
canvas.paste(img, ((w1 - new_w) // 2, (h1 - new_h) // 2))
return canvas |