Make infer general, so it runs on non cuda devices

#6
Files changed (1) hide show
  1. modeling_deepseekocr.py +365 -257
modeling_deepseekocr.py CHANGED
@@ -1,6 +1,9 @@
1
  from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2
  from .configuration_deepseek_v2 import DeepseekV2Config
3
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
4
  from typing import List, Optional, Tuple, Union
5
  from transformers.cache_utils import Cache
6
  import requests
@@ -25,14 +28,13 @@ import time
25
 
26
 
27
  def load_image(image_path):
28
-
29
  try:
30
  image = Image.open(image_path)
31
-
32
  corrected_image = ImageOps.exif_transpose(image)
33
-
34
  return corrected_image
35
-
36
  except Exception as e:
37
  print(f"error: {e}")
38
  try:
@@ -42,7 +44,7 @@ def load_image(image_path):
42
 
43
 
44
  def re_match(text):
45
- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
46
  matches = re.findall(pattern, text, re.DOTALL)
47
 
48
  # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
@@ -51,7 +53,7 @@ def re_match(text):
51
  mathes_image = []
52
  mathes_other = []
53
  for a_match in matches:
54
- if '<|ref|>image<|/ref|>' in a_match[0]:
55
  mathes_image.append(a_match[0])
56
  else:
57
  mathes_other.append(a_match[0])
@@ -59,7 +61,6 @@ def re_match(text):
59
 
60
 
61
  def extract_coordinates_and_label(ref_text, image_width, image_height):
62
-
63
  try:
64
  label_type = ref_text[1]
65
  cor_list = eval(ref_text[2])
@@ -71,33 +72,36 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
71
 
72
 
73
  def draw_bounding_boxes(image, refs, ouput_path):
74
-
75
  image_width, image_height = image.size
76
-
77
  img_draw = image.copy()
78
  draw = ImageDraw.Draw(img_draw)
79
 
80
- overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
81
  draw2 = ImageDraw.Draw(overlay)
82
-
83
  # try:
84
  # except IOError:
85
  # try:
86
- # font = ImageFont.truetype("DejaVuSans.ttf", 20)
87
  # except IOError:
88
  font = ImageFont.load_default()
89
 
90
  img_idx = 0
91
-
92
  for i, ref in enumerate(refs):
93
  try:
94
  result = extract_coordinates_and_label(ref, image_width, image_height)
95
  if result:
96
  label_type, points_list = result
97
-
98
- color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
99
 
100
- color_a = color + (20, )
 
 
 
 
 
 
101
  for points in points_list:
102
  x1, y1, x2, y2 = points
103
 
@@ -107,7 +111,7 @@ def draw_bounding_boxes(image, refs, ouput_path):
107
  x2 = int(x2 / 999 * image_width)
108
  y2 = int(y2 / 999 * image_height)
109
 
110
- if label_type == 'image':
111
  try:
112
  cropped = image.crop((x1, y1, x2, y2))
113
  cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
@@ -115,24 +119,35 @@ def draw_bounding_boxes(image, refs, ouput_path):
115
  print(e)
116
  pass
117
  img_idx += 1
118
-
119
  try:
120
- if label_type == 'title':
121
  draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
122
- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
 
 
 
 
 
123
  else:
124
  draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
125
- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
 
 
 
 
 
126
  text_x = x1
127
  text_y = max(0, y1 - 15)
128
-
129
-
130
  text_bbox = draw.textbbox((0, 0), label_type, font=font)
131
  text_width = text_bbox[2] - text_bbox[0]
132
  text_height = text_bbox[3] - text_bbox[1]
133
- draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
134
- fill=(255, 255, 255, 30))
135
-
 
 
136
  draw.text((text_x, text_y), label_type, font=font, fill=color)
137
  except:
138
  pass
@@ -143,17 +158,13 @@ def draw_bounding_boxes(image, refs, ouput_path):
143
 
144
 
145
  def process_image_with_refs(image, ref_texts, output_path):
146
-
147
  result_image = draw_bounding_boxes(image, ref_texts, output_path)
148
-
149
- return result_image
150
-
151
-
152
 
 
153
 
154
 
155
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
156
- best_ratio_diff = float('inf')
157
  best_ratio = (1, 1)
158
  area = width * height
159
  for ratio in target_ratios:
@@ -169,20 +180,27 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
169
  return best_ratio
170
 
171
 
172
- def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False):
 
 
173
  orig_width, orig_height = image.size
174
  aspect_ratio = orig_width / orig_height
175
 
176
  # calculate the existing image aspect ratio
177
  target_ratios = set(
178
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
179
- i * j <= max_num and i * j >= min_num)
 
 
 
 
180
  # print(target_ratios)
181
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
182
 
183
  # find the closest aspect ratio to the target
184
  target_aspect_ratio = find_closest_aspect_ratio(
185
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
 
186
 
187
  # print(target_aspect_ratio)
188
  # calculate the target width and height
@@ -198,7 +216,7 @@ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnai
198
  (i % (target_width // image_size)) * image_size,
199
  (i // (target_width // image_size)) * image_size,
200
  ((i % (target_width // image_size)) + 1) * image_size,
201
- ((i // (target_width // image_size)) + 1) * image_size
202
  )
203
  # split the image
204
  split_img = resized_img.crop(box)
@@ -210,15 +228,14 @@ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnai
210
  return processed_images, target_aspect_ratio
211
 
212
 
213
-
214
  def normalize_transform(mean, std):
215
  if mean is None and std is None:
216
  transform = None
217
  elif mean is None and std is not None:
218
- mean = [0.] * len(std)
219
  transform = transforms.Normalize(mean=mean, std=std)
220
  elif mean is not None and std is None:
221
- std = [1.] * len(mean)
222
  transform = transforms.Normalize(mean=mean, std=std)
223
  else:
224
  transform = transforms.Normalize(mean=mean, std=std)
@@ -226,11 +243,10 @@ def normalize_transform(mean, std):
226
  return transform
227
 
228
 
229
-
230
  def format_messages(
231
- conversations: List[Dict[str, str]],
232
- sft_format: str = "deepseek",
233
- system_prompt: str = "",
234
  ):
235
  """
236
  Applies the SFT template to conversation.
@@ -264,6 +280,7 @@ def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
264
 
265
  return t
266
 
 
267
  def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
268
  """
269
 
@@ -294,7 +311,7 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
294
  # print(image_path)
295
  # print('----------------')
296
  # exit()
297
-
298
  # pil_img = Image.open(image_path)
299
  pil_img = load_image(image_path)
300
  pil_img = pil_img.convert("RGB")
@@ -304,7 +321,6 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
304
 
305
 
306
  class BaseTransform(ABC):
307
-
308
  def set_rng(self, *args, **kwargs):
309
  pass
310
 
@@ -318,32 +334,32 @@ class BaseTransform(ABC):
318
 
319
  class BasicImageTransform(BaseTransform):
320
  def __init__(
321
- self,
322
  mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
323
  std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
324
- normalize: bool = True
325
  ):
326
  self.mean = mean
327
  self.std = std
328
-
329
- transform_pipelines = [
330
- transforms.ToTensor()
331
- ]
332
 
333
  normalize = normalize_transform(mean, std) if normalize else nn.Identity()
334
  if normalize is not None:
335
  transform_pipelines.append(normalize)
336
 
337
  self.transform = transforms.Compose(transform_pipelines)
338
-
339
  def __call__(self, x):
340
  x = self.transform(x)
341
  return x
342
 
 
343
  class NoEOSTextStreamer(TextStreamer):
344
  def on_finalized_text(self, text: str, stream_end: bool = False):
345
-
346
- eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
 
347
  text = text.replace(eos_text, "\n")
348
  print(text, flush=True, end="")
349
 
@@ -351,6 +367,7 @@ class NoEOSTextStreamer(TextStreamer):
351
  class DeepseekOCRConfig(DeepseekV2Config):
352
  model_type = "DeepseekOCR"
353
 
 
354
  class DeepseekOCRModel(DeepseekV2Model):
355
  config_class = DeepseekOCRConfig
356
 
@@ -361,14 +378,13 @@ class DeepseekOCRModel(DeepseekV2Model):
361
  self.vision_model = build_clip_l()
362
  # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
363
  n_embed = 1280
364
- self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
 
 
365
  embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
366
  self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
367
  self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
368
 
369
-
370
-
371
-
372
  def forward(
373
  self,
374
  input_ids: torch.LongTensor = None,
@@ -384,28 +400,23 @@ class DeepseekOCRModel(DeepseekV2Model):
384
  images_spatial_crop: Optional[torch.FloatTensor] = None,
385
  return_dict: Optional[bool] = None,
386
  ) -> Union[Tuple, BaseModelOutputWithPast]:
387
-
388
-
389
-
390
-
391
  if inputs_embeds is None:
392
  # inputs_embeds = self.embed_tokens(input_ids)
393
  inputs_embeds = self.get_input_embeddings()(input_ids)
394
 
395
-
396
-
397
- sam_model = getattr(self, 'sam_model', None)
398
  # sam_model = self.sam_model
399
- vision_model = getattr(self, 'vision_model', None)
400
-
401
-
402
-
403
- if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
404
 
 
 
 
 
 
405
  idx = 0
406
-
407
  # sam_model = torch.jit.script(sam_model)
408
-
409
  # start_time = time.time()
410
  for image, crop_shape in zip(images, images_spatial_crop):
411
  images_in_this_batch = []
@@ -414,53 +425,86 @@ class DeepseekOCRModel(DeepseekV2Model):
414
  image_ori = image[1]
415
 
416
  with torch.no_grad():
417
- # with torch.inference_mode():
418
-
419
  if torch.sum(patches).item() != 0:
420
  # P, C, H, W = patches.shape
421
  crop_flag = 1
422
  local_features_1 = sam_model(patches)
423
 
424
- local_features_2 = vision_model(patches, local_features_1)
425
  # vit_time = time.time()
426
- local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
 
 
 
 
 
 
427
  local_features = self.projector(local_features)
428
 
429
-
430
  global_features_1 = sam_model(image_ori)
431
- global_features_2 = vision_model(image_ori, global_features_1)
432
- global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
 
 
 
 
 
 
433
  global_features = self.projector(global_features)
434
 
435
- print('=====================')
436
- print('BASE: ', global_features.shape)
437
- print('PATCHES: ', local_features.shape)
438
- print('=====================')
439
 
440
  _, hw, n_dim = global_features.shape
441
- h = w = int(hw ** 0.5)
442
 
443
  _2, hw2, n_dim2 = local_features.shape
444
- h2 = w2 = int(hw2 ** 0.5)
445
 
446
  width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
447
 
448
  global_features = global_features.view(h, w, n_dim)
449
 
450
  global_features = torch.cat(
451
- [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
 
 
 
 
452
  )
453
 
454
  global_features = global_features.view(-1, n_dim)
455
 
456
-
457
- local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
 
 
 
 
 
458
  local_features = torch.cat(
459
- [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
 
 
 
 
 
 
460
  )
461
  local_features = local_features.view(-1, n_dim2)
462
 
463
- global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
 
 
 
 
 
 
 
464
 
465
  # end_time = time.time()
466
 
@@ -469,32 +513,42 @@ class DeepseekOCRModel(DeepseekV2Model):
469
  # print('all: ', end_time - start_time)
470
 
471
  # exit()
472
-
473
  else:
474
  global_features_1 = sam_model(image_ori)
475
- global_features_2 = vision_model(image_ori, global_features_1)
476
- global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
 
 
 
 
 
 
477
  global_features = self.projector(global_features)
478
- print('=====================')
479
- print('BASE: ', global_features.shape)
480
- print('NO PATCHES')
481
- print('=====================')
482
  _, hw, n_dim = global_features.shape
483
- h = w = int(hw ** 0.5)
484
-
485
 
486
  global_features = global_features.view(h, w, n_dim)
487
 
488
  global_features = torch.cat(
489
- [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
 
 
 
 
490
  )
491
 
492
  global_features = global_features.view(-1, n_dim)
493
 
494
- global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
 
 
495
 
496
  images_in_this_batch.append(global_local_features)
497
-
498
 
499
  # print(inputs_embeds.shape)
500
 
@@ -502,21 +556,27 @@ class DeepseekOCRModel(DeepseekV2Model):
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
 
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
 
 
 
506
 
507
  idx += 1
508
-
509
 
510
  return super(DeepseekOCRModel, self).forward(
511
- input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
512
- inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
513
- output_attentions=output_attentions, output_hidden_states=output_hidden_states,
514
- return_dict=return_dict
 
 
 
 
 
515
  )
516
-
517
 
518
- class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
519
 
 
520
  config_class = DeepseekOCRConfig
521
  # supports_gradient_checkpointing = True
522
 
@@ -536,7 +596,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
536
  def get_model(self):
537
  return self.model
538
 
539
-
540
  def forward(
541
  self,
542
  input_ids: torch.LongTensor = None,
@@ -552,17 +611,22 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
552
  images_seq_mask: Optional[torch.FloatTensor] = None,
553
  images_spatial_crop: Optional[torch.FloatTensor] = None,
554
  return_dict: Optional[bool] = None,
555
-
556
  ) -> Union[Tuple, CausalLMOutputWithPast]:
557
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
558
  output_hidden_states = (
559
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
560
  )
561
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
562
-
563
-
564
 
565
- outputs = self.model(
566
  input_ids=input_ids,
567
  past_key_values=past_key_values,
568
  attention_mask=attention_mask,
@@ -572,14 +636,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
572
  output_attentions=output_attentions,
573
  output_hidden_states=output_hidden_states,
574
  images=images,
575
- images_seq_mask = images_seq_mask,
576
- images_spatial_crop = images_spatial_crop,
577
- return_dict=return_dict
578
-
579
  )
580
 
581
-
582
-
583
  # print(transformer_outputs)
584
 
585
  hidden_states = outputs[0]
@@ -613,9 +674,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
613
  attentions=outputs.attentions,
614
  )
615
 
616
-
617
  def prepare_inputs_for_generation(
618
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
619
  ):
620
  # Omit tokens covered by past_key_values
621
  past_length = 0
@@ -632,7 +697,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
632
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
633
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
634
  # input)
635
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
 
 
 
636
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
637
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
638
  # input_ids based on the past_length.
@@ -668,7 +736,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
668
 
669
  # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
670
  # same goes for position ids. Could also help with continued generation.
671
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
 
 
 
 
672
 
673
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
674
  if inputs_embeds is not None and past_key_values is None:
@@ -688,45 +760,55 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
688
  }
689
  )
690
  return model_inputs
691
-
692
 
693
  def disable_torch_init(self):
694
  """
695
  Disable the redundant torch default initialization to accelerate model creation.
696
  """
697
  import torch
 
698
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
699
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
700
 
701
-
702
-
703
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
 
 
 
 
 
 
 
 
 
 
704
  self.disable_torch_init()
705
 
706
  os.makedirs(output_path, exist_ok=True)
707
- os.makedirs(f'{output_path}/images', exist_ok=True)
708
 
709
  if prompt and image_file:
710
  conversation = [
711
  {
712
  "role": "<|User|>",
713
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
714
- "content": f'{prompt}',
715
  # "content": "君不见黄河之水天上来的下一句是什么?",
716
  # "content": "<image>\nFree OCR. ",
717
  # "content": "<image>\nParse the figure. ",
718
  # "content": "<image>\nExtract the text in the image. ",
719
- "images": [f'{image_file}'],
720
  },
721
  {"role": "<|Assistant|>", "content": ""},
722
  ]
723
-
724
  elif prompt:
725
  conversation = [
726
  {
727
  "role": "<|User|>",
728
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
729
- "content": f'{prompt}',
730
  # "content": "君不见黄河之水天上来的下一句是什么?",
731
  # "content": "<image>\nFree OCR. ",
732
  # "content": "<image>\nParse the figure. ",
@@ -736,9 +818,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
736
  {"role": "<|Assistant|>", "content": ""},
737
  ]
738
  else:
739
- assert False, f'prompt is none!'
740
-
741
- prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
 
 
742
 
743
  patch_size = 16
744
  downsample_ratio = 4
@@ -749,15 +833,16 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
749
 
750
  image_draw = images[0].copy()
751
 
752
- w,h = image_draw.size
753
  # print(w, h)
754
  ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
755
-
756
 
757
- image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
 
 
758
  images_seq_mask = []
759
 
760
- image_token = '<image>'
761
  image_token_id = 128815
762
  text_splits = prompt.split(image_token)
763
 
@@ -765,13 +850,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
765
  tokenized_str = []
766
  images_spatial_crop = []
767
  for text_sep, image in zip(text_splits, images):
768
-
769
  tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
770
  tokenized_str += tokenized_sep
771
  images_seq_mask += [False] * len(tokenized_sep)
772
 
773
  if crop_mode:
774
-
775
  if image.size[0] <= 640 and image.size[1] <= 640:
776
  crop_ratio = [1, 1]
777
 
@@ -782,23 +865,22 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
782
  else:
783
  # best_width, best_height = self.image_size, self.image_size
784
  crop_ratio = [1, 1]
785
-
786
  """process the global view"""
787
  # image = image.resize((base_size, base_size))
788
- global_view = ImageOps.pad(image, (base_size, base_size),
789
- color=tuple(int(x * 255) for x in image_transform.mean))
790
-
 
 
 
791
  if base_size == 1024:
792
  valid_img_tokens += int(256 * ratio)
793
  elif base_size == 1280:
794
  valid_img_tokens += int(400 * ratio)
795
  # elif base_size == 640:
796
  # valid_img_tokens += int(100 * ratio)
797
-
798
-
799
 
800
-
801
-
802
  images_list.append(image_transform(global_view).to(torch.bfloat16))
803
 
804
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
@@ -806,31 +888,34 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
806
  width_crop_num, height_crop_num = crop_ratio
807
 
808
  images_spatial_crop.append([width_crop_num, height_crop_num])
809
-
810
-
811
  if width_crop_num > 1 or height_crop_num > 1:
812
  """process the local views"""
813
-
814
  for i in range(len(images_crop_raw)):
815
- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
816
-
 
 
817
  if image_size == 640:
818
  valid_img_tokens += len(images_crop_list) * 100
819
 
820
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
821
- num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
822
-
823
-
824
 
825
  """add image tokens"""
826
 
827
-
828
-
829
- tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
830
  tokenized_image += [image_token_id]
831
  if width_crop_num > 1 or height_crop_num > 1:
832
- tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
833
- num_queries * height_crop_num)
 
 
834
  tokenized_str += tokenized_image
835
  images_seq_mask += [True] * len(tokenized_image)
836
  # num_image_tokens.append(len(tokenized_image))
@@ -841,11 +926,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
841
 
842
  """process the global view"""
843
  if image_size <= 640:
844
- print('directly resize')
845
  image = image.resize((image_size, image_size))
846
  # else:
847
- global_view = ImageOps.pad(image, (image_size, image_size),
848
- color=tuple(int(x * 255) for x in image_transform.mean))
 
 
 
849
  images_list.append(image_transform(global_view).to(torch.bfloat16))
850
 
851
  if base_size == 1024:
@@ -861,18 +949,18 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
861
 
862
  images_spatial_crop.append([width_crop_num, height_crop_num])
863
 
864
-
865
  """add image tokens"""
866
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
867
 
868
- tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
 
 
869
  tokenized_image += [image_token_id]
870
  # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
871
  # num_queries * height_crop_num)
872
  tokenized_str += tokenized_image
873
  images_seq_mask += [True] * len(tokenized_image)
874
  # num_image_tokens.append(len(tokenized_image))
875
-
876
 
877
  """process the last text split"""
878
  tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
@@ -881,19 +969,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
881
 
882
  """add the bos tokens"""
883
  bos_id = 0
884
- tokenized_str = [bos_id] + tokenized_str
885
  images_seq_mask = [False] + images_seq_mask
886
 
887
-
888
-
889
  input_ids = torch.LongTensor(tokenized_str)
890
 
891
-
892
-
893
-
894
  images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
895
 
896
-
897
  if len(images_list) == 0:
898
  images_ori = torch.zeros((1, 3, image_size, image_size))
899
  images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
@@ -907,131 +989,157 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
907
  else:
908
  images_crop = torch.zeros((1, 3, base_size, base_size))
909
 
910
-
911
-
912
  if not eval_mode:
913
- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
- with torch.autocast("cuda", dtype=torch.bfloat16):
 
 
915
  with torch.no_grad():
916
  output_ids = self.generate(
917
- input_ids.unsqueeze(0).cuda(),
918
- images=[(images_crop.cuda(), images_ori.cuda())],
919
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
920
- images_spatial_crop = images_spatial_crop,
 
 
921
  # do_sample=False,
922
  # num_beams = 1,
923
  temperature=0.0,
924
  eos_token_id=tokenizer.eos_token_id,
925
  streamer=streamer,
926
  max_new_tokens=8192,
927
- no_repeat_ngram_size = 20,
928
- use_cache = True
929
- )
930
 
931
  else:
932
- with torch.autocast("cuda", dtype=torch.bfloat16):
933
  with torch.no_grad():
934
  output_ids = self.generate(
935
- input_ids.unsqueeze(0).cuda(),
936
- images=[(images_crop.cuda(), images_ori.cuda())],
937
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
938
- images_spatial_crop = images_spatial_crop,
 
 
939
  # do_sample=False,
940
  # num_beams = 1,
941
  temperature=0.0,
942
  eos_token_id=tokenizer.eos_token_id,
943
  max_new_tokens=8192,
944
- no_repeat_ngram_size = 35,
945
- use_cache = True
946
- )
947
-
948
-
949
- if '<image>' in conversation[0]['content'] and eval_mode:
950
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
951
- stop_str = '<|end▁of▁sentence|>'
952
- if outputs.endswith(stop_str):
953
- outputs = outputs[:-len(stop_str)]
954
- # re_match
955
- outputs = outputs.strip()
956
-
957
- return outputs
958
-
959
- if '<image>' in conversation[0]['content'] and test_compress:
960
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
961
- pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
962
- print('='*50)
963
- print('image size: ', (w, h))
964
- print('valid image tokens: ', int(valid_img_tokens))
965
- print('output texts tokens (valid): ', pure_texts_outputs_token_length)
966
- print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
967
- print('='*50)
968
-
969
-
970
- if '<image>' in conversation[0]['content'] and save_results:
971
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
972
- stop_str = '<|end▁of▁sentence|>'
973
-
974
- print('='*15 + 'save results:' + '='*15)
975
-
 
 
 
 
 
 
 
 
 
976
  # # # # conv.messages[-1][-1] = outputs
977
  if outputs.endswith(stop_str):
978
- outputs = outputs[:-len(stop_str)]
979
  outputs = outputs.strip()
980
 
981
  matches_ref, matches_images, mathes_other = re_match(outputs)
982
  # print(matches_ref)
983
  result = process_image_with_refs(image_draw, matches_ref, output_path)
984
 
985
-
986
  for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
987
- outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
988
-
989
- for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
990
- outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
991
 
 
 
 
 
 
 
992
 
993
  # if 'structural formula' in conversation[0]['content']:
994
  # outputs = '<smiles>' + outputs + '</smiles>'
995
- with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
996
  afile.write(outputs)
997
 
998
- if 'line_type' in outputs:
999
  import matplotlib.pyplot as plt
1000
- lines = eval(outputs)['Line']['line']
1001
 
1002
- line_type = eval(outputs)['Line']['line_type']
 
 
1003
  # print(lines)
1004
 
1005
- endpoints = eval(outputs)['Line']['line_endpoint']
1006
 
1007
- fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1008
  ax.set_xlim(-15, 15)
1009
  ax.set_ylim(-15, 15)
1010
 
1011
  for idx, line in enumerate(lines):
1012
  try:
1013
- p0 = eval(line.split(' -- ')[0])
1014
- p1 = eval(line.split(' -- ')[-1])
1015
 
1016
- if line_type[idx] == '--':
1017
- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
 
 
1018
  else:
1019
- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
 
 
1020
 
1021
- ax.scatter(p0[0], p0[1], s=5, color = 'k')
1022
- ax.scatter(p1[0], p1[1], s=5, color = 'k')
1023
  except:
1024
  pass
1025
 
1026
  for endpoint in endpoints:
1027
-
1028
- label = endpoint.split(': ')[0]
1029
- (x, y) = eval(endpoint.split(': ')[1])
1030
- ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1031
- fontsize=5, fontweight='light')
1032
-
1033
-
1034
- plt.savefig(f'{output_path}/geo.jpg')
 
 
 
 
1035
  plt.close()
1036
 
1037
  result.save(f"{output_path}/result_with_boxes.jpg")
 
1
  from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2
  from .configuration_deepseek_v2 import DeepseekV2Config
3
+ from transformers.modeling_outputs import (
4
+ BaseModelOutputWithPast,
5
+ CausalLMOutputWithPast,
6
+ )
7
  from typing import List, Optional, Tuple, Union
8
  from transformers.cache_utils import Cache
9
  import requests
 
28
 
29
 
30
  def load_image(image_path):
 
31
  try:
32
  image = Image.open(image_path)
33
+
34
  corrected_image = ImageOps.exif_transpose(image)
35
+
36
  return corrected_image
37
+
38
  except Exception as e:
39
  print(f"error: {e}")
40
  try:
 
44
 
45
 
46
  def re_match(text):
47
+ pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
48
  matches = re.findall(pattern, text, re.DOTALL)
49
 
50
  # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
 
53
  mathes_image = []
54
  mathes_other = []
55
  for a_match in matches:
56
+ if "<|ref|>image<|/ref|>" in a_match[0]:
57
  mathes_image.append(a_match[0])
58
  else:
59
  mathes_other.append(a_match[0])
 
61
 
62
 
63
  def extract_coordinates_and_label(ref_text, image_width, image_height):
 
64
  try:
65
  label_type = ref_text[1]
66
  cor_list = eval(ref_text[2])
 
72
 
73
 
74
  def draw_bounding_boxes(image, refs, ouput_path):
 
75
  image_width, image_height = image.size
76
+
77
  img_draw = image.copy()
78
  draw = ImageDraw.Draw(img_draw)
79
 
80
+ overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
81
  draw2 = ImageDraw.Draw(overlay)
82
+
83
  # try:
84
  # except IOError:
85
  # try:
86
+ # font = ImageFont.truetype("DejaVuSans.ttf", 20)
87
  # except IOError:
88
  font = ImageFont.load_default()
89
 
90
  img_idx = 0
91
+
92
  for i, ref in enumerate(refs):
93
  try:
94
  result = extract_coordinates_and_label(ref, image_width, image_height)
95
  if result:
96
  label_type, points_list = result
 
 
97
 
98
+ color = (
99
+ np.random.randint(0, 200),
100
+ np.random.randint(0, 200),
101
+ np.random.randint(0, 255),
102
+ )
103
+
104
+ color_a = color + (20,)
105
  for points in points_list:
106
  x1, y1, x2, y2 = points
107
 
 
111
  x2 = int(x2 / 999 * image_width)
112
  y2 = int(y2 / 999 * image_height)
113
 
114
+ if label_type == "image":
115
  try:
116
  cropped = image.crop((x1, y1, x2, y2))
117
  cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
 
119
  print(e)
120
  pass
121
  img_idx += 1
122
+
123
  try:
124
+ if label_type == "title":
125
  draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
126
+ draw2.rectangle(
127
+ [x1, y1, x2, y2],
128
+ fill=color_a,
129
+ outline=(0, 0, 0, 0),
130
+ width=1,
131
+ )
132
  else:
133
  draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
134
+ draw2.rectangle(
135
+ [x1, y1, x2, y2],
136
+ fill=color_a,
137
+ outline=(0, 0, 0, 0),
138
+ width=1,
139
+ )
140
  text_x = x1
141
  text_y = max(0, y1 - 15)
142
+
 
143
  text_bbox = draw.textbbox((0, 0), label_type, font=font)
144
  text_width = text_bbox[2] - text_bbox[0]
145
  text_height = text_bbox[3] - text_bbox[1]
146
+ draw.rectangle(
147
+ [text_x, text_y, text_x + text_width, text_y + text_height],
148
+ fill=(255, 255, 255, 30),
149
+ )
150
+
151
  draw.text((text_x, text_y), label_type, font=font, fill=color)
152
  except:
153
  pass
 
158
 
159
 
160
  def process_image_with_refs(image, ref_texts, output_path):
 
161
  result_image = draw_bounding_boxes(image, ref_texts, output_path)
 
 
 
 
162
 
163
+ return result_image
164
 
165
 
166
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
167
+ best_ratio_diff = float("inf")
168
  best_ratio = (1, 1)
169
  area = width * height
170
  for ratio in target_ratios:
 
180
  return best_ratio
181
 
182
 
183
+ def dynamic_preprocess(
184
+ image, min_num=2, max_num=9, image_size=640, use_thumbnail=False
185
+ ):
186
  orig_width, orig_height = image.size
187
  aspect_ratio = orig_width / orig_height
188
 
189
  # calculate the existing image aspect ratio
190
  target_ratios = set(
191
+ (i, j)
192
+ for n in range(min_num, max_num + 1)
193
+ for i in range(1, n + 1)
194
+ for j in range(1, n + 1)
195
+ if i * j <= max_num and i * j >= min_num
196
+ )
197
  # print(target_ratios)
198
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
199
 
200
  # find the closest aspect ratio to the target
201
  target_aspect_ratio = find_closest_aspect_ratio(
202
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
203
+ )
204
 
205
  # print(target_aspect_ratio)
206
  # calculate the target width and height
 
216
  (i % (target_width // image_size)) * image_size,
217
  (i // (target_width // image_size)) * image_size,
218
  ((i % (target_width // image_size)) + 1) * image_size,
219
+ ((i // (target_width // image_size)) + 1) * image_size,
220
  )
221
  # split the image
222
  split_img = resized_img.crop(box)
 
228
  return processed_images, target_aspect_ratio
229
 
230
 
 
231
  def normalize_transform(mean, std):
232
  if mean is None and std is None:
233
  transform = None
234
  elif mean is None and std is not None:
235
+ mean = [0.0] * len(std)
236
  transform = transforms.Normalize(mean=mean, std=std)
237
  elif mean is not None and std is None:
238
+ std = [1.0] * len(mean)
239
  transform = transforms.Normalize(mean=mean, std=std)
240
  else:
241
  transform = transforms.Normalize(mean=mean, std=std)
 
243
  return transform
244
 
245
 
 
246
  def format_messages(
247
+ conversations: List[Dict[str, str]],
248
+ sft_format: str = "deepseek",
249
+ system_prompt: str = "",
250
  ):
251
  """
252
  Applies the SFT template to conversation.
 
280
 
281
  return t
282
 
283
+
284
  def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
285
  """
286
 
 
311
  # print(image_path)
312
  # print('----------------')
313
  # exit()
314
+
315
  # pil_img = Image.open(image_path)
316
  pil_img = load_image(image_path)
317
  pil_img = pil_img.convert("RGB")
 
321
 
322
 
323
  class BaseTransform(ABC):
 
324
  def set_rng(self, *args, **kwargs):
325
  pass
326
 
 
334
 
335
  class BasicImageTransform(BaseTransform):
336
  def __init__(
337
+ self,
338
  mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
339
  std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
340
+ normalize: bool = True,
341
  ):
342
  self.mean = mean
343
  self.std = std
344
+
345
+ transform_pipelines = [transforms.ToTensor()]
 
 
346
 
347
  normalize = normalize_transform(mean, std) if normalize else nn.Identity()
348
  if normalize is not None:
349
  transform_pipelines.append(normalize)
350
 
351
  self.transform = transforms.Compose(transform_pipelines)
352
+
353
  def __call__(self, x):
354
  x = self.transform(x)
355
  return x
356
 
357
+
358
  class NoEOSTextStreamer(TextStreamer):
359
  def on_finalized_text(self, text: str, stream_end: bool = False):
360
+ eos_text = self.tokenizer.decode(
361
+ [self.tokenizer.eos_token_id], skip_special_tokens=False
362
+ )
363
  text = text.replace(eos_text, "\n")
364
  print(text, flush=True, end="")
365
 
 
367
  class DeepseekOCRConfig(DeepseekV2Config):
368
  model_type = "DeepseekOCR"
369
 
370
+
371
  class DeepseekOCRModel(DeepseekV2Model):
372
  config_class = DeepseekOCRConfig
373
 
 
378
  self.vision_model = build_clip_l()
379
  # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
380
  n_embed = 1280
381
+ self.projector = MlpProjector(
382
+ Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)
383
+ )
384
  embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
385
  self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
386
  self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
387
 
 
 
 
388
  def forward(
389
  self,
390
  input_ids: torch.LongTensor = None,
 
400
  images_spatial_crop: Optional[torch.FloatTensor] = None,
401
  return_dict: Optional[bool] = None,
402
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
 
 
 
403
  if inputs_embeds is None:
404
  # inputs_embeds = self.embed_tokens(input_ids)
405
  inputs_embeds = self.get_input_embeddings()(input_ids)
406
 
407
+ sam_model = getattr(self, "sam_model", None)
 
 
408
  # sam_model = self.sam_model
409
+ vision_model = getattr(self, "vision_model", None)
 
 
 
 
410
 
411
+ if (
412
+ sam_model is not None
413
+ and (input_ids.shape[1] != 1 or self.training)
414
+ and torch.sum(images[0][1]).item() != 0
415
+ ):
416
  idx = 0
417
+
418
  # sam_model = torch.jit.script(sam_model)
419
+
420
  # start_time = time.time()
421
  for image, crop_shape in zip(images, images_spatial_crop):
422
  images_in_this_batch = []
 
425
  image_ori = image[1]
426
 
427
  with torch.no_grad():
428
+ # with torch.inference_mode():
429
+
430
  if torch.sum(patches).item() != 0:
431
  # P, C, H, W = patches.shape
432
  crop_flag = 1
433
  local_features_1 = sam_model(patches)
434
 
435
+ local_features_2 = vision_model(patches, local_features_1)
436
  # vit_time = time.time()
437
+ local_features = torch.cat(
438
+ (
439
+ local_features_2[:, 1:],
440
+ local_features_1.flatten(2).permute(0, 2, 1),
441
+ ),
442
+ dim=-1,
443
+ )
444
  local_features = self.projector(local_features)
445
 
 
446
  global_features_1 = sam_model(image_ori)
447
+ global_features_2 = vision_model(image_ori, global_features_1)
448
+ global_features = torch.cat(
449
+ (
450
+ global_features_2[:, 1:],
451
+ global_features_1.flatten(2).permute(0, 2, 1),
452
+ ),
453
+ dim=-1,
454
+ )
455
  global_features = self.projector(global_features)
456
 
457
+ print("=====================")
458
+ print("BASE: ", global_features.shape)
459
+ print("PATCHES: ", local_features.shape)
460
+ print("=====================")
461
 
462
  _, hw, n_dim = global_features.shape
463
+ h = w = int(hw**0.5)
464
 
465
  _2, hw2, n_dim2 = local_features.shape
466
+ h2 = w2 = int(hw2**0.5)
467
 
468
  width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
469
 
470
  global_features = global_features.view(h, w, n_dim)
471
 
472
  global_features = torch.cat(
473
+ [
474
+ global_features,
475
+ self.image_newline[None, None, :].expand(h, 1, n_dim),
476
+ ],
477
+ dim=1,
478
  )
479
 
480
  global_features = global_features.view(-1, n_dim)
481
 
482
+ local_features = (
483
+ local_features.view(
484
+ height_crop_num, width_crop_num, h2, w2, n_dim2
485
+ )
486
+ .permute(0, 2, 1, 3, 4)
487
+ .reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
488
+ )
489
  local_features = torch.cat(
490
+ [
491
+ local_features,
492
+ self.image_newline[None, None, :].expand(
493
+ height_crop_num * h2, 1, n_dim2
494
+ ),
495
+ ],
496
+ dim=1,
497
  )
498
  local_features = local_features.view(-1, n_dim2)
499
 
500
+ global_local_features = torch.cat(
501
+ [
502
+ local_features,
503
+ global_features,
504
+ self.view_seperator[None, :],
505
+ ],
506
+ dim=0,
507
+ )
508
 
509
  # end_time = time.time()
510
 
 
513
  # print('all: ', end_time - start_time)
514
 
515
  # exit()
516
+
517
  else:
518
  global_features_1 = sam_model(image_ori)
519
+ global_features_2 = vision_model(image_ori, global_features_1)
520
+ global_features = torch.cat(
521
+ (
522
+ global_features_2[:, 1:],
523
+ global_features_1.flatten(2).permute(0, 2, 1),
524
+ ),
525
+ dim=-1,
526
+ )
527
  global_features = self.projector(global_features)
528
+ print("=====================")
529
+ print("BASE: ", global_features.shape)
530
+ print("NO PATCHES")
531
+ print("=====================")
532
  _, hw, n_dim = global_features.shape
533
+ h = w = int(hw**0.5)
 
534
 
535
  global_features = global_features.view(h, w, n_dim)
536
 
537
  global_features = torch.cat(
538
+ [
539
+ global_features,
540
+ self.image_newline[None, None, :].expand(h, 1, n_dim),
541
+ ],
542
+ dim=1,
543
  )
544
 
545
  global_features = global_features.view(-1, n_dim)
546
 
547
+ global_local_features = torch.cat(
548
+ [global_features, self.view_seperator[None, :]], dim=0
549
+ )
550
 
551
  images_in_this_batch.append(global_local_features)
 
552
 
553
  # print(inputs_embeds.shape)
554
 
 
556
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
557
  # exit()
558
 
559
+ inputs_embeds[idx].masked_scatter_(
560
+ images_seq_mask[idx].unsqueeze(-1).to(self.device),
561
+ images_in_this_batch,
562
+ )
563
 
564
  idx += 1
 
565
 
566
  return super(DeepseekOCRModel, self).forward(
567
+ input_ids=None,
568
+ attention_mask=attention_mask,
569
+ past_key_values=past_key_values,
570
+ inputs_embeds=inputs_embeds,
571
+ use_cache=use_cache,
572
+ position_ids=position_ids,
573
+ output_attentions=output_attentions,
574
+ output_hidden_states=output_hidden_states,
575
+ return_dict=return_dict,
576
  )
 
577
 
 
578
 
579
+ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
580
  config_class = DeepseekOCRConfig
581
  # supports_gradient_checkpointing = True
582
 
 
596
  def get_model(self):
597
  return self.model
598
 
 
599
  def forward(
600
  self,
601
  input_ids: torch.LongTensor = None,
 
611
  images_seq_mask: Optional[torch.FloatTensor] = None,
612
  images_spatial_crop: Optional[torch.FloatTensor] = None,
613
  return_dict: Optional[bool] = None,
 
614
  ) -> Union[Tuple, CausalLMOutputWithPast]:
615
+ output_attentions = (
616
+ output_attentions
617
+ if output_attentions is not None
618
+ else self.config.output_attentions
619
+ )
620
  output_hidden_states = (
621
+ output_hidden_states
622
+ if output_hidden_states is not None
623
+ else self.config.output_hidden_states
624
+ )
625
+ return_dict = (
626
+ return_dict if return_dict is not None else self.config.use_return_dict
627
  )
 
 
 
628
 
629
+ outputs = self.model(
630
  input_ids=input_ids,
631
  past_key_values=past_key_values,
632
  attention_mask=attention_mask,
 
636
  output_attentions=output_attentions,
637
  output_hidden_states=output_hidden_states,
638
  images=images,
639
+ images_seq_mask=images_seq_mask,
640
+ images_spatial_crop=images_spatial_crop,
641
+ return_dict=return_dict,
 
642
  )
643
 
 
 
644
  # print(transformer_outputs)
645
 
646
  hidden_states = outputs[0]
 
674
  attentions=outputs.attentions,
675
  )
676
 
 
677
  def prepare_inputs_for_generation(
678
+ self,
679
+ input_ids,
680
+ past_key_values=None,
681
+ attention_mask=None,
682
+ inputs_embeds=None,
683
+ **kwargs,
684
  ):
685
  # Omit tokens covered by past_key_values
686
  past_length = 0
 
697
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
698
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
699
  # input)
700
+ if (
701
+ attention_mask is not None
702
+ and attention_mask.shape[1] > input_ids.shape[1]
703
+ ):
704
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
705
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
706
  # input_ids based on the past_length.
 
736
 
737
  # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
738
  # same goes for position ids. Could also help with continued generation.
739
+ cache_position = torch.arange(
740
+ past_length,
741
+ past_length + position_ids.shape[-1],
742
+ device=position_ids.device,
743
+ )
744
 
745
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
746
  if inputs_embeds is not None and past_key_values is None:
 
760
  }
761
  )
762
  return model_inputs
 
763
 
764
  def disable_torch_init(self):
765
  """
766
  Disable the redundant torch default initialization to accelerate model creation.
767
  """
768
  import torch
769
+
770
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
771
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
772
 
773
+ def infer(
774
+ self,
775
+ tokenizer,
776
+ prompt="",
777
+ image_file="",
778
+ output_path="",
779
+ base_size=1024,
780
+ image_size=640,
781
+ crop_mode=True,
782
+ test_compress=False,
783
+ save_results=False,
784
+ eval_mode=False,
785
+ ):
786
  self.disable_torch_init()
787
 
788
  os.makedirs(output_path, exist_ok=True)
789
+ os.makedirs(f"{output_path}/images", exist_ok=True)
790
 
791
  if prompt and image_file:
792
  conversation = [
793
  {
794
  "role": "<|User|>",
795
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
796
+ "content": f"{prompt}",
797
  # "content": "君不见黄河之水天上来的下一句是什么?",
798
  # "content": "<image>\nFree OCR. ",
799
  # "content": "<image>\nParse the figure. ",
800
  # "content": "<image>\nExtract the text in the image. ",
801
+ "images": [f"{image_file}"],
802
  },
803
  {"role": "<|Assistant|>", "content": ""},
804
  ]
805
+
806
  elif prompt:
807
  conversation = [
808
  {
809
  "role": "<|User|>",
810
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
811
+ "content": f"{prompt}",
812
  # "content": "君不见黄河之水天上来的下一句是什么?",
813
  # "content": "<image>\nFree OCR. ",
814
  # "content": "<image>\nParse the figure. ",
 
818
  {"role": "<|Assistant|>", "content": ""},
819
  ]
820
  else:
821
+ assert False, f"prompt is none!"
822
+
823
+ prompt = format_messages(
824
+ conversations=conversation, sft_format="plain", system_prompt=""
825
+ )
826
 
827
  patch_size = 16
828
  downsample_ratio = 4
 
833
 
834
  image_draw = images[0].copy()
835
 
836
+ w, h = image_draw.size
837
  # print(w, h)
838
  ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
 
839
 
840
+ image_transform = BasicImageTransform(
841
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True
842
+ )
843
  images_seq_mask = []
844
 
845
+ image_token = "<image>"
846
  image_token_id = 128815
847
  text_splits = prompt.split(image_token)
848
 
 
850
  tokenized_str = []
851
  images_spatial_crop = []
852
  for text_sep, image in zip(text_splits, images):
 
853
  tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
854
  tokenized_str += tokenized_sep
855
  images_seq_mask += [False] * len(tokenized_sep)
856
 
857
  if crop_mode:
 
858
  if image.size[0] <= 640 and image.size[1] <= 640:
859
  crop_ratio = [1, 1]
860
 
 
865
  else:
866
  # best_width, best_height = self.image_size, self.image_size
867
  crop_ratio = [1, 1]
868
+
869
  """process the global view"""
870
  # image = image.resize((base_size, base_size))
871
+ global_view = ImageOps.pad(
872
+ image,
873
+ (base_size, base_size),
874
+ color=tuple(int(x * 255) for x in image_transform.mean),
875
+ )
876
+
877
  if base_size == 1024:
878
  valid_img_tokens += int(256 * ratio)
879
  elif base_size == 1280:
880
  valid_img_tokens += int(400 * ratio)
881
  # elif base_size == 640:
882
  # valid_img_tokens += int(100 * ratio)
 
 
883
 
 
 
884
  images_list.append(image_transform(global_view).to(torch.bfloat16))
885
 
886
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
 
888
  width_crop_num, height_crop_num = crop_ratio
889
 
890
  images_spatial_crop.append([width_crop_num, height_crop_num])
891
+
 
892
  if width_crop_num > 1 or height_crop_num > 1:
893
  """process the local views"""
894
+
895
  for i in range(len(images_crop_raw)):
896
+ images_crop_list.append(
897
+ image_transform(images_crop_raw[i]).to(torch.bfloat16)
898
+ )
899
+
900
  if image_size == 640:
901
  valid_img_tokens += len(images_crop_list) * 100
902
 
903
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
904
+ num_queries_base = math.ceil(
905
+ (base_size // patch_size) / downsample_ratio
906
+ )
907
 
908
  """add image tokens"""
909
 
910
+ tokenized_image = (
911
+ [image_token_id] * num_queries_base + [image_token_id]
912
+ ) * num_queries_base
913
  tokenized_image += [image_token_id]
914
  if width_crop_num > 1 or height_crop_num > 1:
915
+ tokenized_image += (
916
+ [image_token_id] * (num_queries * width_crop_num)
917
+ + [image_token_id]
918
+ ) * (num_queries * height_crop_num)
919
  tokenized_str += tokenized_image
920
  images_seq_mask += [True] * len(tokenized_image)
921
  # num_image_tokens.append(len(tokenized_image))
 
926
 
927
  """process the global view"""
928
  if image_size <= 640:
929
+ print("directly resize")
930
  image = image.resize((image_size, image_size))
931
  # else:
932
+ global_view = ImageOps.pad(
933
+ image,
934
+ (image_size, image_size),
935
+ color=tuple(int(x * 255) for x in image_transform.mean),
936
+ )
937
  images_list.append(image_transform(global_view).to(torch.bfloat16))
938
 
939
  if base_size == 1024:
 
949
 
950
  images_spatial_crop.append([width_crop_num, height_crop_num])
951
 
 
952
  """add image tokens"""
953
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
954
 
955
+ tokenized_image = (
956
+ [image_token_id] * num_queries + [image_token_id]
957
+ ) * num_queries
958
  tokenized_image += [image_token_id]
959
  # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
960
  # num_queries * height_crop_num)
961
  tokenized_str += tokenized_image
962
  images_seq_mask += [True] * len(tokenized_image)
963
  # num_image_tokens.append(len(tokenized_image))
 
964
 
965
  """process the last text split"""
966
  tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
 
969
 
970
  """add the bos tokens"""
971
  bos_id = 0
972
+ tokenized_str = [bos_id] + tokenized_str
973
  images_seq_mask = [False] + images_seq_mask
974
 
 
 
975
  input_ids = torch.LongTensor(tokenized_str)
976
 
 
 
 
977
  images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
978
 
 
979
  if len(images_list) == 0:
980
  images_ori = torch.zeros((1, 3, image_size, image_size))
981
  images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
 
989
  else:
990
  images_crop = torch.zeros((1, 3, base_size, base_size))
991
 
 
 
992
  if not eval_mode:
993
+ streamer = NoEOSTextStreamer(
994
+ tokenizer, skip_prompt=True, skip_special_tokens=False
995
+ )
996
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
997
  with torch.no_grad():
998
  output_ids = self.generate(
999
+ input_ids.unsqueeze(0).to(self.device),
1000
+ images=[
1001
+ (images_crop.to(self.device), images_ori.to(self.device))
1002
+ ],
1003
+ images_seq_mask=images_seq_mask.unsqueeze(0).to(self.device),
1004
+ images_spatial_crop=images_spatial_crop,
1005
  # do_sample=False,
1006
  # num_beams = 1,
1007
  temperature=0.0,
1008
  eos_token_id=tokenizer.eos_token_id,
1009
  streamer=streamer,
1010
  max_new_tokens=8192,
1011
+ no_repeat_ngram_size=20,
1012
+ use_cache=True,
1013
+ )
1014
 
1015
  else:
1016
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
1017
  with torch.no_grad():
1018
  output_ids = self.generate(
1019
+ input_ids.unsqueeze(0).to(self.device),
1020
+ images=[
1021
+ (images_crop.to(self.device), images_ori.to(self.device))
1022
+ ],
1023
+ images_seq_mask=images_seq_mask.unsqueeze(0).to(self.device),
1024
+ images_spatial_crop=images_spatial_crop,
1025
  # do_sample=False,
1026
  # num_beams = 1,
1027
  temperature=0.0,
1028
  eos_token_id=tokenizer.eos_token_id,
1029
  max_new_tokens=8192,
1030
+ no_repeat_ngram_size=35,
1031
+ use_cache=True,
1032
+ )
1033
+
1034
+ if "<image>" in conversation[0]["content"] and eval_mode:
1035
+ outputs = tokenizer.decode(
1036
+ output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
1037
+ )
1038
+ stop_str = "<|end▁of▁sentence|>"
1039
+ if outputs.endswith(stop_str):
1040
+ outputs = outputs[: -len(stop_str)]
1041
+ # re_match
1042
+ outputs = outputs.strip()
1043
+
1044
+ return outputs
1045
+
1046
+ if "<image>" in conversation[0]["content"] and test_compress:
1047
+ outputs = tokenizer.decode(
1048
+ output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
1049
+ )
1050
+ pure_texts_outputs_token_length = len(
1051
+ text_encode(tokenizer, outputs, bos=False, eos=False)
1052
+ )
1053
+ print("=" * 50)
1054
+ print("image size: ", (w, h))
1055
+ print("valid image tokens: ", int(valid_img_tokens))
1056
+ print("output texts tokens (valid): ", pure_texts_outputs_token_length)
1057
+ print(
1058
+ "compression ratio: ",
1059
+ round(pure_texts_outputs_token_length / valid_img_tokens, 2),
1060
+ )
1061
+ print("=" * 50)
1062
+
1063
+ if "<image>" in conversation[0]["content"] and save_results:
1064
+ outputs = tokenizer.decode(
1065
+ output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
1066
+ )
1067
+ stop_str = "<|end▁of▁sentence|>"
1068
+
1069
+ print("=" * 15 + "save results:" + "=" * 15)
1070
+
1071
  # # # # conv.messages[-1][-1] = outputs
1072
  if outputs.endswith(stop_str):
1073
+ outputs = outputs[: -len(stop_str)]
1074
  outputs = outputs.strip()
1075
 
1076
  matches_ref, matches_images, mathes_other = re_match(outputs)
1077
  # print(matches_ref)
1078
  result = process_image_with_refs(image_draw, matches_ref, output_path)
1079
 
 
1080
  for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
1081
+ outputs = outputs.replace(
1082
+ a_match_image, "![](images/" + str(idx) + ".jpg)\n"
1083
+ )
 
1084
 
1085
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
1086
+ outputs = (
1087
+ outputs.replace(a_match_other, "")
1088
+ .replace("\\coloneqq", ":=")
1089
+ .replace("\\eqqcolon", "=:")
1090
+ )
1091
 
1092
  # if 'structural formula' in conversation[0]['content']:
1093
  # outputs = '<smiles>' + outputs + '</smiles>'
1094
+ with open(f"{output_path}/result.mmd", "w", encoding="utf-8") as afile:
1095
  afile.write(outputs)
1096
 
1097
+ if "line_type" in outputs:
1098
  import matplotlib.pyplot as plt
 
1099
 
1100
+ lines = eval(outputs)["Line"]["line"]
1101
+
1102
+ line_type = eval(outputs)["Line"]["line_type"]
1103
  # print(lines)
1104
 
1105
+ endpoints = eval(outputs)["Line"]["line_endpoint"]
1106
 
1107
+ fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
1108
  ax.set_xlim(-15, 15)
1109
  ax.set_ylim(-15, 15)
1110
 
1111
  for idx, line in enumerate(lines):
1112
  try:
1113
+ p0 = eval(line.split(" -- ")[0])
1114
+ p1 = eval(line.split(" -- ")[-1])
1115
 
1116
+ if line_type[idx] == "--":
1117
+ ax.plot(
1118
+ [p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
1119
+ )
1120
  else:
1121
+ ax.plot(
1122
+ [p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
1123
+ )
1124
 
1125
+ ax.scatter(p0[0], p0[1], s=5, color="k")
1126
+ ax.scatter(p1[0], p1[1], s=5, color="k")
1127
  except:
1128
  pass
1129
 
1130
  for endpoint in endpoints:
1131
+ label = endpoint.split(": ")[0]
1132
+ (x, y) = eval(endpoint.split(": ")[1])
1133
+ ax.annotate(
1134
+ label,
1135
+ (x, y),
1136
+ xytext=(1, 1),
1137
+ textcoords="offset points",
1138
+ fontsize=5,
1139
+ fontweight="light",
1140
+ )
1141
+
1142
+ plt.savefig(f"{output_path}/geo.jpg")
1143
  plt.close()
1144
 
1145
  result.save(f"{output_path}/result_with_boxes.jpg")