# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Training for Domain Separation Networks (DSNs).""" from __future__ import division import tensorflow as tf from domain_adaptation.datasets import dataset_factory import dsn slim = tf.contrib.slim FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.') tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic', 'Source dataset to train on.') tf.app.flags.DEFINE_string('target_dataset', 'pose_real', 'Target dataset to train on.') tf.app.flags.DEFINE_string('target_labeled_dataset', 'none', 'Target dataset to train on.') tf.app.flags.DEFINE_string('dataset_dir', None, 'The directory where the dataset files are stored.') tf.app.flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.') tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/', 'Directory where to write event logs.') tf.app.flags.DEFINE_string( 'layers_to_regularize', 'fc3', 'Comma-separated list of layer names to use MMD regularization on.') tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate') tf.app.flags.DEFINE_float('alpha_weight', 1e-6, 'The coefficient for scaling the reconstruction ' 'loss.') tf.app.flags.DEFINE_float( 'beta_weight', 1e-6, 'The coefficient for scaling the private/shared difference loss.') tf.app.flags.DEFINE_float( 'gamma_weight', 1e-6, 'The coefficient for scaling the shared encoding similarity loss.') tf.app.flags.DEFINE_float('pose_weight', 0.125, 'The coefficient for scaling the pose loss.') tf.app.flags.DEFINE_float( 'weight_decay', 1e-6, 'The coefficient for the L2 regularization applied for all weights.') tf.app.flags.DEFINE_integer( 'save_summaries_secs', 60, 'The frequency with which summaries are saved, in seconds.') tf.app.flags.DEFINE_integer( 'save_interval_secs', 60, 'The frequency with which the model is saved, in seconds.') tf.app.flags.DEFINE_integer( 'max_number_of_steps', None, 'The maximum number of gradient steps. Use None to train indefinitely.') tf.app.flags.DEFINE_integer( 'domain_separation_startpoint', 1, 'The global step to add the domain separation losses.') tf.app.flags.DEFINE_integer( 'bipartite_assignment_top_k', 3, 'The number of top-k matches to use in bipartite matching adaptation.') tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.') tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.') tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.') tf.app.flags.DEFINE_bool('use_separation', False, 'Use our domain separation model.') tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.') tf.app.flags.DEFINE_integer( 'ps_tasks', 0, 'The number of parameter servers. If the value is 0, then the parameters ' 'are handled locally by the worker.') tf.app.flags.DEFINE_integer( 'num_readers', 4, 'The number of parallel readers that read data from the dataset.') tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4, 'The number of threads used to create the batches.') tf.app.flags.DEFINE_integer( 'task', 0, 'The Task ID. This value is used when training with multiple workers to ' 'identify each worker.') tf.app.flags.DEFINE_string('decoder_name', 'small_decoder', 'The decoder to use.') tf.app.flags.DEFINE_string('encoder_name', 'default_encoder', 'The encoder to use.') ################################################################################ # Flags that control the architecture and losses ################################################################################ tf.app.flags.DEFINE_string( 'similarity_loss', 'grl', 'The method to use for encouraging the common encoder codes to be ' 'similar, one of "grl", "mmd", "corr".') tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares', 'The name of the reconstruction loss.') tf.app.flags.DEFINE_string('basic_tower', 'pose_mini', 'The basic tower building block.') def provide_batch_fn(): """ The provide_batch function to use. """ return dataset_factory.provide_batch def main(_): model_params = { 'use_separation': FLAGS.use_separation, 'domain_separation_startpoint': FLAGS.domain_separation_startpoint, 'layers_to_regularize': FLAGS.layers_to_regularize, 'alpha_weight': FLAGS.alpha_weight, 'beta_weight': FLAGS.beta_weight, 'gamma_weight': FLAGS.gamma_weight, 'pose_weight': FLAGS.pose_weight, 'recon_loss_name': FLAGS.recon_loss_name, 'decoder_name': FLAGS.decoder_name, 'encoder_name': FLAGS.encoder_name, 'weight_decay': FLAGS.weight_decay, 'batch_size': FLAGS.batch_size, 'use_logging': FLAGS.use_logging, 'ps_tasks': FLAGS.ps_tasks, 'task': FLAGS.task, } g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Load the data. source_images, source_labels = provide_batch_fn()( FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, FLAGS.batch_size, FLAGS.num_preprocessing_threads) target_images, target_labels = provide_batch_fn()( FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers, FLAGS.batch_size, FLAGS.num_preprocessing_threads) # In the unsupervised case all the samples in the labeled # domain are from the source domain. domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],), True) # When using the semisupervised model we include labeled target data in # the source labelled data. if FLAGS.target_labeled_dataset != 'none': # 1000 is the maximum number of labelled target samples that exists in # the datasets. target_semi_images, target_semi_labels = provide_batch_fn()( FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size) # Calculate the proportion of source domain samples in the semi- # supervised setting, so that the proportion is set accordingly in the # batches. proportion = float(source_labels['num_train_samples']) / ( source_labels['num_train_samples'] + target_semi_labels['num_train_samples']) rnd_tensor = tf.random_uniform( (target_semi_images.get_shape().as_list()[0],)) domain_selection_mask = rnd_tensor < proportion source_images = tf.where(domain_selection_mask, source_images, target_semi_images) source_class_labels = tf.where(domain_selection_mask, source_labels['classes'], target_semi_labels['classes']) if 'quaternions' in source_labels: source_pose_labels = tf.where(domain_selection_mask, source_labels['quaternions'], target_semi_labels['quaternions']) (source_images, source_class_labels, source_pose_labels, domain_selection_mask) = tf.train.shuffle_batch( [ source_images, source_class_labels, source_pose_labels, domain_selection_mask ], FLAGS.batch_size, 50000, 5000, num_threads=1, enqueue_many=True) else: (source_images, source_class_labels, domain_selection_mask) = tf.train.shuffle_batch( [source_images, source_class_labels, domain_selection_mask], FLAGS.batch_size, 50000, 5000, num_threads=1, enqueue_many=True) source_labels = {} source_labels['classes'] = source_class_labels if 'quaternions' in source_labels: source_labels['quaternions'] = source_pose_labels slim.get_or_create_global_step() tf.summary.image('source_images', source_images, max_outputs=3) tf.summary.image('target_images', target_images, max_outputs=3) dsn.create_model( source_images, source_labels, domain_selection_mask, target_images, target_labels, FLAGS.similarity_loss, model_params, basic_tower_name=FLAGS.basic_tower) # Configure the optimization scheme: learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), FLAGS.decay_steps, FLAGS.decay_rate, staircase=True, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('total_loss', tf.losses.get_total_loss()) opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) tf.logging.set_verbosity(tf.logging.INFO) # Run training. loss_tensor = slim.learning.create_train_op( slim.losses.get_total_loss(), opt, summarize_gradients=True, colocate_gradients_with_ops=True) slim.learning.train( train_op=loss_tensor, logdir=FLAGS.train_log_dir, master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.max_number_of_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs) if __name__ == '__main__': tf.app.run()