File size: 868 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from src.models.autoencoder.base import BaseAE

class LatentAE(BaseAE):
    def __init__(self, precompute=False, weight_path:str=None):
        super().__init__()
        self.precompute = precompute
        self.model = None
        self.weight_path = weight_path

        from diffusers.models import AutoencoderKL
        setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
        self.scaling_factor = self.model.config.scaling_factor

    def _impl_encode(self, x):
        assert self.model is not None
        if self.precompute:
            return x.mul_(self.scaling_factor)
        encodedx =  self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
        return encodedx

    def _impl_decode(self, x):
        assert self.model is not None
        return self.model.decode(x.div_(self.scaling_factor)).sample