File size: 2,952 Bytes
c1ce505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn


class FCN(nn.Module):
    def __init__(self, d_model, n_commands, n_args, args_dim=256, abs_targets=False):
        super().__init__()

        self.n_args = n_args
        self.args_dim = args_dim
        self.abs_targets = abs_targets

        self.command_fcn = nn.Linear(d_model, n_commands)

        if abs_targets:
            self.args_fcn = nn.Linear(d_model, n_args)
        else:
            self.args_fcn = nn.Linear(d_model, n_args * args_dim)

    def forward(self, out):
        S, N, _ = out.shape

        command_logits = self.command_fcn(out)  # Shape [S, N, n_commands]
        args_logits = self.args_fcn(out)       # Shape [S, N, n_args * args_dim]

        if not self.abs_targets:
            args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim)  # Shape [S, N, n_args, args_dim]

        return command_logits, args_logits


class ArgumentFCN(nn.Module):
    def __init__(self, d_model, n_args, args_dim=256, abs_targets=False):
        super().__init__()

        self.n_args = n_args
        self.args_dim = args_dim
        self.abs_targets = abs_targets

        # classification -> regression
        if abs_targets:
            self.args_fcn = nn.Sequential(
                nn.Linear(d_model, n_args * args_dim),
                nn.Linear(n_args * args_dim, n_args)
            )
        else:
            self.args_fcn = nn.Linear(d_model, n_args * args_dim)

    def forward(self, out):
        S, N, _ = out.shape

        args_logits = self.args_fcn(out)  # Shape [S, N, n_args * args_dim]

        if not self.abs_targets:
            args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim)  # Shape [S, N, n_args, args_dim]

        return args_logits


class HierarchFCN(nn.Module):
    def __init__(self, d_model, dim_z):
        super().__init__()

        # self.visibility_fcn = nn.Linear(d_model, 2)
        # self.z_fcn = nn.Linear(d_model, dim_z)
        self.visibility_fcn = nn.Linear(dim_z, 2)
        self.z_fcn = nn.Linear(dim_z, dim_z)

    def forward(self, out):
        G, N, _ = out.shape

        visibility_logits = self.visibility_fcn(out)  # Shape [G, N, 2]
        z = self.z_fcn(out)  # Shape [G, N, dim_z]

        return visibility_logits.unsqueeze(0), z.unsqueeze(0)


class ResNet(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.linear1 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )
        self.linear3 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )
        self.linear4 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )

    def forward(self, z):
        z = z + self.linear1(z)
        z = z + self.linear2(z)
        z = z + self.linear3(z)
        z = z + self.linear4(z)

        return z