File size: 3,464 Bytes
00c2650 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# -*- coding: utf-8 -*-
import torch
class RNN(torch.nn.Module):
def __init__(self, options, place_cells):
super(RNN, self).__init__()
self.Ng = options.Ng
self.Np = options.Np
self.sequence_length = options.sequence_length
self.weight_decay = options.weight_decay
self.place_cells = place_cells
# Input weights
self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False)
self.RNN = torch.nn.RNN(
input_size=2,
hidden_size=self.Ng,
nonlinearity=options.activation,
bias=False,
)
# Linear read-out weights
self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False)
self.softmax = torch.nn.Softmax(dim=-1)
def g(self, inputs):
"""
Compute grid cell activations.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng].
"""
v, p0 = inputs
init_state = self.encoder(p0)[None]
g, _ = self.RNN(v, init_state)
return g
def predict(self, inputs):
"""
Predict place cell code.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
place_preds: Predicted place cell activations with shape
[batch_size, sequence_length, Np].
"""
place_preds = self.decoder(self.g(inputs))
return place_preds
def set_weights(self, weights):
"""
Load weights from a numpy array (e.g. from the provided example weights).
Assumes weights are in the order: [encoder, rnn_ih, rnn_hh, decoder]
and transposed (TF/Keras format).
"""
with torch.no_grad():
# Encoder: (Np, Ng) -> (Ng, Np)
self.encoder.weight.copy_(torch.from_numpy(weights[0].T).float())
# RNN input: (2, Ng) -> (Ng, 2)
self.RNN.weight_ih_l0.copy_(torch.from_numpy(weights[1].T).float())
# RNN hidden: (Ng, Ng) -> (Ng, Ng)
self.RNN.weight_hh_l0.copy_(torch.from_numpy(weights[2].T).float())
# Decoder: (Ng, Np) -> (Np, Ng)
self.decoder.weight.copy_(torch.from_numpy(weights[3].T).float())
def compute_loss(self, inputs, pc_outputs, pos):
"""
Compute avg. loss and decoding error.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
pc_outputs: Ground truth place cell activations with shape
[batch_size, sequence_length, Np].
pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
Returns:
loss: Avg. loss for this training batch.
err: Avg. decoded position error in cm.
"""
y: torch.Tensor = pc_outputs
preds: torch.Tensor = self.predict(inputs)
loss = torch.nn.functional.cross_entropy(preds.flatten(0, 1), y.flatten(0, 1))
# Weight regularization
loss += self.weight_decay * (self.RNN.weight_hh_l0**2).sum()
# Compute decoding error
pred_pos = self.place_cells.get_nearest_cell_pos(preds)
err = torch.sqrt(((pos - pred_pos) ** 2).sum(-1)).mean()
return loss, err
|