huaweilin's picture
update
14ce5a9
from einops import rearrange
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
def __init__(self, n_embed, embed_dim, l2_norm, beta, input_format='bchw'):
super().__init__()
self.n_embed = n_embed
self.embed_dim = embed_dim
self.l2_norm = l2_norm
self.beta = beta
assert input_format in ['bchw', 'blc']
self.input_format = input_format
self.embedding = nn.Embedding(n_embed, embed_dim)
self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed)
self.bits_per_index = int(np.ceil(np.log2(n_embed)))
def forward(self, z):
batch = z.shape[0]
if self.input_format == 'bchw':
z = rearrange(z, 'b c h w -> b h w c')
if self.l2_norm:
z = F.normalize(z, dim=-1)
z_flatten = z.reshape(-1, self.embed_dim)
embedding_weight = F.normalize(self.embedding.weight, dim=-1)
d = -z_flatten @ embedding_weight.t()
else:
z_flatten = z.reshape(-1, self.embed_dim)
d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
min_encoding_indices = torch.argmin(d.detach(), dim=1)
if not self.training:
used_codes = torch.unique(min_encoding_indices, return_counts=False)
else:
used_codes = None
cb_usage = F.one_hot(min_encoding_indices, self.n_embed).sum(0)
cb_entropy = self.get_entropy(cb_usage)
z_q = self.embedding(min_encoding_indices).view(z.shape)
if self.l2_norm:
z_q = F.normalize(z_q, dim=-1)
# fix the issue with loss scaling
# loss weight should not associate with the dimensionality of words
# loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
loss = self.beta * torch.mean(((z_q.detach() - z) ** 2).sum(dim=-1)) + torch.mean(((z_q - z.detach()) ** 2).sum(dim=-1))
z_q = z + (z_q - z).detach()
if self.input_format == 'bchw':
z_q = rearrange(z_q, 'b h w c -> b c h w')
return z_q, loss, {"H":cb_entropy, "used_codes": used_codes, 'indices': min_encoding_indices.view(batch, -1)}
def get_entropy(self, count, eps=1e-4):
probs = (count + eps) / (count + eps).sum()
H = -(probs * torch.log(probs)).sum()
return H
def get_codebook_entry(self, indices):
z_q = self.embedding(indices)
if self.l2_norm:
z_q = F.normalize(z_q, dim=-1)
if self.input_format == 'bchw':
h = w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], 'Invalid sequence length'
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
return z_q
class EMAVectorQuantizer(nn.Module):
def __init__(self, n_embed, embed_dim, l2_norm, beta, decay=0.99, eps=1e-5, random_restart=True, restart_threshold=1.0, input_format='bchw'):
super().__init__()
self.n_embed = n_embed
self.embed_dim = embed_dim
self.l2_norm = l2_norm
self.beta = beta
self.decay = decay
self.eps = eps
self.random_restart = random_restart
self.restart_threshold = restart_threshold
self.input_format = input_format
self.embedding = nn.Embedding(n_embed, embed_dim)
self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed) # TODO (yzhao): test other initialization methods
self.register_buffer("ema_cluster_size", torch.zeros(self.n_embed))
self.embedding_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embed_dim))
self.embedding_avg.data.copy_(self.embedding.weight.data)
def _tile(self, z):
n_z, embedding_dim = z.shape
if n_z < self.n_embed:
n_repeats = (self.n_embed + n_z - 1) // n_z
std = 0.01 / np.sqrt(embedding_dim)
z = z.repeat(n_repeats, 1)
z = z + torch.randn_like(z) * std
return z
def forward(self, z):
if self.input_format == 'bchw':
z = rearrange(z, 'b c h w -> b h w c')
z_flatten = z.reshape(-1, self.embed_dim)
d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.size(0), self.n_embed, device=z.device)
encodings.scatter_(1, encoding_indices, 1)
z_q = self.embedding(encoding_indices).view(z.shape)
if self.l2_norm:
z = F.normalize(z, dim=-1)
z_q = F.normalize(z_q, dim=-1)
if self.training:
# EMA update cluster size
encodings_sum = encodings.sum(0)
if dist.is_initialized(): dist.all_reduce(encodings_sum)
self.ema_cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1-self.decay)
# EMA update of the embedding vectors
dw = encodings.t() @ z_flatten
if dist.is_initialized(): dist.all_reduce(dw)
self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1-self.decay)
# Laplace smoothing of the cluster size
n = torch.sum(self.ema_cluster_size)
weights = (self.ema_cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
self.embedding.weight.data = self.embedding_avg.data / weights.unsqueeze(1)
if self.random_restart:
zz = self._tile(z_flatten)
_k_rand = zz[torch.randperm(zz.size(0))][:self.n_embed]
if dist.is_initialized(): dist.broadcast(_k_rand, 0)
usage = (self.ema_cluster_size.view(-1, 1) > self.restart_threshold).float()
self.embedding.weight.data.mul_(usage).add_(_k_rand * (1 - usage))
loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
z_q = z + (z_q - z).detach()
if self.input_format == 'bchw':
z_q = rearrange(z_q, 'b h w c -> b c h w')
# TODO (yzhao): monitor utility of the dictionary
return z_q, loss, {}