|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import torch.cuda |
|
|
import argparse |
|
|
|
|
|
|
|
|
from utils import generate_run_ID |
|
|
from place_cells import PlaceCells |
|
|
from trajectory_generator import TrajectoryGenerator |
|
|
from model import RNN |
|
|
from trainer import Trainer |
|
|
|
|
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--save_dir", |
|
|
|
|
|
default="models/", |
|
|
help="directory to save trained models", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--n_epochs", default=100, help="number of training epochs", type=int |
|
|
) |
|
|
parser.add_argument("--n_steps", default=1000, help="batches per epoch", type=int) |
|
|
parser.add_argument( |
|
|
"--batch_size", default=200, help="number of trajectories per batch", type=int |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sequence_length", default=20, help="number of steps in trajectory", type=int |
|
|
) |
|
|
parser.add_argument( |
|
|
"--learning_rate", default=1e-4, help="gradient descent learning rate", type=float |
|
|
) |
|
|
parser.add_argument("--Np", default=512, help="number of place cells", type=int) |
|
|
parser.add_argument("--Ng", default=4096, help="number of grid cells", type=int) |
|
|
parser.add_argument( |
|
|
"--place_cell_rf", |
|
|
default=0.12, |
|
|
help="width of place cell center tuning curve (m)", |
|
|
type=float, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--surround_scale", |
|
|
default=2, |
|
|
help="if DoG, ratio of sigma2^2 to sigma1^2", |
|
|
type=int, |
|
|
) |
|
|
parser.add_argument("--RNN_type", default="RNN", help="RNN or LSTM") |
|
|
parser.add_argument("--activation", default="relu", help="recurrent nonlinearity") |
|
|
parser.add_argument( |
|
|
"--weight_decay", |
|
|
default=1e-4, |
|
|
help="strength of weight decay on recurrent weights", |
|
|
type=float, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--DoG", default=True, help="use difference of gaussians tuning curves" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--periodic", default=False, help="trajectories with periodic boundary conditions" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--box_width", default=2.2, help="width of training environment", type=float |
|
|
) |
|
|
parser.add_argument( |
|
|
"--box_height", default=2.2, help="height of training environment", type=float |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
|
help="device to use for training", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", default=None, help="seed number for all numpy random number generator" |
|
|
) |
|
|
|
|
|
options = parser.parse_args() |
|
|
options.run_ID = generate_run_ID(options) |
|
|
|
|
|
print(f"Using device: {options.device}") |
|
|
|
|
|
if options.seed: |
|
|
np.random.seed(int(options.seed)) |
|
|
|
|
|
place_cells = PlaceCells(options) |
|
|
if options.RNN_type == "RNN": |
|
|
model = RNN(options, place_cells) |
|
|
elif options.RNN_type == "LSTM": |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
if options.device == "cuda": |
|
|
print("Using CUDA") |
|
|
model = model.to(options.device) |
|
|
|
|
|
trajectory_generator = TrajectoryGenerator(options, place_cells) |
|
|
|
|
|
trainer = Trainer(options, model, trajectory_generator) |
|
|
|
|
|
|
|
|
trainer.train(n_epochs=options.n_epochs, n_steps=options.n_steps) |
|
|
|