Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Main script for running fivo""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from collections import defaultdict | |
import numpy as np | |
import tensorflow as tf | |
import bounds | |
import data | |
import models | |
import summary_utils as summ | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.app.flags.DEFINE_integer("random_seed", None, | |
"A random seed for the data generating process. Same seed " | |
"-> same data generating process and initialization.") | |
tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"], | |
"The bound to optimize.") | |
tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"], | |
"The model to use.") | |
tf.app.flags.DEFINE_enum("q_type", "normal", | |
["normal", "simple_mean", "prev_state", "observation"], | |
"The parameterization to use for q") | |
tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"], | |
"The type of prior.") | |
tf.app.flags.DEFINE_boolean("train_p", True, | |
"If false, do not train the model p.") | |
tf.app.flags.DEFINE_integer("state_size", 1, | |
"The dimensionality of the state space.") | |
tf.app.flags.DEFINE_float("variance", 1.0, | |
"The variance of the data generating process.") | |
tf.app.flags.DEFINE_boolean("use_bs", True, | |
"If False, initialize all bs to 0.") | |
tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5, | |
"The weight assigned to the positive mode of the prior in " | |
"both the data generating process and p.") | |
tf.app.flags.DEFINE_float("bimodal_prior_mean", None, | |
"If supplied, sets the mean of the 2 modes of the prior to " | |
"be 1 and -1 times the supplied value. This is for both the " | |
"data generating process and p.") | |
tf.app.flags.DEFINE_float("fixed_observation", None, | |
"If supplied, fix the observation to a constant value in the" | |
" data generating process only.") | |
tf.app.flags.DEFINE_float("r_sigma_init", 1., | |
"Value to initialize variance of r to.") | |
tf.app.flags.DEFINE_enum("observation_type", | |
models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES, | |
"The type of observation for the long chain model.") | |
tf.app.flags.DEFINE_enum("transition_type", | |
models.STANDARD_TRANSITION, models.TRANSITION_TYPES, | |
"The type of transition for the long chain model.") | |
tf.app.flags.DEFINE_float("observation_variance", None, | |
"The variance of the observation. Defaults to 'variance'") | |
tf.app.flags.DEFINE_integer("num_timesteps", 5, | |
"Number of timesteps in the sequence.") | |
tf.app.flags.DEFINE_integer("num_observations", 1, | |
"The number of observations.") | |
tf.app.flags.DEFINE_integer("steps_per_observation", 5, | |
"The number of timesteps between each observation.") | |
tf.app.flags.DEFINE_integer("batch_size", 4, | |
"The number of examples per batch.") | |
tf.app.flags.DEFINE_integer("num_samples", 4, | |
"The number particles to use.") | |
tf.app.flags.DEFINE_integer("num_eval_samples", 512, | |
"The batch size and # of particles to use for eval.") | |
tf.app.flags.DEFINE_string("resampling", "always", | |
"How to resample. Accepts 'always','never', or a " | |
"comma-separated list of booleans like 'true,true,false'.") | |
tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial", | |
"stratified", | |
"systematic", | |
"relaxed-logblend", | |
"relaxed-stateblend", | |
"relaxed-linearblend", | |
"relaxed-stateblend-st",], | |
"Type of resampling method to use.") | |
tf.app.flags.DEFINE_boolean("use_resampling_grads", True, | |
"Whether or not to use resampling grads to optimize FIVO." | |
"Disabled automatically if resampling_method=relaxed.") | |
tf.app.flags.DEFINE_boolean("disable_r", False, | |
"If false, r is not used for fivo-aux and is set to zeros.") | |
tf.app.flags.DEFINE_float("learning_rate", 1e-4, | |
"The learning rate to use for ADAM or SGD.") | |
tf.app.flags.DEFINE_integer("decay_steps", 25000, | |
"The number of steps before the learning rate is halved.") | |
tf.app.flags.DEFINE_integer("max_steps", int(1e6), | |
"The number of steps to run training for.") | |
tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux", | |
"Directory for summaries and checkpoints.") | |
tf.app.flags.DEFINE_integer("summarize_every", int(1e3), | |
"The number of steps between each evaluation.") | |
FLAGS = tf.app.flags.FLAGS | |
def combine_grad_lists(grad_lists): | |
# grads is num_losses by num_variables. | |
# each list could have different variables. | |
# for each variable, sum the grads across all losses. | |
grads_dict = defaultdict(list) | |
var_dict = {} | |
for grad_list in grad_lists: | |
for grad, var in grad_list: | |
if grad is not None: | |
grads_dict[var.name].append(grad) | |
var_dict[var.name] = var | |
final_grads = [] | |
for var_name, var in var_dict.iteritems(): | |
grads = grads_dict[var_name] | |
if len(grads) > 0: | |
tf.logging.info("Var %s has combined grads from %s." % | |
(var_name, [g.name for g in grads])) | |
grad = tf.reduce_sum(grads, axis=0) | |
else: | |
tf.logging.info("Var %s has no grads" % var_name) | |
grad = None | |
final_grads.append((grad, var)) | |
return final_grads | |
def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps): | |
for l in losses: | |
assert isinstance(l, bounds.Loss) | |
lr = tf.train.exponential_decay( | |
learning_rate, global_step, lr_decay_steps, 0.5, staircase=False) | |
tf.summary.scalar("learning_rate", lr) | |
opt = tf.train.AdamOptimizer(lr) | |
ema_ops = [] | |
grads = [] | |
for loss_name, loss, loss_var_collection in losses: | |
tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" % | |
(loss_name, loss_var_collection)) | |
g = opt.compute_gradients(loss, | |
var_list=tf.get_collection(loss_var_collection)) | |
ema_ops.append(summ.summarize_grads(g, loss_name)) | |
grads.append(g) | |
all_grads = combine_grad_lists(grads) | |
apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step) | |
# Update the emas after applying the grads. | |
with tf.control_dependencies([apply_grads_op]): | |
train_op = tf.group(*ema_ops) | |
return train_op | |
def add_check_numerics_ops(): | |
check_op = [] | |
for op in tf.get_default_graph().get_operations(): | |
bad = ["logits/Log", "sample/Reshape", "log_prob/mul", | |
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape", | |
"entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"] | |
if all([x not in op.name for x in bad]): | |
for output in op.outputs: | |
if output.dtype in [tf.float16, tf.float32, tf.float64]: | |
if op._get_control_flow_context() is not None: # pylint: disable=protected-access | |
raise ValueError("`tf.add_check_numerics_ops() is not compatible " | |
"with TensorFlow control flow operations such as " | |
"`tf.cond()` or `tf.while_loop()`.") | |
message = op.name + ":" + str(output.value_index) | |
with tf.control_dependencies(check_op): | |
check_op = [tf.check_numerics(output, message=message)] | |
return tf.group(*check_op) | |
def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs, | |
batch_size, num_samples, num_eval_samples, | |
resampling_schedule, use_resampling_grads, | |
learning_rate, lr_decay_steps, dtype="float64"): | |
num_timesteps = num_obs * steps_per_obs + 1 | |
# Make the dataset. | |
dataset = data.make_long_chain_dataset( | |
state_size=state_size, | |
num_obs=num_obs, | |
steps_per_obs=steps_per_obs, | |
batch_size=batch_size, | |
num_samples=num_samples, | |
variance=FLAGS.variance, | |
observation_variance=FLAGS.observation_variance, | |
dtype=dtype, | |
observation_type=FLAGS.observation_type, | |
transition_type=FLAGS.transition_type, | |
fixed_observation=FLAGS.fixed_observation) | |
itr = dataset.make_one_shot_iterator() | |
_, observations = itr.get_next() | |
# Make the dataset for eval | |
eval_dataset = data.make_long_chain_dataset( | |
state_size=state_size, | |
num_obs=num_obs, | |
steps_per_obs=steps_per_obs, | |
batch_size=batch_size, | |
num_samples=num_eval_samples, | |
variance=FLAGS.variance, | |
observation_variance=FLAGS.observation_variance, | |
dtype=dtype, | |
observation_type=FLAGS.observation_type, | |
transition_type=FLAGS.transition_type, | |
fixed_observation=FLAGS.fixed_observation) | |
eval_itr = eval_dataset.make_one_shot_iterator() | |
_, eval_observations = eval_itr.get_next() | |
# Make the model. | |
model = models.LongChainModel.create( | |
state_size, | |
num_obs, | |
steps_per_obs, | |
observation_type=FLAGS.observation_type, | |
transition_type=FLAGS.transition_type, | |
variance=FLAGS.variance, | |
observation_variance=FLAGS.observation_variance, | |
dtype=tf.as_dtype(dtype), | |
disable_r=FLAGS.disable_r) | |
# Compute the bound and loss | |
if bound == "iwae": | |
(_, losses, ema_op, _, _) = bounds.iwae( | |
model, | |
observations, | |
num_timesteps, | |
num_samples=num_samples) | |
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae( | |
model, | |
eval_observations, | |
num_timesteps, | |
num_samples=num_eval_samples, | |
summarize=False) | |
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) | |
elif bound == "fivo" or "fivo-aux": | |
(_, losses, ema_op, _, _) = bounds.fivo( | |
model, | |
observations, | |
num_timesteps, | |
resampling_schedule=resampling_schedule, | |
use_resampling_grads=use_resampling_grads, | |
resampling_type=FLAGS.resampling_method, | |
aux=("aux" in bound), | |
num_samples=num_samples) | |
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo( | |
model, | |
eval_observations, | |
num_timesteps, | |
resampling_schedule=resampling_schedule, | |
use_resampling_grads=False, | |
resampling_type="multinomial", | |
aux=("aux" in bound), | |
num_samples=num_eval_samples, | |
summarize=False) | |
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) | |
summ.summarize_ess(eval_log_weights, only_last_timestep=True) | |
tf.summary.scalar("log_p_hat", eval_log_p_hat) | |
# Compute and apply grads. | |
global_step = tf.train.get_or_create_global_step() | |
apply_grads = make_apply_grads_op(losses, | |
global_step, | |
learning_rate, | |
lr_decay_steps) | |
# Update the emas after applying the grads. | |
with tf.control_dependencies([apply_grads]): | |
train_op = tf.group(ema_op) | |
# We can't calculate the likelihood for most of these models | |
# so we just return zeros. | |
eval_likelihood = tf.zeros([], dtype=dtype) | |
return global_step, train_op, eval_log_p_hat, eval_likelihood | |
def create_graph(bound, state_size, num_timesteps, batch_size, | |
num_samples, num_eval_samples, resampling_schedule, | |
use_resampling_grads, learning_rate, lr_decay_steps, | |
train_p, dtype='float64'): | |
if FLAGS.use_bs: | |
true_bs = None | |
else: | |
true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)] | |
# Make the dataset. | |
true_bs, dataset = data.make_dataset( | |
bs=true_bs, | |
state_size=state_size, | |
num_timesteps=num_timesteps, | |
batch_size=batch_size, | |
num_samples=num_samples, | |
variance=FLAGS.variance, | |
prior_type=FLAGS.p_type, | |
bimodal_prior_weight=FLAGS.bimodal_prior_weight, | |
bimodal_prior_mean=FLAGS.bimodal_prior_mean, | |
transition_type=FLAGS.transition_type, | |
fixed_observation=FLAGS.fixed_observation, | |
dtype=dtype) | |
itr = dataset.make_one_shot_iterator() | |
_, observations = itr.get_next() | |
# Make the dataset for eval | |
_, eval_dataset = data.make_dataset( | |
bs=true_bs, | |
state_size=state_size, | |
num_timesteps=num_timesteps, | |
batch_size=num_eval_samples, | |
num_samples=num_eval_samples, | |
variance=FLAGS.variance, | |
prior_type=FLAGS.p_type, | |
bimodal_prior_weight=FLAGS.bimodal_prior_weight, | |
bimodal_prior_mean=FLAGS.bimodal_prior_mean, | |
transition_type=FLAGS.transition_type, | |
fixed_observation=FLAGS.fixed_observation, | |
dtype=dtype) | |
eval_itr = eval_dataset.make_one_shot_iterator() | |
_, eval_observations = eval_itr.get_next() | |
# Make the model. | |
if bound == "fivo-aux-td": | |
model = models.TDModel.create( | |
state_size, | |
num_timesteps, | |
variance=FLAGS.variance, | |
train_p=train_p, | |
p_type=FLAGS.p_type, | |
q_type=FLAGS.q_type, | |
mixing_coeff=FLAGS.bimodal_prior_weight, | |
prior_mode_mean=FLAGS.bimodal_prior_mean, | |
observation_variance=FLAGS.observation_variance, | |
transition_type=FLAGS.transition_type, | |
use_bs=FLAGS.use_bs, | |
dtype=tf.as_dtype(dtype), | |
random_seed=FLAGS.random_seed) | |
else: | |
model = models.Model.create( | |
state_size, | |
num_timesteps, | |
variance=FLAGS.variance, | |
train_p=train_p, | |
p_type=FLAGS.p_type, | |
q_type=FLAGS.q_type, | |
mixing_coeff=FLAGS.bimodal_prior_weight, | |
prior_mode_mean=FLAGS.bimodal_prior_mean, | |
observation_variance=FLAGS.observation_variance, | |
transition_type=FLAGS.transition_type, | |
use_bs=FLAGS.use_bs, | |
r_sigma_init=FLAGS.r_sigma_init, | |
dtype=tf.as_dtype(dtype), | |
random_seed=FLAGS.random_seed) | |
# Compute the bound and loss | |
if bound == "iwae": | |
(_, losses, ema_op, _, _) = bounds.iwae( | |
model, | |
observations, | |
num_timesteps, | |
num_samples=num_samples) | |
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae( | |
model, | |
eval_observations, | |
num_timesteps, | |
num_samples=num_eval_samples, | |
summarize=True) | |
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) | |
elif "fivo" in bound: | |
if bound == "fivo-aux-td": | |
(_, losses, ema_op, _, _) = bounds.fivo_aux_td( | |
model, | |
observations, | |
num_timesteps, | |
resampling_schedule=resampling_schedule, | |
num_samples=num_samples) | |
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td( | |
model, | |
eval_observations, | |
num_timesteps, | |
resampling_schedule=resampling_schedule, | |
num_samples=num_eval_samples, | |
summarize=True) | |
else: | |
(_, losses, ema_op, _, _) = bounds.fivo( | |
model, | |
observations, | |
num_timesteps, | |
resampling_schedule=resampling_schedule, | |
use_resampling_grads=use_resampling_grads, | |
resampling_type=FLAGS.resampling_method, | |
aux=("aux" in bound), | |
num_samples=num_samples) | |
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo( | |
model, | |
eval_observations, | |
num_timesteps, | |
resampling_schedule=resampling_schedule, | |
use_resampling_grads=False, | |
resampling_type="multinomial", | |
aux=("aux" in bound), | |
num_samples=num_eval_samples, | |
summarize=True) | |
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) | |
summ.summarize_ess(eval_log_weights, only_last_timestep=True) | |
# if FLAGS.p_type == "bimodal": | |
# # create the observations that showcase the model. | |
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.], | |
# dtype=tf.float64) | |
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1) | |
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean) | |
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k | |
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1]) | |
# # run the model on the explainable observations | |
# if bound == "iwae": | |
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae( | |
# model, | |
# explain_obs, | |
# num_timesteps, | |
# num_samples=num_eval_samples) | |
# elif bound == "fivo" or "fivo-aux": | |
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo( | |
# model, | |
# explain_obs, | |
# num_timesteps, | |
# resampling_schedule=resampling_schedule, | |
# use_resampling_grads=False, | |
# resampling_type="multinomial", | |
# aux=("aux" in bound), | |
# num_samples=num_eval_samples) | |
# summ.summarize_particles(explain_states, | |
# explain_log_weights, | |
# explain_obs, | |
# model) | |
# Calculate the true likelihood. | |
if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')): | |
eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps | |
else: | |
eval_likelihood = tf.zeros_like(eval_log_p_hat) | |
tf.summary.scalar("log_p_hat", eval_log_p_hat) | |
tf.summary.scalar("likelihood", eval_likelihood) | |
tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat) | |
summ.summarize_model(model, true_bs, eval_observations, eval_states, bound, | |
summarize_r=not bound == "fivo-aux-td") | |
# Compute and apply grads. | |
global_step = tf.train.get_or_create_global_step() | |
apply_grads = make_apply_grads_op(losses, | |
global_step, | |
learning_rate, | |
lr_decay_steps) | |
# Update the emas after applying the grads. | |
with tf.control_dependencies([apply_grads]): | |
train_op = tf.group(ema_op) | |
#train_op = tf.group(ema_op, add_check_numerics_ops()) | |
return global_step, train_op, eval_log_p_hat, eval_likelihood | |
def parse_resampling_schedule(schedule, num_timesteps): | |
schedule = schedule.strip().lower() | |
if schedule == "always": | |
return [True] * (num_timesteps - 1) + [False] | |
elif schedule == "never": | |
return [False] * num_timesteps | |
elif "every" in schedule: | |
n = int(schedule.split("_")[1]) | |
return [(i+1) % n == 0 for i in xrange(num_timesteps)] | |
else: | |
sched = [x.strip() == "true" for x in schedule.split(",")] | |
assert len( | |
sched | |
) == num_timesteps, "Wrong number of timesteps in resampling schedule." | |
return sched | |
def create_log_hook(step, eval_log_p_hat, eval_likelihood): | |
def summ_formatter(d): | |
return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d)) | |
hook = tf.train.LoggingTensorHook( | |
{ | |
"step": step, | |
"log_p_hat": eval_log_p_hat, | |
"likelihood": eval_likelihood, | |
}, | |
every_n_iter=FLAGS.summarize_every, | |
formatter=summ_formatter) | |
return hook | |
def create_infrequent_summary_hook(): | |
infrequent_summary_hook = tf.train.SummarySaverHook( | |
save_steps=10000, | |
output_dir=FLAGS.logdir, | |
summary_op=tf.summary.merge_all(key="infrequent_summaries") | |
) | |
return infrequent_summary_hook | |
def main(unused_argv): | |
if FLAGS.model == "long_chain": | |
resampling_schedule = parse_resampling_schedule(FLAGS.resampling, | |
FLAGS.num_timesteps + 1) | |
else: | |
resampling_schedule = parse_resampling_schedule(FLAGS.resampling, | |
FLAGS.num_timesteps) | |
if FLAGS.random_seed is None: | |
seed = np.random.randint(0, high=10000) | |
else: | |
seed = FLAGS.random_seed | |
tf.logging.info("Using random seed %d", seed) | |
if FLAGS.model == "long_chain": | |
assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type | |
assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models" | |
assert not FLAGS.use_bs, "Bs are not supported with long chain models" | |
assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match." | |
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models." | |
if FLAGS.model == "forward": | |
if "nonlinear" not in FLAGS.p_type: | |
assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model." | |
assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model." | |
assert FLAGS.observation_variance is None, "Forward model does not support observation variance." | |
assert FLAGS.num_observations == 1, "Forward model only supports 1 observation." | |
if "relaxed" in FLAGS.resampling_method: | |
FLAGS.use_resampling_grads = False | |
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling." | |
if FLAGS.observation_variance is None: | |
FLAGS.observation_variance = FLAGS.variance | |
if FLAGS.p_type == "bimodal": | |
assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p." | |
if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy": | |
assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model." | |
g = tf.Graph() | |
with g.as_default(): | |
# Set the seeds. | |
tf.set_random_seed(seed) | |
np.random.seed(seed) | |
if FLAGS.model == "long_chain": | |
(global_step, train_op, eval_log_p_hat, | |
eval_likelihood) = create_long_chain_graph( | |
FLAGS.bound, | |
FLAGS.state_size, | |
FLAGS.num_observations, | |
FLAGS.steps_per_observation, | |
FLAGS.batch_size, | |
FLAGS.num_samples, | |
FLAGS.num_eval_samples, | |
resampling_schedule, | |
FLAGS.use_resampling_grads, | |
FLAGS.learning_rate, | |
FLAGS.decay_steps) | |
else: | |
(global_step, train_op, | |
eval_log_p_hat, eval_likelihood) = create_graph( | |
FLAGS.bound, | |
FLAGS.state_size, | |
FLAGS.num_timesteps, | |
FLAGS.batch_size, | |
FLAGS.num_samples, | |
FLAGS.num_eval_samples, | |
resampling_schedule, | |
FLAGS.use_resampling_grads, | |
FLAGS.learning_rate, | |
FLAGS.decay_steps, | |
FLAGS.train_p) | |
log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)] | |
if len(tf.get_collection("infrequent_summaries")) > 0: | |
log_hooks.append(create_infrequent_summary_hook()) | |
tf.logging.info("trainable variables:") | |
tf.logging.info([v.name for v in tf.trainable_variables()]) | |
tf.logging.info("p vars:") | |
tf.logging.info([v.name for v in tf.get_collection("P_VARS")]) | |
tf.logging.info("q vars:") | |
tf.logging.info([v.name for v in tf.get_collection("Q_VARS")]) | |
tf.logging.info("r vars:") | |
tf.logging.info([v.name for v in tf.get_collection("R_VARS")]) | |
tf.logging.info("r tilde vars:") | |
tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")]) | |
with tf.train.MonitoredTrainingSession( | |
master="", | |
is_chief=True, | |
hooks=log_hooks, | |
checkpoint_dir=FLAGS.logdir, | |
save_checkpoint_secs=120, | |
save_summaries_steps=FLAGS.summarize_every, | |
log_step_count_steps=FLAGS.summarize_every) as sess: | |
cur_step = -1 | |
while True: | |
if sess.should_stop() or cur_step > FLAGS.max_steps: | |
break | |
# run a step | |
_, cur_step = sess.run([train_op, global_step]) | |
if __name__ == "__main__": | |
tf.app.run(main) | |