Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import json | |
import os | |
from .api import new_module | |
class HVQVAE(nn.Module): | |
def __init__( | |
self, | |
levels, | |
embedding_dim, | |
enc_config, | |
quantize_config, | |
down_sampler_configs, | |
dec_configs, | |
codebook_scale=1. | |
): | |
super().__init__() | |
self.levels = levels | |
self.enc = new_module(enc_config) | |
self.decs = nn.ModuleList() | |
for i in range(levels): | |
self.decs.append(new_module(dec_configs[i])) | |
self.quantize = new_module(quantize_config) | |
self.down_samplers = nn.ModuleList() | |
for i in range(levels-1): | |
self.down_samplers.append(new_module(down_sampler_configs[i])) | |
self.codebook_scale = codebook_scale | |
def forward(self, input): | |
quants, diffs, ids = self.encode(input) | |
dec_outputs = self.decode(quants[::-1]) | |
total_diff = diffs[0] | |
scale = 1. | |
for diff in diffs[1:]: | |
scale *= self.codebook_scale | |
total_diff = total_diff + diff * scale | |
return dec_outputs, total_diff | |
def encode(self, input): | |
enc_output = self.enc(input) | |
enc_outputs = [enc_output] | |
for l in range(self.levels-1): | |
enc_outputs.append(self.down_samplers[l](enc_outputs[-1])) | |
quants, diffs, ids = [], [], [] | |
for enc_output in enc_outputs: | |
quant, diff, id = self.quantize(enc_output) | |
quants.append(quant.permute(0, 3, 1, 2)) | |
diffs.append(diff) | |
ids.append(id) | |
return quants, diffs, ids | |
def decode(self, quants): | |
dec_outputs = [] | |
for l in range(self.levels-1, -1, -1): | |
dec_outputs.append(self.decs[l](quants[l])) | |
return dec_outputs | |
def decode_code(self, codes): | |
quants = [] | |
for l in range(self.levels): | |
quants.append(self.quantize.embed_code(codes[l]).permute(0, 3, 1, 2)) | |
dec_outputs = self.decode(quants) | |
return dec_outputs | |
def single_encode(self, input, l): | |
assert l >= 0 and l <= 2 | |
enc_output = self.enc(input) | |
for i in range(l): | |
enc_output = self.down_samplers[i](enc_output) | |
quant, diff, id = self.quantize(enc_output) | |
return quant, diff, id | |
def single_decode(self, quant, l): | |
assert l >= 0 and l <= 2 | |
return self.decs[l](quant) | |
def single_decode_code(self, code, l): | |
assert l >= 0 and l <= 2 | |
quant = self.quantize.embed_code(code).permute(0, 3, 1, 2) | |
return self.decs[2-l](quant) | |
def get_last_layer(self): | |
return self.decs[-1].get_last_layer() | |