NCTC / models /research /lfads /synth_data /generate_chaotic_rnn_data.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf # used for flags here
from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
matplotlib.rcParams['image.interpolation'] = 'nearest'
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.")
flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nreplications", 40,
"Number of noise replications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("input_magnitude", 20.0,
"For the input case, what is the value of the input?")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
# Note that with N small, (as it is 25 above), the finite size effects
# will have pretty dramatic effects on the dynamics of the random RNN.
# If you want more complex dynamics, you'll have to run the script a
# lot, or increase N (or g).
# Getting hard vs. easy data can be a little stochastic, so we set the seed.
# Pull out some commonly used parameters.
# These are user parameters (configuration)
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
S = FLAGS.S
input_magnitude = FLAGS.input_magnitude
nreplications = FLAGS.nreplications
E = nreplications * C # total number of trials
# S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations.
ndatasets = N/S # ok if rounded down
train_percentage = FLAGS.train_percentage
ntime_steps = int(T / FLAGS.dt)
# End of user parameters
rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper.
if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nreplications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nreplications should be integral number.'
# Initial condition generation, and condition label generation. This
# happens outside of the dataset loop, so that all datasets have the
# same conditions, which is similar to a neurophys setup.
condition_number = 0
x0s = []
condition_labels = []
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
# replicate the condition label nreplications times
for ns in range(nreplications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
# Containers for storing data across data.
datasets = {}
for n in range(ndatasets):
print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all
# replications this allows the random state for rate generation to be
# independent of n_replications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N:
dataset_name += '_n' + str(n+1)
# Sample neuron subsets. The assumption is the PC axes of the RNN
# are not unit aligned, so sampling units is adequate to sample all
# the high-variance PCs.
P_sxn = np.eye(S,N)
for m in range(n):
P_sxn = np.roll(P_sxn, S, axis=1)
if input_magnitude > 0.0:
# time of "hits" randomly chosen between [1/4 and 3/4] of total time
input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
else:
input_times = None
rates, x0s, inputs = \
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude,
input_times=input_times)
if FLAGS.noise_type == "poisson":
noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
elif FLAGS.noise_type == "gaussian":
noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
else:
raise ValueError("Only noise types supported are poisson or gaussian")
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nreplications)
# Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds)
noisy_data_train, noisy_data_valid = \
split_list_by_inds(noisy_data, train_inds, valid_inds)
input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \
split_list_by_inds(condition_labels, train_inds, valid_inds)
input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, noisy_data, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
noisy_data_train = nparray_and_transpose(noisy_data_train)
noisy_data_valid = nparray_and_transpose(noisy_data_valid)
input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the noisy
# data e.g. spike trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid,
'train_data' : noisy_data_train,
'valid_data' : noisy_data_valid,
'train_percentage' : train_percentage,
'nreplications' : nreplications,
'dt' : rnn['dt'],
'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train,
'input_times_valid' : input_times_valid,
'P_sxn' : P_sxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn['conversion_factor']}
datasets[dataset_name] = data
if S < N:
# Note that this isn't necessary for this synthetic example, but
# it's useful to see how the input factor matrices were initialized
# for actual neurophysiology data.
datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
# Write out the datasets.
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)