Spaces:
Running
Running
 | |
 | |
 | |
# Filtering Variational Objectives | |
This folder contains a TensorFlow implementation of the algorithms from | |
Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Whye Teh. "Filtering Variational Objectives." NIPS 2017. | |
[https://arxiv.org/abs/1705.09279](https://arxiv.org/abs/1705.09279) | |
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO). | |
Additionally it contains several sequential latent variable model implementations: | |
* Variational recurrent neural network (VRNN) | |
* Stochastic recurrent neural network (SRNN) | |
* Gaussian hidden Markov model with linear conditionals (GHMM) | |
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model. | |
#### Directory Structure | |
The important parts of the code are organized as follows. | |
``` | |
run_fivo.py # main script, contains flag definitions | |
fivo | |
├─smc.py # a sequential Monte Carlo implementation | |
├─bounds.py # code for computing each bound, uses smc.py | |
├─runners.py # code for VRNN and SRNN training and evaluation | |
├─ghmm_runners.py # code for GHMM training and evaluation | |
├─data | |
| ├─datasets.py # readers for pianoroll and speech datasets | |
| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets | |
| └─create_timit_dataset.py # preprocesses the TIMIT dataset | |
└─models | |
├─base.py # base classes used in other models | |
├─vrnn.py # VRNN implementation | |
├─srnn.py # SRNN implementation | |
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation | |
bin | |
├─run_train.sh # an example script that runs training | |
├─run_eval.sh # an example script that runs evaluation | |
├─run_sample.sh # an example script that runs sampling | |
├─run_tests.sh # a script that runs all tests | |
└─download_pianorolls.sh # a script that downloads pianoroll files | |
``` | |
### Pianorolls | |
Requirements before we start: | |
* TensorFlow (see [tensorflow.org](http://tensorflow.org) for how to install) | |
* [scipy](https://www.scipy.org/) | |
* [sonnet](https://github.com/deepmind/sonnet) | |
#### Download the Data | |
The pianoroll datasets are encoded as pickled sparse arrays and are available at [http://www-etud.iro.umontreal.ca/~boulanni/icml2012](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). You can use the script `bin/download_pianorolls.sh` to download the files into a directory of your choosing. | |
``` | |
export PIANOROLL_DIR=~/pianorolls | |
mkdir $PIANOROLL_DIR | |
sh bin/download_pianorolls.sh $PIANOROLL_DIR | |
``` | |
#### Preprocess the Data | |
The script `calculate_pianoroll_mean.py` loads a pianoroll pickle file, calculates the mean, updates the pickle file to include the mean under the key `train_mean`, and writes the file back to disk in-place. You should do this for all pianoroll datasets you wish to train on. | |
``` | |
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/piano-midi.de.pkl | |
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/nottingham.de.pkl | |
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/musedata.pkl | |
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl | |
``` | |
#### Training | |
Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`: | |
``` | |
python run_fivo.py \ | |
--mode=train \ | |
--logdir=/tmp/fivo \ | |
--model=vrnn \ | |
--bound=fivo \ | |
--summarize_every=100 \ | |
--batch_size=4 \ | |
--num_samples=4 \ | |
--learning_rate=0.0001 \ | |
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \ | |
--dataset_type="pianoroll" | |
``` | |
You should see output that looks something like this (with extra logging cruft): | |
``` | |
Saving checkpoints for 0 into /tmp/fivo/model.ckpt. | |
Step 1, fivo bound per timestep: -11.322491 | |
global_step/sec: 7.49971 | |
Step 101, fivo bound per timestep: -11.399275 | |
global_step/sec: 8.04498 | |
Step 201, fivo bound per timestep: -11.174991 | |
global_step/sec: 8.03989 | |
Step 301, fivo bound per timestep: -11.073008 | |
``` | |
#### Evaluation | |
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set: | |
``` | |
python run_fivo.py \ | |
--mode=eval \ | |
--split=test \ | |
--alsologtostderr \ | |
--logdir=/tmp/fivo \ | |
--model=vrnn \ | |
--batch_size=4 \ | |
--num_samples=4 \ | |
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \ | |
--dataset_type="pianoroll" | |
``` | |
You should see output like this: | |
``` | |
Restoring parameters from /tmp/fivo/model.ckpt-0 | |
Model restored from step 0, evaluating. | |
test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776 | |
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141 | |
``` | |
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds. | |
#### Sampling | |
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`: | |
``` | |
python run_fivo.py \ | |
--mode=sample \ | |
--alsologtostderr \ | |
--logdir="/tmp/fivo" \ | |
--model=vrnn \ | |
--bound=fivo \ | |
--batch_size=4 \ | |
--num_samples=4 \ | |
--split=test \ | |
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \ | |
--dataset_type="pianoroll" \ | |
--prefix_length=25 \ | |
--sample_length=50 | |
``` | |
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix. | |
You should see very little output. | |
``` | |
Restoring parameters from /tmp/fivo/model.ckpt-0 | |
Running local_init_op. | |
Done running local_init_op. | |
``` | |
Loading the samples with `np.load` confirms that we conditioned the model on 4 | |
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix. | |
``` | |
>>> import numpy as np | |
>>> x = np.load("/tmp/fivo/samples.npz") | |
>>> x[()]['prefixes'].shape | |
(25, 4, 88) | |
>>> x[()]['samples'].shape | |
(50, 4, 4, 88) | |
``` | |
### Training on TIMIT | |
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`. | |
#### Preprocess TIMIT | |
We preprocess TIMIT (as described in our paper) and write it out to a series of TFRecord files. To prepare the TIMIT dataset use the script `create_timit_dataset.py` | |
``` | |
export $TIMIT_DIR=~/timit_dataset | |
mkdir $TIMIT_DIR | |
python data/create_timit_dataset.py \ | |
--raw_timit_dir=$RAW_TIMIT_DIR \ | |
--out_dir=$TIMIT_DIR | |
``` | |
You should see this exact output: | |
``` | |
4389 train / 231 valid / 1680 test | |
train mean: 0.006060 train std: 548.136169 | |
``` | |
#### Training on TIMIT | |
This is very similar to training on pianoroll datasets, with just a few flags switched. | |
``` | |
python run_fivo.py \ | |
--mode=train \ | |
--logdir=/tmp/fivo \ | |
--model=vrnn \ | |
--bound=fivo \ | |
--summarize_every=100 \ | |
--batch_size=4 \ | |
--num_samples=4 \ | |
--learning_rate=0.0001 \ | |
--dataset_path="$TIMIT_DIR/train" \ | |
--dataset_type="speech" | |
``` | |
Evaluation and sampling are similar. | |
### Tests | |
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code. | |
### Contact | |
This codebase is maintained by Dieterich Lawson. For questions and issues please open an issue on the tensorflow/models issues tracker and assign it to @dieterichlawson. | |