promptpix / model /seg_module.py
RingL's picture
updated gradio application
4258b21
from functools import partial
import math
from typing import Iterable
# from black import diff
from torch import nn, einsum
import numpy as np
import torch as th
import torch.nn as nn
import functools
import torch.nn.functional as F
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, normal_init)
from typing import Optional, Union, Tuple, List, Callable, Dict
import math
from einops import rearrange, repeat
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from einops import rearrange
import copy
from torchvision import transforms
from torchvision.transforms import InterpolationMode
coco_category_list_check_person = [
"arm",
'person',
"man",
"woman",
"child",
"boy",
"girl",
"teenager"
]
VOC_category_list_check = {
'aeroplane':['aerop','lane'],
'bicycle':['bicycle'],
'bird':['bird'],
'boat':['boat'],
'bottle':['bottle'],
'bus':['bus'],
'car':['car'],
'cat':['cat'],
'chair':['chair'],
'cow':['cow'],
'diningtable':['table'],
'dog':['dog'],
'horse':['horse'],
'motorbike':['motorbike'],
'person':coco_category_list_check_person,
'pottedplant':['pot','plant','ted'],
'sheep':['sheep'],
'sofa':['sofa'],
'train':['train'],
'tvmonitor':['monitor','tv','monitor']
}
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def resize_fn(img, size):
return transforms.Resize(size, InterpolationMode.BICUBIC)(
transforms.ToPILImage()(img))
import math
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def forward(self, tgt, memory, pos = None, query_pos = None):
output = tgt
for layer in self.layers:
output = layer(output, memory, pos=pos, query_pos=query_pos)
return output
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False,
activation="relu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory, pos = None, query_pos = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
# Projection of x onto y
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
# Orthogonalize x wrt list of vectors ys
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
def power_iteration(W, u_, update=True, eps=1e-12):
# Lists holding singular vectors and values
us, vs, svs = [], [], []
for i, u in enumerate(u_):
# Run one step of the power iteration
with torch.no_grad():
v = torch.matmul(u, W)
# Run Gram-Schmidt to subtract components of all other singular vectors
v = F.normalize(gram_schmidt(v, vs), eps=eps)
# Add to the list
vs += [v]
# Update the other singular vector
u = torch.matmul(v, W.t())
# Run Gram-Schmidt to subtract components of all other singular vectors
u = F.normalize(gram_schmidt(u, us), eps=eps)
# Add to the list
us += [u]
if update:
u_[i][:] = u
# Compute this singular value and add it to the list
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs
# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
self.num_itrs = num_itrs
# Number of singular values
self.num_svs = num_svs
# Transposed?
self.transpose = transpose
# Epsilon value for avoiding divide-by-0
self.eps = eps
# Register a singular vector for each sv
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
# Singular vectors (u side)
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
# Singular values;
# note that these buffers are just for logging and are not used in training.
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
# Update the svs
if self.training:
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):
self.sv[i][:] = sv
return self.weight / svs[0]
# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
class SegBlock(nn.Module):
def __init__(self, in_channels, out_channels, con_channels,
which_conv=nn.Conv2d, which_linear=None, activation=None,
upsample=None):
super(SegBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_linear = which_conv, which_linear
self.activation = activation
self.upsample = upsample
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
self.register_buffer('stored_mean1', torch.zeros(in_channels))
self.register_buffer('stored_var1', torch.ones(in_channels))
self.register_buffer('stored_mean2', torch.zeros(out_channels))
self.register_buffer('stored_var2', torch.ones(out_channels))
self.upsample = upsample
def forward(self, x, y=None):
x = F.batch_norm(x, self.stored_mean1, self.stored_var1, None, None,
self.training, 0.1, 1e-4)
h = self.activation(x)
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = F.batch_norm(h, self.stored_mean2, self.stored_var2, None, None,
self.training, 0.1, 1e-4)
h = self.activation(h)
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs).double()
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x.double() * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : False,
'input_dims' : 2,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
def aggregate_attention(attention_store, res: int, from_where: List[str], is_cross: bool, select: int, prompts=None):
out = []
attention_maps = attention_store.get_average_attention()
num_pixels = res ** 2
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])
out.append(cross_maps)
out = torch.cat(out, dim=1)
return out
class segmodule(nn.Module):
def __init__(self,
embedding_dim=512,
num_heads=8,
num_layers=3,
hidden_dim=2048,
dropout_rate=0):
super().__init__()
self.max_depth = 10
low_feature_channel = 16
mid_feature_channel = 32
high_feature_channel = 64
highest_feature_channel=128
self.low_feature_conv = nn.Sequential(
nn.Conv2d(1280*14+8*77, low_feature_channel, kernel_size=1, bias=False),
)
self.mid_feature_conv = nn.Sequential(
nn.Conv2d(16640+40*77, mid_feature_channel, kernel_size=1, bias=False),
)
self.mid_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel,
out_channels=low_feature_channel+mid_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
self.high_feature_conv = nn.Sequential(
nn.Conv2d(9600+40*77, high_feature_channel, kernel_size=1, bias=False),
)
self.high_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
out_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
self.highest_feature_conv = nn.Sequential(
nn.Conv2d((640+320*6)*2+40*77, highest_feature_channel, kernel_size=1, bias=False),
)
self.highest_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel,
out_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
# feature_dim=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel
feature_dim=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel
query_dim=feature_dim*16
decoder_layer = TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate)
self.transfromer_decoder = TransformerDecoder(decoder_layer, num_layers)
self.mlp = MLP(embedding_dim, embedding_dim, feature_dim, 3)
context_dim=768
self.to_k = nn.Linear(query_dim, embedding_dim, bias=False)
self.to_q = nn.Linear(context_dim, embedding_dim, bias=False)
def forward(self,diffusion_features,controller,prompts,tokenizer,classs,text_embedding):
image_feature=self._prepare_features(diffusion_features,controller,prompts,classs,tokenizer)
final_image_feature=F.interpolate(image_feature, size=512, mode='bilinear', align_corners=False)
b=final_image_feature.size()[0]
patch_size = 4
patch_number=int(image_feature.size()[2]/patch_size)
image_feature = torch.nn.functional.unfold(image_feature, patch_size, stride=patch_size).transpose(1,2).contiguous()
image_feature=rearrange(image_feature, 'b n d -> (b n) d ')
text_embedding=rearrange(text_embedding, 'b n d -> (b n) d ')
q = self.to_q(text_embedding)
k = self.to_k(image_feature)
output_query = self.transfromer_decoder(q, k, None)
output_query=rearrange(output_query, '(b n) d -> b n d',b=b)
mask_embedding=self.mlp(output_query)
seg_result=einsum('b d h w, b n d -> b n h w', final_image_feature, mask_embedding)
return seg_result
def _prepare_features(self, features, attention_store,prompts,classs,tokenizer, upsample='bilinear'):
self.low_feature_size = 16
self.mid_feature_size = 32
self.high_feature_size = 64
low_features = [
F.interpolate(i, size=self.low_feature_size, mode=upsample, align_corners=False) for i in features["low"]
]
low_features = torch.cat(low_features, dim=1)
mid_features = [
F.interpolate(i, size=self.mid_feature_size, mode=upsample, align_corners=False) for i in features["mid"]
]
mid_features = torch.cat(mid_features, dim=1)
high_features = [
F.interpolate(i, size=self.high_feature_size, mode=upsample, align_corners=False) for i in features["high"]
]
high_features = torch.cat(high_features, dim=1)
highest_features=torch.cat(features["highest"],dim=1)
## Attention map
from_where=("up", "down")
select = 0
tokens = tokenizer.encode(prompts[select])
decoder = tokenizer.decode
category_list_check = VOC_category_list_check[classs[0]]
# "up", "down"
attention_maps_8s = aggregate_attention(attention_store, 8, ("up", "mid", "down"), True, select,prompts=prompts)
attention_maps_16s = aggregate_attention(attention_store, 16, from_where, True, select,prompts=prompts)
attention_maps_32 = aggregate_attention(attention_store, 32, from_where, True, select,prompts=prompts)
attention_maps_64 = aggregate_attention(attention_store, 64, from_where, True, select,prompts=prompts)
attention_maps_8s = rearrange(attention_maps_8s, 'b c h w d-> b (c d) h w')
attention_maps_16s = rearrange(attention_maps_16s, 'b c h w d-> b (c d) h w')
attention_maps_32 = rearrange(attention_maps_32, 'b c h w d-> b (c d) h w')
attention_maps_64 = rearrange(attention_maps_64, 'b c h w d-> b (c d) h w')
attention_maps_8s = F.interpolate(attention_maps_8s, size=self.low_feature_size, mode=upsample, align_corners=False)
attention_maps_16s = F.interpolate(attention_maps_16s, size=self.mid_feature_size, mode=upsample, align_corners=False)
attention_maps_32 = F.interpolate(attention_maps_32, size=self.high_feature_size, mode=upsample, align_corners=False)
features_dict = {
'low': torch.cat([low_features, attention_maps_8s], dim=1) ,
'mid': torch.cat([mid_features, attention_maps_16s], dim=1) ,
'high': torch.cat([high_features, attention_maps_32], dim=1) ,
'highest':torch.cat([highest_features, attention_maps_64], dim=1) ,
}
low_feat = self.low_feature_conv(features_dict['low'])
low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False)
mid_feat = self.mid_feature_conv(features_dict['mid'])
mid_feat = torch.cat([low_feat, mid_feat], dim=1)
mid_feat = self.mid_feature_mix_conv(mid_feat, y=None)
mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False)
high_feat = self.high_feature_conv(features_dict['high'])
high_feat = torch.cat([mid_feat, high_feat], dim=1)
high_feat = self.high_feature_mix_conv(high_feat, y=None)
highest_feat=self.highest_feature_conv(features_dict['highest'])
highest_feat=torch.cat([high_feat,highest_feat],dim=1)
highest_feat=self.highest_feature_mix_conv(highest_feat,y=None)
return highest_feat