Spaces:
Running
on
Zero
Running
on
Zero
from collections import defaultdict | |
import logging | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from ncut_pytorch import nystrom_ncut | |
from ncut_pytorch.ncut_pytorch import find_gamma_by_degree_after_fps | |
from ncut_pytorch import NCUT, kway_ncut | |
from ncut_pytorch.ncut_pytorch import find_gamma_by_degree_after_fps | |
from omegaconf import DictConfig | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from riemann_curvature_loss import compute_riemann_curvature_loss, compute_boundary_loss, compute_repulsion_loss | |
from riemann_curvature_loss import compute_axis_align_loss | |
import gradio as gr | |
from ncut_pytorch.ncut_pytorch import affinity_from_features, ncut | |
from ncut_pytorch.affinity_gamma import find_gamma_by_degree_after_fps | |
from ncut_pytorch.math_utils import compute_riemann_curvature_loss, compute_boundary_loss, compute_repulsion_loss, compute_axis_align_loss | |
def _kway_ncut_loss(eigvec_gt, eigvec_hat, n_eig): | |
_eigvec_gt = eigvec_gt[:, :n_eig] | |
_eigvec_hat = eigvec_hat[:, :n_eig] | |
loss = F.smooth_l1_loss(_eigvec_gt @ _eigvec_gt.T, _eigvec_hat @ _eigvec_hat.T) | |
return loss | |
def flag_space_loss(eigvec_gt, eigvec_hat, n_eig, start=4, step_mult=2): | |
if torch.all(eigvec_gt == 0) or torch.all(eigvec_hat == 0): | |
return torch.tensor(0, device=eigvec_gt.device) | |
loss = 0 | |
n_eig = start // step_mult | |
while True: | |
n_eig *= step_mult | |
loss += _kway_ncut_loss(eigvec_gt, eigvec_hat, n_eig) | |
if n_eig > eigvec_gt.shape[1] or n_eig > eigvec_hat.shape[1]: | |
break | |
return loss | |
def ncut_wrapper(features, n_eig, distance='rbf', gamma=0.5): | |
A = affinity_from_features(features, distance=distance, gamma=gamma) | |
eigvec, eigval = ncut(A, n_eig) | |
return eigvec, eigval | |
def get_fg_mask(image_embeds, num_clusters=3): | |
# image_embeds b, l, c | |
if image_embeds.dim() == 2: | |
image_embeds = image_embeds.unsqueeze(0) | |
b, l, c = image_embeds.shape | |
hw = int(np.sqrt(l)) | |
inp = image_embeds[:, 1:].reshape(b*hw*hw, c) | |
gamma = find_gamma_by_degree_after_fps(inp, 0.1, distance='rbf') | |
eigvec, eigval = NCUT(10, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) | |
kway_onehot = kway_ncut(eigvec[:, :num_clusters]) | |
kway_index = kway_onehot.argmax(dim=-1) | |
kway_index = kway_index.reshape(b, hw, hw) | |
centers = kway_index[:, 8, 8] | |
corners = torch.cat([kway_index[:, 0, 0], kway_index[:, 0, 15], kway_index[:, 15, 0], kway_index[:, 15, 15]], dim=0) | |
center_mode = centers.mode().values.item() | |
corner_mode = corners.mode().values.item() | |
fg_mask = kway_index == center_mode | |
fg_mask = fg_mask.reshape(b, hw*hw) | |
# add back the first token | |
fg_mask = torch.cat([torch.ones((b, 1), device=fg_mask.device), fg_mask], dim=1) | |
fg_mask = fg_mask.bool() | |
return fg_mask | |
class MLP(nn.Module): | |
def __init__(self, in_dim, out_dim, n_layer=4, latent_dim=4096): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(in_dim, latent_dim), | |
nn.GELU(), | |
*[nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU()) for _ in range(n_layer)], | |
nn.Linear(latent_dim, out_dim) | |
) | |
def forward(self, x): | |
return self.mlp(x) | |
class CompressionModel(pl.LightningModule): | |
def __init__(self, cfg, gradio_progress=False, id_mapping=True): | |
super().__init__() | |
self.id_mapping = id_mapping | |
self.compress = MLP(cfg.in_dim, cfg.mood_dim, cfg.n_layer, cfg.latent_dim) | |
self.uncompress = MLP(cfg.mood_dim, cfg.out_dim, cfg.n_layer, cfg.latent_dim) | |
if self.id_mapping: | |
self.uncompress_dummy = MLP(cfg.mood_dim, cfg.in_dim, cfg.n_layer, cfg.latent_dim) | |
self.cfg = cfg | |
self.loss_history = defaultdict(list) | |
self.gradio_progress = gradio_progress | |
self.progress = gr.Progress() | |
def training_step(self, batch): | |
if self.gradio_progress and self.trainer.global_step % 10 == 0 and self.trainer.global_step > 0: | |
self.progress(self.trainer.global_step/self.cfg.steps, desc=f"Training, loss = {self.loss_history['recon'][-1]:.4f}") | |
feats = batch[0] | |
target_feats = batch[1] | |
fg_masks = batch[2].flatten() | |
feats_compressed = self.compress(feats) | |
feats_uncompressed = self.uncompress(feats_compressed) | |
if self.id_mapping: | |
feats_uncompressed_dummy = self.uncompress_dummy(feats_compressed) | |
if self.trainer.global_step == 0: | |
self.gamma = find_gamma_by_degree_after_fps(feats[fg_masks], 0.1, distance='rbf') | |
eigvec_gt, eigval_gt = ncut_wrapper(feats[fg_masks], self.cfg.n_eig, gamma=self.gamma) | |
eigvec_hat, eigval_hat = ncut_wrapper(feats_compressed, self.cfg.n_eig, gamma=self.gamma) | |
eigvec_hat = eigvec_hat[fg_masks] | |
total_loss = 0 | |
if self.cfg.eigvec_loss > 0: | |
eigvec_loss = flag_space_loss(eigvec_gt, eigvec_hat, n_eig=self.cfg.n_eig) | |
self.log("loss/eigvec", eigvec_loss, prog_bar=True) | |
total_loss += eigvec_loss * self.cfg.eigvec_loss | |
self.loss_history['eigvec'].append(eigvec_loss.item()) | |
if (self.cfg.recon_loss_fg > 0) and torch.any(fg_masks): | |
recon_loss_fg = F.smooth_l1_loss(target_feats[fg_masks], feats_uncompressed[fg_masks]) | |
self.log("loss/recon_fg", recon_loss_fg, prog_bar=True) | |
total_loss += recon_loss_fg * self.cfg.recon_loss_fg | |
self.loss_history['recon'].append(recon_loss_fg.item()) | |
if self.id_mapping and self.cfg.recon_loss_fg_dummy > 0 and torch.any(fg_masks): | |
recon_loss_fg_dummy = F.smooth_l1_loss(feats[fg_masks], feats_uncompressed_dummy[fg_masks]) | |
self.log("loss/recon_fg_dummy", recon_loss_fg_dummy, prog_bar=True) | |
total_loss += recon_loss_fg_dummy * self.cfg.recon_loss_fg_dummy | |
if (self.cfg.recon_loss_bg > 0) and not torch.all(fg_masks): | |
recon_loss_bg = F.smooth_l1_loss(target_feats[~fg_masks], feats_uncompressed[~fg_masks]) | |
self.log("loss/recon_bg", recon_loss_bg, prog_bar=True) | |
total_loss += recon_loss_bg * self.cfg.recon_loss_bg | |
if self.id_mapping and self.cfg.recon_loss_bg_dummy > 0 and not torch.all(fg_masks): | |
recon_loss_bg_dummy = F.smooth_l1_loss(feats[~fg_masks], feats_uncompressed_dummy[~fg_masks]) | |
self.log("loss/recon_bg_dummy", recon_loss_bg_dummy, prog_bar=True) | |
total_loss += recon_loss_bg_dummy * self.cfg.recon_loss_bg_dummy | |
if self.cfg.riemann_curvature_loss > 0: | |
riemann_curvature_loss = compute_riemann_curvature_loss(feats_compressed[fg_masks]) | |
self.log("loss/riemann_curvature", riemann_curvature_loss, prog_bar=True) | |
total_loss += riemann_curvature_loss * self.cfg.riemann_curvature_loss | |
if self.cfg.axis_align_loss > 0: | |
axis_align_loss = compute_axis_align_loss(feats_compressed[fg_masks]) | |
self.log("loss/axis_align", axis_align_loss, prog_bar=True) | |
total_loss += axis_align_loss * self.cfg.axis_align_loss | |
if self.cfg.repulsion_loss > 0: | |
repulsion_loss = compute_repulsion_loss(feats_compressed[fg_masks]) | |
self.log("loss/repulsion", repulsion_loss, prog_bar=True) | |
total_loss += repulsion_loss * self.cfg.repulsion_loss | |
if self.cfg.boundary_loss > 0: | |
boundary_loss = compute_boundary_loss(feats_compressed) | |
self.log("loss/boundary", boundary_loss, prog_bar=True) | |
total_loss += boundary_loss * self.cfg.boundary_loss | |
loss = total_loss | |
self.log("loss/total", loss, prog_bar=True) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.NAdam(self.parameters(), lr=self.cfg.lr) | |
return optimizer | |
class DatasetWithSimplices(torch.utils.data.Dataset): | |
def __init__(self, input_feats, target_feats, plus_masks): | |
self.input_feats = input_feats | |
self.target_feats = target_feats | |
self.plus_masks = plus_masks | |
def __len__(self): | |
return len(self.input_feats) | |
def __getitem__(self, idx): | |
return self.input_feats[idx], self.target_feats[idx], self.plus_masks[idx] | |
def free_memory(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
import gc | |
gc.collect() | |
def train_compression_model(model, cfg: DictConfig, input_feats, target_feats, | |
plus_masks=None, devices=[0], compute_fg_mask=False): | |
free_memory() | |
b, l, c = input_feats.shape | |
if compute_fg_mask and plus_masks is None: | |
plus_masks = get_fg_mask(input_feats) | |
if plus_masks is None: | |
plus_masks = torch.ones((b*l)).bool() | |
plus_masks = plus_masks.flatten() | |
input_feats = input_feats.flatten(end_dim=-2) | |
target_feats = target_feats.flatten(end_dim=-2) | |
# logger = pl.loggers.TensorBoardLogger(cfg.log_dir, name=cfg.name) | |
trainer = pl.Trainer(max_steps=cfg.steps, | |
gradient_clip_val=cfg.grad_clip_val, | |
accelerator="gpu", | |
devices=devices, | |
enable_checkpointing=False, | |
# logger=logger, | |
) | |
dataset = DatasetWithSimplices(input_feats, target_feats, plus_masks) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True) | |
trainer.fit(model, dataloader) | |
return trainer |