Spaces:
Paused
Paused
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SinusoidalPositionEmbeddings(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, time): | |
| device = time.device | |
| half_dim = self.dim // 2 | |
| embeddings = math.log(10000) / (half_dim - 1) | |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
| embeddings = time[:, None] * embeddings[None, :] | |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
| return embeddings | |
| class DropoutSampler(torch.nn.Module): | |
| def __init__(self, num_features, num_outputs, dropout_rate = 0.5): | |
| super(DropoutSampler, self).__init__() | |
| self.linear = nn.Linear(num_features, num_features) | |
| self.linear2 = nn.Linear(num_features, num_features) | |
| self.predict = nn.Linear(num_features, num_outputs) | |
| self.num_features = num_features | |
| self.num_outputs = num_outputs | |
| self.dropout_rate = dropout_rate | |
| def forward(self, x): | |
| x = F.relu(self.linear(x)) | |
| if self.dropout_rate > 0: | |
| x = F.dropout(x, self.dropout_rate) | |
| x = F.relu(self.linear2(x)) | |
| return self.predict(x) | |
| class EncoderMLP(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True): | |
| super(EncoderMLP, self).__init__() | |
| self.uses_pt = uses_pt | |
| self.output = out_dim | |
| d5 = int(in_dim) | |
| d6 = int(2 * self.output) | |
| d7 = self.output | |
| self.encode_position = nn.Sequential( | |
| nn.Linear(pt_dim, in_dim), | |
| nn.LayerNorm(in_dim), | |
| nn.ReLU(), | |
| nn.Linear(in_dim, in_dim), | |
| nn.LayerNorm(in_dim), | |
| nn.ReLU(), | |
| ) | |
| d5 = 2 * in_dim if self.uses_pt else in_dim | |
| self.fc_block = nn.Sequential( | |
| nn.Linear(int(d5), d6), | |
| nn.LayerNorm(int(d6)), | |
| nn.ReLU(), | |
| nn.Linear(int(d6), d6), | |
| nn.LayerNorm(int(d6)), | |
| nn.ReLU(), | |
| nn.Linear(d6, d7)) | |
| def forward(self, x, pt=None): | |
| if self.uses_pt: | |
| if pt is None: raise RuntimeError('did not provide pt') | |
| y = self.encode_position(pt) | |
| x = torch.cat([x, y], dim=-1) | |
| return self.fc_block(x) | |
| class MeanEncoder(torch.nn.Module): | |
| def __init__(self, input_channels=3, use_xyz=True, output=512, scale=0.04, factor=1): | |
| super(MeanEncoder, self).__init__() | |
| self.uses_rgb = False | |
| self.dim = 3 | |
| def forward(self, xyz, f=None): | |
| # Fix shape | |
| if f is not None: | |
| if len(f.shape) < 3: | |
| f = f.transpose(0,1).contiguous() | |
| f = f[None] | |
| else: | |
| f = f.transpose(1,2).contiguous() | |
| if len(xyz.shape) == 3: | |
| center = torch.mean(xyz, dim=1) | |
| elif len(xyz.shape) == 2: | |
| center = torch.mean(xyz, dim=0) | |
| else: | |
| raise RuntimeError('not sure what to do with points of shape ' + str(xyz.shape)) | |
| assert(xyz.shape[-1]) == 3 | |
| assert(center.shape[-1]) == 3 | |
| return center, center |