File size: 6,320 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from einops import rearrange
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
def __init__(self, n_embed, embed_dim, l2_norm, beta, input_format='bchw'):
super().__init__()
self.n_embed = n_embed
self.embed_dim = embed_dim
self.l2_norm = l2_norm
self.beta = beta
assert input_format in ['bchw', 'blc']
self.input_format = input_format
self.embedding = nn.Embedding(n_embed, embed_dim)
self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed)
self.bits_per_index = int(np.ceil(np.log2(n_embed)))
def forward(self, z):
batch = z.shape[0]
if self.input_format == 'bchw':
z = rearrange(z, 'b c h w -> b h w c')
if self.l2_norm:
z = F.normalize(z, dim=-1)
z_flatten = z.reshape(-1, self.embed_dim)
embedding_weight = F.normalize(self.embedding.weight, dim=-1)
d = -z_flatten @ embedding_weight.t()
else:
z_flatten = z.reshape(-1, self.embed_dim)
d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
min_encoding_indices = torch.argmin(d.detach(), dim=1)
if not self.training:
used_codes = torch.unique(min_encoding_indices, return_counts=False)
else:
used_codes = None
cb_usage = F.one_hot(min_encoding_indices, self.n_embed).sum(0)
cb_entropy = self.get_entropy(cb_usage)
z_q = self.embedding(min_encoding_indices).view(z.shape)
if self.l2_norm:
z_q = F.normalize(z_q, dim=-1)
# fix the issue with loss scaling
# loss weight should not associate with the dimensionality of words
# loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
loss = self.beta * torch.mean(((z_q.detach() - z) ** 2).sum(dim=-1)) + torch.mean(((z_q - z.detach()) ** 2).sum(dim=-1))
z_q = z + (z_q - z).detach()
if self.input_format == 'bchw':
z_q = rearrange(z_q, 'b h w c -> b c h w')
return z_q, loss, {"H":cb_entropy, "used_codes": used_codes, 'indices': min_encoding_indices.view(batch, -1)}
def get_entropy(self, count, eps=1e-4):
probs = (count + eps) / (count + eps).sum()
H = -(probs * torch.log(probs)).sum()
return H
def get_codebook_entry(self, indices):
z_q = self.embedding(indices)
if self.l2_norm:
z_q = F.normalize(z_q, dim=-1)
if self.input_format == 'bchw':
h = w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], 'Invalid sequence length'
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
return z_q
class EMAVectorQuantizer(nn.Module):
def __init__(self, n_embed, embed_dim, l2_norm, beta, decay=0.99, eps=1e-5, random_restart=True, restart_threshold=1.0, input_format='bchw'):
super().__init__()
self.n_embed = n_embed
self.embed_dim = embed_dim
self.l2_norm = l2_norm
self.beta = beta
self.decay = decay
self.eps = eps
self.random_restart = random_restart
self.restart_threshold = restart_threshold
self.input_format = input_format
self.embedding = nn.Embedding(n_embed, embed_dim)
self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed) # TODO (yzhao): test other initialization methods
self.register_buffer("ema_cluster_size", torch.zeros(self.n_embed))
self.embedding_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embed_dim))
self.embedding_avg.data.copy_(self.embedding.weight.data)
def _tile(self, z):
n_z, embedding_dim = z.shape
if n_z < self.n_embed:
n_repeats = (self.n_embed + n_z - 1) // n_z
std = 0.01 / np.sqrt(embedding_dim)
z = z.repeat(n_repeats, 1)
z = z + torch.randn_like(z) * std
return z
def forward(self, z):
if self.input_format == 'bchw':
z = rearrange(z, 'b c h w -> b h w c')
z_flatten = z.reshape(-1, self.embed_dim)
d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.size(0), self.n_embed, device=z.device)
encodings.scatter_(1, encoding_indices, 1)
z_q = self.embedding(encoding_indices).view(z.shape)
if self.l2_norm:
z = F.normalize(z, dim=-1)
z_q = F.normalize(z_q, dim=-1)
if self.training:
# EMA update cluster size
encodings_sum = encodings.sum(0)
if dist.is_initialized(): dist.all_reduce(encodings_sum)
self.ema_cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1-self.decay)
# EMA update of the embedding vectors
dw = encodings.t() @ z_flatten
if dist.is_initialized(): dist.all_reduce(dw)
self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1-self.decay)
# Laplace smoothing of the cluster size
n = torch.sum(self.ema_cluster_size)
weights = (self.ema_cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
self.embedding.weight.data = self.embedding_avg.data / weights.unsqueeze(1)
if self.random_restart:
zz = self._tile(z_flatten)
_k_rand = zz[torch.randperm(zz.size(0))][:self.n_embed]
if dist.is_initialized(): dist.broadcast(_k_rand, 0)
usage = (self.ema_cluster_size.view(-1, 1) > self.restart_threshold).float()
self.embedding.weight.data.mul_(usage).add_(_k_rand * (1 - usage))
loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
z_q = z + (z_q - z).detach()
if self.input_format == 'bchw':
z_q = rearrange(z_q, 'b h w c -> b c h w')
# TODO (yzhao): monitor utility of the dictionary
return z_q, loss, {}
|