from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM from .configuration_deepseek_v2 import DeepseekV2Config from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from typing import List, Optional, Tuple, Union from transformers.cache_utils import Cache import requests from PIL import Image, ImageOps, ImageDraw, ImageFont from io import BytesIO import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import os from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector from addict import Dict from transformers import TextStreamer from .conversation import get_conv_template from abc import ABC import math import re from tqdm import tqdm import numpy as np import time def load_image(image_path): try: image = Image.open(image_path) corrected_image = ImageOps.exif_transpose(image) return corrected_image except Exception as e: print(f"error: {e}") try: return Image.open(image_path) except: return None def re_match(text): pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' matches = re.findall(pattern, text, re.DOTALL) # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) mathes_image = [] mathes_other = [] for a_match in matches: if '<|ref|>image<|/ref|>' in a_match[0]: mathes_image.append(a_match[0]) else: mathes_other.append(a_match[0]) return matches, mathes_image, mathes_other def extract_coordinates_and_label(ref_text, image_width, image_height): try: label_type = ref_text[1] cor_list = eval(ref_text[2]) except Exception as e: print(e) return None return (label_type, cor_list) def draw_bounding_boxes(image, refs, ouput_path): image_width, image_height = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) draw2 = ImageDraw.Draw(overlay) # try: # except IOError: # try: # font = ImageFont.truetype("DejaVuSans.ttf", 20) # except IOError: font = ImageFont.load_default() img_idx = 0 for i, ref in enumerate(refs): try: result = extract_coordinates_and_label(ref, image_width, image_height) if result: label_type, points_list = result color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) color_a = color + (20, ) for points in points_list: x1, y1, x2, y2 = points x1 = int(x1 / 999 * image_width) y1 = int(y1 / 999 * image_height) x2 = int(x2 / 999 * image_width) y2 = int(y2 / 999 * image_height) if label_type == 'image': try: cropped = image.crop((x1, y1, x2, y2)) cropped.save(f"{ouput_path}/images/{img_idx}.jpg") except Exception as e: print(e) pass img_idx += 1 try: if label_type == 'title': draw.rectangle([x1, y1, x2, y2], outline=color, width=4) draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) else: draw.rectangle([x1, y1, x2, y2], outline=color, width=2) draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) text_x = x1 text_y = max(0, y1 - 15) text_bbox = draw.textbbox((0, 0), label_type, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], fill=(255, 255, 255, 30)) draw.text((text_x, text_y), label_type, font=font, fill=color) except: pass except: continue img_draw.paste(overlay, (0, 0), overlay) return img_draw def process_image_with_refs(image, ref_texts, output_path): result_image = draw_bounding_boxes(image, ref_texts, output_path) return result_image def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) # print(target_ratios) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # print(target_aspect_ratio) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images, target_aspect_ratio def normalize_transform(mean, std): if mean is None and std is None: transform = None elif mean is None and std is not None: mean = [0.] * len(std) transform = transforms.Normalize(mean=mean, std=std) elif mean is not None and std is None: std = [1.] * len(mean) transform = transforms.Normalize(mean=mean, std=std) else: transform = transforms.Normalize(mean=mean, std=std) return transform def format_messages( conversations: List[Dict[str, str]], sft_format: str = "deepseek", system_prompt: str = "", ): """ Applies the SFT template to conversation. Args: conversations (List[Dict]): A List of messages. sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". Returns: sft_prompt (str): The formatted text. """ conv = get_conv_template(sft_format) conv.set_system_message(system_prompt) for message in conversations: conv.append_message(message["role"], message["content"].strip()) sft_prompt = conv.get_prompt().strip() return sft_prompt def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): t = tokenizer.encode(text, add_special_tokens=False) bos_id = 0 eos_id = 1 if bos: t = [bos_id] + t if eos: t = t + [eos_id] return t def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: """ Args: conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : [ { "role": "User", "content": "\nExtract all information from this image and convert them into markdown format.", "images": ["./examples/table_datasets.png"] }, {"role": "Assistant", "content": ""}, ] Returns: pil_images (List[PIL.Image.Image]): the list of PIL images. """ pil_images = [] for message in conversations: if "images" not in message: continue for image_path in message["images"]: # print('----------------') # print(image_path) # print('----------------') # exit() # pil_img = Image.open(image_path) pil_img = load_image(image_path) pil_img = pil_img.convert("RGB") pil_images.append(pil_img) return pil_images class BaseTransform(ABC): def set_rng(self, *args, **kwargs): pass def __call__(self, *args, **kwargs) -> torch.Tensor: pass @property def default_shape(self): raise NotImplementedError class BasicImageTransform(BaseTransform): def __init__( self, mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), normalize: bool = True ): self.mean = mean self.std = std transform_pipelines = [ transforms.ToTensor() ] normalize = normalize_transform(mean, std) if normalize else nn.Identity() if normalize is not None: transform_pipelines.append(normalize) self.transform = transforms.Compose(transform_pipelines) def __call__(self, x): x = self.transform(x) return x class NoEOSTextStreamer(TextStreamer): def on_finalized_text(self, text: str, stream_end: bool = False): eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) text = text.replace(eos_text, "\n") print(text, flush=True, end="") class DeepseekOCRConfig(DeepseekV2Config): model_type = "DeepseekOCR" class DeepseekOCRModel(DeepseekV2Model): config_class = DeepseekOCRConfig def __init__(self, config: DeepseekV2Config): super(DeepseekOCRModel, self).__init__(config) self.sam_model = build_sam_vit_b() self.vision_model = build_clip_l() # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) n_embed = 1280 self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.FloatTensor] = None, images_spatial_crop: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if inputs_embeds is None: # inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids) sam_model = getattr(self, 'sam_model', None) # sam_model = self.sam_model vision_model = getattr(self, 'vision_model', None) if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: idx = 0 # sam_model = torch.jit.script(sam_model) # start_time = time.time() for image, crop_shape in zip(images, images_spatial_crop): images_in_this_batch = [] patches = image[0] image_ori = image[1] with torch.no_grad(): # with torch.inference_mode(): if torch.sum(patches).item() != 0: # P, C, H, W = patches.shape crop_flag = 1 local_features_1 = sam_model(patches) local_features_2 = vision_model(patches, local_features_1) # vit_time = time.time() local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) local_features = self.projector(local_features) global_features_1 = sam_model(image_ori) global_features_2 = vision_model(image_ori, global_features_1) global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features = self.projector(global_features) print('=====================') print('BASE: ', global_features.shape) print('PATCHES: ', local_features.shape) print('=====================') _, hw, n_dim = global_features.shape h = w = int(hw ** 0.5) _2, hw2, n_dim2 = local_features.shape h2 = w2 = int(hw2 ** 0.5) width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] global_features = global_features.view(h, w, n_dim) global_features = torch.cat( [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 ) global_features = global_features.view(-1, n_dim) local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) local_features = torch.cat( [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 ) local_features = local_features.view(-1, n_dim2) global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) # end_time = time.time() # print('sam: ', sam_time - start_time) # print('vit: ', vit_time - sam_time) # print('all: ', end_time - start_time) # exit() else: global_features_1 = sam_model(image_ori) global_features_2 = vision_model(image_ori, global_features_1) global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features = self.projector(global_features) print('=====================') print('BASE: ', global_features.shape) print('NO PATCHES') print('=====================') _, hw, n_dim = global_features.shape h = w = int(hw ** 0.5) global_features = global_features.view(h, w, n_dim) global_features = torch.cat( [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 ) global_features = global_features.view(-1, n_dim) global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) images_in_this_batch.append(global_local_features) # print(inputs_embeds.shape) if images_in_this_batch: images_in_this_batch = torch.cat(images_in_this_batch, dim=0) # exit() inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) idx += 1 return super(DeepseekOCRModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): config_class = DeepseekOCRConfig # supports_gradient_checkpointing = True def __init__(self, config): super(DeepseekV2ForCausalLM, self).__init__(config) self.model = DeepseekOCRModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.FloatTensor] = None, images_spatial_crop: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, images=images, images_seq_mask = images_seq_mask, images_spatial_crop = images_spatial_crop, return_dict=return_dict ) # print(transformer_outputs) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() # logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if self.generation_config.cache_implementation == "static": # # generation with static cache # cache_position = kwargs.get("cache_position", None) # if cache_position is None: # past_length = 0 # else: # past_length = cache_position[-1] + 1 # input_ids = input_ids[:, past_length:] # position_ids = position_ids[:, past_length:] # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), "images_seq_mask": kwargs.get("images_seq_mask", None), "images_spatial_crop": kwargs.get("images_spatial_crop", None), } ) return model_inputs def disable_torch_init(self): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): self.disable_torch_init() os.makedirs(output_path, exist_ok=True) os.makedirs(f'{output_path}/images', exist_ok=True) if prompt and image_file: conversation = [ { "role": "<|User|>", # "content": "\n<|grounding|>Given the layout of the image. ", "content": f'{prompt}', # "content": "君不见黄河之水天上来的下一句是什么?", # "content": "\nFree OCR. ", # "content": "\nParse the figure. ", # "content": "\nExtract the text in the image. ", "images": [f'{image_file}'], }, {"role": "<|Assistant|>", "content": ""}, ] elif prompt: conversation = [ { "role": "<|User|>", # "content": "\n<|grounding|>Given the layout of the image. ", "content": f'{prompt}', # "content": "君不见黄河之水天上来的下一句是什么?", # "content": "\nFree OCR. ", # "content": "\nParse the figure. ", # "content": "\nExtract the text in the image. ", # "images": [f'{image_file}'], }, {"role": "<|Assistant|>", "content": ""}, ] else: assert False, f'prompt is none!' prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') patch_size = 16 downsample_ratio = 4 images = load_pil_images(conversation) valid_img_tokens = 0 ratio = 1 image_draw = images[0].copy() w,h = image_draw.size # print(w, h) ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) images_seq_mask = [] image_token = '' image_token_id = 128815 text_splits = prompt.split(image_token) images_list, images_crop_list, images_seq_mask = [], [], [] tokenized_str = [] images_spatial_crop = [] for text_sep, image in zip(text_splits, images): tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) if crop_mode: if image.size[0] <= 640 and image.size[1] <= 640: crop_ratio = [1, 1] else: if crop_mode: # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) images_crop_raw, crop_ratio = dynamic_preprocess(image) else: # best_width, best_height = self.image_size, self.image_size crop_ratio = [1, 1] """process the global view""" # image = image.resize((base_size, base_size)) global_view = ImageOps.pad(image, (base_size, base_size), color=tuple(int(x * 255) for x in image_transform.mean)) if base_size == 1024: valid_img_tokens += int(256 * ratio) elif base_size == 1280: valid_img_tokens += int(400 * ratio) # elif base_size == 640: # valid_img_tokens += int(100 * ratio) images_list.append(image_transform(global_view).to(torch.bfloat16)) # global_view_tensor = image_transform(global_view).to(torch.bfloat16) width_crop_num, height_crop_num = crop_ratio images_spatial_crop.append([width_crop_num, height_crop_num]) if width_crop_num > 1 or height_crop_num > 1: """process the local views""" for i in range(len(images_crop_raw)): images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) if image_size == 640: valid_img_tokens += len(images_crop_list) * 100 num_queries = math.ceil((image_size // patch_size) / downsample_ratio) num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) """add image tokens""" tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base tokenized_image += [image_token_id] if width_crop_num > 1 or height_crop_num > 1: tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( num_queries * height_crop_num) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # num_image_tokens.append(len(tokenized_image)) else: # best_width, best_height = self.image_size, self.image_size # print(image.size, (best_width, best_height)) # check the select_best_resolutions func """process the global view""" if image_size <= 640: print('directly resize') image = image.resize((image_size, image_size)) # else: global_view = ImageOps.pad(image, (image_size, image_size), color=tuple(int(x * 255) for x in image_transform.mean)) images_list.append(image_transform(global_view).to(torch.bfloat16)) if base_size == 1024: valid_img_tokens += int(256 * ratio) elif base_size == 1280: valid_img_tokens += int(400 * ratio) elif base_size == 640: valid_img_tokens += int(100 * 1) elif base_size == 512: valid_img_tokens += int(64 * 1) width_crop_num, height_crop_num = 1, 1 images_spatial_crop.append([width_crop_num, height_crop_num]) """add image tokens""" num_queries = math.ceil((image_size // patch_size) / downsample_ratio) tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries tokenized_image += [image_token_id] # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( # num_queries * height_crop_num) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # num_image_tokens.append(len(tokenized_image)) """process the last text split""" tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) """add the bos tokens""" bos_id = 0 tokenized_str = [bos_id] + tokenized_str images_seq_mask = [False] + images_seq_mask input_ids = torch.LongTensor(tokenized_str) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) if len(images_list) == 0: images_ori = torch.zeros((1, 3, image_size, image_size)) images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) images_crop = torch.zeros((1, 3, base_size, base_size)) else: images_ori = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) if images_crop_list: images_crop = torch.stack(images_crop_list, dim=0) else: images_crop = torch.zeros((1, 3, base_size, base_size)) if not eval_mode: streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) with torch.autocast("cuda", dtype=torch.bfloat16): with torch.no_grad(): output_ids = self.generate( input_ids.unsqueeze(0).cuda(), images=[(images_crop.cuda(), images_ori.cuda())], images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), images_spatial_crop = images_spatial_crop, # do_sample=False, # num_beams = 1, temperature=0.0, eos_token_id=tokenizer.eos_token_id, streamer=streamer, max_new_tokens=8192, no_repeat_ngram_size = 20, use_cache = True ) else: with torch.autocast("cuda", dtype=torch.bfloat16): with torch.no_grad(): output_ids = self.generate( input_ids.unsqueeze(0).cuda(), images=[(images_crop.cuda(), images_ori.cuda())], images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), images_spatial_crop = images_spatial_crop, # do_sample=False, # num_beams = 1, temperature=0.0, eos_token_id=tokenizer.eos_token_id, max_new_tokens=8192, no_repeat_ngram_size = 35, use_cache = True ) if '' in conversation[0]['content'] and eval_mode: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) stop_str = '<|end▁of▁sentence|>' if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] # re_match outputs = outputs.strip() return outputs if '' in conversation[0]['content'] and test_compress: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) print('='*50) print('image size: ', (w, h)) print('valid image tokens: ', int(valid_img_tokens)) print('output texts tokens (valid): ', pure_texts_outputs_token_length) print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) print('='*50) if '' in conversation[0]['content'] and save_results: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) stop_str = '<|end▁of▁sentence|>' print('='*15 + 'save results:' + '='*15) # # # # conv.messages[-1][-1] = outputs if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() matches_ref, matches_images, mathes_other = re_match(outputs) # print(matches_ref) result = process_image_with_refs(image_draw, matches_ref, output_path) for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') # if 'structural formula' in conversation[0]['content']: # outputs = '' + outputs + '' with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: afile.write(outputs) if 'line_type' in outputs: import matplotlib.pyplot as plt lines = eval(outputs)['Line']['line'] line_type = eval(outputs)['Line']['line_type'] # print(lines) endpoints = eval(outputs)['Line']['line_endpoint'] fig, ax = plt.subplots(figsize=(3,3), dpi=200) ax.set_xlim(-15, 15) ax.set_ylim(-15, 15) for idx, line in enumerate(lines): try: p0 = eval(line.split(' -- ')[0]) p1 = eval(line.split(' -- ')[-1]) if line_type[idx] == '--': ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') else: ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') ax.scatter(p0[0], p0[1], s=5, color = 'k') ax.scatter(p1[0], p1[1], s=5, color = 'k') except: pass for endpoint in endpoints: label = endpoint.split(': ')[0] (x, y) = eval(endpoint.split(': ')[1]) ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', fontsize=5, fontweight='light') plt.savefig(f'{output_path}/geo.jpg') plt.close() result.save(f"{output_path}/result_with_boxes.jpg")