Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import copy | |
import timm | |
from torch.nn import Parameter | |
from src.utils.no_grad import no_grad | |
from typing import Callable, Iterator, Tuple | |
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from torchvision.transforms import Normalize | |
from src.diffusion.base.training import * | |
from src.diffusion.base.scheduling import BaseScheduler | |
def inverse_sigma(alpha, sigma): | |
return 1/sigma**2 | |
def snr(alpha, sigma): | |
return alpha/sigma | |
def minsnr(alpha, sigma, threshold=5): | |
return torch.clip(alpha/sigma, min=threshold) | |
def maxsnr(alpha, sigma, threshold=5): | |
return torch.clip(alpha/sigma, max=threshold) | |
def constant(alpha, sigma): | |
return 1 | |
def time_shift_fn(t, timeshift=1.0): | |
return t/(t+(1-t)*timeshift) | |
class DisperseTrainer(BaseTrainer): | |
def __init__( | |
self, | |
scheduler: BaseScheduler, | |
loss_weight_fn:Callable=constant, | |
feat_loss_weight: float=0.5, | |
lognorm_t=False, | |
timeshift=1.0, | |
align_layer=8, | |
temperature=1.0, | |
*args, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.lognorm_t = lognorm_t | |
self.scheduler = scheduler | |
self.timeshift = timeshift | |
self.loss_weight_fn = loss_weight_fn | |
self.feat_loss_weight = feat_loss_weight | |
self.align_layer = align_layer | |
self.temperature = temperature | |
def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): | |
batch_size, c, height, width = x.shape | |
if self.lognorm_t: | |
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid() | |
else: | |
base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32) | |
t = time_shift_fn(base_t, self.timeshift).to(x.dtype) | |
noise = torch.randn_like(x) | |
alpha = self.scheduler.alpha(t) | |
dalpha = self.scheduler.dalpha(t) | |
sigma = self.scheduler.sigma(t) | |
dsigma = self.scheduler.dsigma(t) | |
x_t = alpha * x + noise * sigma | |
v_t = dalpha * x + dsigma * noise | |
src_feature = [] | |
def forward_hook(net, input, output): | |
feature = output | |
if isinstance(feature, tuple): | |
feature = feature[0] # mmdit | |
src_feature.append(feature) | |
if getattr(net, "encoder", None) is not None: | |
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) | |
else: | |
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) | |
out = net(x_t, t, y) | |
handle.remove() | |
disperse_distance = 0.0 | |
for sf in src_feature: | |
sf = torch.mean(sf, dim=1, keepdim=False) | |
distance = (sf[None, :, :] - sf[:, None, :])**2 | |
distance = distance.sum(dim=-1) | |
sf_disperse_loss = torch.exp(-distance/self.temperature) | |
mask = 1-torch.eye(batch_size, device=distance.device, dtype=distance.dtype) | |
disperse_distance += (sf_disperse_loss*mask).sum()/mask.numel() + 1e-6 | |
disperse_loss = disperse_distance.log() | |
weight = self.loss_weight_fn(alpha, sigma) | |
fm_loss = weight*(out - v_t)**2 | |
out = dict( | |
fm_loss=fm_loss.mean(), | |
cos_loss=disperse_loss.mean(), | |
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(), | |
) | |
return out | |