Spaces:
Running
on
L4
Running
on
L4
File size: 1,137 Bytes
0a475cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch.nn as nn
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
class SketchDecoder(nn.Module):
"""
Autoregressive generative model
"""
def __init__(self,
**kwargs):
super().__init__()
self.vocab_size = 196042
self.bos_token_id = 151643
self.eos_token_id = 196041
self.pad_token_id = 151643
config = AutoConfig.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
#n_positions=8192,
vocab_size=self.vocab_size,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id)
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
config=config,
#torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
#device_map ="cuda",
ignore_mismatched_sizes=True
)
self.transformer.resize_token_embeddings(self.vocab_size)
def forward(self, *args, **kwargs):
raise NotImplementedError("Forward pass not included in open-source version")
|