Spaces:
Running
on
L4
Running
on
L4
import torch.nn as nn | |
import torch | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig | |
class SketchDecoder(nn.Module): | |
""" | |
Autoregressive generative model | |
""" | |
def __init__(self, | |
**kwargs): | |
super().__init__() | |
self.vocab_size = 196042 | |
self.bos_token_id = 151643 | |
self.eos_token_id = 196041 | |
self.pad_token_id = 151643 | |
config = AutoConfig.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", | |
#n_positions=8192, | |
vocab_size=self.vocab_size, | |
bos_token_id=self.bos_token_id, | |
eos_token_id=self.eos_token_id, | |
pad_token_id=self.pad_token_id) | |
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", | |
config=config, | |
#torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", | |
#device_map ="cuda", | |
ignore_mismatched_sizes=True | |
) | |
self.transformer.resize_token_embeddings(self.vocab_size) | |
def forward(self, *args, **kwargs): | |
raise NotImplementedError("Forward pass not included in open-source version") | |