promptpix / model /depth_module.py
RingL's picture
updated gradio application
4258b21
from functools import partial
import math
from typing import Iterable
from black import diff
from torch import nn, einsum
import numpy as np
import torch as th
import torch.nn as nn
import functools
import torch.nn.functional as F
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, normal_init)
from typing import Optional, Union, Tuple, List, Callable, Dict
import math
from einops import rearrange, repeat
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from einops import rearrange
import copy
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import fvcore.nn.weight_init as weight_init
class SelfAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
return self.forward_post(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
class CrossAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
class FFNLayer(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt):
tgt2 = self.norm(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt):
if self.normalize_before:
return self.forward_pre(tgt)
return self.forward_post(tgt)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def resize_fn(img, size):
return transforms.Resize(size, InterpolationMode.BICUBIC)(
transforms.ToPILImage()(img))
import math
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def forward(self, tgt, memory, pos = None, query_pos = None):
output = tgt
for layer in self.layers:
output = layer(output, memory, pos=pos, query_pos=query_pos)
return output
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False,
activation="relu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory, pos = None, query_pos = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
# print(tgt2.shape)
tgt2 = self.self_attn(q, k, value=tgt2)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
# Projection of x onto y
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
# Orthogonalize x wrt list of vectors ys
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
def power_iteration(W, u_, update=True, eps=1e-12):
# Lists holding singular vectors and values
us, vs, svs = [], [], []
for i, u in enumerate(u_):
# Run one step of the power iteration
with torch.no_grad():
v = torch.matmul(u, W)
# Run Gram-Schmidt to subtract components of all other singular vectors
v = F.normalize(gram_schmidt(v, vs), eps=eps)
# Add to the list
vs += [v]
# Update the other singular vector
u = torch.matmul(v, W.t())
# Run Gram-Schmidt to subtract components of all other singular vectors
u = F.normalize(gram_schmidt(u, us), eps=eps)
# Add to the list
us += [u]
if update:
u_[i][:] = u
# Compute this singular value and add it to the list
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs
# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
self.num_itrs = num_itrs
# Number of singular values
self.num_svs = num_svs
# Transposed?
self.transpose = transpose
# Epsilon value for avoiding divide-by-0
self.eps = eps
# Register a singular vector for each sv
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
# Singular vectors (u side)
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
# Singular values;
# note that these buffers are just for logging and are not used in training.
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
# Update the svs
if self.training:
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):
self.sv[i][:] = sv
return self.weight / svs[0]
# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
class SegBlock(nn.Module):
def __init__(self, in_channels, out_channels, con_channels,
which_conv=nn.Conv2d, which_linear=None, activation=None,
upsample=None):
super(SegBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_linear = which_conv, which_linear
self.activation = activation
self.upsample = upsample
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
self.register_buffer('stored_mean1', torch.zeros(in_channels))
self.register_buffer('stored_var1', torch.ones(in_channels))
self.register_buffer('stored_mean2', torch.zeros(out_channels))
self.register_buffer('stored_var2', torch.ones(out_channels))
self.upsample = upsample
def forward(self, x, y=None):
x = F.batch_norm(x, self.stored_mean1, self.stored_var1, None, None,
self.training, 0.1, 1e-4)
h = self.activation(x)
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = F.batch_norm(h, self.stored_mean2, self.stored_var2, None, None,
self.training, 0.1, 1e-4)
h = self.activation(h)
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs).double()
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x.double() * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : False,
'input_dims' : 2,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
def aggregate_attention(attention_store, res: int, from_where: List[str], is_cross: bool, select: int, prompts=None):
out = []
attention_maps = attention_store.get_average_attention()
num_pixels = res ** 2
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])
out.append(cross_maps)
out = torch.cat(out, dim=1)
return out
class Depthmodule(nn.Module):
def __init__(self,
max_depth=80
):
super().__init__()
self.max_depth = max_depth
low_feature_channel = 256
mid_feature_channel = 256
high_feature_channel = 256
highest_feature_channel=256
self.low_feature_conv = nn.Sequential(
nn.Conv2d(1280*14+8*77, low_feature_channel, kernel_size=1, bias=False),
)
self.mid_feature_conv = nn.Sequential(
nn.Conv2d(16640+40*77, mid_feature_channel, kernel_size=1, bias=False),
)
self.mid_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel,
out_channels=low_feature_channel+mid_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
self.high_feature_conv = nn.Sequential(
nn.Conv2d(9600+40*77, high_feature_channel, kernel_size=1, bias=False),
)
self.high_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
out_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
self.highest_feature_conv = nn.Sequential(
nn.Conv2d((640+320*6)*2+40*77, highest_feature_channel, kernel_size=1, bias=False),
)
self.highest_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel,
out_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
feature_dim=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel
self.last_layer_depth = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(feature_dim, 1, kernel_size=3, stride=1, padding=1))
for m in self.last_layer_depth.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
def forward(self,diffusion_features,controller,prompts,tokenizer):
image_feature=self._prepare_features(diffusion_features,controller,prompts,tokenizer)
final_image_feature=F.interpolate(image_feature, size=512, mode='bilinear', align_corners=False)
b=final_image_feature.size()[0]
out_depth = self.last_layer_depth(final_image_feature)
out_depth = torch.sigmoid(out_depth) * self.max_depth
return out_depth
def _prepare_features(self, features, attention_store,prompts,tokenizer, upsample='bilinear'):
self.low_feature_size = 16
self.mid_feature_size = 32
self.high_feature_size = 64
self.final_high_feature_size = 160
low_features = [
F.interpolate(i, size=self.low_feature_size, mode=upsample, align_corners=False) for i in features["low"]
]
low_features = torch.cat(low_features, dim=1)
mid_features = [
F.interpolate(i, size=self.mid_feature_size, mode=upsample, align_corners=False) for i in features["mid"]
]
mid_features = torch.cat(mid_features, dim=1)
high_features = [
F.interpolate(i, size=self.high_feature_size, mode=upsample, align_corners=False) for i in features["high"]
]
high_features = torch.cat(high_features, dim=1)
highest_features=torch.cat(features["highest"],dim=1)
## Attention map
from_where=("up", "down")
select = 0
# "up", "down"
attention_maps_8s = aggregate_attention(attention_store, 8, ("up", "mid", "down"), True, select,prompts=prompts)
attention_maps_16s = aggregate_attention(attention_store, 16, from_where, True, select,prompts=prompts)
attention_maps_32 = aggregate_attention(attention_store, 32, from_where, True, select,prompts=prompts)
attention_maps_64 = aggregate_attention(attention_store, 64, from_where, True, select,prompts=prompts)
attention_maps_8s = rearrange(attention_maps_8s, 'b c h w d-> b (c d) h w')
attention_maps_16s = rearrange(attention_maps_16s, 'b c h w d-> b (c d) h w')
attention_maps_32 = rearrange(attention_maps_32, 'b c h w d-> b (c d) h w')
attention_maps_64 = rearrange(attention_maps_64, 'b c h w d-> b (c d) h w')
attention_maps_8s = F.interpolate(attention_maps_8s, size=self.low_feature_size, mode=upsample, align_corners=False)
attention_maps_16s = F.interpolate(attention_maps_16s, size=self.mid_feature_size, mode=upsample, align_corners=False)
attention_maps_32 = F.interpolate(attention_maps_32, size=self.high_feature_size, mode=upsample, align_corners=False)
features_dict = {
'low': torch.cat([low_features, attention_maps_8s], dim=1) ,
'mid': torch.cat([mid_features, attention_maps_16s], dim=1) ,
'high': torch.cat([high_features, attention_maps_32], dim=1) ,
'highest':torch.cat([highest_features, attention_maps_64], dim=1) ,
}
low_feat = self.low_feature_conv(features_dict['low'])
low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False)
mid_feat = self.mid_feature_conv(features_dict['mid'])
mid_feat = torch.cat([low_feat, mid_feat], dim=1)
mid_feat = self.mid_feature_mix_conv(mid_feat, y=None)
mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False)
high_feat = self.high_feature_conv(features_dict['high'])
high_feat = torch.cat([mid_feat, high_feat], dim=1)
high_feat = self.high_feature_mix_conv(high_feat, y=None)
highest_feat=self.highest_feature_conv(features_dict['highest'])
highest_feat=torch.cat([high_feat,highest_feat],dim=1)
highest_feat=self.highest_feature_mix_conv(highest_feat,y=None)
highest_feat = F.interpolate(highest_feat, size=self.final_high_feature_size, mode='bilinear', align_corners=False)
return highest_feat