Spaces:
Running
on
Zero
Running
on
Zero
File size: 868 Bytes
56238f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from src.models.autoencoder.base import BaseAE
class LatentAE(BaseAE):
def __init__(self, precompute=False, weight_path:str=None):
super().__init__()
self.precompute = precompute
self.model = None
self.weight_path = weight_path
from diffusers.models import AutoencoderKL
setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
self.scaling_factor = self.model.config.scaling_factor
def _impl_encode(self, x):
assert self.model is not None
if self.precompute:
return x.mul_(self.scaling_factor)
encodedx = self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
return encodedx
def _impl_decode(self, x):
assert self.model is not None
return self.model.decode(x.div_(self.scaling_factor)).sample |