File size: 3,089 Bytes
00c2650 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import numpy as np
import tensorflow as tf
import torch.cuda
import argparse
from utils import generate_run_ID
from place_cells import PlaceCells
from trajectory_generator import TrajectoryGenerator
from model import RNN
from trainer import Trainer
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_dir",
# default='/mnt/fs2/bsorsch/grid_cells/models/',
default="models/",
help="directory to save trained models",
)
parser.add_argument(
"--n_epochs", default=100, help="number of training epochs", type=int
)
parser.add_argument("--n_steps", default=1000, help="batches per epoch", type=int)
parser.add_argument(
"--batch_size", default=200, help="number of trajectories per batch", type=int
)
parser.add_argument(
"--sequence_length", default=20, help="number of steps in trajectory", type=int
)
parser.add_argument(
"--learning_rate", default=1e-4, help="gradient descent learning rate", type=float
)
parser.add_argument("--Np", default=512, help="number of place cells", type=int)
parser.add_argument("--Ng", default=4096, help="number of grid cells", type=int)
parser.add_argument(
"--place_cell_rf",
default=0.12,
help="width of place cell center tuning curve (m)",
type=float,
)
parser.add_argument(
"--surround_scale",
default=2,
help="if DoG, ratio of sigma2^2 to sigma1^2",
type=int,
)
parser.add_argument("--RNN_type", default="RNN", help="RNN or LSTM")
parser.add_argument("--activation", default="relu", help="recurrent nonlinearity")
parser.add_argument(
"--weight_decay",
default=1e-4,
help="strength of weight decay on recurrent weights",
type=float,
)
parser.add_argument(
"--DoG", default=True, help="use difference of gaussians tuning curves"
)
parser.add_argument(
"--periodic", default=False, help="trajectories with periodic boundary conditions"
)
parser.add_argument(
"--box_width", default=2.2, help="width of training environment", type=float
)
parser.add_argument(
"--box_height", default=2.2, help="height of training environment", type=float
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="device to use for training",
)
parser.add_argument(
"--seed", default=None, help="seed number for all numpy random number generator"
)
options = parser.parse_args()
options.run_ID = generate_run_ID(options)
print(f"Using device: {options.device}")
if options.seed:
np.random.seed(int(options.seed))
place_cells = PlaceCells(options)
if options.RNN_type == "RNN":
model = RNN(options, place_cells)
elif options.RNN_type == "LSTM":
# model = LSTM(options, place_cells)
raise NotImplementedError
# Put model on GPU if using GPU
if options.device == "cuda":
print("Using CUDA")
model = model.to(options.device)
trajectory_generator = TrajectoryGenerator(options, place_cells)
trainer = Trainer(options, model, trajectory_generator)
# Train
trainer.train(n_epochs=options.n_epochs, n_steps=options.n_steps)
|