File size: 1,137 Bytes
0a475cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch.nn as nn
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
 
class SketchDecoder(nn.Module):
  """
  Autoregressive generative model 
  """
  
  def __init__(self,
              **kwargs):
    super().__init__()
    self.vocab_size = 196042
    self.bos_token_id = 151643
    self.eos_token_id = 196041
    self.pad_token_id = 151643

    config = AutoConfig.from_pretrained(
          "Qwen/Qwen2.5-VL-3B-Instruct",
          #n_positions=8192,
          vocab_size=self.vocab_size,
          bos_token_id=self.bos_token_id,
          eos_token_id=self.eos_token_id,
          pad_token_id=self.pad_token_id)

    self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-3B-Instruct", 
        config=config,
        #torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
        #device_map ="cuda",
        ignore_mismatched_sizes=True
    )

    self.transformer.resize_token_embeddings(self.vocab_size)

  def forward(self, *args, **kwargs):
      raise NotImplementedError("Forward pass not included in open-source version")