File size: 4,015 Bytes
6cd35b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Microsoft Corporation.
import math
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from .positional_encoding import SphericalHarmonics
class LocationEncoder(nn.Module):
def __init__(
self,
dim_hidden: int,
num_layers: int,
dim_out: int,
legendre_polys: int = 10,
):
super().__init__()
self.posenc = SphericalHarmonics(legendre_polys=legendre_polys)
self.nnet = SirenNet(
dim_in=self.posenc.embedding_dim,
dim_hidden=dim_hidden,
num_layers=num_layers,
dim_out=dim_out,
)
def forward(self, x):
x = self.posenc(x)
return self.nnet(x)
class SirenNet(nn.Module):
"""Sinusoidal Representation Network (SIREN)"""
def __init__(
self,
dim_in,
dim_hidden,
dim_out,
num_layers,
w0=1.0,
w0_initial=30.0,
use_bias=True,
final_activation=None,
degreeinput=False,
dropout=True,
):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.degreeinput = degreeinput
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden
self.layers.append(
Siren(
dim_in=layer_dim_in,
dim_out=dim_hidden,
w0=layer_w0,
use_bias=use_bias,
is_first=is_first,
dropout=dropout,
)
)
final_activation = (
nn.Identity() if not exists(final_activation) else final_activation
)
self.last_layer = Siren(
dim_in=dim_hidden,
dim_out=dim_out,
w0=w0,
use_bias=use_bias,
activation=final_activation,
dropout=False,
)
def forward(self, x, mods=None):
# do some normalization to bring degrees in a -pi to pi range
if self.degreeinput:
x = torch.deg2rad(x) - torch.pi
mods = cast_tuple(mods, self.num_layers)
for layer, mod in zip(self.layers, mods):
x = layer(x)
if exists(mod):
x *= rearrange(mod, "d -> () d")
return self.last_layer(x)
class Sine(nn.Module):
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
class Siren(nn.Module):
def __init__(
self,
dim_in,
dim_out,
w0=1.0,
c=6.0,
is_first=False,
use_bias=True,
activation=None,
dropout=False,
):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
self.dim_out = dim_out
self.dropout = dropout
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c=c, w0=w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
def init_(self, weight, bias, c, w0):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if exists(bias):
bias.uniform_(-w_std, w_std)
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
if self.dropout:
out = F.dropout(out, training=self.training)
out = self.activation(out)
return out
def exists(val):
return val is not None
def cast_tuple(val, repeat=1):
return val if isinstance(val, tuple) else ((val,) * repeat)
|