File size: 3,089 Bytes
00c2650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import tensorflow as tf
import torch.cuda
import argparse


from utils import generate_run_ID
from place_cells import PlaceCells
from trajectory_generator import TrajectoryGenerator
from model import RNN
from trainer import Trainer

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--save_dir",
    # default='/mnt/fs2/bsorsch/grid_cells/models/',
    default="models/",
    help="directory to save trained models",
)
parser.add_argument(
    "--n_epochs", default=100, help="number of training epochs", type=int
)
parser.add_argument("--n_steps", default=1000, help="batches per epoch", type=int)
parser.add_argument(
    "--batch_size", default=200, help="number of trajectories per batch", type=int
)
parser.add_argument(
    "--sequence_length", default=20, help="number of steps in trajectory", type=int
)
parser.add_argument(
    "--learning_rate", default=1e-4, help="gradient descent learning rate", type=float
)
parser.add_argument("--Np", default=512, help="number of place cells", type=int)
parser.add_argument("--Ng", default=4096, help="number of grid cells", type=int)
parser.add_argument(
    "--place_cell_rf",
    default=0.12,
    help="width of place cell center tuning curve (m)",
    type=float,
)
parser.add_argument(
    "--surround_scale",
    default=2,
    help="if DoG, ratio of sigma2^2 to sigma1^2",
    type=int,
)
parser.add_argument("--RNN_type", default="RNN", help="RNN or LSTM")
parser.add_argument("--activation", default="relu", help="recurrent nonlinearity")
parser.add_argument(
    "--weight_decay",
    default=1e-4,
    help="strength of weight decay on recurrent weights",
    type=float,
)
parser.add_argument(
    "--DoG", default=True, help="use difference of gaussians tuning curves"
)
parser.add_argument(
    "--periodic", default=False, help="trajectories with periodic boundary conditions"
)
parser.add_argument(
    "--box_width", default=2.2, help="width of training environment", type=float
)
parser.add_argument(
    "--box_height", default=2.2, help="height of training environment", type=float
)
parser.add_argument(
    "--device",
    default="cuda" if torch.cuda.is_available() else "cpu",
    help="device to use for training",
)
parser.add_argument(
    "--seed", default=None, help="seed number for all numpy random number generator"
)

options = parser.parse_args()
options.run_ID = generate_run_ID(options)

print(f"Using device: {options.device}")

if options.seed:
    np.random.seed(int(options.seed))

place_cells = PlaceCells(options)
if options.RNN_type == "RNN":
    model = RNN(options, place_cells)
elif options.RNN_type == "LSTM":
    # model = LSTM(options, place_cells)
    raise NotImplementedError

# Put model on GPU if using GPU
if options.device == "cuda":
    print("Using CUDA")
model = model.to(options.device)

trajectory_generator = TrajectoryGenerator(options, place_cells)

trainer = Trainer(options, model, trajectory_generator)

# Train
trainer.train(n_epochs=options.n_epochs, n_steps=options.n_steps)