|
from torch import nn |
|
import torch.utils.checkpoint |
|
from transformers import Qwen3ForCausalLM |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import logging |
|
from .configuration_andesvl import AndesVLConfig |
|
from .modeling_aimv2_navit_rope import Aimv2VisionModel |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class AndesVLForConditionalGeneration(PreTrainedModel): |
|
config_class = AndesVLConfig |
|
main_input_name = 'pixel_values' |
|
_supports_flash_attn_2 = True |
|
_no_split_modules = ['Aimv2VisionModel','Qwen3DecoderLayer'] |
|
|
|
|
|
def __init__(self, config: AndesVLConfig): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.vision_encoder = Aimv2VisionModel(config.vision_config) |
|
self.language_model = Qwen3ForCausalLM(config.text_config) |
|
|
|
vit_hidden_size = self.vision_encoder.config.hidden_size |
|
llm_hidden_size = self.language_model.config.hidden_size |
|
self.patch_size = self.vision_encoder.config.patch_size |
|
self.mlp = nn.Sequential( |
|
nn.Linear(vit_hidden_size * 4, vit_hidden_size * 4), |
|
nn.GELU(), |
|
nn.Linear(vit_hidden_size * 4, llm_hidden_size), |
|
) |
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.language_model.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.lm_head = new_embeddings |
|
|
|
def get_flated_pixel_values(self, pixel_values): |
|
flated_pixel_values = [] |
|
image_grid_hw = [] |
|
for pv in pixel_values: |
|
c, h, w = pv.shape |
|
assert c==3 and h%self.patch_size==0 and w%self.patch_size==0, f"{c}, {w}, {h}, {self.patch_size}" |
|
image_grid_hw.append((h//self.patch_size, w//self.patch_size)) |
|
fpv = pv.reshape(c, h//(2*self.patch_size), 2, self.patch_size, w//(2*self.patch_size), 2, self.patch_size) |
|
flated_pixel_values.append(fpv.permute(1, 4, 2, 5, 0, 3, 6).reshape(-1, c*self.patch_size*self.patch_size)) |
|
flated_pixel_values = torch.cat(flated_pixel_values, dim=0) |
|
image_grid_hw = torch.tensor(image_grid_hw, device=flated_pixel_values.device) |
|
return flated_pixel_values, image_grid_hw |
|
|
|
|
|
def get_vit_embeds_and_merge(self, pixel_values, image_grid_hw, input_embeds, image_flags): |
|
""" |
|
Args: |
|
pixel_values: (Len_img, H_vit0), 拉平后的初始patch特征,按照序列维度拼接在一起 |
|
image_grid_hw: (N_img, 2), 每个图片的宽高 |
|
input_embeds: (Bt, Lt, Ht), 每个token的embedding |
|
image_flags: (Bt, Lt), 每个token是否是图片 |
|
""" |
|
vit_embeds = self.vision_encoder(pixel_values, image_grid_hw) |
|
vit_embeds = vit_embeds.view(-1, vit_embeds.shape[-1]*4) |
|
vit_embeds = self.mlp(vit_embeds) |
|
vit_embeds = vit_embeds[:image_flags.sum()] |
|
Bt, Lt, Ht = input_embeds.shape |
|
input_embeds = input_embeds.reshape(-1, Ht) |
|
image_flags = image_flags.view(-1) |
|
input_embeds[image_flags == 1] = vit_embeds |
|
input_embeds = input_embeds.view(Bt, Lt, Ht) |
|
return input_embeds |
|
|
|
@torch.inference_mode() |
|
@torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
|
def generate( |
|
self, |
|
pixel_values=None, |
|
input_ids=None, |
|
attention_mask=None, |
|
image_flags=None, |
|
generation_config=None, |
|
**generate_kwargs, |
|
) -> torch.LongTensor: |
|
|
|
input_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
if image_flags != None and (image_flags == 1).sum() > 0: |
|
flated_pixel_values, image_grid_hw = self.get_flated_pixel_values(pixel_values) |
|
input_embeds = self.get_vit_embeds_and_merge(flated_pixel_values, image_grid_hw, input_embeds, image_flags) |
|
outputs = self.language_model.generate( |
|
input_ids=input_ids, |
|
inputs_embeds=input_embeds, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
use_cache=True, |
|
**generate_kwargs, |
|
) |
|
return outputs |
|
|
|
|
|
def completion(self, prompt, images, tokenizer, image_processor, **kwargs): |
|
"""输入一段文字和一组图片(其中文字中的图片用占位符标记为<image>),输出补全的文本""" |
|
assert prompt.count("<image>") == len(images), "图片数量和占位符数量不匹配" |
|
def replacement(m): |
|
token_count = image_tokens.pop(0) |
|
return f"<img>{'<|vision_pad|>' * token_count}</img>" |
|
|
|
max_size = kwargs.get("max_size", 733) |
|
base = self.patch_size*2 |
|
image_token_id = tokenizer.vocab['<|vision_pad|>'] |
|
background_color = tuple(int(x*255) for x in image_processor.image_mean) |
|
transform = T.Compose([T.ToTensor(),T.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)]) |
|
pixel_values = [] |
|
image_tokens = [] |
|
for image in images: |
|
if isinstance(image, (tuple, list)): |
|
image, detail = image |
|
else: |
|
detail = "low" |
|
image = load_image(image) |
|
if detail=="low": |
|
image = native_preprocess(image, max_size, base, background_color, min_tokens=4) |
|
pixel_values.append(transform(image)) |
|
image_tokens.append(image.size[0]*image.size[1]//(base*base)) |
|
else: |
|
raise NotImplementedError("暂未实现") |
|
new_prompt = re.sub(r"<image>", replacement, prompt) |
|
input_ids = tokenizer(new_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device) |
|
image_flags = (input_ids == image_token_id).int() |
|
input_ids = input_ids.to(self.vision_encoder.device) |
|
pixel_values = [pv.to(self.vision_encoder.device) for pv in pixel_values] |
|
image_flags = image_flags.to(self.vision_encoder.device) |
|
output_ids = self.generate(pixel_values=pixel_values, input_ids=input_ids, image_flags=image_flags, **kwargs)[0][input_ids.shape[1]:] |
|
return tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
|
def chat(self, messages, tokenizer, image_processor, **kwargs): |
|
"""输入是一组对话信息(openai格式),输出是回复""" |
|
prompt = "" |
|
images = [] |
|
for message in messages: |
|
role = message["role"] |
|
assert role in ["user", "assistant", "system"], f"非法的角色{role}" |
|
content = message['content'] |
|
if isinstance(content, str): |
|
prompt += f"<|im_start|>{role}\n{content}{tokenizer.eos_token}\n" |
|
elif isinstance(content, list): |
|
temp = "" |
|
for sub_content in content: |
|
if sub_content['type']=='text': |
|
temp += f"{sub_content['text']}" |
|
elif sub_content['type']=='image_url': |
|
temp += "<image>" |
|
images.append([load_image(sub_content['image_url']['url']), sub_content['image_url'].get("detail",'low')]) |
|
prompt += f"<|im_start|>{role}\n{temp}{tokenizer.eos_token}\n" |
|
else: |
|
raise ValueError(f"非法的内容{content}") |
|
prompt += f"<|im_start|>assistant\n" |
|
thinking = 'thinking' in kwargs and kwargs['thinking'] |
|
if 'thinking' in kwargs: |
|
kwargs.pop('thinking') |
|
prompt += f"<|im_start|>assistant\n" + ('<think>' if thinking else '') |
|
return ('<think>' if thinking else '') + self.completion(prompt, images, tokenizer, image_processor, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import math |
|
import re |
|
from typing import Union |
|
import requests |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import torchvision.transforms as T |
|
|
|
def load_image(source: Union[str, Image.Image]) -> Image.Image: |
|
"""加载图像""" |
|
if isinstance(source, Image.Image): |
|
img = source |
|
elif isinstance(source, str): |
|
if source.startswith('http'): |
|
response = requests.get(source) |
|
response.raise_for_status() |
|
img = Image.open(BytesIO(response.content)) |
|
elif os.path.exists(source): |
|
img = Image.open(source) |
|
elif source.startswith('data:image'): |
|
img = Image.open(BytesIO(base64.b64decode(source.split(',')[1]))) |
|
else: |
|
raise ValueError("Unsupported image source") |
|
else: |
|
raise ValueError("Unsupported image source") |
|
return img.convert('RGB') |
|
|
|
def get_scaled_img_size(image_size, max_area, base, max_resolution=4172, upper=True): |
|
"""计算缩放后的图片大小和包裹矩形的大小""" |
|
|
|
aspect_ratio = image_size[0] / image_size[1] |
|
|
|
max_width = math.floor(math.sqrt(max_area * aspect_ratio)) |
|
max_height = math.floor(math.sqrt(max_area / aspect_ratio)) |
|
max_width, max_height = min(max_width, max_resolution), min( |
|
max_height, max_resolution |
|
) |
|
max_width, max_height = max(max_width, base), max(max_height, base) |
|
|
|
if not upper: |
|
|
|
max_width = max_width - max_width % base |
|
max_height = max_height - max_height % base |
|
else: |
|
|
|
max_width = min(max_width + (base - max_width % base), max_resolution) |
|
max_height = min(max_height + (base - max_height % base), max_resolution) |
|
|
|
scale_factor = min(max_width / image_size[0], max_height / image_size[1]) |
|
|
|
new_image_size = ( |
|
round(image_size[0] * scale_factor), |
|
round(image_size[1] * scale_factor), |
|
) |
|
|
|
bounding_box_size = (max_width, max_height) |
|
return new_image_size, bounding_box_size |
|
|
|
|
|
def max_preprocess( |
|
img, max_size, base, background_color, max_resolution=4172, upper=True, force_resize=False |
|
): |
|
"""对图片进行预处理,使其面积接近max_size**2""" |
|
|
|
w, h = img.size |
|
if max(w, h) > max_resolution: |
|
scale = max_resolution / max(w, h) |
|
w, h = int(w * scale), int(h * scale) |
|
|
|
new_image_size, bounding_box_size = get_scaled_img_size( |
|
(w, h), max_size**2, base, max_resolution, upper |
|
) |
|
if force_resize: |
|
return img.resize(bounding_box_size) |
|
|
|
canvas = Image.new("RGB", bounding_box_size, background_color) |
|
|
|
paste_width = (bounding_box_size[0] - new_image_size[0]) // 2 |
|
paste_height = (bounding_box_size[1] - new_image_size[1]) // 2 |
|
|
|
canvas.paste(img.resize(new_image_size), (paste_width, paste_height)) |
|
return canvas |
|
|
|
def native_preprocess( |
|
img, max_size, base, background_color, max_resolution=4172, min_tokens=64 |
|
): |
|
|
|
|
|
w, h = img.size |
|
|
|
if max(w, h) > max_resolution: |
|
scale = max_resolution / max(w, h) |
|
w, h = int(w * scale), int(h * scale) |
|
img = img.resize((w, h)) |
|
if w * h > max_size**2: |
|
return max_preprocess(img, max_size, base, background_color, max_resolution) |
|
if w * h < (base * base * min_tokens): |
|
return max_preprocess( |
|
img, |
|
int(base * (min_tokens**0.5)), |
|
base, |
|
background_color, |
|
max_resolution, |
|
) |
|
w1, h1 = w + base - w % base, h + base - h % base |
|
if w1 == w and h1 == h: |
|
return img |
|
else: |
|
|
|
scale = min(w1 / w, h1 / h) |
|
new_w, new_h = int(w * scale), int(h * scale) |
|
img = img.resize((new_w, new_h)) |
|
canvas = Image.new("RGB", (w1, h1), background_color) |
|
canvas.paste(img, ((w1 - new_w) // 2, (h1 - new_h) // 2)) |
|
return canvas |