Make infer general, so it runs on non cuda devices
Browse filesInfer hard codes .cuda(), which does not allow it to run on CPU or other devices, instead use self.device so it can run on CPU, MPS etc.
- modeling_deepseekocr.py +365 -257
modeling_deepseekocr.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
|
| 2 |
from .configuration_deepseek_v2 import DeepseekV2Config
|
| 3 |
-
from transformers.modeling_outputs import
|
|
|
|
|
|
|
|
|
|
| 4 |
from typing import List, Optional, Tuple, Union
|
| 5 |
from transformers.cache_utils import Cache
|
| 6 |
import requests
|
|
@@ -25,14 +28,13 @@ import time
|
|
| 25 |
|
| 26 |
|
| 27 |
def load_image(image_path):
|
| 28 |
-
|
| 29 |
try:
|
| 30 |
image = Image.open(image_path)
|
| 31 |
-
|
| 32 |
corrected_image = ImageOps.exif_transpose(image)
|
| 33 |
-
|
| 34 |
return corrected_image
|
| 35 |
-
|
| 36 |
except Exception as e:
|
| 37 |
print(f"error: {e}")
|
| 38 |
try:
|
|
@@ -42,7 +44,7 @@ def load_image(image_path):
|
|
| 42 |
|
| 43 |
|
| 44 |
def re_match(text):
|
| 45 |
-
pattern = r
|
| 46 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 47 |
|
| 48 |
# pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
|
|
@@ -51,7 +53,7 @@ def re_match(text):
|
|
| 51 |
mathes_image = []
|
| 52 |
mathes_other = []
|
| 53 |
for a_match in matches:
|
| 54 |
-
if
|
| 55 |
mathes_image.append(a_match[0])
|
| 56 |
else:
|
| 57 |
mathes_other.append(a_match[0])
|
|
@@ -59,7 +61,6 @@ def re_match(text):
|
|
| 59 |
|
| 60 |
|
| 61 |
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
| 62 |
-
|
| 63 |
try:
|
| 64 |
label_type = ref_text[1]
|
| 65 |
cor_list = eval(ref_text[2])
|
|
@@ -71,33 +72,36 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
|
|
| 71 |
|
| 72 |
|
| 73 |
def draw_bounding_boxes(image, refs, ouput_path):
|
| 74 |
-
|
| 75 |
image_width, image_height = image.size
|
| 76 |
-
|
| 77 |
img_draw = image.copy()
|
| 78 |
draw = ImageDraw.Draw(img_draw)
|
| 79 |
|
| 80 |
-
overlay = Image.new(
|
| 81 |
draw2 = ImageDraw.Draw(overlay)
|
| 82 |
-
|
| 83 |
# try:
|
| 84 |
# except IOError:
|
| 85 |
# try:
|
| 86 |
-
# font = ImageFont.truetype("DejaVuSans.ttf", 20)
|
| 87 |
# except IOError:
|
| 88 |
font = ImageFont.load_default()
|
| 89 |
|
| 90 |
img_idx = 0
|
| 91 |
-
|
| 92 |
for i, ref in enumerate(refs):
|
| 93 |
try:
|
| 94 |
result = extract_coordinates_and_label(ref, image_width, image_height)
|
| 95 |
if result:
|
| 96 |
label_type, points_list = result
|
| 97 |
-
|
| 98 |
-
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
for points in points_list:
|
| 102 |
x1, y1, x2, y2 = points
|
| 103 |
|
|
@@ -107,7 +111,7 @@ def draw_bounding_boxes(image, refs, ouput_path):
|
|
| 107 |
x2 = int(x2 / 999 * image_width)
|
| 108 |
y2 = int(y2 / 999 * image_height)
|
| 109 |
|
| 110 |
-
if label_type ==
|
| 111 |
try:
|
| 112 |
cropped = image.crop((x1, y1, x2, y2))
|
| 113 |
cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
|
|
@@ -115,24 +119,35 @@ def draw_bounding_boxes(image, refs, ouput_path):
|
|
| 115 |
print(e)
|
| 116 |
pass
|
| 117 |
img_idx += 1
|
| 118 |
-
|
| 119 |
try:
|
| 120 |
-
if label_type ==
|
| 121 |
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
| 122 |
-
draw2.rectangle(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
else:
|
| 124 |
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
| 125 |
-
draw2.rectangle(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
text_x = x1
|
| 127 |
text_y = max(0, y1 - 15)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
| 131 |
text_width = text_bbox[2] - text_bbox[0]
|
| 132 |
text_height = text_bbox[3] - text_bbox[1]
|
| 133 |
-
draw.rectangle(
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
| 137 |
except:
|
| 138 |
pass
|
|
@@ -143,17 +158,13 @@ def draw_bounding_boxes(image, refs, ouput_path):
|
|
| 143 |
|
| 144 |
|
| 145 |
def process_image_with_refs(image, ref_texts, output_path):
|
| 146 |
-
|
| 147 |
result_image = draw_bounding_boxes(image, ref_texts, output_path)
|
| 148 |
-
|
| 149 |
-
return result_image
|
| 150 |
-
|
| 151 |
-
|
| 152 |
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 156 |
-
best_ratio_diff = float(
|
| 157 |
best_ratio = (1, 1)
|
| 158 |
area = width * height
|
| 159 |
for ratio in target_ratios:
|
|
@@ -169,20 +180,27 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
|
|
| 169 |
return best_ratio
|
| 170 |
|
| 171 |
|
| 172 |
-
def dynamic_preprocess(
|
|
|
|
|
|
|
| 173 |
orig_width, orig_height = image.size
|
| 174 |
aspect_ratio = orig_width / orig_height
|
| 175 |
|
| 176 |
# calculate the existing image aspect ratio
|
| 177 |
target_ratios = set(
|
| 178 |
-
(i, j)
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
# print(target_ratios)
|
| 181 |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 182 |
|
| 183 |
# find the closest aspect ratio to the target
|
| 184 |
target_aspect_ratio = find_closest_aspect_ratio(
|
| 185 |
-
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
|
|
| 186 |
|
| 187 |
# print(target_aspect_ratio)
|
| 188 |
# calculate the target width and height
|
|
@@ -198,7 +216,7 @@ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnai
|
|
| 198 |
(i % (target_width // image_size)) * image_size,
|
| 199 |
(i // (target_width // image_size)) * image_size,
|
| 200 |
((i % (target_width // image_size)) + 1) * image_size,
|
| 201 |
-
((i // (target_width // image_size)) + 1) * image_size
|
| 202 |
)
|
| 203 |
# split the image
|
| 204 |
split_img = resized_img.crop(box)
|
|
@@ -210,15 +228,14 @@ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnai
|
|
| 210 |
return processed_images, target_aspect_ratio
|
| 211 |
|
| 212 |
|
| 213 |
-
|
| 214 |
def normalize_transform(mean, std):
|
| 215 |
if mean is None and std is None:
|
| 216 |
transform = None
|
| 217 |
elif mean is None and std is not None:
|
| 218 |
-
mean = [0.] * len(std)
|
| 219 |
transform = transforms.Normalize(mean=mean, std=std)
|
| 220 |
elif mean is not None and std is None:
|
| 221 |
-
std = [1.] * len(mean)
|
| 222 |
transform = transforms.Normalize(mean=mean, std=std)
|
| 223 |
else:
|
| 224 |
transform = transforms.Normalize(mean=mean, std=std)
|
|
@@ -226,11 +243,10 @@ def normalize_transform(mean, std):
|
|
| 226 |
return transform
|
| 227 |
|
| 228 |
|
| 229 |
-
|
| 230 |
def format_messages(
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
):
|
| 235 |
"""
|
| 236 |
Applies the SFT template to conversation.
|
|
@@ -264,6 +280,7 @@ def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
|
|
| 264 |
|
| 265 |
return t
|
| 266 |
|
|
|
|
| 267 |
def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
|
| 268 |
"""
|
| 269 |
|
|
@@ -294,7 +311,7 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
|
|
| 294 |
# print(image_path)
|
| 295 |
# print('----------------')
|
| 296 |
# exit()
|
| 297 |
-
|
| 298 |
# pil_img = Image.open(image_path)
|
| 299 |
pil_img = load_image(image_path)
|
| 300 |
pil_img = pil_img.convert("RGB")
|
|
@@ -304,7 +321,6 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
|
|
| 304 |
|
| 305 |
|
| 306 |
class BaseTransform(ABC):
|
| 307 |
-
|
| 308 |
def set_rng(self, *args, **kwargs):
|
| 309 |
pass
|
| 310 |
|
|
@@ -318,32 +334,32 @@ class BaseTransform(ABC):
|
|
| 318 |
|
| 319 |
class BasicImageTransform(BaseTransform):
|
| 320 |
def __init__(
|
| 321 |
-
self,
|
| 322 |
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
| 323 |
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
| 324 |
-
normalize: bool = True
|
| 325 |
):
|
| 326 |
self.mean = mean
|
| 327 |
self.std = std
|
| 328 |
-
|
| 329 |
-
transform_pipelines = [
|
| 330 |
-
transforms.ToTensor()
|
| 331 |
-
]
|
| 332 |
|
| 333 |
normalize = normalize_transform(mean, std) if normalize else nn.Identity()
|
| 334 |
if normalize is not None:
|
| 335 |
transform_pipelines.append(normalize)
|
| 336 |
|
| 337 |
self.transform = transforms.Compose(transform_pipelines)
|
| 338 |
-
|
| 339 |
def __call__(self, x):
|
| 340 |
x = self.transform(x)
|
| 341 |
return x
|
| 342 |
|
|
|
|
| 343 |
class NoEOSTextStreamer(TextStreamer):
|
| 344 |
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 345 |
-
|
| 346 |
-
|
|
|
|
| 347 |
text = text.replace(eos_text, "\n")
|
| 348 |
print(text, flush=True, end="")
|
| 349 |
|
|
@@ -351,6 +367,7 @@ class NoEOSTextStreamer(TextStreamer):
|
|
| 351 |
class DeepseekOCRConfig(DeepseekV2Config):
|
| 352 |
model_type = "DeepseekOCR"
|
| 353 |
|
|
|
|
| 354 |
class DeepseekOCRModel(DeepseekV2Model):
|
| 355 |
config_class = DeepseekOCRConfig
|
| 356 |
|
|
@@ -361,14 +378,13 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 361 |
self.vision_model = build_clip_l()
|
| 362 |
# self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
|
| 363 |
n_embed = 1280
|
| 364 |
-
self.projector =
|
|
|
|
|
|
|
| 365 |
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
|
| 366 |
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 367 |
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 368 |
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
def forward(
|
| 373 |
self,
|
| 374 |
input_ids: torch.LongTensor = None,
|
|
@@ -384,28 +400,23 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 384 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 385 |
return_dict: Optional[bool] = None,
|
| 386 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
if inputs_embeds is None:
|
| 392 |
# inputs_embeds = self.embed_tokens(input_ids)
|
| 393 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
sam_model = getattr(self, 'sam_model', None)
|
| 398 |
# sam_model = self.sam_model
|
| 399 |
-
vision_model = getattr(self,
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
idx = 0
|
| 406 |
-
|
| 407 |
# sam_model = torch.jit.script(sam_model)
|
| 408 |
-
|
| 409 |
# start_time = time.time()
|
| 410 |
for image, crop_shape in zip(images, images_spatial_crop):
|
| 411 |
images_in_this_batch = []
|
|
@@ -414,53 +425,86 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 414 |
image_ori = image[1]
|
| 415 |
|
| 416 |
with torch.no_grad():
|
| 417 |
-
|
| 418 |
-
|
| 419 |
if torch.sum(patches).item() != 0:
|
| 420 |
# P, C, H, W = patches.shape
|
| 421 |
crop_flag = 1
|
| 422 |
local_features_1 = sam_model(patches)
|
| 423 |
|
| 424 |
-
local_features_2 = vision_model(patches, local_features_1)
|
| 425 |
# vit_time = time.time()
|
| 426 |
-
local_features = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
local_features = self.projector(local_features)
|
| 428 |
|
| 429 |
-
|
| 430 |
global_features_1 = sam_model(image_ori)
|
| 431 |
-
global_features_2 = vision_model(image_ori, global_features_1)
|
| 432 |
-
global_features = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
global_features = self.projector(global_features)
|
| 434 |
|
| 435 |
-
print(
|
| 436 |
-
print(
|
| 437 |
-
print(
|
| 438 |
-
print(
|
| 439 |
|
| 440 |
_, hw, n_dim = global_features.shape
|
| 441 |
-
h = w = int(hw
|
| 442 |
|
| 443 |
_2, hw2, n_dim2 = local_features.shape
|
| 444 |
-
h2 = w2 = int(hw2
|
| 445 |
|
| 446 |
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
|
| 447 |
|
| 448 |
global_features = global_features.view(h, w, n_dim)
|
| 449 |
|
| 450 |
global_features = torch.cat(
|
| 451 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
)
|
| 453 |
|
| 454 |
global_features = global_features.view(-1, n_dim)
|
| 455 |
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
local_features = torch.cat(
|
| 459 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
)
|
| 461 |
local_features = local_features.view(-1, n_dim2)
|
| 462 |
|
| 463 |
-
global_local_features = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
# end_time = time.time()
|
| 466 |
|
|
@@ -469,32 +513,42 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 469 |
# print('all: ', end_time - start_time)
|
| 470 |
|
| 471 |
# exit()
|
| 472 |
-
|
| 473 |
else:
|
| 474 |
global_features_1 = sam_model(image_ori)
|
| 475 |
-
global_features_2 = vision_model(image_ori, global_features_1)
|
| 476 |
-
global_features = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
global_features = self.projector(global_features)
|
| 478 |
-
print(
|
| 479 |
-
print(
|
| 480 |
-
print(
|
| 481 |
-
print(
|
| 482 |
_, hw, n_dim = global_features.shape
|
| 483 |
-
h = w = int(hw
|
| 484 |
-
|
| 485 |
|
| 486 |
global_features = global_features.view(h, w, n_dim)
|
| 487 |
|
| 488 |
global_features = torch.cat(
|
| 489 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
)
|
| 491 |
|
| 492 |
global_features = global_features.view(-1, n_dim)
|
| 493 |
|
| 494 |
-
global_local_features = torch.cat(
|
|
|
|
|
|
|
| 495 |
|
| 496 |
images_in_this_batch.append(global_local_features)
|
| 497 |
-
|
| 498 |
|
| 499 |
# print(inputs_embeds.shape)
|
| 500 |
|
|
@@ -502,21 +556,27 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
|
| 505 |
-
inputs_embeds[idx].masked_scatter_(
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
-
|
| 509 |
|
| 510 |
return super(DeepseekOCRModel, self).forward(
|
| 511 |
-
input_ids=None,
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
)
|
| 516 |
-
|
| 517 |
|
| 518 |
-
class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
| 519 |
|
|
|
|
| 520 |
config_class = DeepseekOCRConfig
|
| 521 |
# supports_gradient_checkpointing = True
|
| 522 |
|
|
@@ -536,7 +596,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 536 |
def get_model(self):
|
| 537 |
return self.model
|
| 538 |
|
| 539 |
-
|
| 540 |
def forward(
|
| 541 |
self,
|
| 542 |
input_ids: torch.LongTensor = None,
|
|
@@ -552,17 +611,22 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 552 |
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 553 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 554 |
return_dict: Optional[bool] = None,
|
| 555 |
-
|
| 556 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 557 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
output_hidden_states = (
|
| 559 |
-
output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
)
|
| 561 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 562 |
-
|
| 563 |
-
|
| 564 |
|
| 565 |
-
outputs
|
| 566 |
input_ids=input_ids,
|
| 567 |
past_key_values=past_key_values,
|
| 568 |
attention_mask=attention_mask,
|
|
@@ -572,14 +636,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 572 |
output_attentions=output_attentions,
|
| 573 |
output_hidden_states=output_hidden_states,
|
| 574 |
images=images,
|
| 575 |
-
images_seq_mask
|
| 576 |
-
images_spatial_crop
|
| 577 |
-
return_dict=return_dict
|
| 578 |
-
|
| 579 |
)
|
| 580 |
|
| 581 |
-
|
| 582 |
-
|
| 583 |
# print(transformer_outputs)
|
| 584 |
|
| 585 |
hidden_states = outputs[0]
|
|
@@ -613,9 +674,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 613 |
attentions=outputs.attentions,
|
| 614 |
)
|
| 615 |
|
| 616 |
-
|
| 617 |
def prepare_inputs_for_generation(
|
| 618 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
):
|
| 620 |
# Omit tokens covered by past_key_values
|
| 621 |
past_length = 0
|
|
@@ -632,7 +697,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 632 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 633 |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 634 |
# input)
|
| 635 |
-
if
|
|
|
|
|
|
|
|
|
|
| 636 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 637 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 638 |
# input_ids based on the past_length.
|
|
@@ -668,7 +736,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 668 |
|
| 669 |
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
| 670 |
# same goes for position ids. Could also help with continued generation.
|
| 671 |
-
cache_position = torch.arange(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
|
| 673 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 674 |
if inputs_embeds is not None and past_key_values is None:
|
|
@@ -688,45 +760,55 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 688 |
}
|
| 689 |
)
|
| 690 |
return model_inputs
|
| 691 |
-
|
| 692 |
|
| 693 |
def disable_torch_init(self):
|
| 694 |
"""
|
| 695 |
Disable the redundant torch default initialization to accelerate model creation.
|
| 696 |
"""
|
| 697 |
import torch
|
|
|
|
| 698 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 699 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 700 |
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
self.disable_torch_init()
|
| 705 |
|
| 706 |
os.makedirs(output_path, exist_ok=True)
|
| 707 |
-
os.makedirs(f
|
| 708 |
|
| 709 |
if prompt and image_file:
|
| 710 |
conversation = [
|
| 711 |
{
|
| 712 |
"role": "<|User|>",
|
| 713 |
# "content": "<image>\n<|grounding|>Given the layout of the image. ",
|
| 714 |
-
"content": f
|
| 715 |
# "content": "君不见黄河之水天上来的下一句是什么?",
|
| 716 |
# "content": "<image>\nFree OCR. ",
|
| 717 |
# "content": "<image>\nParse the figure. ",
|
| 718 |
# "content": "<image>\nExtract the text in the image. ",
|
| 719 |
-
"images": [f
|
| 720 |
},
|
| 721 |
{"role": "<|Assistant|>", "content": ""},
|
| 722 |
]
|
| 723 |
-
|
| 724 |
elif prompt:
|
| 725 |
conversation = [
|
| 726 |
{
|
| 727 |
"role": "<|User|>",
|
| 728 |
# "content": "<image>\n<|grounding|>Given the layout of the image. ",
|
| 729 |
-
"content": f
|
| 730 |
# "content": "君不见黄河之水天上来的下一句是什么?",
|
| 731 |
# "content": "<image>\nFree OCR. ",
|
| 732 |
# "content": "<image>\nParse the figure. ",
|
|
@@ -736,9 +818,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 736 |
{"role": "<|Assistant|>", "content": ""},
|
| 737 |
]
|
| 738 |
else:
|
| 739 |
-
assert False, f
|
| 740 |
-
|
| 741 |
-
prompt = format_messages(
|
|
|
|
|
|
|
| 742 |
|
| 743 |
patch_size = 16
|
| 744 |
downsample_ratio = 4
|
|
@@ -749,15 +833,16 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 749 |
|
| 750 |
image_draw = images[0].copy()
|
| 751 |
|
| 752 |
-
w,h = image_draw.size
|
| 753 |
# print(w, h)
|
| 754 |
ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
|
| 755 |
-
|
| 756 |
|
| 757 |
-
image_transform=BasicImageTransform(
|
|
|
|
|
|
|
| 758 |
images_seq_mask = []
|
| 759 |
|
| 760 |
-
image_token =
|
| 761 |
image_token_id = 128815
|
| 762 |
text_splits = prompt.split(image_token)
|
| 763 |
|
|
@@ -765,13 +850,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 765 |
tokenized_str = []
|
| 766 |
images_spatial_crop = []
|
| 767 |
for text_sep, image in zip(text_splits, images):
|
| 768 |
-
|
| 769 |
tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
|
| 770 |
tokenized_str += tokenized_sep
|
| 771 |
images_seq_mask += [False] * len(tokenized_sep)
|
| 772 |
|
| 773 |
if crop_mode:
|
| 774 |
-
|
| 775 |
if image.size[0] <= 640 and image.size[1] <= 640:
|
| 776 |
crop_ratio = [1, 1]
|
| 777 |
|
|
@@ -782,23 +865,22 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 782 |
else:
|
| 783 |
# best_width, best_height = self.image_size, self.image_size
|
| 784 |
crop_ratio = [1, 1]
|
| 785 |
-
|
| 786 |
"""process the global view"""
|
| 787 |
# image = image.resize((base_size, base_size))
|
| 788 |
-
global_view = ImageOps.pad(
|
| 789 |
-
|
| 790 |
-
|
|
|
|
|
|
|
|
|
|
| 791 |
if base_size == 1024:
|
| 792 |
valid_img_tokens += int(256 * ratio)
|
| 793 |
elif base_size == 1280:
|
| 794 |
valid_img_tokens += int(400 * ratio)
|
| 795 |
# elif base_size == 640:
|
| 796 |
# valid_img_tokens += int(100 * ratio)
|
| 797 |
-
|
| 798 |
-
|
| 799 |
|
| 800 |
-
|
| 801 |
-
|
| 802 |
images_list.append(image_transform(global_view).to(torch.bfloat16))
|
| 803 |
|
| 804 |
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
|
@@ -806,31 +888,34 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 806 |
width_crop_num, height_crop_num = crop_ratio
|
| 807 |
|
| 808 |
images_spatial_crop.append([width_crop_num, height_crop_num])
|
| 809 |
-
|
| 810 |
-
|
| 811 |
if width_crop_num > 1 or height_crop_num > 1:
|
| 812 |
"""process the local views"""
|
| 813 |
-
|
| 814 |
for i in range(len(images_crop_raw)):
|
| 815 |
-
images_crop_list.append(
|
| 816 |
-
|
|
|
|
|
|
|
| 817 |
if image_size == 640:
|
| 818 |
valid_img_tokens += len(images_crop_list) * 100
|
| 819 |
|
| 820 |
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 821 |
-
num_queries_base = math.ceil(
|
| 822 |
-
|
| 823 |
-
|
| 824 |
|
| 825 |
"""add image tokens"""
|
| 826 |
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
tokenized_image += [image_token_id]
|
| 831 |
if width_crop_num > 1 or height_crop_num > 1:
|
| 832 |
-
tokenized_image += (
|
| 833 |
-
|
|
|
|
|
|
|
| 834 |
tokenized_str += tokenized_image
|
| 835 |
images_seq_mask += [True] * len(tokenized_image)
|
| 836 |
# num_image_tokens.append(len(tokenized_image))
|
|
@@ -841,11 +926,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 841 |
|
| 842 |
"""process the global view"""
|
| 843 |
if image_size <= 640:
|
| 844 |
-
print(
|
| 845 |
image = image.resize((image_size, image_size))
|
| 846 |
# else:
|
| 847 |
-
global_view = ImageOps.pad(
|
| 848 |
-
|
|
|
|
|
|
|
|
|
|
| 849 |
images_list.append(image_transform(global_view).to(torch.bfloat16))
|
| 850 |
|
| 851 |
if base_size == 1024:
|
|
@@ -861,18 +949,18 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 861 |
|
| 862 |
images_spatial_crop.append([width_crop_num, height_crop_num])
|
| 863 |
|
| 864 |
-
|
| 865 |
"""add image tokens"""
|
| 866 |
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 867 |
|
| 868 |
-
tokenized_image = (
|
|
|
|
|
|
|
| 869 |
tokenized_image += [image_token_id]
|
| 870 |
# tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
|
| 871 |
# num_queries * height_crop_num)
|
| 872 |
tokenized_str += tokenized_image
|
| 873 |
images_seq_mask += [True] * len(tokenized_image)
|
| 874 |
# num_image_tokens.append(len(tokenized_image))
|
| 875 |
-
|
| 876 |
|
| 877 |
"""process the last text split"""
|
| 878 |
tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
|
|
@@ -881,19 +969,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 881 |
|
| 882 |
"""add the bos tokens"""
|
| 883 |
bos_id = 0
|
| 884 |
-
tokenized_str = [bos_id] + tokenized_str
|
| 885 |
images_seq_mask = [False] + images_seq_mask
|
| 886 |
|
| 887 |
-
|
| 888 |
-
|
| 889 |
input_ids = torch.LongTensor(tokenized_str)
|
| 890 |
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
| 895 |
|
| 896 |
-
|
| 897 |
if len(images_list) == 0:
|
| 898 |
images_ori = torch.zeros((1, 3, image_size, image_size))
|
| 899 |
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
|
@@ -907,131 +989,157 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 907 |
else:
|
| 908 |
images_crop = torch.zeros((1, 3, base_size, base_size))
|
| 909 |
|
| 910 |
-
|
| 911 |
-
|
| 912 |
if not eval_mode:
|
| 913 |
-
streamer = NoEOSTextStreamer(
|
| 914 |
-
|
|
|
|
|
|
|
| 915 |
with torch.no_grad():
|
| 916 |
output_ids = self.generate(
|
| 917 |
-
input_ids.unsqueeze(0).
|
| 918 |
-
images=[
|
| 919 |
-
|
| 920 |
-
|
|
|
|
|
|
|
| 921 |
# do_sample=False,
|
| 922 |
# num_beams = 1,
|
| 923 |
temperature=0.0,
|
| 924 |
eos_token_id=tokenizer.eos_token_id,
|
| 925 |
streamer=streamer,
|
| 926 |
max_new_tokens=8192,
|
| 927 |
-
no_repeat_ngram_size
|
| 928 |
-
use_cache
|
| 929 |
-
|
| 930 |
|
| 931 |
else:
|
| 932 |
-
with torch.autocast(
|
| 933 |
with torch.no_grad():
|
| 934 |
output_ids = self.generate(
|
| 935 |
-
input_ids.unsqueeze(0).
|
| 936 |
-
images=[
|
| 937 |
-
|
| 938 |
-
|
|
|
|
|
|
|
| 939 |
# do_sample=False,
|
| 940 |
# num_beams = 1,
|
| 941 |
temperature=0.0,
|
| 942 |
eos_token_id=tokenizer.eos_token_id,
|
| 943 |
max_new_tokens=8192,
|
| 944 |
-
no_repeat_ngram_size
|
| 945 |
-
use_cache
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
print(
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 976 |
# # # # conv.messages[-1][-1] = outputs
|
| 977 |
if outputs.endswith(stop_str):
|
| 978 |
-
outputs = outputs[
|
| 979 |
outputs = outputs.strip()
|
| 980 |
|
| 981 |
matches_ref, matches_images, mathes_other = re_match(outputs)
|
| 982 |
# print(matches_ref)
|
| 983 |
result = process_image_with_refs(image_draw, matches_ref, output_path)
|
| 984 |
|
| 985 |
-
|
| 986 |
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
|
| 987 |
-
outputs = outputs.replace(
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
|
| 991 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
|
| 993 |
# if 'structural formula' in conversation[0]['content']:
|
| 994 |
# outputs = '<smiles>' + outputs + '</smiles>'
|
| 995 |
-
with open(f
|
| 996 |
afile.write(outputs)
|
| 997 |
|
| 998 |
-
if
|
| 999 |
import matplotlib.pyplot as plt
|
| 1000 |
-
lines = eval(outputs)['Line']['line']
|
| 1001 |
|
| 1002 |
-
|
|
|
|
|
|
|
| 1003 |
# print(lines)
|
| 1004 |
|
| 1005 |
-
endpoints = eval(outputs)[
|
| 1006 |
|
| 1007 |
-
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
|
| 1008 |
ax.set_xlim(-15, 15)
|
| 1009 |
ax.set_ylim(-15, 15)
|
| 1010 |
|
| 1011 |
for idx, line in enumerate(lines):
|
| 1012 |
try:
|
| 1013 |
-
p0 = eval(line.split(
|
| 1014 |
-
p1 = eval(line.split(
|
| 1015 |
|
| 1016 |
-
if line_type[idx] ==
|
| 1017 |
-
ax.plot(
|
|
|
|
|
|
|
| 1018 |
else:
|
| 1019 |
-
ax.plot(
|
|
|
|
|
|
|
| 1020 |
|
| 1021 |
-
ax.scatter(p0[0], p0[1], s=5, color
|
| 1022 |
-
ax.scatter(p1[0], p1[1], s=5, color
|
| 1023 |
except:
|
| 1024 |
pass
|
| 1025 |
|
| 1026 |
for endpoint in endpoints:
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1035 |
plt.close()
|
| 1036 |
|
| 1037 |
result.save(f"{output_path}/result_with_boxes.jpg")
|
|
|
|
| 1 |
from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
|
| 2 |
from .configuration_deepseek_v2 import DeepseekV2Config
|
| 3 |
+
from transformers.modeling_outputs import (
|
| 4 |
+
BaseModelOutputWithPast,
|
| 5 |
+
CausalLMOutputWithPast,
|
| 6 |
+
)
|
| 7 |
from typing import List, Optional, Tuple, Union
|
| 8 |
from transformers.cache_utils import Cache
|
| 9 |
import requests
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def load_image(image_path):
|
|
|
|
| 31 |
try:
|
| 32 |
image = Image.open(image_path)
|
| 33 |
+
|
| 34 |
corrected_image = ImageOps.exif_transpose(image)
|
| 35 |
+
|
| 36 |
return corrected_image
|
| 37 |
+
|
| 38 |
except Exception as e:
|
| 39 |
print(f"error: {e}")
|
| 40 |
try:
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def re_match(text):
|
| 47 |
+
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
|
| 48 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 49 |
|
| 50 |
# pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
|
|
|
|
| 53 |
mathes_image = []
|
| 54 |
mathes_other = []
|
| 55 |
for a_match in matches:
|
| 56 |
+
if "<|ref|>image<|/ref|>" in a_match[0]:
|
| 57 |
mathes_image.append(a_match[0])
|
| 58 |
else:
|
| 59 |
mathes_other.append(a_match[0])
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def extract_coordinates_and_label(ref_text, image_width, image_height):
|
|
|
|
| 64 |
try:
|
| 65 |
label_type = ref_text[1]
|
| 66 |
cor_list = eval(ref_text[2])
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def draw_bounding_boxes(image, refs, ouput_path):
|
|
|
|
| 75 |
image_width, image_height = image.size
|
| 76 |
+
|
| 77 |
img_draw = image.copy()
|
| 78 |
draw = ImageDraw.Draw(img_draw)
|
| 79 |
|
| 80 |
+
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
|
| 81 |
draw2 = ImageDraw.Draw(overlay)
|
| 82 |
+
|
| 83 |
# try:
|
| 84 |
# except IOError:
|
| 85 |
# try:
|
| 86 |
+
# font = ImageFont.truetype("DejaVuSans.ttf", 20)
|
| 87 |
# except IOError:
|
| 88 |
font = ImageFont.load_default()
|
| 89 |
|
| 90 |
img_idx = 0
|
| 91 |
+
|
| 92 |
for i, ref in enumerate(refs):
|
| 93 |
try:
|
| 94 |
result = extract_coordinates_and_label(ref, image_width, image_height)
|
| 95 |
if result:
|
| 96 |
label_type, points_list = result
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
color = (
|
| 99 |
+
np.random.randint(0, 200),
|
| 100 |
+
np.random.randint(0, 200),
|
| 101 |
+
np.random.randint(0, 255),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
color_a = color + (20,)
|
| 105 |
for points in points_list:
|
| 106 |
x1, y1, x2, y2 = points
|
| 107 |
|
|
|
|
| 111 |
x2 = int(x2 / 999 * image_width)
|
| 112 |
y2 = int(y2 / 999 * image_height)
|
| 113 |
|
| 114 |
+
if label_type == "image":
|
| 115 |
try:
|
| 116 |
cropped = image.crop((x1, y1, x2, y2))
|
| 117 |
cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
|
|
|
|
| 119 |
print(e)
|
| 120 |
pass
|
| 121 |
img_idx += 1
|
| 122 |
+
|
| 123 |
try:
|
| 124 |
+
if label_type == "title":
|
| 125 |
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
| 126 |
+
draw2.rectangle(
|
| 127 |
+
[x1, y1, x2, y2],
|
| 128 |
+
fill=color_a,
|
| 129 |
+
outline=(0, 0, 0, 0),
|
| 130 |
+
width=1,
|
| 131 |
+
)
|
| 132 |
else:
|
| 133 |
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
| 134 |
+
draw2.rectangle(
|
| 135 |
+
[x1, y1, x2, y2],
|
| 136 |
+
fill=color_a,
|
| 137 |
+
outline=(0, 0, 0, 0),
|
| 138 |
+
width=1,
|
| 139 |
+
)
|
| 140 |
text_x = x1
|
| 141 |
text_y = max(0, y1 - 15)
|
| 142 |
+
|
|
|
|
| 143 |
text_bbox = draw.textbbox((0, 0), label_type, font=font)
|
| 144 |
text_width = text_bbox[2] - text_bbox[0]
|
| 145 |
text_height = text_bbox[3] - text_bbox[1]
|
| 146 |
+
draw.rectangle(
|
| 147 |
+
[text_x, text_y, text_x + text_width, text_y + text_height],
|
| 148 |
+
fill=(255, 255, 255, 30),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
draw.text((text_x, text_y), label_type, font=font, fill=color)
|
| 152 |
except:
|
| 153 |
pass
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
def process_image_with_refs(image, ref_texts, output_path):
|
|
|
|
| 161 |
result_image = draw_bounding_boxes(image, ref_texts, output_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
return result_image
|
| 164 |
|
| 165 |
|
| 166 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 167 |
+
best_ratio_diff = float("inf")
|
| 168 |
best_ratio = (1, 1)
|
| 169 |
area = width * height
|
| 170 |
for ratio in target_ratios:
|
|
|
|
| 180 |
return best_ratio
|
| 181 |
|
| 182 |
|
| 183 |
+
def dynamic_preprocess(
|
| 184 |
+
image, min_num=2, max_num=9, image_size=640, use_thumbnail=False
|
| 185 |
+
):
|
| 186 |
orig_width, orig_height = image.size
|
| 187 |
aspect_ratio = orig_width / orig_height
|
| 188 |
|
| 189 |
# calculate the existing image aspect ratio
|
| 190 |
target_ratios = set(
|
| 191 |
+
(i, j)
|
| 192 |
+
for n in range(min_num, max_num + 1)
|
| 193 |
+
for i in range(1, n + 1)
|
| 194 |
+
for j in range(1, n + 1)
|
| 195 |
+
if i * j <= max_num and i * j >= min_num
|
| 196 |
+
)
|
| 197 |
# print(target_ratios)
|
| 198 |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 199 |
|
| 200 |
# find the closest aspect ratio to the target
|
| 201 |
target_aspect_ratio = find_closest_aspect_ratio(
|
| 202 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
| 203 |
+
)
|
| 204 |
|
| 205 |
# print(target_aspect_ratio)
|
| 206 |
# calculate the target width and height
|
|
|
|
| 216 |
(i % (target_width // image_size)) * image_size,
|
| 217 |
(i // (target_width // image_size)) * image_size,
|
| 218 |
((i % (target_width // image_size)) + 1) * image_size,
|
| 219 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 220 |
)
|
| 221 |
# split the image
|
| 222 |
split_img = resized_img.crop(box)
|
|
|
|
| 228 |
return processed_images, target_aspect_ratio
|
| 229 |
|
| 230 |
|
|
|
|
| 231 |
def normalize_transform(mean, std):
|
| 232 |
if mean is None and std is None:
|
| 233 |
transform = None
|
| 234 |
elif mean is None and std is not None:
|
| 235 |
+
mean = [0.0] * len(std)
|
| 236 |
transform = transforms.Normalize(mean=mean, std=std)
|
| 237 |
elif mean is not None and std is None:
|
| 238 |
+
std = [1.0] * len(mean)
|
| 239 |
transform = transforms.Normalize(mean=mean, std=std)
|
| 240 |
else:
|
| 241 |
transform = transforms.Normalize(mean=mean, std=std)
|
|
|
|
| 243 |
return transform
|
| 244 |
|
| 245 |
|
|
|
|
| 246 |
def format_messages(
|
| 247 |
+
conversations: List[Dict[str, str]],
|
| 248 |
+
sft_format: str = "deepseek",
|
| 249 |
+
system_prompt: str = "",
|
| 250 |
):
|
| 251 |
"""
|
| 252 |
Applies the SFT template to conversation.
|
|
|
|
| 280 |
|
| 281 |
return t
|
| 282 |
|
| 283 |
+
|
| 284 |
def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
|
| 285 |
"""
|
| 286 |
|
|
|
|
| 311 |
# print(image_path)
|
| 312 |
# print('----------------')
|
| 313 |
# exit()
|
| 314 |
+
|
| 315 |
# pil_img = Image.open(image_path)
|
| 316 |
pil_img = load_image(image_path)
|
| 317 |
pil_img = pil_img.convert("RGB")
|
|
|
|
| 321 |
|
| 322 |
|
| 323 |
class BaseTransform(ABC):
|
|
|
|
| 324 |
def set_rng(self, *args, **kwargs):
|
| 325 |
pass
|
| 326 |
|
|
|
|
| 334 |
|
| 335 |
class BasicImageTransform(BaseTransform):
|
| 336 |
def __init__(
|
| 337 |
+
self,
|
| 338 |
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
| 339 |
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
| 340 |
+
normalize: bool = True,
|
| 341 |
):
|
| 342 |
self.mean = mean
|
| 343 |
self.std = std
|
| 344 |
+
|
| 345 |
+
transform_pipelines = [transforms.ToTensor()]
|
|
|
|
|
|
|
| 346 |
|
| 347 |
normalize = normalize_transform(mean, std) if normalize else nn.Identity()
|
| 348 |
if normalize is not None:
|
| 349 |
transform_pipelines.append(normalize)
|
| 350 |
|
| 351 |
self.transform = transforms.Compose(transform_pipelines)
|
| 352 |
+
|
| 353 |
def __call__(self, x):
|
| 354 |
x = self.transform(x)
|
| 355 |
return x
|
| 356 |
|
| 357 |
+
|
| 358 |
class NoEOSTextStreamer(TextStreamer):
|
| 359 |
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 360 |
+
eos_text = self.tokenizer.decode(
|
| 361 |
+
[self.tokenizer.eos_token_id], skip_special_tokens=False
|
| 362 |
+
)
|
| 363 |
text = text.replace(eos_text, "\n")
|
| 364 |
print(text, flush=True, end="")
|
| 365 |
|
|
|
|
| 367 |
class DeepseekOCRConfig(DeepseekV2Config):
|
| 368 |
model_type = "DeepseekOCR"
|
| 369 |
|
| 370 |
+
|
| 371 |
class DeepseekOCRModel(DeepseekV2Model):
|
| 372 |
config_class = DeepseekOCRConfig
|
| 373 |
|
|
|
|
| 378 |
self.vision_model = build_clip_l()
|
| 379 |
# self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
|
| 380 |
n_embed = 1280
|
| 381 |
+
self.projector = MlpProjector(
|
| 382 |
+
Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)
|
| 383 |
+
)
|
| 384 |
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
|
| 385 |
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 386 |
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 387 |
|
|
|
|
|
|
|
|
|
|
| 388 |
def forward(
|
| 389 |
self,
|
| 390 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 400 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 401 |
return_dict: Optional[bool] = None,
|
| 402 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
if inputs_embeds is None:
|
| 404 |
# inputs_embeds = self.embed_tokens(input_ids)
|
| 405 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 406 |
|
| 407 |
+
sam_model = getattr(self, "sam_model", None)
|
|
|
|
|
|
|
| 408 |
# sam_model = self.sam_model
|
| 409 |
+
vision_model = getattr(self, "vision_model", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
if (
|
| 412 |
+
sam_model is not None
|
| 413 |
+
and (input_ids.shape[1] != 1 or self.training)
|
| 414 |
+
and torch.sum(images[0][1]).item() != 0
|
| 415 |
+
):
|
| 416 |
idx = 0
|
| 417 |
+
|
| 418 |
# sam_model = torch.jit.script(sam_model)
|
| 419 |
+
|
| 420 |
# start_time = time.time()
|
| 421 |
for image, crop_shape in zip(images, images_spatial_crop):
|
| 422 |
images_in_this_batch = []
|
|
|
|
| 425 |
image_ori = image[1]
|
| 426 |
|
| 427 |
with torch.no_grad():
|
| 428 |
+
# with torch.inference_mode():
|
| 429 |
+
|
| 430 |
if torch.sum(patches).item() != 0:
|
| 431 |
# P, C, H, W = patches.shape
|
| 432 |
crop_flag = 1
|
| 433 |
local_features_1 = sam_model(patches)
|
| 434 |
|
| 435 |
+
local_features_2 = vision_model(patches, local_features_1)
|
| 436 |
# vit_time = time.time()
|
| 437 |
+
local_features = torch.cat(
|
| 438 |
+
(
|
| 439 |
+
local_features_2[:, 1:],
|
| 440 |
+
local_features_1.flatten(2).permute(0, 2, 1),
|
| 441 |
+
),
|
| 442 |
+
dim=-1,
|
| 443 |
+
)
|
| 444 |
local_features = self.projector(local_features)
|
| 445 |
|
|
|
|
| 446 |
global_features_1 = sam_model(image_ori)
|
| 447 |
+
global_features_2 = vision_model(image_ori, global_features_1)
|
| 448 |
+
global_features = torch.cat(
|
| 449 |
+
(
|
| 450 |
+
global_features_2[:, 1:],
|
| 451 |
+
global_features_1.flatten(2).permute(0, 2, 1),
|
| 452 |
+
),
|
| 453 |
+
dim=-1,
|
| 454 |
+
)
|
| 455 |
global_features = self.projector(global_features)
|
| 456 |
|
| 457 |
+
print("=====================")
|
| 458 |
+
print("BASE: ", global_features.shape)
|
| 459 |
+
print("PATCHES: ", local_features.shape)
|
| 460 |
+
print("=====================")
|
| 461 |
|
| 462 |
_, hw, n_dim = global_features.shape
|
| 463 |
+
h = w = int(hw**0.5)
|
| 464 |
|
| 465 |
_2, hw2, n_dim2 = local_features.shape
|
| 466 |
+
h2 = w2 = int(hw2**0.5)
|
| 467 |
|
| 468 |
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
|
| 469 |
|
| 470 |
global_features = global_features.view(h, w, n_dim)
|
| 471 |
|
| 472 |
global_features = torch.cat(
|
| 473 |
+
[
|
| 474 |
+
global_features,
|
| 475 |
+
self.image_newline[None, None, :].expand(h, 1, n_dim),
|
| 476 |
+
],
|
| 477 |
+
dim=1,
|
| 478 |
)
|
| 479 |
|
| 480 |
global_features = global_features.view(-1, n_dim)
|
| 481 |
|
| 482 |
+
local_features = (
|
| 483 |
+
local_features.view(
|
| 484 |
+
height_crop_num, width_crop_num, h2, w2, n_dim2
|
| 485 |
+
)
|
| 486 |
+
.permute(0, 2, 1, 3, 4)
|
| 487 |
+
.reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
|
| 488 |
+
)
|
| 489 |
local_features = torch.cat(
|
| 490 |
+
[
|
| 491 |
+
local_features,
|
| 492 |
+
self.image_newline[None, None, :].expand(
|
| 493 |
+
height_crop_num * h2, 1, n_dim2
|
| 494 |
+
),
|
| 495 |
+
],
|
| 496 |
+
dim=1,
|
| 497 |
)
|
| 498 |
local_features = local_features.view(-1, n_dim2)
|
| 499 |
|
| 500 |
+
global_local_features = torch.cat(
|
| 501 |
+
[
|
| 502 |
+
local_features,
|
| 503 |
+
global_features,
|
| 504 |
+
self.view_seperator[None, :],
|
| 505 |
+
],
|
| 506 |
+
dim=0,
|
| 507 |
+
)
|
| 508 |
|
| 509 |
# end_time = time.time()
|
| 510 |
|
|
|
|
| 513 |
# print('all: ', end_time - start_time)
|
| 514 |
|
| 515 |
# exit()
|
| 516 |
+
|
| 517 |
else:
|
| 518 |
global_features_1 = sam_model(image_ori)
|
| 519 |
+
global_features_2 = vision_model(image_ori, global_features_1)
|
| 520 |
+
global_features = torch.cat(
|
| 521 |
+
(
|
| 522 |
+
global_features_2[:, 1:],
|
| 523 |
+
global_features_1.flatten(2).permute(0, 2, 1),
|
| 524 |
+
),
|
| 525 |
+
dim=-1,
|
| 526 |
+
)
|
| 527 |
global_features = self.projector(global_features)
|
| 528 |
+
print("=====================")
|
| 529 |
+
print("BASE: ", global_features.shape)
|
| 530 |
+
print("NO PATCHES")
|
| 531 |
+
print("=====================")
|
| 532 |
_, hw, n_dim = global_features.shape
|
| 533 |
+
h = w = int(hw**0.5)
|
|
|
|
| 534 |
|
| 535 |
global_features = global_features.view(h, w, n_dim)
|
| 536 |
|
| 537 |
global_features = torch.cat(
|
| 538 |
+
[
|
| 539 |
+
global_features,
|
| 540 |
+
self.image_newline[None, None, :].expand(h, 1, n_dim),
|
| 541 |
+
],
|
| 542 |
+
dim=1,
|
| 543 |
)
|
| 544 |
|
| 545 |
global_features = global_features.view(-1, n_dim)
|
| 546 |
|
| 547 |
+
global_local_features = torch.cat(
|
| 548 |
+
[global_features, self.view_seperator[None, :]], dim=0
|
| 549 |
+
)
|
| 550 |
|
| 551 |
images_in_this_batch.append(global_local_features)
|
|
|
|
| 552 |
|
| 553 |
# print(inputs_embeds.shape)
|
| 554 |
|
|
|
|
| 556 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 557 |
# exit()
|
| 558 |
|
| 559 |
+
inputs_embeds[idx].masked_scatter_(
|
| 560 |
+
images_seq_mask[idx].unsqueeze(-1).to(self.device),
|
| 561 |
+
images_in_this_batch,
|
| 562 |
+
)
|
| 563 |
|
| 564 |
idx += 1
|
|
|
|
| 565 |
|
| 566 |
return super(DeepseekOCRModel, self).forward(
|
| 567 |
+
input_ids=None,
|
| 568 |
+
attention_mask=attention_mask,
|
| 569 |
+
past_key_values=past_key_values,
|
| 570 |
+
inputs_embeds=inputs_embeds,
|
| 571 |
+
use_cache=use_cache,
|
| 572 |
+
position_ids=position_ids,
|
| 573 |
+
output_attentions=output_attentions,
|
| 574 |
+
output_hidden_states=output_hidden_states,
|
| 575 |
+
return_dict=return_dict,
|
| 576 |
)
|
|
|
|
| 577 |
|
|
|
|
| 578 |
|
| 579 |
+
class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
| 580 |
config_class = DeepseekOCRConfig
|
| 581 |
# supports_gradient_checkpointing = True
|
| 582 |
|
|
|
|
| 596 |
def get_model(self):
|
| 597 |
return self.model
|
| 598 |
|
|
|
|
| 599 |
def forward(
|
| 600 |
self,
|
| 601 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 611 |
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 612 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 613 |
return_dict: Optional[bool] = None,
|
|
|
|
| 614 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 615 |
+
output_attentions = (
|
| 616 |
+
output_attentions
|
| 617 |
+
if output_attentions is not None
|
| 618 |
+
else self.config.output_attentions
|
| 619 |
+
)
|
| 620 |
output_hidden_states = (
|
| 621 |
+
output_hidden_states
|
| 622 |
+
if output_hidden_states is not None
|
| 623 |
+
else self.config.output_hidden_states
|
| 624 |
+
)
|
| 625 |
+
return_dict = (
|
| 626 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 627 |
)
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
+
outputs = self.model(
|
| 630 |
input_ids=input_ids,
|
| 631 |
past_key_values=past_key_values,
|
| 632 |
attention_mask=attention_mask,
|
|
|
|
| 636 |
output_attentions=output_attentions,
|
| 637 |
output_hidden_states=output_hidden_states,
|
| 638 |
images=images,
|
| 639 |
+
images_seq_mask=images_seq_mask,
|
| 640 |
+
images_spatial_crop=images_spatial_crop,
|
| 641 |
+
return_dict=return_dict,
|
|
|
|
| 642 |
)
|
| 643 |
|
|
|
|
|
|
|
| 644 |
# print(transformer_outputs)
|
| 645 |
|
| 646 |
hidden_states = outputs[0]
|
|
|
|
| 674 |
attentions=outputs.attentions,
|
| 675 |
)
|
| 676 |
|
|
|
|
| 677 |
def prepare_inputs_for_generation(
|
| 678 |
+
self,
|
| 679 |
+
input_ids,
|
| 680 |
+
past_key_values=None,
|
| 681 |
+
attention_mask=None,
|
| 682 |
+
inputs_embeds=None,
|
| 683 |
+
**kwargs,
|
| 684 |
):
|
| 685 |
# Omit tokens covered by past_key_values
|
| 686 |
past_length = 0
|
|
|
|
| 697 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 698 |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 699 |
# input)
|
| 700 |
+
if (
|
| 701 |
+
attention_mask is not None
|
| 702 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
| 703 |
+
):
|
| 704 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 705 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 706 |
# input_ids based on the past_length.
|
|
|
|
| 736 |
|
| 737 |
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
| 738 |
# same goes for position ids. Could also help with continued generation.
|
| 739 |
+
cache_position = torch.arange(
|
| 740 |
+
past_length,
|
| 741 |
+
past_length + position_ids.shape[-1],
|
| 742 |
+
device=position_ids.device,
|
| 743 |
+
)
|
| 744 |
|
| 745 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 746 |
if inputs_embeds is not None and past_key_values is None:
|
|
|
|
| 760 |
}
|
| 761 |
)
|
| 762 |
return model_inputs
|
|
|
|
| 763 |
|
| 764 |
def disable_torch_init(self):
|
| 765 |
"""
|
| 766 |
Disable the redundant torch default initialization to accelerate model creation.
|
| 767 |
"""
|
| 768 |
import torch
|
| 769 |
+
|
| 770 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 771 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 772 |
|
| 773 |
+
def infer(
|
| 774 |
+
self,
|
| 775 |
+
tokenizer,
|
| 776 |
+
prompt="",
|
| 777 |
+
image_file="",
|
| 778 |
+
output_path="",
|
| 779 |
+
base_size=1024,
|
| 780 |
+
image_size=640,
|
| 781 |
+
crop_mode=True,
|
| 782 |
+
test_compress=False,
|
| 783 |
+
save_results=False,
|
| 784 |
+
eval_mode=False,
|
| 785 |
+
):
|
| 786 |
self.disable_torch_init()
|
| 787 |
|
| 788 |
os.makedirs(output_path, exist_ok=True)
|
| 789 |
+
os.makedirs(f"{output_path}/images", exist_ok=True)
|
| 790 |
|
| 791 |
if prompt and image_file:
|
| 792 |
conversation = [
|
| 793 |
{
|
| 794 |
"role": "<|User|>",
|
| 795 |
# "content": "<image>\n<|grounding|>Given the layout of the image. ",
|
| 796 |
+
"content": f"{prompt}",
|
| 797 |
# "content": "君不见黄河之水天上来的下一句是什么?",
|
| 798 |
# "content": "<image>\nFree OCR. ",
|
| 799 |
# "content": "<image>\nParse the figure. ",
|
| 800 |
# "content": "<image>\nExtract the text in the image. ",
|
| 801 |
+
"images": [f"{image_file}"],
|
| 802 |
},
|
| 803 |
{"role": "<|Assistant|>", "content": ""},
|
| 804 |
]
|
| 805 |
+
|
| 806 |
elif prompt:
|
| 807 |
conversation = [
|
| 808 |
{
|
| 809 |
"role": "<|User|>",
|
| 810 |
# "content": "<image>\n<|grounding|>Given the layout of the image. ",
|
| 811 |
+
"content": f"{prompt}",
|
| 812 |
# "content": "君不见黄河之水天上来的下一句是什么?",
|
| 813 |
# "content": "<image>\nFree OCR. ",
|
| 814 |
# "content": "<image>\nParse the figure. ",
|
|
|
|
| 818 |
{"role": "<|Assistant|>", "content": ""},
|
| 819 |
]
|
| 820 |
else:
|
| 821 |
+
assert False, f"prompt is none!"
|
| 822 |
+
|
| 823 |
+
prompt = format_messages(
|
| 824 |
+
conversations=conversation, sft_format="plain", system_prompt=""
|
| 825 |
+
)
|
| 826 |
|
| 827 |
patch_size = 16
|
| 828 |
downsample_ratio = 4
|
|
|
|
| 833 |
|
| 834 |
image_draw = images[0].copy()
|
| 835 |
|
| 836 |
+
w, h = image_draw.size
|
| 837 |
# print(w, h)
|
| 838 |
ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
|
|
|
|
| 839 |
|
| 840 |
+
image_transform = BasicImageTransform(
|
| 841 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True
|
| 842 |
+
)
|
| 843 |
images_seq_mask = []
|
| 844 |
|
| 845 |
+
image_token = "<image>"
|
| 846 |
image_token_id = 128815
|
| 847 |
text_splits = prompt.split(image_token)
|
| 848 |
|
|
|
|
| 850 |
tokenized_str = []
|
| 851 |
images_spatial_crop = []
|
| 852 |
for text_sep, image in zip(text_splits, images):
|
|
|
|
| 853 |
tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
|
| 854 |
tokenized_str += tokenized_sep
|
| 855 |
images_seq_mask += [False] * len(tokenized_sep)
|
| 856 |
|
| 857 |
if crop_mode:
|
|
|
|
| 858 |
if image.size[0] <= 640 and image.size[1] <= 640:
|
| 859 |
crop_ratio = [1, 1]
|
| 860 |
|
|
|
|
| 865 |
else:
|
| 866 |
# best_width, best_height = self.image_size, self.image_size
|
| 867 |
crop_ratio = [1, 1]
|
| 868 |
+
|
| 869 |
"""process the global view"""
|
| 870 |
# image = image.resize((base_size, base_size))
|
| 871 |
+
global_view = ImageOps.pad(
|
| 872 |
+
image,
|
| 873 |
+
(base_size, base_size),
|
| 874 |
+
color=tuple(int(x * 255) for x in image_transform.mean),
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
if base_size == 1024:
|
| 878 |
valid_img_tokens += int(256 * ratio)
|
| 879 |
elif base_size == 1280:
|
| 880 |
valid_img_tokens += int(400 * ratio)
|
| 881 |
# elif base_size == 640:
|
| 882 |
# valid_img_tokens += int(100 * ratio)
|
|
|
|
|
|
|
| 883 |
|
|
|
|
|
|
|
| 884 |
images_list.append(image_transform(global_view).to(torch.bfloat16))
|
| 885 |
|
| 886 |
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
|
|
|
| 888 |
width_crop_num, height_crop_num = crop_ratio
|
| 889 |
|
| 890 |
images_spatial_crop.append([width_crop_num, height_crop_num])
|
| 891 |
+
|
|
|
|
| 892 |
if width_crop_num > 1 or height_crop_num > 1:
|
| 893 |
"""process the local views"""
|
| 894 |
+
|
| 895 |
for i in range(len(images_crop_raw)):
|
| 896 |
+
images_crop_list.append(
|
| 897 |
+
image_transform(images_crop_raw[i]).to(torch.bfloat16)
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
if image_size == 640:
|
| 901 |
valid_img_tokens += len(images_crop_list) * 100
|
| 902 |
|
| 903 |
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 904 |
+
num_queries_base = math.ceil(
|
| 905 |
+
(base_size // patch_size) / downsample_ratio
|
| 906 |
+
)
|
| 907 |
|
| 908 |
"""add image tokens"""
|
| 909 |
|
| 910 |
+
tokenized_image = (
|
| 911 |
+
[image_token_id] * num_queries_base + [image_token_id]
|
| 912 |
+
) * num_queries_base
|
| 913 |
tokenized_image += [image_token_id]
|
| 914 |
if width_crop_num > 1 or height_crop_num > 1:
|
| 915 |
+
tokenized_image += (
|
| 916 |
+
[image_token_id] * (num_queries * width_crop_num)
|
| 917 |
+
+ [image_token_id]
|
| 918 |
+
) * (num_queries * height_crop_num)
|
| 919 |
tokenized_str += tokenized_image
|
| 920 |
images_seq_mask += [True] * len(tokenized_image)
|
| 921 |
# num_image_tokens.append(len(tokenized_image))
|
|
|
|
| 926 |
|
| 927 |
"""process the global view"""
|
| 928 |
if image_size <= 640:
|
| 929 |
+
print("directly resize")
|
| 930 |
image = image.resize((image_size, image_size))
|
| 931 |
# else:
|
| 932 |
+
global_view = ImageOps.pad(
|
| 933 |
+
image,
|
| 934 |
+
(image_size, image_size),
|
| 935 |
+
color=tuple(int(x * 255) for x in image_transform.mean),
|
| 936 |
+
)
|
| 937 |
images_list.append(image_transform(global_view).to(torch.bfloat16))
|
| 938 |
|
| 939 |
if base_size == 1024:
|
|
|
|
| 949 |
|
| 950 |
images_spatial_crop.append([width_crop_num, height_crop_num])
|
| 951 |
|
|
|
|
| 952 |
"""add image tokens"""
|
| 953 |
num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
|
| 954 |
|
| 955 |
+
tokenized_image = (
|
| 956 |
+
[image_token_id] * num_queries + [image_token_id]
|
| 957 |
+
) * num_queries
|
| 958 |
tokenized_image += [image_token_id]
|
| 959 |
# tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
|
| 960 |
# num_queries * height_crop_num)
|
| 961 |
tokenized_str += tokenized_image
|
| 962 |
images_seq_mask += [True] * len(tokenized_image)
|
| 963 |
# num_image_tokens.append(len(tokenized_image))
|
|
|
|
| 964 |
|
| 965 |
"""process the last text split"""
|
| 966 |
tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
|
|
|
|
| 969 |
|
| 970 |
"""add the bos tokens"""
|
| 971 |
bos_id = 0
|
| 972 |
+
tokenized_str = [bos_id] + tokenized_str
|
| 973 |
images_seq_mask = [False] + images_seq_mask
|
| 974 |
|
|
|
|
|
|
|
| 975 |
input_ids = torch.LongTensor(tokenized_str)
|
| 976 |
|
|
|
|
|
|
|
|
|
|
| 977 |
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
| 978 |
|
|
|
|
| 979 |
if len(images_list) == 0:
|
| 980 |
images_ori = torch.zeros((1, 3, image_size, image_size))
|
| 981 |
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
|
|
|
| 989 |
else:
|
| 990 |
images_crop = torch.zeros((1, 3, base_size, base_size))
|
| 991 |
|
|
|
|
|
|
|
| 992 |
if not eval_mode:
|
| 993 |
+
streamer = NoEOSTextStreamer(
|
| 994 |
+
tokenizer, skip_prompt=True, skip_special_tokens=False
|
| 995 |
+
)
|
| 996 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 997 |
with torch.no_grad():
|
| 998 |
output_ids = self.generate(
|
| 999 |
+
input_ids.unsqueeze(0).to(self.device),
|
| 1000 |
+
images=[
|
| 1001 |
+
(images_crop.to(self.device), images_ori.to(self.device))
|
| 1002 |
+
],
|
| 1003 |
+
images_seq_mask=images_seq_mask.unsqueeze(0).to(self.device),
|
| 1004 |
+
images_spatial_crop=images_spatial_crop,
|
| 1005 |
# do_sample=False,
|
| 1006 |
# num_beams = 1,
|
| 1007 |
temperature=0.0,
|
| 1008 |
eos_token_id=tokenizer.eos_token_id,
|
| 1009 |
streamer=streamer,
|
| 1010 |
max_new_tokens=8192,
|
| 1011 |
+
no_repeat_ngram_size=20,
|
| 1012 |
+
use_cache=True,
|
| 1013 |
+
)
|
| 1014 |
|
| 1015 |
else:
|
| 1016 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 1017 |
with torch.no_grad():
|
| 1018 |
output_ids = self.generate(
|
| 1019 |
+
input_ids.unsqueeze(0).to(self.device),
|
| 1020 |
+
images=[
|
| 1021 |
+
(images_crop.to(self.device), images_ori.to(self.device))
|
| 1022 |
+
],
|
| 1023 |
+
images_seq_mask=images_seq_mask.unsqueeze(0).to(self.device),
|
| 1024 |
+
images_spatial_crop=images_spatial_crop,
|
| 1025 |
# do_sample=False,
|
| 1026 |
# num_beams = 1,
|
| 1027 |
temperature=0.0,
|
| 1028 |
eos_token_id=tokenizer.eos_token_id,
|
| 1029 |
max_new_tokens=8192,
|
| 1030 |
+
no_repeat_ngram_size=35,
|
| 1031 |
+
use_cache=True,
|
| 1032 |
+
)
|
| 1033 |
+
|
| 1034 |
+
if "<image>" in conversation[0]["content"] and eval_mode:
|
| 1035 |
+
outputs = tokenizer.decode(
|
| 1036 |
+
output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
|
| 1037 |
+
)
|
| 1038 |
+
stop_str = "<|end▁of▁sentence|>"
|
| 1039 |
+
if outputs.endswith(stop_str):
|
| 1040 |
+
outputs = outputs[: -len(stop_str)]
|
| 1041 |
+
# re_match
|
| 1042 |
+
outputs = outputs.strip()
|
| 1043 |
+
|
| 1044 |
+
return outputs
|
| 1045 |
+
|
| 1046 |
+
if "<image>" in conversation[0]["content"] and test_compress:
|
| 1047 |
+
outputs = tokenizer.decode(
|
| 1048 |
+
output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
|
| 1049 |
+
)
|
| 1050 |
+
pure_texts_outputs_token_length = len(
|
| 1051 |
+
text_encode(tokenizer, outputs, bos=False, eos=False)
|
| 1052 |
+
)
|
| 1053 |
+
print("=" * 50)
|
| 1054 |
+
print("image size: ", (w, h))
|
| 1055 |
+
print("valid image tokens: ", int(valid_img_tokens))
|
| 1056 |
+
print("output texts tokens (valid): ", pure_texts_outputs_token_length)
|
| 1057 |
+
print(
|
| 1058 |
+
"compression ratio: ",
|
| 1059 |
+
round(pure_texts_outputs_token_length / valid_img_tokens, 2),
|
| 1060 |
+
)
|
| 1061 |
+
print("=" * 50)
|
| 1062 |
+
|
| 1063 |
+
if "<image>" in conversation[0]["content"] and save_results:
|
| 1064 |
+
outputs = tokenizer.decode(
|
| 1065 |
+
output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
|
| 1066 |
+
)
|
| 1067 |
+
stop_str = "<|end▁of▁sentence|>"
|
| 1068 |
+
|
| 1069 |
+
print("=" * 15 + "save results:" + "=" * 15)
|
| 1070 |
+
|
| 1071 |
# # # # conv.messages[-1][-1] = outputs
|
| 1072 |
if outputs.endswith(stop_str):
|
| 1073 |
+
outputs = outputs[: -len(stop_str)]
|
| 1074 |
outputs = outputs.strip()
|
| 1075 |
|
| 1076 |
matches_ref, matches_images, mathes_other = re_match(outputs)
|
| 1077 |
# print(matches_ref)
|
| 1078 |
result = process_image_with_refs(image_draw, matches_ref, output_path)
|
| 1079 |
|
|
|
|
| 1080 |
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
|
| 1081 |
+
outputs = outputs.replace(
|
| 1082 |
+
a_match_image, " + ".jpg)\n"
|
| 1083 |
+
)
|
|
|
|
| 1084 |
|
| 1085 |
+
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
|
| 1086 |
+
outputs = (
|
| 1087 |
+
outputs.replace(a_match_other, "")
|
| 1088 |
+
.replace("\\coloneqq", ":=")
|
| 1089 |
+
.replace("\\eqqcolon", "=:")
|
| 1090 |
+
)
|
| 1091 |
|
| 1092 |
# if 'structural formula' in conversation[0]['content']:
|
| 1093 |
# outputs = '<smiles>' + outputs + '</smiles>'
|
| 1094 |
+
with open(f"{output_path}/result.mmd", "w", encoding="utf-8") as afile:
|
| 1095 |
afile.write(outputs)
|
| 1096 |
|
| 1097 |
+
if "line_type" in outputs:
|
| 1098 |
import matplotlib.pyplot as plt
|
|
|
|
| 1099 |
|
| 1100 |
+
lines = eval(outputs)["Line"]["line"]
|
| 1101 |
+
|
| 1102 |
+
line_type = eval(outputs)["Line"]["line_type"]
|
| 1103 |
# print(lines)
|
| 1104 |
|
| 1105 |
+
endpoints = eval(outputs)["Line"]["line_endpoint"]
|
| 1106 |
|
| 1107 |
+
fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
|
| 1108 |
ax.set_xlim(-15, 15)
|
| 1109 |
ax.set_ylim(-15, 15)
|
| 1110 |
|
| 1111 |
for idx, line in enumerate(lines):
|
| 1112 |
try:
|
| 1113 |
+
p0 = eval(line.split(" -- ")[0])
|
| 1114 |
+
p1 = eval(line.split(" -- ")[-1])
|
| 1115 |
|
| 1116 |
+
if line_type[idx] == "--":
|
| 1117 |
+
ax.plot(
|
| 1118 |
+
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
|
| 1119 |
+
)
|
| 1120 |
else:
|
| 1121 |
+
ax.plot(
|
| 1122 |
+
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
|
| 1123 |
+
)
|
| 1124 |
|
| 1125 |
+
ax.scatter(p0[0], p0[1], s=5, color="k")
|
| 1126 |
+
ax.scatter(p1[0], p1[1], s=5, color="k")
|
| 1127 |
except:
|
| 1128 |
pass
|
| 1129 |
|
| 1130 |
for endpoint in endpoints:
|
| 1131 |
+
label = endpoint.split(": ")[0]
|
| 1132 |
+
(x, y) = eval(endpoint.split(": ")[1])
|
| 1133 |
+
ax.annotate(
|
| 1134 |
+
label,
|
| 1135 |
+
(x, y),
|
| 1136 |
+
xytext=(1, 1),
|
| 1137 |
+
textcoords="offset points",
|
| 1138 |
+
fontsize=5,
|
| 1139 |
+
fontweight="light",
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
plt.savefig(f"{output_path}/geo.jpg")
|
| 1143 |
plt.close()
|
| 1144 |
|
| 1145 |
result.save(f"{output_path}/result_with_boxes.jpg")
|