Design_warper / networks.py
gaur3009's picture
Update networks.py
15e6502 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torchvision import models
import os
import numpy as np
class Options:
def __init__(self):
# Image dimensions
self.fine_height = 256
self.fine_width = 192
# GMM parameters
self.grid_size = 5
self.input_nc = 22 # For extractionA
self.input_nc_B = 1 # For extractionB
# TOM parameters
self.tom_input_nc = 26 # 3(agnostic) + 3(warped) + 1(mask) + 19(features)
self.tom_output_nc = 4 # 3(rendered) + 1(composite mask)
# Training settings
self.use_dropout = False
self.norm_layer = nn.BatchNorm2d
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
print(f'initialization method [{init_type}]')
net.apply(weights_init_normal)
class FeatureExtraction(nn.Module):
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
super(FeatureExtraction, self).__init__()
# Build feature extraction layers
layers = [
nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
norm_layer(ngf)
]
for i in range(n_layers):
in_channels = min(2**i * ngf, 512)
out_channels = min(2**(i+1) * ngf, 512)
layers += [
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
norm_layer(out_channels)
]
# Final processing blocks
layers += [
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
norm_layer(512),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(True)
]
self.model = nn.Sequential(*layers)
init_weights(self.model)
def forward(self, x):
return self.model(x)
class FeatureL2Norm(nn.Module):
def __init__(self):
super(FeatureL2Norm, self).__init__()
def forward(self, feature):
epsilon = 1e-6
norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature)
return torch.div(feature, norm)
class FeatureCorrelation(nn.Module):
def __init__(self):
super(FeatureCorrelation, self).__init__()
def forward(self, feature_A, feature_B):
b, c, h, w = feature_A.size()
feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h*w)
feature_B = feature_B.view(b, c, h*w).transpose(1, 2)
feature_mul = torch.bmm(feature_B, feature_A)
return feature_mul.view(b, h, w, h*w).transpose(2, 3).transpose(1, 2)
class FeatureRegression(nn.Module):
def __init__(self, input_nc=512, output_dim=6):
super(FeatureRegression, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(64 * 4 * 3, output_dim)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.conv(x)
x = x.contiguous().view(x.size(0), -1)
x = self.linear(x)
return self.tanh(x)
# networks.py - TpsGridGen class replacement
class TpsGridGen(nn.Module):
def __init__(self, out_h=256, out_w=192, grid_size=5):
super(TpsGridGen, self).__init__()
self.out_h = out_h
self.out_w = out_w
self.grid_size = grid_size
self.N = grid_size * grid_size
# Create grid in numpy
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
# Sampling grid with dim-0 (Y) and dim-1 (X) coords
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
self.grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3) # [1, H, W, 1]
self.grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3) # [1, H, W, 1]
# Register buffers
self.register_buffer('grid_X_base', self.grid_X)
self.register_buffer('grid_Y_base', self.grid_Y)
# Initialize regular grid for control points
axis_coords = np.linspace(-1, 1, grid_size)
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
P_X = np.reshape(P_X, (-1, 1)) # [N, 1]
P_Y = np.reshape(P_Y, (-1, 1)) # [N, 1]
self.P_X = torch.FloatTensor(P_X)
self.P_Y = torch.FloatTensor(P_Y)
self.register_buffer('P_X_base', self.P_X)
self.register_buffer('P_Y_base', self.P_Y)
# Compute inverse matrix L^-1
Li = self.compute_L_inverse(P_X, P_Y)
self.register_buffer('Li', torch.FloatTensor(Li))
def compute_L_inverse(self, X, Y):
N = X.shape[0] # num of points (along dim 0)
# Construct matrix K
Xmat = np.tile(X, (1, N))
Ymat = np.tile(Y, (1, N))
P_dist_squared = np.power(Xmat - Xmat.T, 2) + np.power(Ymat - Ymat.T, 2)
P_dist_squared[P_dist_squared == 0] = 1 # make diagonal 1 to avoid NaN in log computation
K = P_dist_squared * np.log(P_dist_squared)
# Construct matrix L
O = np.ones((N, 1))
Z = np.zeros((3, 3))
P = np.concatenate((O, X, Y), axis=1)
L = np.concatenate((np.concatenate((K, P), axis=1),
np.concatenate((P.T, Z), axis=1)), axis=0)
Li = np.linalg.inv(L)
return Li
def forward(self, theta):
batch_size = theta.size(0)
device = theta.device
# Split theta into point coordinates
Q_X = theta[:, :self.N].view(batch_size, self.N, 1, 1)
Q_Y = theta[:, self.N:].view(batch_size, self.N, 1, 1)
Q_X = Q_X + self.P_X_base.expand_as(Q_X)
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
# Get spatial dimensions of points
points = torch.cat((self.grid_X_base.expand(batch_size, -1, -1, -1),
self.grid_Y_base.expand(batch_size, -1, -1, -1)), 3)
# Repeat pre-defined control points along spatial dimensions of points to be transformed
P_X = self.P_X_base.expand(batch_size, 1, 1, self.N)
P_Y = self.P_Y_base.expand(batch_size, 1, 1, self.N)
# Compute weights for non-linear part
W_X = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1))
W_Y = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1))
# Reshape to [B, H, W, N]
W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
# Compute weights for affine part
A_X = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1))
A_Y = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1))
# Reshape to [B, H, W, 1, 3]
A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
# Compute distance P_i - (grid_X, grid_Y)
points_X = points[:, :, :, 0].unsqueeze(3) # [B, H, W, 1]
points_Y = points[:, :, :, 1].unsqueeze(3) # [B, H, W, 1]
delta_X = points_X - P_X
delta_Y = points_Y - P_Y
# Compute U (radial basis function)
dist_squared = torch.pow(delta_X, 2) + torch.pow(delta_Y, 2)
dist_squared[dist_squared == 0] = 1 # avoid NaN in log computation
U = dist_squared * torch.log(dist_squared)
# Compute non-affine part
points_X_prime = torch.sum(torch.mul(W_X, U), dim=4)
points_Y_prime = torch.sum(torch.mul(W_Y, U), dim=4)
# Compute affine part
A_X0 = A_X[:, :, :, :, 0]
A_X1 = A_X[:, :, :, :, 1]
A_X2 = A_X[:, :, :, :, 2]
A_Y0 = A_Y[:, :, :, :, 0]
A_Y1 = A_Y[:, :, :, :, 1]
A_Y2 = A_Y[:, :, :, :, 2]
points_X_prime += A_X0 + torch.mul(A_X1, points_X.squeeze(3)) + torch.mul(A_X2, points_Y.squeeze(3))
points_Y_prime += A_Y0 + torch.mul(A_Y1, points_X.squeeze(3)) + torch.mul(A_Y2, points_Y.squeeze(3))
return torch.cat((points_X_prime.unsqueeze(3), points_Y_prime.unsqueeze(3)), 3)
class GMM(nn.Module):
def __init__(self, opt=None):
super(GMM, self).__init__()
if opt is None:
opt = Options()
self.extractionA = FeatureExtraction(opt.input_nc)
self.extractionB = FeatureExtraction(opt.input_nc_B)
self.l2norm = FeatureL2Norm()
self.correlation = FeatureCorrelation()
self.regression = FeatureRegression(input_nc=192, output_dim=2*opt.grid_size**2)
self.gridGen = TpsGridGen(opt.fine_height, opt.fine_width, opt.grid_size)
def forward(self, inputA, inputB):
featureA = self.extractionA(inputA)
featureB = self.extractionB(inputB)
featureA = self.l2norm(featureA)
featureB = self.l2norm(featureB)
correlation = self.correlation(featureA, featureB)
theta = self.regression(correlation)
grid = self.gridGen(theta)
return grid, theta
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False,
norm_layer=nn.InstanceNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.InstanceNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
# Build UNet structure
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, input_nc=None, submodule=None,
norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(
ngf, ngf * 2, input_nc=None, submodule=unet_block,
norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(
output_nc, ngf, input_nc=input_nc, submodule=unet_block,
outermost=True, norm_layer=norm_layer)
def forward(self, input):
return self.model(input)
class TOM(nn.Module):
def __init__(self, opt=None):
super(TOM, self).__init__()
if opt is None:
opt = Options()
self.unet = UnetGenerator(
input_nc=opt.tom_input_nc,
output_nc=opt.tom_output_nc,
num_downs=6,
norm_layer=nn.InstanceNorm2d
)
def forward(self, x):
output = self.unet(x)
p_rendered, m_composite = torch.split(output, [3, 1], dim=1)
p_rendered = torch.tanh(p_rendered)
m_composite = torch.sigmoid(m_composite)
return p_rendered, m_composite
def save_checkpoint(model, save_path):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(model.state_dict(), save_path)
def load_checkpoint(model, checkpoint_path, strict=True):
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# Create a new state dict that matches our model architecture
new_state_dict = {}
for key, value in state_dict.items():
# Handle name changes
new_key = key
if 'gridGen' in key:
# Map old parameter names to new ones
if 'P_X' in key and 'base' not in key:
new_key = key.replace('P_X', 'P_X_base')
elif 'P_Y' in key and 'base' not in key:
new_key = key.replace('P_Y', 'P_Y_base')
elif 'grid_X' in key and 'base' not in key:
new_key = key.replace('grid_X', 'grid_X_base')
elif 'grid_Y' in key and 'base' not in key:
new_key = key.replace('grid_Y', 'grid_Y_base')
# Only include keys that exist in the current model
if new_key in model.state_dict():
new_state_dict[new_key] = value
# Add missing TPS parameters if needed
tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
'gridGen.grid_X_base', 'gridGen.grid_Y_base']
for param in tps_params:
if param not in new_state_dict and hasattr(model, 'gridGen'):
if param in model.state_dict():
print(f"Initializing missing TPS parameter: {param}")
new_state_dict[param] = model.state_dict()[param]
# Load the state dict
model.load_state_dict(new_state_dict, strict=False)
# Print warnings
model_keys = set(model.state_dict().keys())
loaded_keys = set(new_state_dict.keys())
missing = model_keys - loaded_keys
unexpected = set(state_dict.keys()) - set(new_state_dict.keys())
if missing:
print(f"Missing keys: {sorted(missing)}")
if unexpected:
print(f"Unexpected keys: {sorted(unexpected)}")
return model