File size: 885 Bytes
45e7a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from pathlib import Path
from hyimage.common.constants import PRECISION_TO_TYPE
from .hunyuanimage_vae import HunyuanVAE2D

def load_vae(device, vae_path: str = None, vae_precision: str = None):
    config = HunyuanVAE2D.load_config(vae_path)
    vae = HunyuanVAE2D.from_config(config)

    if Path(vae_path).exists():
        ckpt = torch.load(Path(vae_path) / "pytorch_model.ckpt", map_location='cpu')
        if "state_dict" in ckpt:
            ckpt = ckpt["state_dict"]
        vae_ckpt = {}
        for k, v in ckpt.items():
            if k.startswith("vae."):
                vae_ckpt[k.replace("vae.", "")] = v
        vae.load_state_dict(vae_ckpt)

    if vae_precision is not None:
        vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])

    vae.requires_grad_(False)

    if device is not None:
        vae = vae.to(device)

    vae.eval()
    return vae