import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from torchvision import models import os import numpy as np class Options: def __init__(self): # Image dimensions self.fine_height = 256 self.fine_width = 192 # GMM parameters self.grid_size = 5 self.input_nc = 22 # For extractionA self.input_nc_B = 1 # For extractionB # TOM parameters self.tom_input_nc = 26 # 3(agnostic) + 3(warped) + 1(mask) + 19(features) self.tom_output_nc = 4 # 3(rendered) + 1(composite mask) # Training settings self.use_dropout = False self.norm_layer = nn.BatchNorm2d def weights_init_normal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def init_weights(net, init_type='normal'): print(f'initialization method [{init_type}]') net.apply(weights_init_normal) class FeatureExtraction(nn.Module): def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d): super(FeatureExtraction, self).__init__() # Build feature extraction layers layers = [ nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), nn.ReLU(True), norm_layer(ngf) ] for i in range(n_layers): in_channels = min(2**i * ngf, 512) out_channels = min(2**(i+1) * ngf, 512) layers += [ nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), nn.ReLU(True), norm_layer(out_channels) ] # Final processing blocks layers += [ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True), norm_layer(512), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True) ] self.model = nn.Sequential(*layers) init_weights(self.model) def forward(self, x): return self.model(x) class FeatureL2Norm(nn.Module): def __init__(self): super(FeatureL2Norm, self).__init__() def forward(self, feature): epsilon = 1e-6 norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature) return torch.div(feature, norm) class FeatureCorrelation(nn.Module): def __init__(self): super(FeatureCorrelation, self).__init__() def forward(self, feature_A, feature_B): b, c, h, w = feature_A.size() feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h*w) feature_B = feature_B.view(b, c, h*w).transpose(1, 2) feature_mul = torch.bmm(feature_B, feature_A) return feature_mul.view(b, h, w, h*w).transpose(2, 3).transpose(1, 2) class FeatureRegression(nn.Module): def __init__(self, input_nc=512, output_dim=6): super(FeatureRegression, self).__init__() self.conv = nn.Sequential( nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.linear = nn.Linear(64 * 4 * 3, output_dim) self.tanh = nn.Tanh() def forward(self, x): x = self.conv(x) x = x.contiguous().view(x.size(0), -1) x = self.linear(x) return self.tanh(x) # networks.py - TpsGridGen class replacement class TpsGridGen(nn.Module): def __init__(self, out_h=256, out_w=192, grid_size=5): super(TpsGridGen, self).__init__() self.out_h = out_h self.out_w = out_w self.grid_size = grid_size self.N = grid_size * grid_size # Create grid in numpy self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32) # Sampling grid with dim-0 (Y) and dim-1 (X) coords grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h)) self.grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3) # [1, H, W, 1] self.grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3) # [1, H, W, 1] # Register buffers self.register_buffer('grid_X_base', self.grid_X) self.register_buffer('grid_Y_base', self.grid_Y) # Initialize regular grid for control points axis_coords = np.linspace(-1, 1, grid_size) P_Y, P_X = np.meshgrid(axis_coords, axis_coords) P_X = np.reshape(P_X, (-1, 1)) # [N, 1] P_Y = np.reshape(P_Y, (-1, 1)) # [N, 1] self.P_X = torch.FloatTensor(P_X) self.P_Y = torch.FloatTensor(P_Y) self.register_buffer('P_X_base', self.P_X) self.register_buffer('P_Y_base', self.P_Y) # Compute inverse matrix L^-1 Li = self.compute_L_inverse(P_X, P_Y) self.register_buffer('Li', torch.FloatTensor(Li)) def compute_L_inverse(self, X, Y): N = X.shape[0] # num of points (along dim 0) # Construct matrix K Xmat = np.tile(X, (1, N)) Ymat = np.tile(Y, (1, N)) P_dist_squared = np.power(Xmat - Xmat.T, 2) + np.power(Ymat - Ymat.T, 2) P_dist_squared[P_dist_squared == 0] = 1 # make diagonal 1 to avoid NaN in log computation K = P_dist_squared * np.log(P_dist_squared) # Construct matrix L O = np.ones((N, 1)) Z = np.zeros((3, 3)) P = np.concatenate((O, X, Y), axis=1) L = np.concatenate((np.concatenate((K, P), axis=1), np.concatenate((P.T, Z), axis=1)), axis=0) Li = np.linalg.inv(L) return Li def forward(self, theta): batch_size = theta.size(0) device = theta.device # Split theta into point coordinates Q_X = theta[:, :self.N].view(batch_size, self.N, 1, 1) Q_Y = theta[:, self.N:].view(batch_size, self.N, 1, 1) Q_X = Q_X + self.P_X_base.expand_as(Q_X) Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y) # Get spatial dimensions of points points = torch.cat((self.grid_X_base.expand(batch_size, -1, -1, -1), self.grid_Y_base.expand(batch_size, -1, -1, -1)), 3) # Repeat pre-defined control points along spatial dimensions of points to be transformed P_X = self.P_X_base.expand(batch_size, 1, 1, self.N) P_Y = self.P_Y_base.expand(batch_size, 1, 1, self.N) # Compute weights for non-linear part W_X = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1)) W_Y = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1)) # Reshape to [B, H, W, N] W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1) W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1) # Compute weights for affine part A_X = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1)) A_Y = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1)) # Reshape to [B, H, W, 1, 3] A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1) A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1) # Compute distance P_i - (grid_X, grid_Y) points_X = points[:, :, :, 0].unsqueeze(3) # [B, H, W, 1] points_Y = points[:, :, :, 1].unsqueeze(3) # [B, H, W, 1] delta_X = points_X - P_X delta_Y = points_Y - P_Y # Compute U (radial basis function) dist_squared = torch.pow(delta_X, 2) + torch.pow(delta_Y, 2) dist_squared[dist_squared == 0] = 1 # avoid NaN in log computation U = dist_squared * torch.log(dist_squared) # Compute non-affine part points_X_prime = torch.sum(torch.mul(W_X, U), dim=4) points_Y_prime = torch.sum(torch.mul(W_Y, U), dim=4) # Compute affine part A_X0 = A_X[:, :, :, :, 0] A_X1 = A_X[:, :, :, :, 1] A_X2 = A_X[:, :, :, :, 2] A_Y0 = A_Y[:, :, :, :, 0] A_Y1 = A_Y[:, :, :, :, 1] A_Y2 = A_Y[:, :, :, :, 2] points_X_prime += A_X0 + torch.mul(A_X1, points_X.squeeze(3)) + torch.mul(A_X2, points_Y.squeeze(3)) points_Y_prime += A_Y0 + torch.mul(A_Y1, points_X.squeeze(3)) + torch.mul(A_Y2, points_Y.squeeze(3)) return torch.cat((points_X_prime.unsqueeze(3), points_Y_prime.unsqueeze(3)), 3) class GMM(nn.Module): def __init__(self, opt=None): super(GMM, self).__init__() if opt is None: opt = Options() self.extractionA = FeatureExtraction(opt.input_nc) self.extractionB = FeatureExtraction(opt.input_nc_B) self.l2norm = FeatureL2Norm() self.correlation = FeatureCorrelation() self.regression = FeatureRegression(input_nc=192, output_dim=2*opt.grid_size**2) self.gridGen = TpsGridGen(opt.fine_height, opt.fine_width, opt.grid_size) def forward(self, inputA, inputB): featureA = self.extractionA(inputA) featureB = self.extractionB(inputB) featureA = self.l2norm(featureA) featureB = self.l2norm(featureB) correlation = self.correlation(featureA, featureB) theta = self.regression(correlation) grid = self.gridGen(theta) return grid, theta class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1) class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False): super(UnetGenerator, self).__init__() # Build UNet structure unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock( ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock( ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock( ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) self.model = UnetSkipConnectionBlock( output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) def forward(self, input): return self.model(input) class TOM(nn.Module): def __init__(self, opt=None): super(TOM, self).__init__() if opt is None: opt = Options() self.unet = UnetGenerator( input_nc=opt.tom_input_nc, output_nc=opt.tom_output_nc, num_downs=6, norm_layer=nn.InstanceNorm2d ) def forward(self, x): output = self.unet(x) p_rendered, m_composite = torch.split(output, [3, 1], dim=1) p_rendered = torch.tanh(p_rendered) m_composite = torch.sigmoid(m_composite) return p_rendered, m_composite def save_checkpoint(model, save_path): if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) torch.save(model.state_dict(), save_path) def load_checkpoint(model, checkpoint_path, strict=True): if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) # Create a new state dict that matches our model architecture new_state_dict = {} for key, value in state_dict.items(): # Handle name changes new_key = key if 'gridGen' in key: # Map old parameter names to new ones if 'P_X' in key and 'base' not in key: new_key = key.replace('P_X', 'P_X_base') elif 'P_Y' in key and 'base' not in key: new_key = key.replace('P_Y', 'P_Y_base') elif 'grid_X' in key and 'base' not in key: new_key = key.replace('grid_X', 'grid_X_base') elif 'grid_Y' in key and 'base' not in key: new_key = key.replace('grid_Y', 'grid_Y_base') # Only include keys that exist in the current model if new_key in model.state_dict(): new_state_dict[new_key] = value # Add missing TPS parameters if needed tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li', 'gridGen.grid_X_base', 'gridGen.grid_Y_base'] for param in tps_params: if param not in new_state_dict and hasattr(model, 'gridGen'): if param in model.state_dict(): print(f"Initializing missing TPS parameter: {param}") new_state_dict[param] = model.state_dict()[param] # Load the state dict model.load_state_dict(new_state_dict, strict=False) # Print warnings model_keys = set(model.state_dict().keys()) loaded_keys = set(new_state_dict.keys()) missing = model_keys - loaded_keys unexpected = set(state_dict.keys()) - set(new_state_dict.keys()) if missing: print(f"Missing keys: {sorted(missing)}") if unexpected: print(f"Unexpected keys: {sorted(unexpected)}") return model