leobcc
Agen baseline definition
37a9f08
import torch.nn as nn
import torch
from torchvision import models
import torch.nn.init as init
import numpy as np
from .embedders import get_embedder
class ImplicitNet(nn.Module):
def __init__(self, opt):
super().__init__()
dims = [opt.d_in] + list(
opt.dims) + [opt.d_out + opt.feature_vector_size]
self.num_layers = len(dims)
self.skip_in = opt.skip_in
self.embed_fn = None
self.opt = opt
if opt.multires > 0:
embed_fn, input_ch = get_embedder(opt.multires, input_dims=opt.d_in, mode=opt.embedder_mode)
self.embed_fn = embed_fn
dims[0] = input_ch
self.cond = opt.cond
if self.cond == 'smpl':
self.cond_layer = [0]
self.cond_dim = 69
elif self.cond == 'frame':
self.cond_layer = [0]
self.cond_dim = opt.dim_frame_encoding
self.dim_pose_embed = 0
if self.dim_pose_embed > 0:
self.lin_p0 = nn.Linear(self.cond_dim, self.dim_pose_embed)
self.cond_dim = self.dim_pose_embed
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
if self.cond != 'none' and l in self.cond_layer:
lin = nn.Linear(dims[l] + self.cond_dim, out_dim)
else:
lin = nn.Linear(dims[l], out_dim)
if opt.init == 'geometry':
if l == self.num_layers - 2:
torch.nn.init.normal_(lin.weight,
mean=np.sqrt(np.pi) /
np.sqrt(dims[l]),
std=0.0001)
torch.nn.init.constant_(lin.bias, -opt.bias)
elif opt.multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0,
np.sqrt(2) / np.sqrt(out_dim))
elif opt.multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0,
np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):],
0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0,
np.sqrt(2) / np.sqrt(out_dim))
if opt.init == 'zero':
init_val = 1e-5
if l == self.num_layers - 2:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.uniform_(lin.weight, -init_val, init_val)
if opt.weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.softplus = nn.Softplus(beta=100)
def forward(self, input, cond, current_epoch=None):
if input.ndim == 2: input = input.unsqueeze(0)
num_batch, num_point, num_dim = input.shape
if num_batch * num_point == 0: return input
input = input.reshape(num_batch * num_point, num_dim)
if self.cond != 'none':
num_batch, num_cond = cond[self.cond].shape
input_cond = cond[self.cond].unsqueeze(1).expand(num_batch, num_point, num_cond)
input_cond = input_cond.reshape(num_batch * num_point, num_cond)
if self.dim_pose_embed:
input_cond = self.lin_p0(input_cond)
if self.embed_fn is not None:
input = self.embed_fn(input)
x = input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if self.cond != 'none' and l in self.cond_layer:
x = torch.cat([x, input_cond], dim=-1)
if l in self.skip_in:
x = torch.cat([x, input], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.softplus(x)
x = x.reshape(num_batch, num_point, -1)
return x
def gradient(self, x, cond):
x.requires_grad_(True)
y = self.forward(x, cond)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)
class RenderingNet(nn.Module):
def __init__(self, opt):
super().__init__()
self.mode = opt.mode
dims = [opt.d_in + opt.feature_vector_size] + list(
opt.dims) + [opt.d_out]
self.embedview_fn = None
if opt.multires_view > 0:
embedview_fn, input_ch = get_embedder(opt.multires_view)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
if self.mode == 'nerf_frame_encoding':
dims[0] += opt.dim_frame_encoding
if self.mode == 'pose':
self.dim_cond_embed = 8
self.cond_dim = 69 # dimension of the body pose, global orientation excluded.
# lower the condition dimension
self.lin_pose = torch.nn.Linear(self.cond_dim, self.dim_cond_embed)
self.num_layers = len(dims)
for l in range(0, self.num_layers - 1):
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if opt.weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, points, normals, view_dirs, body_pose, feature_vectors, frame_latent_code=None):
if self.embedview_fn is not None:
if self.mode == 'nerf_frame_encoding':
view_dirs = self.embedview_fn(view_dirs)
if self.mode == 'nerf_frame_encoding':
frame_latent_code = frame_latent_code.expand(view_dirs.shape[0], -1)
rendering_input = torch.cat([view_dirs, frame_latent_code, feature_vectors], dim=-1)
elif self.mode == 'pose':
num_points = points.shape[0]
body_pose = body_pose.unsqueeze(1).expand(-1, num_points, -1).reshape(num_points, -1)
body_pose = self.lin_pose(body_pose)
rendering_input = torch.cat([points, normals, body_pose, feature_vectors], dim=-1)
else:
raise NotImplementedError
x = rendering_input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
x = lin(x)
if l < self.num_layers - 2:
x = self.relu(x)
x = self.sigmoid(x)
return x
class GeometryEncodingNet(nn.Module):
def __init__(self, input_size=259, hidden_size=256, hidden_encoding_size=3, output_size=3):
super(GeometryEncodingNet, self).__init__()
# Define fully connected layers
self.fc1_e = nn.Linear(input_size-3, int(hidden_size/2))
#self.fc2_e = nn.Linear(hidden_size, int(hidden_size/2))
#self.fc3_e = nn.Linear(int(hidden_size/2), int(hidden_size/4))
self.fc4_e = nn.Linear(int(hidden_size/2), hidden_encoding_size)
# Define fully connected layers
self.fc1 = nn.Linear(hidden_encoding_size+3, hidden_encoding_size+3)
self.fc2 = nn.Linear(hidden_encoding_size+3, hidden_encoding_size+3)
self.fc3 = nn.Linear(hidden_encoding_size+3, hidden_encoding_size+3)
self.fc4 = nn.Linear(hidden_encoding_size+3, output_size)
# Initialize weights close to the identity function
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# Initialize weights with a small amount of noise around zero
init.normal_(m.weight, mean=1, std=0.01)
init.constant_(m.bias, 0)
def forward(self, x, frame_encoding_vector):
x_c = x
# Process the encoding vector of the frame
frame_encoding_vector = torch.relu(self.fc1_e(frame_encoding_vector))
#frame_encoding_vector = torch.relu(self.fc2_e(frame_encoding_vector))
#frame_encoding_vector = torch.relu(self.fc3_e(frame_encoding_vector))
frame_encoding_vector = self.fc4_e(frame_encoding_vector)
# Concatenate the frame encoding vector with the points coordinates
x = torch.cat((x_c, frame_encoding_vector.unsqueeze(0).expand(x_c.size(0), -1)), dim=-1)
# Injects the frame encoding vector in the points' coordinates
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
output = self.fc4(x)
return output