Spaces:
Running
on
L4
Running
on
L4
File size: 2,952 Bytes
c1ce505 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
import torch.nn as nn
class FCN(nn.Module):
def __init__(self, d_model, n_commands, n_args, args_dim=256, abs_targets=False):
super().__init__()
self.n_args = n_args
self.args_dim = args_dim
self.abs_targets = abs_targets
self.command_fcn = nn.Linear(d_model, n_commands)
if abs_targets:
self.args_fcn = nn.Linear(d_model, n_args)
else:
self.args_fcn = nn.Linear(d_model, n_args * args_dim)
def forward(self, out):
S, N, _ = out.shape
command_logits = self.command_fcn(out) # Shape [S, N, n_commands]
args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim]
if not self.abs_targets:
args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim]
return command_logits, args_logits
class ArgumentFCN(nn.Module):
def __init__(self, d_model, n_args, args_dim=256, abs_targets=False):
super().__init__()
self.n_args = n_args
self.args_dim = args_dim
self.abs_targets = abs_targets
# classification -> regression
if abs_targets:
self.args_fcn = nn.Sequential(
nn.Linear(d_model, n_args * args_dim),
nn.Linear(n_args * args_dim, n_args)
)
else:
self.args_fcn = nn.Linear(d_model, n_args * args_dim)
def forward(self, out):
S, N, _ = out.shape
args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim]
if not self.abs_targets:
args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim]
return args_logits
class HierarchFCN(nn.Module):
def __init__(self, d_model, dim_z):
super().__init__()
# self.visibility_fcn = nn.Linear(d_model, 2)
# self.z_fcn = nn.Linear(d_model, dim_z)
self.visibility_fcn = nn.Linear(dim_z, 2)
self.z_fcn = nn.Linear(dim_z, dim_z)
def forward(self, out):
G, N, _ = out.shape
visibility_logits = self.visibility_fcn(out) # Shape [G, N, 2]
z = self.z_fcn(out) # Shape [G, N, dim_z]
return visibility_logits.unsqueeze(0), z.unsqueeze(0)
class ResNet(nn.Module):
def __init__(self, d_model):
super().__init__()
self.linear1 = nn.Sequential(
nn.Linear(d_model, d_model), nn.ReLU()
)
self.linear2 = nn.Sequential(
nn.Linear(d_model, d_model), nn.ReLU()
)
self.linear3 = nn.Sequential(
nn.Linear(d_model, d_model), nn.ReLU()
)
self.linear4 = nn.Sequential(
nn.Linear(d_model, d_model), nn.ReLU()
)
def forward(self, z):
z = z + self.linear1(z)
z = z + self.linear2(z)
z = z + self.linear3(z)
z = z + self.linear4(z)
return z
|