MoodSpace / compression_model_mkii.py
huzey's picture
update ncut
c8b08bf
from collections import defaultdict
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from ncut_pytorch import nystrom_ncut
from ncut_pytorch.ncut_pytorch import find_gamma_by_degree_after_fps
from ncut_pytorch import NCUT, kway_ncut
from ncut_pytorch.ncut_pytorch import find_gamma_by_degree_after_fps
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from riemann_curvature_loss import compute_riemann_curvature_loss, compute_boundary_loss, compute_repulsion_loss
from riemann_curvature_loss import compute_axis_align_loss
import gradio as gr
from ncut_pytorch.ncut_pytorch import affinity_from_features, ncut
from ncut_pytorch.affinity_gamma import find_gamma_by_degree_after_fps
from ncut_pytorch.math_utils import compute_riemann_curvature_loss, compute_boundary_loss, compute_repulsion_loss, compute_axis_align_loss
def _kway_ncut_loss(eigvec_gt, eigvec_hat, n_eig):
_eigvec_gt = eigvec_gt[:, :n_eig]
_eigvec_hat = eigvec_hat[:, :n_eig]
loss = F.smooth_l1_loss(_eigvec_gt @ _eigvec_gt.T, _eigvec_hat @ _eigvec_hat.T)
return loss
def flag_space_loss(eigvec_gt, eigvec_hat, n_eig, start=4, step_mult=2):
if torch.all(eigvec_gt == 0) or torch.all(eigvec_hat == 0):
return torch.tensor(0, device=eigvec_gt.device)
loss = 0
n_eig = start // step_mult
while True:
n_eig *= step_mult
loss += _kway_ncut_loss(eigvec_gt, eigvec_hat, n_eig)
if n_eig > eigvec_gt.shape[1] or n_eig > eigvec_hat.shape[1]:
break
return loss
def ncut_wrapper(features, n_eig, distance='rbf', gamma=0.5):
A = affinity_from_features(features, distance=distance, gamma=gamma)
eigvec, eigval = ncut(A, n_eig)
return eigvec, eigval
@torch.no_grad()
def get_fg_mask(image_embeds, num_clusters=3):
# image_embeds b, l, c
if image_embeds.dim() == 2:
image_embeds = image_embeds.unsqueeze(0)
b, l, c = image_embeds.shape
hw = int(np.sqrt(l))
inp = image_embeds[:, 1:].reshape(b*hw*hw, c)
gamma = find_gamma_by_degree_after_fps(inp, 0.1, distance='rbf')
eigvec, eigval = NCUT(10, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp)
kway_onehot = kway_ncut(eigvec[:, :num_clusters])
kway_index = kway_onehot.argmax(dim=-1)
kway_index = kway_index.reshape(b, hw, hw)
centers = kway_index[:, 8, 8]
corners = torch.cat([kway_index[:, 0, 0], kway_index[:, 0, 15], kway_index[:, 15, 0], kway_index[:, 15, 15]], dim=0)
center_mode = centers.mode().values.item()
corner_mode = corners.mode().values.item()
fg_mask = kway_index == center_mode
fg_mask = fg_mask.reshape(b, hw*hw)
# add back the first token
fg_mask = torch.cat([torch.ones((b, 1), device=fg_mask.device), fg_mask], dim=1)
fg_mask = fg_mask.bool()
return fg_mask
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, n_layer=4, latent_dim=4096):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, latent_dim),
nn.GELU(),
*[nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU()) for _ in range(n_layer)],
nn.Linear(latent_dim, out_dim)
)
def forward(self, x):
return self.mlp(x)
class CompressionModel(pl.LightningModule):
def __init__(self, cfg, gradio_progress=False, id_mapping=True):
super().__init__()
self.id_mapping = id_mapping
self.compress = MLP(cfg.in_dim, cfg.mood_dim, cfg.n_layer, cfg.latent_dim)
self.uncompress = MLP(cfg.mood_dim, cfg.out_dim, cfg.n_layer, cfg.latent_dim)
if self.id_mapping:
self.uncompress_dummy = MLP(cfg.mood_dim, cfg.in_dim, cfg.n_layer, cfg.latent_dim)
self.cfg = cfg
self.loss_history = defaultdict(list)
self.gradio_progress = gradio_progress
self.progress = gr.Progress()
def training_step(self, batch):
if self.gradio_progress and self.trainer.global_step % 10 == 0 and self.trainer.global_step > 0:
self.progress(self.trainer.global_step/self.cfg.steps, desc=f"Training, loss = {self.loss_history['recon'][-1]:.4f}")
feats = batch[0]
target_feats = batch[1]
fg_masks = batch[2].flatten()
feats_compressed = self.compress(feats)
feats_uncompressed = self.uncompress(feats_compressed)
if self.id_mapping:
feats_uncompressed_dummy = self.uncompress_dummy(feats_compressed)
if self.trainer.global_step == 0:
self.gamma = find_gamma_by_degree_after_fps(feats[fg_masks], 0.1, distance='rbf')
eigvec_gt, eigval_gt = ncut_wrapper(feats[fg_masks], self.cfg.n_eig, gamma=self.gamma)
eigvec_hat, eigval_hat = ncut_wrapper(feats_compressed, self.cfg.n_eig, gamma=self.gamma)
eigvec_hat = eigvec_hat[fg_masks]
total_loss = 0
if self.cfg.eigvec_loss > 0:
eigvec_loss = flag_space_loss(eigvec_gt, eigvec_hat, n_eig=self.cfg.n_eig)
self.log("loss/eigvec", eigvec_loss, prog_bar=True)
total_loss += eigvec_loss * self.cfg.eigvec_loss
self.loss_history['eigvec'].append(eigvec_loss.item())
if (self.cfg.recon_loss_fg > 0) and torch.any(fg_masks):
recon_loss_fg = F.smooth_l1_loss(target_feats[fg_masks], feats_uncompressed[fg_masks])
self.log("loss/recon_fg", recon_loss_fg, prog_bar=True)
total_loss += recon_loss_fg * self.cfg.recon_loss_fg
self.loss_history['recon'].append(recon_loss_fg.item())
if self.id_mapping and self.cfg.recon_loss_fg_dummy > 0 and torch.any(fg_masks):
recon_loss_fg_dummy = F.smooth_l1_loss(feats[fg_masks], feats_uncompressed_dummy[fg_masks])
self.log("loss/recon_fg_dummy", recon_loss_fg_dummy, prog_bar=True)
total_loss += recon_loss_fg_dummy * self.cfg.recon_loss_fg_dummy
if (self.cfg.recon_loss_bg > 0) and not torch.all(fg_masks):
recon_loss_bg = F.smooth_l1_loss(target_feats[~fg_masks], feats_uncompressed[~fg_masks])
self.log("loss/recon_bg", recon_loss_bg, prog_bar=True)
total_loss += recon_loss_bg * self.cfg.recon_loss_bg
if self.id_mapping and self.cfg.recon_loss_bg_dummy > 0 and not torch.all(fg_masks):
recon_loss_bg_dummy = F.smooth_l1_loss(feats[~fg_masks], feats_uncompressed_dummy[~fg_masks])
self.log("loss/recon_bg_dummy", recon_loss_bg_dummy, prog_bar=True)
total_loss += recon_loss_bg_dummy * self.cfg.recon_loss_bg_dummy
if self.cfg.riemann_curvature_loss > 0:
riemann_curvature_loss = compute_riemann_curvature_loss(feats_compressed[fg_masks])
self.log("loss/riemann_curvature", riemann_curvature_loss, prog_bar=True)
total_loss += riemann_curvature_loss * self.cfg.riemann_curvature_loss
if self.cfg.axis_align_loss > 0:
axis_align_loss = compute_axis_align_loss(feats_compressed[fg_masks])
self.log("loss/axis_align", axis_align_loss, prog_bar=True)
total_loss += axis_align_loss * self.cfg.axis_align_loss
if self.cfg.repulsion_loss > 0:
repulsion_loss = compute_repulsion_loss(feats_compressed[fg_masks])
self.log("loss/repulsion", repulsion_loss, prog_bar=True)
total_loss += repulsion_loss * self.cfg.repulsion_loss
if self.cfg.boundary_loss > 0:
boundary_loss = compute_boundary_loss(feats_compressed)
self.log("loss/boundary", boundary_loss, prog_bar=True)
total_loss += boundary_loss * self.cfg.boundary_loss
loss = total_loss
self.log("loss/total", loss, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.NAdam(self.parameters(), lr=self.cfg.lr)
return optimizer
class DatasetWithSimplices(torch.utils.data.Dataset):
def __init__(self, input_feats, target_feats, plus_masks):
self.input_feats = input_feats
self.target_feats = target_feats
self.plus_masks = plus_masks
def __len__(self):
return len(self.input_feats)
def __getitem__(self, idx):
return self.input_feats[idx], self.target_feats[idx], self.plus_masks[idx]
def free_memory():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import gc
gc.collect()
def train_compression_model(model, cfg: DictConfig, input_feats, target_feats,
plus_masks=None, devices=[0], compute_fg_mask=False):
free_memory()
b, l, c = input_feats.shape
if compute_fg_mask and plus_masks is None:
plus_masks = get_fg_mask(input_feats)
if plus_masks is None:
plus_masks = torch.ones((b*l)).bool()
plus_masks = plus_masks.flatten()
input_feats = input_feats.flatten(end_dim=-2)
target_feats = target_feats.flatten(end_dim=-2)
# logger = pl.loggers.TensorBoardLogger(cfg.log_dir, name=cfg.name)
trainer = pl.Trainer(max_steps=cfg.steps,
gradient_clip_val=cfg.grad_clip_val,
accelerator="gpu",
devices=devices,
enable_checkpointing=False,
# logger=logger,
)
dataset = DatasetWithSimplices(input_feats, target_feats, plus_masks)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)
trainer.fit(model, dataloader)
return trainer