# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Trains the PixelDA model.""" from functools import partial import os # Dependency imports import tensorflow as tf from domain_adaptation.datasets import dataset_factory from domain_adaptation.pixel_domain_adaptation import pixelda_losses from domain_adaptation.pixel_domain_adaptation import pixelda_model from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess from domain_adaptation.pixel_domain_adaptation import pixelda_utils from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams slim = tf.contrib.slim flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.') flags.DEFINE_integer( 'ps_tasks', 0, 'The number of parameter servers. If the value is 0, then the parameters ' 'are handled locally by the worker.') flags.DEFINE_integer( 'task', 0, 'The Task ID. This value is used when training with multiple workers to ' 'identify each worker.') flags.DEFINE_string('train_log_dir', '/tmp/pixelda/', 'Directory where to write event logs.') flags.DEFINE_integer( 'save_summaries_steps', 500, 'The frequency with which summaries are saved, in seconds.') flags.DEFINE_integer('save_interval_secs', 300, 'The frequency with which the model is saved, in seconds.') flags.DEFINE_boolean('summarize_gradients', False, 'Whether to summarize model gradients') flags.DEFINE_integer( 'print_loss_steps', 100, 'The frequency with which the losses are printed, in steps.') flags.DEFINE_string('source_dataset', 'mnist', 'The name of the source dataset.' ' If hparams="arch=dcgan", this flag is ignored.') flags.DEFINE_string('target_dataset', 'mnist_m', 'The name of the target dataset.') flags.DEFINE_string('source_split_name', 'train', 'Name of the train split for the source.') flags.DEFINE_string('target_split_name', 'train', 'Name of the train split for the target.') flags.DEFINE_string('dataset_dir', '', 'The directory where the datasets can be found.') flags.DEFINE_integer( 'num_readers', 4, 'The number of parallel readers that read data from the dataset.') flags.DEFINE_integer('num_preprocessing_threads', 4, 'The number of threads used to create the batches.') # HParams flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values') def _get_vars_and_update_ops(hparams, scope): """Returns the variables and update ops for a particular variable scope. Args: hparams: The hyperparameters struct. scope: The variable scope. Returns: A tuple consisting of trainable variables and update ops. """ is_trainable = lambda x: x in tf.trainable_variables() var_list = filter(is_trainable, slim.get_model_variables(scope)) global_step = slim.get_or_create_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) tf.logging.info('All variables for scope: %s', slim.get_model_variables(scope)) tf.logging.info('Trainable variables for scope: %s', var_list) return var_list, update_ops def _train(discriminator_train_op, generator_train_op, logdir, master='', is_chief=True, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, hparams=None): """Runs the training loop. Args: discriminator_train_op: A `Tensor` that, when executed, will apply the gradients and return the loss value for the discriminator. generator_train_op: A `Tensor` that, when executed, will apply the gradients and return the loss value for the generator. logdir: The directory where the graph and checkpoints are saved. master: The URL of the master. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. scaffold: An tf.train.Scaffold instance. hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the training loop. chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run inside the training loop for the chief trainer only. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If `save_summaries_steps` is set to `None`, then the default summary saver isn't used. hparams: The hparams struct. Returns: the value of the loss function after training. Raises: ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or `save_summaries_steps` are `None. """ global_step = slim.get_or_create_global_step() scaffold = scaffold or tf.train.Scaffold() hooks = hooks or [] if is_chief: session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=logdir, master=master) if chief_only_hooks: hooks.extend(chief_only_hooks) hooks.append(tf.train.StepCounterHook(output_dir=logdir)) if save_summaries_steps: if logdir is None: raise ValueError( 'logdir cannot be None when save_summaries_steps is None') hooks.append( tf.train.SummarySaverHook( scaffold=scaffold, save_steps=save_summaries_steps, output_dir=logdir)) if save_checkpoint_secs: if logdir is None: raise ValueError( 'logdir cannot be None when save_checkpoint_secs is None') hooks.append( tf.train.CheckpointSaverHook( logdir, save_secs=save_checkpoint_secs, scaffold=scaffold)) else: session_creator = tf.train.WorkerSessionCreator( scaffold=scaffold, master=master) with tf.train.MonitoredSession( session_creator=session_creator, hooks=hooks) as session: loss = None while not session.should_stop(): # Run the domain classifier op X times. for _ in range(hparams.discriminator_steps): if session.should_stop(): return loss loss, np_global_step = session.run( [discriminator_train_op, global_step]) if np_global_step % FLAGS.print_loss_steps == 0: tf.logging.info('Step %d: Discriminator Loss = %.2f', np_global_step, loss) # Run the generator op X times. for _ in range(hparams.generator_steps): if session.should_stop(): return loss loss, np_global_step = session.run([generator_train_op, global_step]) if np_global_step % FLAGS.print_loss_steps == 0: tf.logging.info('Step %d: Generator Loss = %.2f', np_global_step, loss) return loss def run_training(run_dir, checkpoint_dir, hparams): """Runs the training loop. Args: run_dir: The directory where training specific logs are placed checkpoint_dir: The directory where the checkpoints and log files are stored. hparams: The hyperparameters struct. Raises: ValueError: if hparams.arch is not recognized. """ for path in [run_dir, checkpoint_dir]: if not tf.gfile.Exists(path): tf.gfile.MakeDirs(path) # Serialize hparams to log dir hparams_filename = os.path.join(checkpoint_dir, 'hparams.json') with tf.gfile.FastGFile(hparams_filename, 'w') as f: f.write(hparams.to_json()) with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): global_step = slim.get_or_create_global_step() ######################### # Preprocess the inputs # ######################### target_dataset = dataset_factory.get_dataset( FLAGS.target_dataset, split_name='train', dataset_dir=FLAGS.dataset_dir) target_images, _ = dataset_factory.provide_batch( FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, hparams.batch_size, FLAGS.num_preprocessing_threads) num_target_classes = target_dataset.num_classes if hparams.arch not in ['dcgan']: source_dataset = dataset_factory.get_dataset( FLAGS.source_dataset, split_name='train', dataset_dir=FLAGS.dataset_dir) num_source_classes = source_dataset.num_classes source_images, source_labels = dataset_factory.provide_batch( FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, hparams.batch_size, FLAGS.num_preprocessing_threads) # Data provider provides 1 hot labels, but we expect categorical. source_labels['class'] = tf.argmax(source_labels['classes'], 1) del source_labels['classes'] if num_source_classes != num_target_classes: raise ValueError( 'Source and Target datasets must have same number of classes. ' 'Are %d and %d' % (num_source_classes, num_target_classes)) else: source_images = None source_labels = None #################### # Define the model # #################### end_points = pixelda_model.create_model( hparams, target_images, source_images=source_images, source_labels=source_labels, is_training=True, num_classes=num_target_classes) ################################# # Get the variables to optimize # ################################# generator_vars, generator_update_ops = _get_vars_and_update_ops( hparams, 'generator') discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops( hparams, 'discriminator') ######################## # Configure the losses # ######################## generator_loss = pixelda_losses.g_step_loss( source_images, source_labels, end_points, hparams, num_classes=num_target_classes) discriminator_loss = pixelda_losses.d_step_loss( end_points, source_labels, num_target_classes, hparams) ########################### # Create the training ops # ########################### learning_rate = hparams.learning_rate if hparams.lr_decay_steps: learning_rate = tf.train.exponential_decay( learning_rate, slim.get_or_create_global_step(), decay_steps=hparams.lr_decay_steps, decay_rate=hparams.lr_decay_rate, staircase=True) tf.summary.scalar('Learning_rate', learning_rate) if hparams.discriminator_steps == 0: discriminator_train_op = tf.no_op() else: discriminator_optimizer = tf.train.AdamOptimizer( learning_rate, beta1=hparams.adam_beta1) discriminator_train_op = slim.learning.create_train_op( discriminator_loss, discriminator_optimizer, update_ops=discriminator_update_ops, variables_to_train=discriminator_vars, clip_gradient_norm=hparams.clip_gradient_norm, summarize_gradients=FLAGS.summarize_gradients) if hparams.generator_steps == 0: generator_train_op = tf.no_op() else: generator_optimizer = tf.train.AdamOptimizer( learning_rate, beta1=hparams.adam_beta1) generator_train_op = slim.learning.create_train_op( generator_loss, generator_optimizer, update_ops=generator_update_ops, variables_to_train=generator_vars, clip_gradient_norm=hparams.clip_gradient_norm, summarize_gradients=FLAGS.summarize_gradients) ############# # Summaries # ############# pixelda_utils.summarize_model(end_points) pixelda_utils.summarize_transferred_grid( end_points['transferred_images'], source_images, name='Transferred') if 'source_images_recon' in end_points: pixelda_utils.summarize_transferred_grid( end_points['source_images_recon'], source_images, name='Source Reconstruction') pixelda_utils.summaries_color_distributions(end_points['transferred_images'], 'Transferred') pixelda_utils.summaries_color_distributions(target_images, 'Target') if source_images is not None: pixelda_utils.summarize_transferred(source_images, end_points['transferred_images']) pixelda_utils.summaries_color_distributions(source_images, 'Source') pixelda_utils.summaries_color_distributions( tf.abs(source_images - end_points['transferred_images']), 'Abs(Source_minus_Transferred)') number_of_steps = None if hparams.num_training_examples: # Want to control by amount of data seen, not # steps number_of_steps = hparams.num_training_examples / hparams.batch_size hooks = [tf.train.StepCounterHook(),] chief_only_hooks = [ tf.train.CheckpointSaverHook( saver=tf.train.Saver(), checkpoint_dir=run_dir, save_secs=FLAGS.save_interval_secs) ] if number_of_steps: hooks.append(tf.train.StopAtStepHook(last_step=number_of_steps)) _train( discriminator_train_op, generator_train_op, logdir=run_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=None, save_summaries_steps=FLAGS.save_summaries_steps, hparams=hparams) def main(_): tf.logging.set_verbosity(tf.logging.INFO) hparams = create_hparams(FLAGS.hparams) run_training( run_dir=FLAGS.train_log_dir, checkpoint_dir=FLAGS.train_log_dir, hparams=hparams) if __name__ == '__main__': tf.app.run()