import torch.nn as nn import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig class SketchDecoder(nn.Module): """ Autoregressive generative model """ def __init__(self, **kwargs): super().__init__() self.vocab_size = 196042 self.bos_token_id = 151643 self.eos_token_id = 196041 self.pad_token_id = 151643 config = AutoConfig.from_pretrained( "Qwen/Qwen2.5-VL-3B-Instruct", #n_positions=8192, vocab_size=self.vocab_size, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id) self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2.5-VL-3B-Instruct", config=config, #torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", #device_map ="cuda", ignore_mismatched_sizes=True ) self.transformer.resize_token_embeddings(self.vocab_size) def forward(self, *args, **kwargs): raise NotImplementedError("Forward pass not included in open-source version")