| | import math |
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from torch.distributions import Independent, Normal, MultivariateNormal |
| | import torch.nn.functional as F |
| |
|
| | from transformers import AutoModel, AutoModelForCausalLM |
| | from tqdm import tqdm |
| | from tqdm.notebook import tqdm as tqdm_notebook |
| |
|
| |
|
| | class Res(nn.Module): |
| | def __init__(self, H): |
| | super().__init__() |
| | self.u1 = nn.Linear(H, H) |
| | self.u2 = nn.Linear(H, H) |
| |
|
| | self.v1 = nn.Linear(H, H) |
| | self.v2 = nn.Linear(H, H) |
| | self.w = nn.Linear(H, H) |
| |
|
| | def forward(self, x): |
| | x = self.w(x) |
| | x = x + torch.relu(self.v1(torch.relu(self.u1(x)))) |
| | return x + torch.relu(self.v2(torch.relu(self.u2(x)))) |
| |
|
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, H, out=None): |
| | super().__init__() |
| | out = out or H |
| | self.mlp = nn.Sequential( |
| | nn.Linear(H, H), |
| | nn.ReLU(), |
| | nn.Linear(H, H), |
| | nn.ReLU(), |
| | nn.Linear(H, out), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.mlp(x) |
| |
|
| |
|
| | class Encoder(nn.Module): |
| | def __init__(self, tokenizer, model_name_or_path="roberta-base", **kwargs): |
| | super().__init__() |
| | self.encoder = AutoModel.from_pretrained(model_name_or_path) |
| | self.encoder.resize_token_embeddings(len(tokenizer)) |
| | self.dim = self.encoder.config.hidden_size |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def forward(self, **inputs): |
| | model_inputs = { |
| | k: inputs[k].to(self.device) |
| | for k in ("input_ids", "attention_mask") |
| | } |
| | if inputs.get("token_type_ids", None) is not None: |
| | model_inputs["token_type_ids"] = inputs["token_type_ids"].to( |
| | self.device |
| | ) |
| | out = self.encoder(**model_inputs) |
| | emb = out.last_hidden_state[:, 0] |
| | return emb |
| |
|
| |
|
| | class PrefixDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | tokenizer, |
| | model_name_or_path="gpt2", |
| | prefix_length=1, |
| | ffn="res", |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.decoder = AutoModelForCausalLM.from_pretrained(model_name_or_path) |
| | self.hidden_dim = D = self.decoder.config.n_embd |
| | self.num_layers = L = self.decoder.config.n_layer |
| | self.num_heads = H = self.decoder.config.n_head |
| | self.prefix_length = K = prefix_length |
| | self.lin1 = nn.Linear(D, D * 2) |
| | self.z_size = D * L * K * 2 |
| | if ffn == "res": |
| | self.mlp = nn.Sequential(Res(D), nn.Linear(D, self.z_size)) |
| | else: |
| | self.mlp = MLP(D, self.z_size) |
| |
|
| | def get_prefix(self, z): |
| | B = z.shape[0] |
| | D, L, H, K = ( |
| | self.hidden_dim, |
| | self.num_layers, |
| | self.num_heads, |
| | self.prefix_length, |
| | ) |
| | z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) |
| | keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
| | layers = tuple( |
| | [ |
| | (k.squeeze(-1), v.squeeze(-1)) |
| | for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
| | ] |
| | ) |
| | return layers |
| |
|
| | def forward(self, z, **inputs): |
| | B = z.shape[0] |
| | D, L, H, K = ( |
| | self.hidden_dim, |
| | self.num_layers, |
| | self.num_heads, |
| | self.prefix_length, |
| | ) |
| | z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) |
| | keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
| | layers = tuple( |
| | [ |
| | (k.squeeze(-1), v.squeeze(-1)) |
| | for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
| | ] |
| | ) |
| | input_ids = inputs["input_ids"].to(z.device) |
| | attention_mask = inputs["attention_mask"].to(z.device) |
| | attention_mask = torch.cat( |
| | [torch.ones(B, K, dtype=bool, device=z.device), attention_mask], |
| | 1, |
| | ) |
| | out = self.decoder( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=layers, |
| | ) |
| | return out |
| |
|
| |
|
| | def get_inputs( |
| | inputs, prefix, keys=["input_ids", "attention_mask", "token_type_ids"] |
| | ): |
| | return {k: inputs.get(f"{prefix}{k}", None) for k in keys} |
| |
|
| |
|
| | class VAE(nn.Module): |
| | def __init__(self, encoder, decoder, beta=1.0, do_sample=True, **kwargs): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.beta = beta |
| | D = decoder.hidden_dim |
| | self.lin = nn.Linear(D, D * 2) |
| | self.do_sample = do_sample |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, sample=True, **inputs): |
| | enc = self.encoder(**get_inputs(inputs, "enc_")) |
| | B, D = enc.shape |
| | mu, logvar = ( |
| | t.squeeze(-1) for t in self.lin(enc).view(B, D, 2).chunk(2, -1) |
| | ) |
| | qz = Normal(mu, logvar.exp()) |
| | pz = Normal(torch.zeros_like(mu[0]), torch.ones_like(mu[0])) |
| | kl = torch.distributions.kl_divergence(qz, pz).sum(-1) |
| | if sample: |
| | z = qz.rsample() |
| | else: |
| | z = mu |
| | return z, kl |
| |
|
| | def forward(self, **inputs): |
| | z, kl = self.get_z(sample=self.do_sample, **inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | out["kl"] = kl |
| | return out |
| |
|
| |
|
| | class AAE(nn.Module): |
| | def __init__(self, encoder, decoder, _lambda=1.0, word_drop=None, **kwargs): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self._lambda = _lambda |
| | dim = decoder.hidden_dim |
| | self.D = nn.Sequential( |
| | nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
| | ) |
| | self.word_drop = word_drop |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | if self.word_drop is not None: |
| | m = inputs["enc_attention_mask"] |
| | b = torch.rand_like(m.float()) > self.word_drop |
| | inputs["enc_attention_mask"] = m & b |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def loss_adv(self, z): |
| | |
| | zn = torch.randn_like(z) |
| | zeros = torch.zeros(len(z), 1, device=z.device) |
| | ones = torch.ones(len(z), 1, device=z.device) |
| | loss_d = F.binary_cross_entropy( |
| | self.D(z.detach()), zeros, reduction="none" |
| | ) + F.binary_cross_entropy(self.D(zn), ones, reduction="none") |
| | adv = F.binary_cross_entropy(self.D(z), ones, reduction="none") |
| | return loss_d, adv |
| |
|
| | def forward(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["l_rec"] = -log_probs.sum(-1) |
| | out["loss_d"], out["adv"] = self.loss_adv(z) |
| | return out |
| |
|
| |
|
| | class AE(nn.Module): |
| | def __init__(self, encoder, decoder, **kwargs): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | dim = decoder.hidden_dim |
| | self.D = nn.Sequential( |
| | nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
| | ) |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def step(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | return z, out |
| |
|
| | def forward(self, **inputs): |
| | z, out = self.step(**inputs) |
| | out["loss_c"] = torch.zeros_like(out["loss_r"]) |
| | return out |
| |
|
| |
|
| | class CDAE(nn.Module): |
| | def __init__( |
| | self, encoder, decoder, _lambda=1.0, word_drop=None, tau=1.0, **kwargs |
| | ): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self._lambda = _lambda |
| | dim = decoder.hidden_dim |
| | self.D = nn.Sequential( |
| | nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
| | ) |
| | self.word_drop = word_drop |
| | self.tau = tau |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def do_mask(self, **inputs): |
| | m = inputs["enc_attention_mask"] |
| | b = torch.rand_like(m.float()) > self.word_drop |
| | inputs["enc_attention_mask"] = m & b |
| |
|
| | B, N = inputs["dec_attention_mask"].shape |
| | _, M = m.shape |
| | m2 = inputs["dec_attention_mask"] |
| | if N <= M: |
| | b2 = b[:, :N] |
| | else: |
| | b_ = torch.rand((B, N - M), device=b.device) > self.word_drop |
| | b2 = torch.cat([b, b_], -1) |
| | inputs["dec_attention_mask"] = m2 & b2 |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def step(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | return z, out |
| |
|
| | def loss_c(self, z, z2): |
| | scores = -(torch.cdist(z, z2) ** 2) |
| | log_probs = (scores / self.tau).log_softmax(-1) |
| | loss = -torch.diagonal(log_probs) |
| | return loss |
| |
|
| | def forward(self, **inputs): |
| | z, out = self.step(**inputs) |
| | self.do_mask(**inputs) |
| | z_, out_ = self.step(**inputs) |
| | out["loss_r"] = out["loss_r"] + out_["loss_r"] |
| | out["loss_c"] = self.loss_c(z, z_) |
| | return out |
| |
|
| |
|
| | def run_aae_epoch( |
| | model, |
| | batches, |
| | opt, |
| | optD, |
| | num_samples=1, |
| | lambda_adv=1.0, |
| | desc="", |
| | notebook=True, |
| | ): |
| | losses = {k: [] for k in ("l_rec", "adv", "loss_d")} |
| | t = ( |
| | tqdm_notebook(batches, desc=desc) |
| | if notebook |
| | else tqdm(batches, desc=desc) |
| | ) |
| | for batch in t: |
| | model_inputs = { |
| | k: v.to(model.device) |
| | for k, v in batch.items() |
| | if type(v) == torch.Tensor |
| | } |
| | out = model(**model_inputs) |
| | loss = (out["l_rec"] + lambda_adv * out["adv"]).sum() |
| | opt.zero_grad() |
| | loss.backward() |
| | opt.step() |
| |
|
| | loss_d = out["loss_d"].sum() |
| | optD.zero_grad() |
| | loss_d.backward() |
| | optD.step() |
| |
|
| | d = {} |
| | for k in ("l_rec", "adv", "loss_d"): |
| | d[k] = out[k].mean().item() |
| | losses[k].append(out[k].detach().cpu().numpy()) |
| | t.set_postfix(d) |
| | return {k: np.concatenate(v, 0) for k, v in losses.items()} |
| |
|
| |
|
| | class GAE(nn.Module): |
| | def __init__(self, encoder, decoder, tau=0.05, **kwargs): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.tau = tau |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def loss_c(self, z, z2): |
| | scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T |
| | log_probs = (scores / self.tau).log_softmax(-1) |
| | loss = -torch.diagonal(log_probs) |
| | return loss |
| |
|
| | def forward(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | out["loss_c"] = self.loss_c(z) |
| | return out |
| |
|
| |
|
| | class CAE(nn.Module): |
| | def __init__(self, encoder, decoder, tau=0.05, **kwargs): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.tau = tau |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def loss_c(self, z, z2): |
| | scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T |
| | log_probs = (scores / self.tau).log_softmax(-1) |
| | loss = -torch.diagonal(log_probs) |
| | return loss |
| |
|
| | def forward(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | with torch.no_grad(): |
| | z2, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | out["loss_c"] = self.loss_c(z, z2) |
| | return out |
| |
|
| |
|
| | def run_cae_epoch( |
| | model, |
| | batches, |
| | opt, |
| | num_samples=1, |
| | lambda_c=1.0, |
| | desc="", |
| | notebook=True, |
| | ): |
| | losses = {k: [] for k in ("loss_r", "loss_c")} |
| | t = ( |
| | tqdm_notebook(batches, desc=desc) |
| | if notebook |
| | else tqdm(batches, desc=desc) |
| | ) |
| | model.train() |
| | for batch in t: |
| | model_inputs = { |
| | k: v.to(model.device) |
| | for k, v in batch.items() |
| | if type(v) == torch.Tensor |
| | } |
| | out = model(**model_inputs) |
| | loss = (out["loss_r"] + lambda_c * out["loss_c"]).sum() |
| | opt.zero_grad() |
| | loss.backward() |
| | opt.step() |
| | d = {} |
| | for k in ("loss_r", "loss_c"): |
| | d[k] = out[k].mean().item() |
| | losses[k].append(out[k].detach().cpu().numpy()) |
| | t.set_postfix(d) |
| | return {k: np.concatenate(v, 0) for k, v in losses.items()} |
| |
|
| |
|
| | def batch_kl(l1, s1, l2=None, s2=None): |
| | |
| | return |
| |
|
| |
|
| | class SubpopCondAE(nn.Module): |
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | num_labels, |
| | sublabels=4, |
| | tau=0.05, |
| | disc_loss=True, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.dim = dim = decoder.hidden_dim |
| | self.locs = nn.Parameter(torch.randn(num_labels * sublabels, dim)) |
| | self.log_scales = nn.Parameter(torch.zeros(num_labels * sublabels, dim)) |
| | self.num_labels = num_labels |
| | self.sublabels = sublabels |
| | self.L = num_labels * sublabels |
| | self.tau = tau |
| | self.disc_loss = disc_loss |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def loss_c(self, z, **inputs): |
| | scores = [] |
| | for i in range(self.L): |
| | dist = Independent( |
| | Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 |
| | ) |
| | scores.append(dist.log_prob(z)) |
| | B = z.shape[0] |
| | sub_log_probs = torch.stack(scores, -1) |
| | if self.disc_loss: |
| | sub_log_probs = sub_log_probs.log_softmax(-1) |
| | log_probs = sub_log_probs.view( |
| | B, self.num_labels, self.num_sublabels |
| | ).logsumexp(-1) |
| | loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
| | acc = log_probs.argmax(-1) == inputs["label"] |
| | return { |
| | "loss_c": loss, |
| | "log_probs": log_probs, |
| | "sub_log_probs": sub_log_probs, |
| | "acc": acc.float(), |
| | } |
| |
|
| | def get_kl(self): |
| | p = MultivariateNormal( |
| | torch.zeros(self.dim, device=self.device), |
| | torch.eye(self.dim, device=self.device), |
| | ) |
| | kl = 0 |
| | for i in range(self.L): |
| | q = MultivariateNormal( |
| | self.locs[i], torch.diag(self.log_scales[i].exp()) |
| | ) |
| | kl += torch.distributions.kl_divergence(q, p) |
| | return kl |
| |
|
| | def forward(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | out_c = self.loss_c(z, **inputs) |
| | for k, v in out_c.items(): |
| | out[k] = v |
| | out["kl"] = self.get_kl().unsqueeze(0) |
| | return out |
| |
|
| |
|
| | def gaussian_prob_product(m1, s1, m2, s2, rho=1.0): |
| | |
| | s1_inv = 1 / s1 |
| | s2_inv = 1 / s2 |
| | s_hat = 1 / (s1 + s2) |
| | m_hat = s1_inv * s1 + s2_inv * s2 |
| | dim = m1.shape[-1] |
| | return ( |
| | ((2 * math.pi) ** ((1 - 2 * rho) * dim / 2)) |
| | * (rho ** (-dim / 2)) |
| | * torch.sqrt(s_hat.prod(-1)) |
| | * ((s1.prod(-1) * s2.prod(-1)) ** (-rho / 2)) |
| | * torch.exp( |
| | -(1 / rho) |
| | * ( |
| | m1 @ (s1_inv * m1).T |
| | + m2 @ (s2_inv * m2).T |
| | - m_hat @ (s_hat * m_hat).T |
| | ) |
| | ) |
| | ) |
| |
|
| |
|
| | class CondAE(nn.Module): |
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | num_labels, |
| | logdet=False, |
| | l2_reg=False, |
| | disc_loss=True, |
| | tau=0.05, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.dim = dim = decoder.hidden_dim |
| | self.locs = nn.Parameter(torch.randn(num_labels, dim)) |
| | self.log_scales = nn.Parameter(torch.zeros(num_labels, dim)) |
| | self.num_labels = num_labels |
| | self.tau = tau |
| | self.logdet = logdet |
| | self.l2_reg = l2_reg |
| | self.disc_loss = disc_loss |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def loss_c(self, z, **inputs): |
| | scores = [] |
| | for i in range(self.num_labels): |
| | dist = Independent( |
| | Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 |
| | ) |
| | scores.append(dist.log_prob(z)) |
| | log_probs = torch.stack(scores, -1) |
| | if self.disc_loss: |
| | log_probs = log_probs.log_softmax(-1) |
| | loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
| | acc = log_probs.argmax(-1) == inputs["label"] |
| | return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} |
| |
|
| | def get_kl(self): |
| | p = MultivariateNormal( |
| | torch.zeros(self.dim, device=self.device), |
| | torch.eye(self.dim, device=self.device), |
| | ) |
| | kl = 0 |
| | for i in range(self.num_labels): |
| | q = MultivariateNormal( |
| | self.locs[i], torch.diag(self.log_scales[i].exp()) |
| | ) |
| | kl += torch.distributions.kl_divergence(q, p) |
| | if self.logdet: |
| | K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) |
| | kl += torch.logdet(K) |
| | elif self.l2_reg: |
| | K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) |
| | kl += torch.log( |
| | torch.linalg.norm(K / K.shape[0], dim=(-2, -1)) ** 2 |
| | ).sum() |
| | return kl |
| |
|
| | def forward(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | out_c = self.loss_c(z, **inputs) |
| | for k, v in out_c.items(): |
| | out[k] = v |
| | out["kl"] = self.get_kl().unsqueeze(0) |
| | return out |
| |
|
| |
|
| | class BasicCondAE(nn.Module): |
| | def __init__(self, encoder, decoder, num_labels, tau=0.05, **kwargs): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.dim = dim = decoder.hidden_dim |
| | self.linear = nn.Linear(dim, num_labels) |
| | self.num_labels = num_labels |
| | self.tau = tau |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.device |
| |
|
| | def get_z(self, **inputs): |
| | return self.encoder(**get_inputs(inputs, "enc_")), None |
| |
|
| | def loss_c(self, z, **inputs): |
| | log_probs = self.linear(z).log_softmax(-1) |
| | loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
| | acc = log_probs.argmax(-1) == inputs["label"] |
| | return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} |
| |
|
| | def forward(self, **inputs): |
| | z, _ = self.get_z(**inputs) |
| | out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| | b, n, _ = out["logits"].shape |
| | log_probs = out["logits"].log_softmax(-1) |
| | log_probs = torch.gather( |
| | log_probs[:, :-1], |
| | -1, |
| | inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| | ).squeeze(-1) |
| | log_probs = log_probs.masked_fill( |
| | ~inputs["dec_attention_mask"][:, 1:], 0 |
| | ) |
| | out["loss_r"] = -log_probs.sum(-1) |
| | out_c = self.loss_c(z, **inputs) |
| | for k, v in out_c.items(): |
| | out[k] = v |
| | out["kl"] = torch.zeros_like(out["loss_r"]) |
| | return out |
| |
|
| |
|
| | def run_cond_ae_epoch( |
| | model, |
| | batches, |
| | opt, |
| | num_samples=1, |
| | lambda_c=1.0, |
| | lambda_r=1.0, |
| | beta=1.0, |
| | desc="", |
| | notebook=True, |
| | ): |
| | losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} |
| | t = ( |
| | tqdm_notebook(batches, desc=desc) |
| | if notebook |
| | else tqdm(batches, desc=desc) |
| | ) |
| | model.train() |
| | for batch in t: |
| | model_inputs = { |
| | k: v.to(model.device) |
| | for k, v in batch.items() |
| | if type(v) == torch.Tensor |
| | } |
| | out = model(**model_inputs) |
| | loss = ( |
| | lambda_r * out["loss_r"] + lambda_c * out["loss_c"] |
| | ).sum() + beta * out["kl"].sum() |
| | opt.zero_grad() |
| | loss.backward() |
| | opt.step() |
| | d = {} |
| | for k in ("loss_r", "loss_c", "kl", "acc"): |
| | d[k] = out[k].mean().item() |
| | losses[k].append(out[k].detach().cpu().numpy()) |
| | t.set_postfix(d) |
| | return {k: np.concatenate(v, 0) for k, v in losses.items()} |
| |
|
| |
|
| | def run_cond_ae_eval( |
| | model, |
| | batches, |
| | lambda_c=1.0, |
| | beta=1.0, |
| | desc="", |
| | notebook=True, |
| | ): |
| | losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} |
| | t = ( |
| | tqdm_notebook(batches, desc=desc) |
| | if notebook |
| | else tqdm(batches, desc=desc) |
| | ) |
| | model.eval() |
| | for batch in t: |
| | model_inputs = { |
| | k: v.to(model.device) |
| | for k, v in batch.items() |
| | if type(v) == torch.Tensor |
| | } |
| | with torch.no_grad(): |
| | out = model(**model_inputs) |
| | loss = ( |
| | out["loss_r"] + lambda_c * out["loss_c"] |
| | ).sum() + beta * out["kl"].sum() |
| | d = {} |
| | for k in ("loss_r", "loss_c", "kl", "acc"): |
| | d[k] = out[k].mean().item() |
| | losses[k].append(out[k].detach().cpu().numpy()) |
| | t.set_postfix(d) |
| | return {k: np.concatenate(v, 0) for k, v in losses.items()} |
| |
|
| |
|
| | def generate( |
| | model, |
| | tokenizer, |
| | batch=None, |
| | z=None, |
| | do_sample=False, |
| | max_length=128, |
| | **kwargs, |
| | ): |
| | if z is None: |
| | with torch.no_grad(): |
| | z, _ = model.get_z(sample=False, **batch) |
| | B, D = z.shape |
| | else: |
| | z = torch.tensor(z, device=model.device) |
| | B, D = z.shape |
| | D, L, H, K = ( |
| | model.decoder.hidden_dim, |
| | model.decoder.num_layers, |
| | model.decoder.num_heads, |
| | model.decoder.prefix_length, |
| | ) |
| | z_up = model.decoder.mlp(z).reshape(B, H, K, D // H, L, 2) |
| | keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
| | layers = tuple( |
| | [ |
| | (k.squeeze(-1), v.squeeze(-1)) |
| | for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
| | ] |
| | ) |
| | output = model.decoder.decoder.generate( |
| | input_ids=torch.tensor( |
| | [[tokenizer.bos_token_id]] * B, device=model.device |
| | ), |
| | attention_mask=torch.ones((B, K + 1), device=model.device), |
| | past=layers, |
| | do_sample=do_sample, |
| | max_length=max_length, |
| | **kwargs, |
| | ) |
| | lst = tokenizer.batch_decode(output[:, 1:]) |
| | return [l.replace("<|endoftext|>", "") for l in lst] |
| |
|
| |
|
| | def get_embeddings(model, batches, desc="", notebook=True): |
| | out = [] |
| | t = ( |
| | tqdm_notebook(batches, desc=desc) |
| | if notebook |
| | else tqdm(batches, desc=desc) |
| | ) |
| | model.eval() |
| | for batch in t: |
| | with torch.no_grad(): |
| | model_inputs = { |
| | k: v.to(model.device) |
| | for k, v in batch.items() |
| | if type(v) == torch.Tensor |
| | } |
| | z, _ = model.get_z(sample=False, **model_inputs) |
| | out.append(z.detach().cpu().numpy()) |
| | return np.concatenate(out, 0) |
| |
|
| |
|
| | def interpolate(model, tokenizer, a, b, num_steps=10, **kwargs): |
| | z = np.stack( |
| | [l * b + (1 - l) * a for l in np.linspace(0, 1.0, num_steps)], 0 |
| | ) |
| | return generate(model, tokenizer, z=z, **kwargs) |
| |
|