Baraaqasem's picture
Upload 49 files
413d4d0 verified
import torch
from torch import nn
from torch import einsum
from torch.nn import functional as F
class VectorQuantize(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
def forward(self, z):
B, C, H, W = z.shape
z_e = self.proj(z)
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
flatten = z_e.reshape(-1, self.embedding_dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed.weight.t()
+ self.embed.weight.pow(2).sum(1, keepdim=True).t()
)
_, embed_ind = (-dist).max(1)
embed_ind = embed_ind.view(B, H, W)
z_q = self.embed_code(embed_ind)
diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \
+ (z_q - z_e.detach()).pow(2).mean()
z_q = z_e + (z_q - z_e).detach()
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
class VectorQuantizeEMA(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1,
decay=0.99,
eps=1e-5,
pre_proj=True,
training_loc=True):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.training_loc = training_loc
self.pre_proj = pre_proj
if self.pre_proj:
self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", self.embed.weight.data.clone())
self.decay = decay
self.eps = eps
def forward(self, z):
B, C, H, W = z.shape
if self.pre_proj:
z_e = self.proj(z)
else:
z_e = z
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
flatten = z_e.reshape(-1, self.embedding_dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed.weight.t()
+ self.embed.weight.pow(2).sum(1, keepdim=True).t()
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(B, H, W)
z_q = self.embed_code(embed_ind)
diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean()
z_q = z_e + (z_q - z_e).detach()
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
class GumbelQuantize(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1,
straight_through=True,
kl_weight=5e-4,
temp_init=1.,
eps=1e-5):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.kl_weight = kl_weight
self.temperature = temp_init
self.eps = eps
self.proj = nn.Conv2d(hidden_dim, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
self.straight_through = straight_through
def forward(self, z, temp=None):
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
B, C, H, W = z.shape
z_e = self.proj(z)
soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard)
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
qy = F.softmax(z_e, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean()
embed_ind = soft_one_hot.argmax(dim=1)
z_q = z_q.permute(0, 2, 3, 1)
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)