|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class RNN(torch.nn.Module): |
|
|
def __init__(self, options, place_cells): |
|
|
super(RNN, self).__init__() |
|
|
self.Ng = options.Ng |
|
|
self.Np = options.Np |
|
|
self.sequence_length = options.sequence_length |
|
|
self.weight_decay = options.weight_decay |
|
|
self.place_cells = place_cells |
|
|
|
|
|
|
|
|
self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False) |
|
|
self.RNN = torch.nn.RNN( |
|
|
input_size=2, |
|
|
hidden_size=self.Ng, |
|
|
nonlinearity=options.activation, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False) |
|
|
|
|
|
self.softmax = torch.nn.Softmax(dim=-1) |
|
|
|
|
|
def g(self, inputs): |
|
|
""" |
|
|
Compute grid cell activations. |
|
|
Args: |
|
|
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
|
|
|
|
|
Returns: |
|
|
g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng]. |
|
|
""" |
|
|
v, p0 = inputs |
|
|
init_state = self.encoder(p0)[None] |
|
|
g, _ = self.RNN(v, init_state) |
|
|
return g |
|
|
|
|
|
def predict(self, inputs): |
|
|
""" |
|
|
Predict place cell code. |
|
|
Args: |
|
|
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
|
|
|
|
|
Returns: |
|
|
place_preds: Predicted place cell activations with shape |
|
|
[batch_size, sequence_length, Np]. |
|
|
""" |
|
|
place_preds = self.decoder(self.g(inputs)) |
|
|
|
|
|
return place_preds |
|
|
|
|
|
def set_weights(self, weights): |
|
|
""" |
|
|
Load weights from a numpy array (e.g. from the provided example weights). |
|
|
Assumes weights are in the order: [encoder, rnn_ih, rnn_hh, decoder] |
|
|
and transposed (TF/Keras format). |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
self.encoder.weight.copy_(torch.from_numpy(weights[0].T).float()) |
|
|
|
|
|
|
|
|
self.RNN.weight_ih_l0.copy_(torch.from_numpy(weights[1].T).float()) |
|
|
|
|
|
|
|
|
self.RNN.weight_hh_l0.copy_(torch.from_numpy(weights[2].T).float()) |
|
|
|
|
|
|
|
|
self.decoder.weight.copy_(torch.from_numpy(weights[3].T).float()) |
|
|
|
|
|
def compute_loss(self, inputs, pc_outputs, pos): |
|
|
""" |
|
|
Compute avg. loss and decoding error. |
|
|
Args: |
|
|
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
|
|
pc_outputs: Ground truth place cell activations with shape |
|
|
[batch_size, sequence_length, Np]. |
|
|
pos: Ground truth 2d position with shape [batch_size, sequence_length, 2]. |
|
|
|
|
|
Returns: |
|
|
loss: Avg. loss for this training batch. |
|
|
err: Avg. decoded position error in cm. |
|
|
""" |
|
|
y: torch.Tensor = pc_outputs |
|
|
preds: torch.Tensor = self.predict(inputs) |
|
|
loss = torch.nn.functional.cross_entropy(preds.flatten(0, 1), y.flatten(0, 1)) |
|
|
|
|
|
|
|
|
loss += self.weight_decay * (self.RNN.weight_hh_l0**2).sum() |
|
|
|
|
|
|
|
|
pred_pos = self.place_cells.get_nearest_cell_pos(preds) |
|
|
err = torch.sqrt(((pos - pred_pos) ** 2).sum(-1)).mean() |
|
|
|
|
|
return loss, err |
|
|
|