OmniSVG-3B / decoder.py
OmniSVG's picture
Upload 3 files
0a475cf verified
import torch.nn as nn
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
class SketchDecoder(nn.Module):
"""
Autoregressive generative model
"""
def __init__(self,
**kwargs):
super().__init__()
self.vocab_size = 196042
self.bos_token_id = 151643
self.eos_token_id = 196041
self.pad_token_id = 151643
config = AutoConfig.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
#n_positions=8192,
vocab_size=self.vocab_size,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id)
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
config=config,
#torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
#device_map ="cuda",
ignore_mismatched_sizes=True
)
self.transformer.resize_token_embeddings(self.vocab_size)
def forward(self, *args, **kwargs):
raise NotImplementedError("Forward pass not included in open-source version")