Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# ============================================================================== | |
from __future__ import print_function | |
import h5py | |
import numpy as np | |
import os | |
import tensorflow as tf # used for flags here | |
from utils import write_datasets | |
from synthetic_data_utils import add_alignment_projections, generate_data | |
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds | |
from synthetic_data_utils import nparray_and_transpose | |
from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import scipy.signal | |
matplotlib.rcParams['image.interpolation'] = 'nearest' | |
DATA_DIR = "rnn_synth_data_v1.0" | |
flags = tf.app.flags | |
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", | |
"Directory for saving data.") | |
flags.DEFINE_string("datafile_name", "thits_data", | |
"Name of data file for input case.") | |
flags.DEFINE_string("noise_type", "poisson", "Noise type for data.") | |
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") | |
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") | |
flags.DEFINE_integer("C", 100, "Number of conditions") | |
flags.DEFINE_integer("N", 50, "Number of units for the RNN") | |
flags.DEFINE_integer("S", 50, "Number of sampled units from RNN") | |
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.") | |
flags.DEFINE_float("train_percentage", 4.0/5.0, | |
"Percentage of train vs validation trials") | |
flags.DEFINE_integer("nreplications", 40, | |
"Number of noise replications of the same underlying rates.") | |
flags.DEFINE_float("g", 1.5, "Complexity of dynamics") | |
flags.DEFINE_float("x0_std", 1.0, | |
"Volume from which to pull initial conditions (affects diversity of dynamics.") | |
flags.DEFINE_float("tau", 0.025, "Time constant of RNN") | |
flags.DEFINE_float("dt", 0.010, "Time bin") | |
flags.DEFINE_float("input_magnitude", 20.0, | |
"For the input case, what is the value of the input?") | |
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second") | |
FLAGS = flags.FLAGS | |
# Note that with N small, (as it is 25 above), the finite size effects | |
# will have pretty dramatic effects on the dynamics of the random RNN. | |
# If you want more complex dynamics, you'll have to run the script a | |
# lot, or increase N (or g). | |
# Getting hard vs. easy data can be a little stochastic, so we set the seed. | |
# Pull out some commonly used parameters. | |
# These are user parameters (configuration) | |
rng = np.random.RandomState(seed=FLAGS.synth_data_seed) | |
T = FLAGS.T | |
C = FLAGS.C | |
N = FLAGS.N | |
S = FLAGS.S | |
input_magnitude = FLAGS.input_magnitude | |
nreplications = FLAGS.nreplications | |
E = nreplications * C # total number of trials | |
# S is the number of measurements in each datasets, w/ each | |
# dataset having a different set of observations. | |
ndatasets = N/S # ok if rounded down | |
train_percentage = FLAGS.train_percentage | |
ntime_steps = int(T / FLAGS.dt) | |
# End of user parameters | |
rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate) | |
# Check to make sure the RNN is the one we used in the paper. | |
if N == 50: | |
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?' | |
rem_check = nreplications * train_percentage | |
assert abs(rem_check - int(rem_check)) < 1e-8, \ | |
'Train percentage * nreplications should be integral number.' | |
# Initial condition generation, and condition label generation. This | |
# happens outside of the dataset loop, so that all datasets have the | |
# same conditions, which is similar to a neurophys setup. | |
condition_number = 0 | |
x0s = [] | |
condition_labels = [] | |
for c in range(C): | |
x0 = FLAGS.x0_std * rng.randn(N, 1) | |
x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times | |
# replicate the condition label nreplications times | |
for ns in range(nreplications): | |
condition_labels.append(condition_number) | |
condition_number += 1 | |
x0s = np.concatenate(x0s, axis=1) | |
# Containers for storing data across data. | |
datasets = {} | |
for n in range(ndatasets): | |
print(n+1, " of ", ndatasets) | |
# First generate all firing rates. in the next loop, generate all | |
# replications this allows the random state for rate generation to be | |
# independent of n_replications. | |
dataset_name = 'dataset_N' + str(N) + '_S' + str(S) | |
if S < N: | |
dataset_name += '_n' + str(n+1) | |
# Sample neuron subsets. The assumption is the PC axes of the RNN | |
# are not unit aligned, so sampling units is adequate to sample all | |
# the high-variance PCs. | |
P_sxn = np.eye(S,N) | |
for m in range(n): | |
P_sxn = np.roll(P_sxn, S, axis=1) | |
if input_magnitude > 0.0: | |
# time of "hits" randomly chosen between [1/4 and 3/4] of total time | |
input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4) | |
else: | |
input_times = None | |
rates, x0s, inputs = \ | |
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, | |
input_magnitude=input_magnitude, | |
input_times=input_times) | |
if FLAGS.noise_type == "poisson": | |
noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate']) | |
elif FLAGS.noise_type == "gaussian": | |
noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate']) | |
else: | |
raise ValueError("Only noise types supported are poisson or gaussian") | |
# split into train and validation sets | |
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, | |
nreplications) | |
# Split the data, inputs, labels and times into train vs. validation. | |
rates_train, rates_valid = \ | |
split_list_by_inds(rates, train_inds, valid_inds) | |
noisy_data_train, noisy_data_valid = \ | |
split_list_by_inds(noisy_data, train_inds, valid_inds) | |
input_train, inputs_valid = \ | |
split_list_by_inds(inputs, train_inds, valid_inds) | |
condition_labels_train, condition_labels_valid = \ | |
split_list_by_inds(condition_labels, train_inds, valid_inds) | |
input_times_train, input_times_valid = \ | |
split_list_by_inds(input_times, train_inds, valid_inds) | |
# Turn rates, noisy_data, and input into numpy arrays. | |
rates_train = nparray_and_transpose(rates_train) | |
rates_valid = nparray_and_transpose(rates_valid) | |
noisy_data_train = nparray_and_transpose(noisy_data_train) | |
noisy_data_valid = nparray_and_transpose(noisy_data_valid) | |
input_train = nparray_and_transpose(input_train) | |
inputs_valid = nparray_and_transpose(inputs_valid) | |
# Note that we put these 'truth' rates and input into this | |
# structure, the only data that is used in LFADS are the noisy | |
# data e.g. spike trains. The rest is either for printing or posterity. | |
data = {'train_truth': rates_train, | |
'valid_truth': rates_valid, | |
'input_train_truth' : input_train, | |
'input_valid_truth' : inputs_valid, | |
'train_data' : noisy_data_train, | |
'valid_data' : noisy_data_valid, | |
'train_percentage' : train_percentage, | |
'nreplications' : nreplications, | |
'dt' : rnn['dt'], | |
'input_magnitude' : input_magnitude, | |
'input_times_train' : input_times_train, | |
'input_times_valid' : input_times_valid, | |
'P_sxn' : P_sxn, | |
'condition_labels_train' : condition_labels_train, | |
'condition_labels_valid' : condition_labels_valid, | |
'conversion_factor': 1.0 / rnn['conversion_factor']} | |
datasets[dataset_name] = data | |
if S < N: | |
# Note that this isn't necessary for this synthetic example, but | |
# it's useful to see how the input factor matrices were initialized | |
# for actual neurophysiology data. | |
datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs) | |
# Write out the datasets. | |
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets) | |