aamirshakir commited on
Commit
155d104
·
verified ·
1 Parent(s): 5951289

Make infer general, so it runs on non cuda devices

Browse files

Infer hard codes .cuda(), which does not allow it to run on CPU or other devices, instead use self.device so it can run on CPU, MPS etc.

Files changed (1) hide show
  1. modeling_deepseekocr.py +365 -257
modeling_deepseekocr.py CHANGED
@@ -1,6 +1,9 @@
1
  from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2
  from .configuration_deepseek_v2 import DeepseekV2Config
3
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
4
  from typing import List, Optional, Tuple, Union
5
  from transformers.cache_utils import Cache
6
  import requests
@@ -25,14 +28,13 @@ import time
25
 
26
 
27
  def load_image(image_path):
28
-
29
  try:
30
  image = Image.open(image_path)
31
-
32
  corrected_image = ImageOps.exif_transpose(image)
33
-
34
  return corrected_image
35
-
36
  except Exception as e:
37
  print(f"error: {e}")
38
  try:
@@ -42,7 +44,7 @@ def load_image(image_path):
42
 
43
 
44
  def re_match(text):
45
- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
46
  matches = re.findall(pattern, text, re.DOTALL)
47
 
48
  # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
@@ -51,7 +53,7 @@ def re_match(text):
51
  mathes_image = []
52
  mathes_other = []
53
  for a_match in matches:
54
- if '<|ref|>image<|/ref|>' in a_match[0]:
55
  mathes_image.append(a_match[0])
56
  else:
57
  mathes_other.append(a_match[0])
@@ -59,7 +61,6 @@ def re_match(text):
59
 
60
 
61
  def extract_coordinates_and_label(ref_text, image_width, image_height):
62
-
63
  try:
64
  label_type = ref_text[1]
65
  cor_list = eval(ref_text[2])
@@ -71,33 +72,36 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
71
 
72
 
73
  def draw_bounding_boxes(image, refs, ouput_path):
74
-
75
  image_width, image_height = image.size
76
-
77
  img_draw = image.copy()
78
  draw = ImageDraw.Draw(img_draw)
79
 
80
- overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
81
  draw2 = ImageDraw.Draw(overlay)
82
-
83
  # try:
84
  # except IOError:
85
  # try:
86
- # font = ImageFont.truetype("DejaVuSans.ttf", 20)
87
  # except IOError:
88
  font = ImageFont.load_default()
89
 
90
  img_idx = 0
91
-
92
  for i, ref in enumerate(refs):
93
  try:
94
  result = extract_coordinates_and_label(ref, image_width, image_height)
95
  if result:
96
  label_type, points_list = result
97
-
98
- color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
99
 
100
- color_a = color + (20, )
 
 
 
 
 
 
101
  for points in points_list:
102
  x1, y1, x2, y2 = points
103
 
@@ -107,7 +111,7 @@ def draw_bounding_boxes(image, refs, ouput_path):
107
  x2 = int(x2 / 999 * image_width)
108
  y2 = int(y2 / 999 * image_height)
109
 
110
- if label_type == 'image':
111
  try:
112
  cropped = image.crop((x1, y1, x2, y2))
113
  cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
@@ -115,24 +119,35 @@ def draw_bounding_boxes(image, refs, ouput_path):
115
  print(e)
116
  pass
117
  img_idx += 1
118
-
119
  try:
120
- if label_type == 'title':
121
  draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
122
- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
 
 
 
 
 
123
  else:
124
  draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
125
- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
 
 
 
 
 
126
  text_x = x1
127
  text_y = max(0, y1 - 15)
128
-
129
-
130
  text_bbox = draw.textbbox((0, 0), label_type, font=font)
131
  text_width = text_bbox[2] - text_bbox[0]
132
  text_height = text_bbox[3] - text_bbox[1]
133
- draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
134
- fill=(255, 255, 255, 30))
135
-
 
 
136
  draw.text((text_x, text_y), label_type, font=font, fill=color)
137
  except:
138
  pass
@@ -143,17 +158,13 @@ def draw_bounding_boxes(image, refs, ouput_path):
143
 
144
 
145
  def process_image_with_refs(image, ref_texts, output_path):
146
-
147
  result_image = draw_bounding_boxes(image, ref_texts, output_path)
148
-
149
- return result_image
150
-
151
-
152
 
 
153
 
154
 
155
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
156
- best_ratio_diff = float('inf')
157
  best_ratio = (1, 1)
158
  area = width * height
159
  for ratio in target_ratios:
@@ -169,20 +180,27 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
169
  return best_ratio
170
 
171
 
172
- def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False):
 
 
173
  orig_width, orig_height = image.size
174
  aspect_ratio = orig_width / orig_height
175
 
176
  # calculate the existing image aspect ratio
177
  target_ratios = set(
178
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
179
- i * j <= max_num and i * j >= min_num)
 
 
 
 
180
  # print(target_ratios)
181
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
182
 
183
  # find the closest aspect ratio to the target
184
  target_aspect_ratio = find_closest_aspect_ratio(
185
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
 
186
 
187
  # print(target_aspect_ratio)
188
  # calculate the target width and height
@@ -198,7 +216,7 @@ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnai
198
  (i % (target_width // image_size)) * image_size,
199
  (i // (target_width // image_size)) * image_size,
200
  ((i % (target_width // image_size)) + 1) * image_size,
201
- ((i // (target_width // image_size)) + 1) * image_size
202
  )
203
  # split the image
204
  split_img = resized_img.crop(box)
@@ -210,15 +228,14 @@ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnai
210
  return processed_images, target_aspect_ratio
211
 
212
 
213
-
214
  def normalize_transform(mean, std):
215
  if mean is None and std is None:
216
  transform = None
217
  elif mean is None and std is not None:
218
- mean = [0.] * len(std)
219
  transform = transforms.Normalize(mean=mean, std=std)
220
  elif mean is not None and std is None:
221
- std = [1.] * len(mean)
222
  transform = transforms.Normalize(mean=mean, std=std)
223
  else:
224
  transform = transforms.Normalize(mean=mean, std=std)
@@ -226,11 +243,10 @@ def normalize_transform(mean, std):
226
  return transform
227
 
228
 
229
-
230
  def format_messages(
231
- conversations: List[Dict[str, str]],
232
- sft_format: str = "deepseek",
233
- system_prompt: str = "",
234
  ):
235
  """
236
  Applies the SFT template to conversation.
@@ -264,6 +280,7 @@ def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
264
 
265
  return t
266
 
 
267
  def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
268
  """
269
 
@@ -294,7 +311,7 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
294
  # print(image_path)
295
  # print('----------------')
296
  # exit()
297
-
298
  # pil_img = Image.open(image_path)
299
  pil_img = load_image(image_path)
300
  pil_img = pil_img.convert("RGB")
@@ -304,7 +321,6 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
304
 
305
 
306
  class BaseTransform(ABC):
307
-
308
  def set_rng(self, *args, **kwargs):
309
  pass
310
 
@@ -318,32 +334,32 @@ class BaseTransform(ABC):
318
 
319
  class BasicImageTransform(BaseTransform):
320
  def __init__(
321
- self,
322
  mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
323
  std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
324
- normalize: bool = True
325
  ):
326
  self.mean = mean
327
  self.std = std
328
-
329
- transform_pipelines = [
330
- transforms.ToTensor()
331
- ]
332
 
333
  normalize = normalize_transform(mean, std) if normalize else nn.Identity()
334
  if normalize is not None:
335
  transform_pipelines.append(normalize)
336
 
337
  self.transform = transforms.Compose(transform_pipelines)
338
-
339
  def __call__(self, x):
340
  x = self.transform(x)
341
  return x
342
 
 
343
  class NoEOSTextStreamer(TextStreamer):
344
  def on_finalized_text(self, text: str, stream_end: bool = False):
345
-
346
- eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
 
347
  text = text.replace(eos_text, "\n")
348
  print(text, flush=True, end="")
349
 
@@ -351,6 +367,7 @@ class NoEOSTextStreamer(TextStreamer):
351
  class DeepseekOCRConfig(DeepseekV2Config):
352
  model_type = "DeepseekOCR"
353
 
 
354
  class DeepseekOCRModel(DeepseekV2Model):
355
  config_class = DeepseekOCRConfig
356
 
@@ -361,14 +378,13 @@ class DeepseekOCRModel(DeepseekV2Model):
361
  self.vision_model = build_clip_l()
362
  # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
363
  n_embed = 1280
364
- self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
 
 
365
  embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
366
  self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
367
  self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
368
 
369
-
370
-
371
-
372
  def forward(
373
  self,
374
  input_ids: torch.LongTensor = None,
@@ -384,28 +400,23 @@ class DeepseekOCRModel(DeepseekV2Model):
384
  images_spatial_crop: Optional[torch.FloatTensor] = None,
385
  return_dict: Optional[bool] = None,
386
  ) -> Union[Tuple, BaseModelOutputWithPast]:
387
-
388
-
389
-
390
-
391
  if inputs_embeds is None:
392
  # inputs_embeds = self.embed_tokens(input_ids)
393
  inputs_embeds = self.get_input_embeddings()(input_ids)
394
 
395
-
396
-
397
- sam_model = getattr(self, 'sam_model', None)
398
  # sam_model = self.sam_model
399
- vision_model = getattr(self, 'vision_model', None)
400
-
401
-
402
-
403
- if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
404
 
 
 
 
 
 
405
  idx = 0
406
-
407
  # sam_model = torch.jit.script(sam_model)
408
-
409
  # start_time = time.time()
410
  for image, crop_shape in zip(images, images_spatial_crop):
411
  images_in_this_batch = []
@@ -414,53 +425,86 @@ class DeepseekOCRModel(DeepseekV2Model):
414
  image_ori = image[1]
415
 
416
  with torch.no_grad():
417
- # with torch.inference_mode():
418
-
419
  if torch.sum(patches).item() != 0:
420
  # P, C, H, W = patches.shape
421
  crop_flag = 1
422
  local_features_1 = sam_model(patches)
423
 
424
- local_features_2 = vision_model(patches, local_features_1)
425
  # vit_time = time.time()
426
- local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
 
 
 
 
 
 
427
  local_features = self.projector(local_features)
428
 
429
-
430
  global_features_1 = sam_model(image_ori)
431
- global_features_2 = vision_model(image_ori, global_features_1)
432
- global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
 
 
 
 
 
 
433
  global_features = self.projector(global_features)
434
 
435
- print('=====================')
436
- print('BASE: ', global_features.shape)
437
- print('PATCHES: ', local_features.shape)
438
- print('=====================')
439
 
440
  _, hw, n_dim = global_features.shape
441
- h = w = int(hw ** 0.5)
442
 
443
  _2, hw2, n_dim2 = local_features.shape
444
- h2 = w2 = int(hw2 ** 0.5)
445
 
446
  width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
447
 
448
  global_features = global_features.view(h, w, n_dim)
449
 
450
  global_features = torch.cat(
451
- [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
 
 
 
 
452
  )
453
 
454
  global_features = global_features.view(-1, n_dim)
455
 
456
-
457
- local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
 
 
 
 
 
458
  local_features = torch.cat(
459
- [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
 
 
 
 
 
 
460
  )
461
  local_features = local_features.view(-1, n_dim2)
462
 
463
- global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
 
 
 
 
 
 
 
464
 
465
  # end_time = time.time()
466
 
@@ -469,32 +513,42 @@ class DeepseekOCRModel(DeepseekV2Model):
469
  # print('all: ', end_time - start_time)
470
 
471
  # exit()
472
-
473
  else:
474
  global_features_1 = sam_model(image_ori)
475
- global_features_2 = vision_model(image_ori, global_features_1)
476
- global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
 
 
 
 
 
 
477
  global_features = self.projector(global_features)
478
- print('=====================')
479
- print('BASE: ', global_features.shape)
480
- print('NO PATCHES')
481
- print('=====================')
482
  _, hw, n_dim = global_features.shape
483
- h = w = int(hw ** 0.5)
484
-
485
 
486
  global_features = global_features.view(h, w, n_dim)
487
 
488
  global_features = torch.cat(
489
- [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
 
 
 
 
490
  )
491
 
492
  global_features = global_features.view(-1, n_dim)
493
 
494
- global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
 
 
495
 
496
  images_in_this_batch.append(global_local_features)
497
-
498
 
499
  # print(inputs_embeds.shape)
500
 
@@ -502,21 +556,27 @@ class DeepseekOCRModel(DeepseekV2Model):
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
 
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
 
 
 
506
 
507
  idx += 1
508
-
509
 
510
  return super(DeepseekOCRModel, self).forward(
511
- input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
512
- inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
513
- output_attentions=output_attentions, output_hidden_states=output_hidden_states,
514
- return_dict=return_dict
 
 
 
 
 
515
  )
516
-
517
 
518
- class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
519
 
 
520
  config_class = DeepseekOCRConfig
521
  # supports_gradient_checkpointing = True
522
 
@@ -536,7 +596,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
536
  def get_model(self):
537
  return self.model
538
 
539
-
540
  def forward(
541
  self,
542
  input_ids: torch.LongTensor = None,
@@ -552,17 +611,22 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
552
  images_seq_mask: Optional[torch.FloatTensor] = None,
553
  images_spatial_crop: Optional[torch.FloatTensor] = None,
554
  return_dict: Optional[bool] = None,
555
-
556
  ) -> Union[Tuple, CausalLMOutputWithPast]:
557
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
558
  output_hidden_states = (
559
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
560
  )
561
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
562
-
563
-
564
 
565
- outputs = self.model(
566
  input_ids=input_ids,
567
  past_key_values=past_key_values,
568
  attention_mask=attention_mask,
@@ -572,14 +636,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
572
  output_attentions=output_attentions,
573
  output_hidden_states=output_hidden_states,
574
  images=images,
575
- images_seq_mask = images_seq_mask,
576
- images_spatial_crop = images_spatial_crop,
577
- return_dict=return_dict
578
-
579
  )
580
 
581
-
582
-
583
  # print(transformer_outputs)
584
 
585
  hidden_states = outputs[0]
@@ -613,9 +674,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
613
  attentions=outputs.attentions,
614
  )
615
 
616
-
617
  def prepare_inputs_for_generation(
618
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
619
  ):
620
  # Omit tokens covered by past_key_values
621
  past_length = 0
@@ -632,7 +697,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
632
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
633
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
634
  # input)
635
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
 
 
 
636
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
637
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
638
  # input_ids based on the past_length.
@@ -668,7 +736,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
668
 
669
  # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
670
  # same goes for position ids. Could also help with continued generation.
671
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
 
 
 
 
672
 
673
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
674
  if inputs_embeds is not None and past_key_values is None:
@@ -688,45 +760,55 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
688
  }
689
  )
690
  return model_inputs
691
-
692
 
693
  def disable_torch_init(self):
694
  """
695
  Disable the redundant torch default initialization to accelerate model creation.
696
  """
697
  import torch
 
698
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
699
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
700
 
701
-
702
-
703
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
 
 
 
 
 
 
 
 
 
 
704
  self.disable_torch_init()
705
 
706
  os.makedirs(output_path, exist_ok=True)
707
- os.makedirs(f'{output_path}/images', exist_ok=True)
708
 
709
  if prompt and image_file:
710
  conversation = [
711
  {
712
  "role": "<|User|>",
713
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
714
- "content": f'{prompt}',
715
  # "content": "君不见黄河之水天上来的下一句是什么?",
716
  # "content": "<image>\nFree OCR. ",
717
  # "content": "<image>\nParse the figure. ",
718
  # "content": "<image>\nExtract the text in the image. ",
719
- "images": [f'{image_file}'],
720
  },
721
  {"role": "<|Assistant|>", "content": ""},
722
  ]
723
-
724
  elif prompt:
725
  conversation = [
726
  {
727
  "role": "<|User|>",
728
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
729
- "content": f'{prompt}',
730
  # "content": "君不见黄河之水天上来的下一句是什么?",
731
  # "content": "<image>\nFree OCR. ",
732
  # "content": "<image>\nParse the figure. ",
@@ -736,9 +818,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
736
  {"role": "<|Assistant|>", "content": ""},
737
  ]
738
  else:
739
- assert False, f'prompt is none!'
740
-
741
- prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
 
 
742
 
743
  patch_size = 16
744
  downsample_ratio = 4
@@ -749,15 +833,16 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
749
 
750
  image_draw = images[0].copy()
751
 
752
- w,h = image_draw.size
753
  # print(w, h)
754
  ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
755
-
756
 
757
- image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
 
 
758
  images_seq_mask = []
759
 
760
- image_token = '<image>'
761
  image_token_id = 128815
762
  text_splits = prompt.split(image_token)
763
 
@@ -765,13 +850,11 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
765
  tokenized_str = []
766
  images_spatial_crop = []
767
  for text_sep, image in zip(text_splits, images):
768
-
769
  tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
770
  tokenized_str += tokenized_sep
771
  images_seq_mask += [False] * len(tokenized_sep)
772
 
773
  if crop_mode:
774
-
775
  if image.size[0] <= 640 and image.size[1] <= 640:
776
  crop_ratio = [1, 1]
777
 
@@ -782,23 +865,22 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
782
  else:
783
  # best_width, best_height = self.image_size, self.image_size
784
  crop_ratio = [1, 1]
785
-
786
  """process the global view"""
787
  # image = image.resize((base_size, base_size))
788
- global_view = ImageOps.pad(image, (base_size, base_size),
789
- color=tuple(int(x * 255) for x in image_transform.mean))
790
-
 
 
 
791
  if base_size == 1024:
792
  valid_img_tokens += int(256 * ratio)
793
  elif base_size == 1280:
794
  valid_img_tokens += int(400 * ratio)
795
  # elif base_size == 640:
796
  # valid_img_tokens += int(100 * ratio)
797
-
798
-
799
 
800
-
801
-
802
  images_list.append(image_transform(global_view).to(torch.bfloat16))
803
 
804
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
@@ -806,31 +888,34 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
806
  width_crop_num, height_crop_num = crop_ratio
807
 
808
  images_spatial_crop.append([width_crop_num, height_crop_num])
809
-
810
-
811
  if width_crop_num > 1 or height_crop_num > 1:
812
  """process the local views"""
813
-
814
  for i in range(len(images_crop_raw)):
815
- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
816
-
 
 
817
  if image_size == 640:
818
  valid_img_tokens += len(images_crop_list) * 100
819
 
820
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
821
- num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
822
-
823
-
824
 
825
  """add image tokens"""
826
 
827
-
828
-
829
- tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
830
  tokenized_image += [image_token_id]
831
  if width_crop_num > 1 or height_crop_num > 1:
832
- tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
833
- num_queries * height_crop_num)
 
 
834
  tokenized_str += tokenized_image
835
  images_seq_mask += [True] * len(tokenized_image)
836
  # num_image_tokens.append(len(tokenized_image))
@@ -841,11 +926,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
841
 
842
  """process the global view"""
843
  if image_size <= 640:
844
- print('directly resize')
845
  image = image.resize((image_size, image_size))
846
  # else:
847
- global_view = ImageOps.pad(image, (image_size, image_size),
848
- color=tuple(int(x * 255) for x in image_transform.mean))
 
 
 
849
  images_list.append(image_transform(global_view).to(torch.bfloat16))
850
 
851
  if base_size == 1024:
@@ -861,18 +949,18 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
861
 
862
  images_spatial_crop.append([width_crop_num, height_crop_num])
863
 
864
-
865
  """add image tokens"""
866
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
867
 
868
- tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
 
 
869
  tokenized_image += [image_token_id]
870
  # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
871
  # num_queries * height_crop_num)
872
  tokenized_str += tokenized_image
873
  images_seq_mask += [True] * len(tokenized_image)
874
  # num_image_tokens.append(len(tokenized_image))
875
-
876
 
877
  """process the last text split"""
878
  tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
@@ -881,19 +969,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
881
 
882
  """add the bos tokens"""
883
  bos_id = 0
884
- tokenized_str = [bos_id] + tokenized_str
885
  images_seq_mask = [False] + images_seq_mask
886
 
887
-
888
-
889
  input_ids = torch.LongTensor(tokenized_str)
890
 
891
-
892
-
893
-
894
  images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
895
 
896
-
897
  if len(images_list) == 0:
898
  images_ori = torch.zeros((1, 3, image_size, image_size))
899
  images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
@@ -907,131 +989,157 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
907
  else:
908
  images_crop = torch.zeros((1, 3, base_size, base_size))
909
 
910
-
911
-
912
  if not eval_mode:
913
- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
- with torch.autocast("cuda", dtype=torch.bfloat16):
 
 
915
  with torch.no_grad():
916
  output_ids = self.generate(
917
- input_ids.unsqueeze(0).cuda(),
918
- images=[(images_crop.cuda(), images_ori.cuda())],
919
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
920
- images_spatial_crop = images_spatial_crop,
 
 
921
  # do_sample=False,
922
  # num_beams = 1,
923
  temperature=0.0,
924
  eos_token_id=tokenizer.eos_token_id,
925
  streamer=streamer,
926
  max_new_tokens=8192,
927
- no_repeat_ngram_size = 20,
928
- use_cache = True
929
- )
930
 
931
  else:
932
- with torch.autocast("cuda", dtype=torch.bfloat16):
933
  with torch.no_grad():
934
  output_ids = self.generate(
935
- input_ids.unsqueeze(0).cuda(),
936
- images=[(images_crop.cuda(), images_ori.cuda())],
937
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
938
- images_spatial_crop = images_spatial_crop,
 
 
939
  # do_sample=False,
940
  # num_beams = 1,
941
  temperature=0.0,
942
  eos_token_id=tokenizer.eos_token_id,
943
  max_new_tokens=8192,
944
- no_repeat_ngram_size = 35,
945
- use_cache = True
946
- )
947
-
948
-
949
- if '<image>' in conversation[0]['content'] and eval_mode:
950
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
951
- stop_str = '<|end▁of▁sentence|>'
952
- if outputs.endswith(stop_str):
953
- outputs = outputs[:-len(stop_str)]
954
- # re_match
955
- outputs = outputs.strip()
956
-
957
- return outputs
958
-
959
- if '<image>' in conversation[0]['content'] and test_compress:
960
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
961
- pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
962
- print('='*50)
963
- print('image size: ', (w, h))
964
- print('valid image tokens: ', int(valid_img_tokens))
965
- print('output texts tokens (valid): ', pure_texts_outputs_token_length)
966
- print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
967
- print('='*50)
968
-
969
-
970
- if '<image>' in conversation[0]['content'] and save_results:
971
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
972
- stop_str = '<|end▁of▁sentence|>'
973
-
974
- print('='*15 + 'save results:' + '='*15)
975
-
 
 
 
 
 
 
 
 
 
976
  # # # # conv.messages[-1][-1] = outputs
977
  if outputs.endswith(stop_str):
978
- outputs = outputs[:-len(stop_str)]
979
  outputs = outputs.strip()
980
 
981
  matches_ref, matches_images, mathes_other = re_match(outputs)
982
  # print(matches_ref)
983
  result = process_image_with_refs(image_draw, matches_ref, output_path)
984
 
985
-
986
  for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
987
- outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
988
-
989
- for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
990
- outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
991
 
 
 
 
 
 
 
992
 
993
  # if 'structural formula' in conversation[0]['content']:
994
  # outputs = '<smiles>' + outputs + '</smiles>'
995
- with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
996
  afile.write(outputs)
997
 
998
- if 'line_type' in outputs:
999
  import matplotlib.pyplot as plt
1000
- lines = eval(outputs)['Line']['line']
1001
 
1002
- line_type = eval(outputs)['Line']['line_type']
 
 
1003
  # print(lines)
1004
 
1005
- endpoints = eval(outputs)['Line']['line_endpoint']
1006
 
1007
- fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1008
  ax.set_xlim(-15, 15)
1009
  ax.set_ylim(-15, 15)
1010
 
1011
  for idx, line in enumerate(lines):
1012
  try:
1013
- p0 = eval(line.split(' -- ')[0])
1014
- p1 = eval(line.split(' -- ')[-1])
1015
 
1016
- if line_type[idx] == '--':
1017
- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
 
 
1018
  else:
1019
- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
 
 
1020
 
1021
- ax.scatter(p0[0], p0[1], s=5, color = 'k')
1022
- ax.scatter(p1[0], p1[1], s=5, color = 'k')
1023
  except:
1024
  pass
1025
 
1026
  for endpoint in endpoints:
1027
-
1028
- label = endpoint.split(': ')[0]
1029
- (x, y) = eval(endpoint.split(': ')[1])
1030
- ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1031
- fontsize=5, fontweight='light')
1032
-
1033
-
1034
- plt.savefig(f'{output_path}/geo.jpg')
 
 
 
 
1035
  plt.close()
1036
 
1037
  result.save(f"{output_path}/result_with_boxes.jpg")
 
1
  from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2
  from .configuration_deepseek_v2 import DeepseekV2Config
3
+ from transformers.modeling_outputs import (
4
+ BaseModelOutputWithPast,
5
+ CausalLMOutputWithPast,
6
+ )
7
  from typing import List, Optional, Tuple, Union
8
  from transformers.cache_utils import Cache
9
  import requests
 
28
 
29
 
30
  def load_image(image_path):
 
31
  try:
32
  image = Image.open(image_path)
33
+
34
  corrected_image = ImageOps.exif_transpose(image)
35
+
36
  return corrected_image
37
+
38
  except Exception as e:
39
  print(f"error: {e}")
40
  try:
 
44
 
45
 
46
  def re_match(text):
47
+ pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
48
  matches = re.findall(pattern, text, re.DOTALL)
49
 
50
  # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
 
53
  mathes_image = []
54
  mathes_other = []
55
  for a_match in matches:
56
+ if "<|ref|>image<|/ref|>" in a_match[0]:
57
  mathes_image.append(a_match[0])
58
  else:
59
  mathes_other.append(a_match[0])
 
61
 
62
 
63
  def extract_coordinates_and_label(ref_text, image_width, image_height):
 
64
  try:
65
  label_type = ref_text[1]
66
  cor_list = eval(ref_text[2])
 
72
 
73
 
74
  def draw_bounding_boxes(image, refs, ouput_path):
 
75
  image_width, image_height = image.size
76
+
77
  img_draw = image.copy()
78
  draw = ImageDraw.Draw(img_draw)
79
 
80
+ overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
81
  draw2 = ImageDraw.Draw(overlay)
82
+
83
  # try:
84
  # except IOError:
85
  # try:
86
+ # font = ImageFont.truetype("DejaVuSans.ttf", 20)
87
  # except IOError:
88
  font = ImageFont.load_default()
89
 
90
  img_idx = 0
91
+
92
  for i, ref in enumerate(refs):
93
  try:
94
  result = extract_coordinates_and_label(ref, image_width, image_height)
95
  if result:
96
  label_type, points_list = result
 
 
97
 
98
+ color = (
99
+ np.random.randint(0, 200),
100
+ np.random.randint(0, 200),
101
+ np.random.randint(0, 255),
102
+ )
103
+
104
+ color_a = color + (20,)
105
  for points in points_list:
106
  x1, y1, x2, y2 = points
107
 
 
111
  x2 = int(x2 / 999 * image_width)
112
  y2 = int(y2 / 999 * image_height)
113
 
114
+ if label_type == "image":
115
  try:
116
  cropped = image.crop((x1, y1, x2, y2))
117
  cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
 
119
  print(e)
120
  pass
121
  img_idx += 1
122
+
123
  try:
124
+ if label_type == "title":
125
  draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
126
+ draw2.rectangle(
127
+ [x1, y1, x2, y2],
128
+ fill=color_a,
129
+ outline=(0, 0, 0, 0),
130
+ width=1,
131
+ )
132
  else:
133
  draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
134
+ draw2.rectangle(
135
+ [x1, y1, x2, y2],
136
+ fill=color_a,
137
+ outline=(0, 0, 0, 0),
138
+ width=1,
139
+ )
140
  text_x = x1
141
  text_y = max(0, y1 - 15)
142
+
 
143
  text_bbox = draw.textbbox((0, 0), label_type, font=font)
144
  text_width = text_bbox[2] - text_bbox[0]
145
  text_height = text_bbox[3] - text_bbox[1]
146
+ draw.rectangle(
147
+ [text_x, text_y, text_x + text_width, text_y + text_height],
148
+ fill=(255, 255, 255, 30),
149
+ )
150
+
151
  draw.text((text_x, text_y), label_type, font=font, fill=color)
152
  except:
153
  pass
 
158
 
159
 
160
  def process_image_with_refs(image, ref_texts, output_path):
 
161
  result_image = draw_bounding_boxes(image, ref_texts, output_path)
 
 
 
 
162
 
163
+ return result_image
164
 
165
 
166
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
167
+ best_ratio_diff = float("inf")
168
  best_ratio = (1, 1)
169
  area = width * height
170
  for ratio in target_ratios:
 
180
  return best_ratio
181
 
182
 
183
+ def dynamic_preprocess(
184
+ image, min_num=2, max_num=9, image_size=640, use_thumbnail=False
185
+ ):
186
  orig_width, orig_height = image.size
187
  aspect_ratio = orig_width / orig_height
188
 
189
  # calculate the existing image aspect ratio
190
  target_ratios = set(
191
+ (i, j)
192
+ for n in range(min_num, max_num + 1)
193
+ for i in range(1, n + 1)
194
+ for j in range(1, n + 1)
195
+ if i * j <= max_num and i * j >= min_num
196
+ )
197
  # print(target_ratios)
198
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
199
 
200
  # find the closest aspect ratio to the target
201
  target_aspect_ratio = find_closest_aspect_ratio(
202
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
203
+ )
204
 
205
  # print(target_aspect_ratio)
206
  # calculate the target width and height
 
216
  (i % (target_width // image_size)) * image_size,
217
  (i // (target_width // image_size)) * image_size,
218
  ((i % (target_width // image_size)) + 1) * image_size,
219
+ ((i // (target_width // image_size)) + 1) * image_size,
220
  )
221
  # split the image
222
  split_img = resized_img.crop(box)
 
228
  return processed_images, target_aspect_ratio
229
 
230
 
 
231
  def normalize_transform(mean, std):
232
  if mean is None and std is None:
233
  transform = None
234
  elif mean is None and std is not None:
235
+ mean = [0.0] * len(std)
236
  transform = transforms.Normalize(mean=mean, std=std)
237
  elif mean is not None and std is None:
238
+ std = [1.0] * len(mean)
239
  transform = transforms.Normalize(mean=mean, std=std)
240
  else:
241
  transform = transforms.Normalize(mean=mean, std=std)
 
243
  return transform
244
 
245
 
 
246
  def format_messages(
247
+ conversations: List[Dict[str, str]],
248
+ sft_format: str = "deepseek",
249
+ system_prompt: str = "",
250
  ):
251
  """
252
  Applies the SFT template to conversation.
 
280
 
281
  return t
282
 
283
+
284
  def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
285
  """
286
 
 
311
  # print(image_path)
312
  # print('----------------')
313
  # exit()
314
+
315
  # pil_img = Image.open(image_path)
316
  pil_img = load_image(image_path)
317
  pil_img = pil_img.convert("RGB")
 
321
 
322
 
323
  class BaseTransform(ABC):
 
324
  def set_rng(self, *args, **kwargs):
325
  pass
326
 
 
334
 
335
  class BasicImageTransform(BaseTransform):
336
  def __init__(
337
+ self,
338
  mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
339
  std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
340
+ normalize: bool = True,
341
  ):
342
  self.mean = mean
343
  self.std = std
344
+
345
+ transform_pipelines = [transforms.ToTensor()]
 
 
346
 
347
  normalize = normalize_transform(mean, std) if normalize else nn.Identity()
348
  if normalize is not None:
349
  transform_pipelines.append(normalize)
350
 
351
  self.transform = transforms.Compose(transform_pipelines)
352
+
353
  def __call__(self, x):
354
  x = self.transform(x)
355
  return x
356
 
357
+
358
  class NoEOSTextStreamer(TextStreamer):
359
  def on_finalized_text(self, text: str, stream_end: bool = False):
360
+ eos_text = self.tokenizer.decode(
361
+ [self.tokenizer.eos_token_id], skip_special_tokens=False
362
+ )
363
  text = text.replace(eos_text, "\n")
364
  print(text, flush=True, end="")
365
 
 
367
  class DeepseekOCRConfig(DeepseekV2Config):
368
  model_type = "DeepseekOCR"
369
 
370
+
371
  class DeepseekOCRModel(DeepseekV2Model):
372
  config_class = DeepseekOCRConfig
373
 
 
378
  self.vision_model = build_clip_l()
379
  # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
380
  n_embed = 1280
381
+ self.projector = MlpProjector(
382
+ Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)
383
+ )
384
  embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
385
  self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
386
  self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
387
 
 
 
 
388
  def forward(
389
  self,
390
  input_ids: torch.LongTensor = None,
 
400
  images_spatial_crop: Optional[torch.FloatTensor] = None,
401
  return_dict: Optional[bool] = None,
402
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
 
 
 
403
  if inputs_embeds is None:
404
  # inputs_embeds = self.embed_tokens(input_ids)
405
  inputs_embeds = self.get_input_embeddings()(input_ids)
406
 
407
+ sam_model = getattr(self, "sam_model", None)
 
 
408
  # sam_model = self.sam_model
409
+ vision_model = getattr(self, "vision_model", None)
 
 
 
 
410
 
411
+ if (
412
+ sam_model is not None
413
+ and (input_ids.shape[1] != 1 or self.training)
414
+ and torch.sum(images[0][1]).item() != 0
415
+ ):
416
  idx = 0
417
+
418
  # sam_model = torch.jit.script(sam_model)
419
+
420
  # start_time = time.time()
421
  for image, crop_shape in zip(images, images_spatial_crop):
422
  images_in_this_batch = []
 
425
  image_ori = image[1]
426
 
427
  with torch.no_grad():
428
+ # with torch.inference_mode():
429
+
430
  if torch.sum(patches).item() != 0:
431
  # P, C, H, W = patches.shape
432
  crop_flag = 1
433
  local_features_1 = sam_model(patches)
434
 
435
+ local_features_2 = vision_model(patches, local_features_1)
436
  # vit_time = time.time()
437
+ local_features = torch.cat(
438
+ (
439
+ local_features_2[:, 1:],
440
+ local_features_1.flatten(2).permute(0, 2, 1),
441
+ ),
442
+ dim=-1,
443
+ )
444
  local_features = self.projector(local_features)
445
 
 
446
  global_features_1 = sam_model(image_ori)
447
+ global_features_2 = vision_model(image_ori, global_features_1)
448
+ global_features = torch.cat(
449
+ (
450
+ global_features_2[:, 1:],
451
+ global_features_1.flatten(2).permute(0, 2, 1),
452
+ ),
453
+ dim=-1,
454
+ )
455
  global_features = self.projector(global_features)
456
 
457
+ print("=====================")
458
+ print("BASE: ", global_features.shape)
459
+ print("PATCHES: ", local_features.shape)
460
+ print("=====================")
461
 
462
  _, hw, n_dim = global_features.shape
463
+ h = w = int(hw**0.5)
464
 
465
  _2, hw2, n_dim2 = local_features.shape
466
+ h2 = w2 = int(hw2**0.5)
467
 
468
  width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
469
 
470
  global_features = global_features.view(h, w, n_dim)
471
 
472
  global_features = torch.cat(
473
+ [
474
+ global_features,
475
+ self.image_newline[None, None, :].expand(h, 1, n_dim),
476
+ ],
477
+ dim=1,
478
  )
479
 
480
  global_features = global_features.view(-1, n_dim)
481
 
482
+ local_features = (
483
+ local_features.view(
484
+ height_crop_num, width_crop_num, h2, w2, n_dim2
485
+ )
486
+ .permute(0, 2, 1, 3, 4)
487
+ .reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
488
+ )
489
  local_features = torch.cat(
490
+ [
491
+ local_features,
492
+ self.image_newline[None, None, :].expand(
493
+ height_crop_num * h2, 1, n_dim2
494
+ ),
495
+ ],
496
+ dim=1,
497
  )
498
  local_features = local_features.view(-1, n_dim2)
499
 
500
+ global_local_features = torch.cat(
501
+ [
502
+ local_features,
503
+ global_features,
504
+ self.view_seperator[None, :],
505
+ ],
506
+ dim=0,
507
+ )
508
 
509
  # end_time = time.time()
510
 
 
513
  # print('all: ', end_time - start_time)
514
 
515
  # exit()
516
+
517
  else:
518
  global_features_1 = sam_model(image_ori)
519
+ global_features_2 = vision_model(image_ori, global_features_1)
520
+ global_features = torch.cat(
521
+ (
522
+ global_features_2[:, 1:],
523
+ global_features_1.flatten(2).permute(0, 2, 1),
524
+ ),
525
+ dim=-1,
526
+ )
527
  global_features = self.projector(global_features)
528
+ print("=====================")
529
+ print("BASE: ", global_features.shape)
530
+ print("NO PATCHES")
531
+ print("=====================")
532
  _, hw, n_dim = global_features.shape
533
+ h = w = int(hw**0.5)
 
534
 
535
  global_features = global_features.view(h, w, n_dim)
536
 
537
  global_features = torch.cat(
538
+ [
539
+ global_features,
540
+ self.image_newline[None, None, :].expand(h, 1, n_dim),
541
+ ],
542
+ dim=1,
543
  )
544
 
545
  global_features = global_features.view(-1, n_dim)
546
 
547
+ global_local_features = torch.cat(
548
+ [global_features, self.view_seperator[None, :]], dim=0
549
+ )
550
 
551
  images_in_this_batch.append(global_local_features)
 
552
 
553
  # print(inputs_embeds.shape)
554
 
 
556
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
557
  # exit()
558
 
559
+ inputs_embeds[idx].masked_scatter_(
560
+ images_seq_mask[idx].unsqueeze(-1).to(self.device),
561
+ images_in_this_batch,
562
+ )
563
 
564
  idx += 1
 
565
 
566
  return super(DeepseekOCRModel, self).forward(
567
+ input_ids=None,
568
+ attention_mask=attention_mask,
569
+ past_key_values=past_key_values,
570
+ inputs_embeds=inputs_embeds,
571
+ use_cache=use_cache,
572
+ position_ids=position_ids,
573
+ output_attentions=output_attentions,
574
+ output_hidden_states=output_hidden_states,
575
+ return_dict=return_dict,
576
  )
 
577
 
 
578
 
579
+ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
580
  config_class = DeepseekOCRConfig
581
  # supports_gradient_checkpointing = True
582
 
 
596
  def get_model(self):
597
  return self.model
598
 
 
599
  def forward(
600
  self,
601
  input_ids: torch.LongTensor = None,
 
611
  images_seq_mask: Optional[torch.FloatTensor] = None,
612
  images_spatial_crop: Optional[torch.FloatTensor] = None,
613
  return_dict: Optional[bool] = None,
 
614
  ) -> Union[Tuple, CausalLMOutputWithPast]:
615
+ output_attentions = (
616
+ output_attentions
617
+ if output_attentions is not None
618
+ else self.config.output_attentions
619
+ )
620
  output_hidden_states = (
621
+ output_hidden_states
622
+ if output_hidden_states is not None
623
+ else self.config.output_hidden_states
624
+ )
625
+ return_dict = (
626
+ return_dict if return_dict is not None else self.config.use_return_dict
627
  )
 
 
 
628
 
629
+ outputs = self.model(
630
  input_ids=input_ids,
631
  past_key_values=past_key_values,
632
  attention_mask=attention_mask,
 
636
  output_attentions=output_attentions,
637
  output_hidden_states=output_hidden_states,
638
  images=images,
639
+ images_seq_mask=images_seq_mask,
640
+ images_spatial_crop=images_spatial_crop,
641
+ return_dict=return_dict,
 
642
  )
643
 
 
 
644
  # print(transformer_outputs)
645
 
646
  hidden_states = outputs[0]
 
674
  attentions=outputs.attentions,
675
  )
676
 
 
677
  def prepare_inputs_for_generation(
678
+ self,
679
+ input_ids,
680
+ past_key_values=None,
681
+ attention_mask=None,
682
+ inputs_embeds=None,
683
+ **kwargs,
684
  ):
685
  # Omit tokens covered by past_key_values
686
  past_length = 0
 
697
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
698
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
699
  # input)
700
+ if (
701
+ attention_mask is not None
702
+ and attention_mask.shape[1] > input_ids.shape[1]
703
+ ):
704
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
705
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
706
  # input_ids based on the past_length.
 
736
 
737
  # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
738
  # same goes for position ids. Could also help with continued generation.
739
+ cache_position = torch.arange(
740
+ past_length,
741
+ past_length + position_ids.shape[-1],
742
+ device=position_ids.device,
743
+ )
744
 
745
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
746
  if inputs_embeds is not None and past_key_values is None:
 
760
  }
761
  )
762
  return model_inputs
 
763
 
764
  def disable_torch_init(self):
765
  """
766
  Disable the redundant torch default initialization to accelerate model creation.
767
  """
768
  import torch
769
+
770
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
771
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
772
 
773
+ def infer(
774
+ self,
775
+ tokenizer,
776
+ prompt="",
777
+ image_file="",
778
+ output_path="",
779
+ base_size=1024,
780
+ image_size=640,
781
+ crop_mode=True,
782
+ test_compress=False,
783
+ save_results=False,
784
+ eval_mode=False,
785
+ ):
786
  self.disable_torch_init()
787
 
788
  os.makedirs(output_path, exist_ok=True)
789
+ os.makedirs(f"{output_path}/images", exist_ok=True)
790
 
791
  if prompt and image_file:
792
  conversation = [
793
  {
794
  "role": "<|User|>",
795
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
796
+ "content": f"{prompt}",
797
  # "content": "君不见黄河之水天上来的下一句是什么?",
798
  # "content": "<image>\nFree OCR. ",
799
  # "content": "<image>\nParse the figure. ",
800
  # "content": "<image>\nExtract the text in the image. ",
801
+ "images": [f"{image_file}"],
802
  },
803
  {"role": "<|Assistant|>", "content": ""},
804
  ]
805
+
806
  elif prompt:
807
  conversation = [
808
  {
809
  "role": "<|User|>",
810
  # "content": "<image>\n<|grounding|>Given the layout of the image. ",
811
+ "content": f"{prompt}",
812
  # "content": "君不见黄河之水天上来的下一句是什么?",
813
  # "content": "<image>\nFree OCR. ",
814
  # "content": "<image>\nParse the figure. ",
 
818
  {"role": "<|Assistant|>", "content": ""},
819
  ]
820
  else:
821
+ assert False, f"prompt is none!"
822
+
823
+ prompt = format_messages(
824
+ conversations=conversation, sft_format="plain", system_prompt=""
825
+ )
826
 
827
  patch_size = 16
828
  downsample_ratio = 4
 
833
 
834
  image_draw = images[0].copy()
835
 
836
+ w, h = image_draw.size
837
  # print(w, h)
838
  ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
 
839
 
840
+ image_transform = BasicImageTransform(
841
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True
842
+ )
843
  images_seq_mask = []
844
 
845
+ image_token = "<image>"
846
  image_token_id = 128815
847
  text_splits = prompt.split(image_token)
848
 
 
850
  tokenized_str = []
851
  images_spatial_crop = []
852
  for text_sep, image in zip(text_splits, images):
 
853
  tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
854
  tokenized_str += tokenized_sep
855
  images_seq_mask += [False] * len(tokenized_sep)
856
 
857
  if crop_mode:
 
858
  if image.size[0] <= 640 and image.size[1] <= 640:
859
  crop_ratio = [1, 1]
860
 
 
865
  else:
866
  # best_width, best_height = self.image_size, self.image_size
867
  crop_ratio = [1, 1]
868
+
869
  """process the global view"""
870
  # image = image.resize((base_size, base_size))
871
+ global_view = ImageOps.pad(
872
+ image,
873
+ (base_size, base_size),
874
+ color=tuple(int(x * 255) for x in image_transform.mean),
875
+ )
876
+
877
  if base_size == 1024:
878
  valid_img_tokens += int(256 * ratio)
879
  elif base_size == 1280:
880
  valid_img_tokens += int(400 * ratio)
881
  # elif base_size == 640:
882
  # valid_img_tokens += int(100 * ratio)
 
 
883
 
 
 
884
  images_list.append(image_transform(global_view).to(torch.bfloat16))
885
 
886
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
 
888
  width_crop_num, height_crop_num = crop_ratio
889
 
890
  images_spatial_crop.append([width_crop_num, height_crop_num])
891
+
 
892
  if width_crop_num > 1 or height_crop_num > 1:
893
  """process the local views"""
894
+
895
  for i in range(len(images_crop_raw)):
896
+ images_crop_list.append(
897
+ image_transform(images_crop_raw[i]).to(torch.bfloat16)
898
+ )
899
+
900
  if image_size == 640:
901
  valid_img_tokens += len(images_crop_list) * 100
902
 
903
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
904
+ num_queries_base = math.ceil(
905
+ (base_size // patch_size) / downsample_ratio
906
+ )
907
 
908
  """add image tokens"""
909
 
910
+ tokenized_image = (
911
+ [image_token_id] * num_queries_base + [image_token_id]
912
+ ) * num_queries_base
913
  tokenized_image += [image_token_id]
914
  if width_crop_num > 1 or height_crop_num > 1:
915
+ tokenized_image += (
916
+ [image_token_id] * (num_queries * width_crop_num)
917
+ + [image_token_id]
918
+ ) * (num_queries * height_crop_num)
919
  tokenized_str += tokenized_image
920
  images_seq_mask += [True] * len(tokenized_image)
921
  # num_image_tokens.append(len(tokenized_image))
 
926
 
927
  """process the global view"""
928
  if image_size <= 640:
929
+ print("directly resize")
930
  image = image.resize((image_size, image_size))
931
  # else:
932
+ global_view = ImageOps.pad(
933
+ image,
934
+ (image_size, image_size),
935
+ color=tuple(int(x * 255) for x in image_transform.mean),
936
+ )
937
  images_list.append(image_transform(global_view).to(torch.bfloat16))
938
 
939
  if base_size == 1024:
 
949
 
950
  images_spatial_crop.append([width_crop_num, height_crop_num])
951
 
 
952
  """add image tokens"""
953
  num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
954
 
955
+ tokenized_image = (
956
+ [image_token_id] * num_queries + [image_token_id]
957
+ ) * num_queries
958
  tokenized_image += [image_token_id]
959
  # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
960
  # num_queries * height_crop_num)
961
  tokenized_str += tokenized_image
962
  images_seq_mask += [True] * len(tokenized_image)
963
  # num_image_tokens.append(len(tokenized_image))
 
964
 
965
  """process the last text split"""
966
  tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
 
969
 
970
  """add the bos tokens"""
971
  bos_id = 0
972
+ tokenized_str = [bos_id] + tokenized_str
973
  images_seq_mask = [False] + images_seq_mask
974
 
 
 
975
  input_ids = torch.LongTensor(tokenized_str)
976
 
 
 
 
977
  images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
978
 
 
979
  if len(images_list) == 0:
980
  images_ori = torch.zeros((1, 3, image_size, image_size))
981
  images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
 
989
  else:
990
  images_crop = torch.zeros((1, 3, base_size, base_size))
991
 
 
 
992
  if not eval_mode:
993
+ streamer = NoEOSTextStreamer(
994
+ tokenizer, skip_prompt=True, skip_special_tokens=False
995
+ )
996
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
997
  with torch.no_grad():
998
  output_ids = self.generate(
999
+ input_ids.unsqueeze(0).to(self.device),
1000
+ images=[
1001
+ (images_crop.to(self.device), images_ori.to(self.device))
1002
+ ],
1003
+ images_seq_mask=images_seq_mask.unsqueeze(0).to(self.device),
1004
+ images_spatial_crop=images_spatial_crop,
1005
  # do_sample=False,
1006
  # num_beams = 1,
1007
  temperature=0.0,
1008
  eos_token_id=tokenizer.eos_token_id,
1009
  streamer=streamer,
1010
  max_new_tokens=8192,
1011
+ no_repeat_ngram_size=20,
1012
+ use_cache=True,
1013
+ )
1014
 
1015
  else:
1016
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
1017
  with torch.no_grad():
1018
  output_ids = self.generate(
1019
+ input_ids.unsqueeze(0).to(self.device),
1020
+ images=[
1021
+ (images_crop.to(self.device), images_ori.to(self.device))
1022
+ ],
1023
+ images_seq_mask=images_seq_mask.unsqueeze(0).to(self.device),
1024
+ images_spatial_crop=images_spatial_crop,
1025
  # do_sample=False,
1026
  # num_beams = 1,
1027
  temperature=0.0,
1028
  eos_token_id=tokenizer.eos_token_id,
1029
  max_new_tokens=8192,
1030
+ no_repeat_ngram_size=35,
1031
+ use_cache=True,
1032
+ )
1033
+
1034
+ if "<image>" in conversation[0]["content"] and eval_mode:
1035
+ outputs = tokenizer.decode(
1036
+ output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
1037
+ )
1038
+ stop_str = "<|end▁of▁sentence|>"
1039
+ if outputs.endswith(stop_str):
1040
+ outputs = outputs[: -len(stop_str)]
1041
+ # re_match
1042
+ outputs = outputs.strip()
1043
+
1044
+ return outputs
1045
+
1046
+ if "<image>" in conversation[0]["content"] and test_compress:
1047
+ outputs = tokenizer.decode(
1048
+ output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
1049
+ )
1050
+ pure_texts_outputs_token_length = len(
1051
+ text_encode(tokenizer, outputs, bos=False, eos=False)
1052
+ )
1053
+ print("=" * 50)
1054
+ print("image size: ", (w, h))
1055
+ print("valid image tokens: ", int(valid_img_tokens))
1056
+ print("output texts tokens (valid): ", pure_texts_outputs_token_length)
1057
+ print(
1058
+ "compression ratio: ",
1059
+ round(pure_texts_outputs_token_length / valid_img_tokens, 2),
1060
+ )
1061
+ print("=" * 50)
1062
+
1063
+ if "<image>" in conversation[0]["content"] and save_results:
1064
+ outputs = tokenizer.decode(
1065
+ output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1] :]
1066
+ )
1067
+ stop_str = "<|end▁of▁sentence|>"
1068
+
1069
+ print("=" * 15 + "save results:" + "=" * 15)
1070
+
1071
  # # # # conv.messages[-1][-1] = outputs
1072
  if outputs.endswith(stop_str):
1073
+ outputs = outputs[: -len(stop_str)]
1074
  outputs = outputs.strip()
1075
 
1076
  matches_ref, matches_images, mathes_other = re_match(outputs)
1077
  # print(matches_ref)
1078
  result = process_image_with_refs(image_draw, matches_ref, output_path)
1079
 
 
1080
  for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
1081
+ outputs = outputs.replace(
1082
+ a_match_image, "![](images/" + str(idx) + ".jpg)\n"
1083
+ )
 
1084
 
1085
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
1086
+ outputs = (
1087
+ outputs.replace(a_match_other, "")
1088
+ .replace("\\coloneqq", ":=")
1089
+ .replace("\\eqqcolon", "=:")
1090
+ )
1091
 
1092
  # if 'structural formula' in conversation[0]['content']:
1093
  # outputs = '<smiles>' + outputs + '</smiles>'
1094
+ with open(f"{output_path}/result.mmd", "w", encoding="utf-8") as afile:
1095
  afile.write(outputs)
1096
 
1097
+ if "line_type" in outputs:
1098
  import matplotlib.pyplot as plt
 
1099
 
1100
+ lines = eval(outputs)["Line"]["line"]
1101
+
1102
+ line_type = eval(outputs)["Line"]["line_type"]
1103
  # print(lines)
1104
 
1105
+ endpoints = eval(outputs)["Line"]["line_endpoint"]
1106
 
1107
+ fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
1108
  ax.set_xlim(-15, 15)
1109
  ax.set_ylim(-15, 15)
1110
 
1111
  for idx, line in enumerate(lines):
1112
  try:
1113
+ p0 = eval(line.split(" -- ")[0])
1114
+ p1 = eval(line.split(" -- ")[-1])
1115
 
1116
+ if line_type[idx] == "--":
1117
+ ax.plot(
1118
+ [p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
1119
+ )
1120
  else:
1121
+ ax.plot(
1122
+ [p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
1123
+ )
1124
 
1125
+ ax.scatter(p0[0], p0[1], s=5, color="k")
1126
+ ax.scatter(p1[0], p1[1], s=5, color="k")
1127
  except:
1128
  pass
1129
 
1130
  for endpoint in endpoints:
1131
+ label = endpoint.split(": ")[0]
1132
+ (x, y) = eval(endpoint.split(": ")[1])
1133
+ ax.annotate(
1134
+ label,
1135
+ (x, y),
1136
+ xytext=(1, 1),
1137
+ textcoords="offset points",
1138
+ fontsize=5,
1139
+ fontweight="light",
1140
+ )
1141
+
1142
+ plt.savefig(f"{output_path}/geo.jpg")
1143
  plt.close()
1144
 
1145
  result.save(f"{output_path}/result_with_boxes.jpg")