Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from src.models.autoencoder.base import BaseAE | |
class LatentAE(BaseAE): | |
def __init__(self, precompute=False, weight_path:str=None): | |
super().__init__() | |
self.precompute = precompute | |
self.model = None | |
self.weight_path = weight_path | |
from diffusers.models import AutoencoderKL | |
setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path)) | |
self.scaling_factor = self.model.config.scaling_factor | |
def _impl_encode(self, x): | |
assert self.model is not None | |
if self.precompute: | |
return x.mul_(self.scaling_factor) | |
encodedx = self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor) | |
return encodedx | |
def _impl_decode(self, x): | |
assert self.model is not None | |
return self.model.decode(x.div_(self.scaling_factor)).sample |