File size: 467 Bytes
6858cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from diffusers import AutoencoderDC, SanaTransformer2DModel
import torch


def build_sana(vision_tower_cfg, **kwargs):
    sana = SanaTransformer2DModel.from_pretrained(vision_tower_cfg.diffusion_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16)
    return sana


def build_vae(vision_tower_cfg, **kwargs):
    vae = AutoencoderDC.from_pretrained(vision_tower_cfg.diffusion_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16)
    return vae