NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training for Domain Separation Networks (DSNs)."""
from __future__ import division
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
import dsn
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 32,
'The number of images in each batch.')
tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
'Source dataset to train on.')
tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('dataset_dir', None,
'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('master', '',
'BNS name of the TensorFlow master to use.')
tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
'Directory where to write event logs.')
tf.app.flags.DEFINE_string(
'layers_to_regularize', 'fc3',
'Comma-separated list of layer names to use MMD regularization on.')
tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
'The coefficient for scaling the reconstruction '
'loss.')
tf.app.flags.DEFINE_float(
'beta_weight', 1e-6,
'The coefficient for scaling the private/shared difference loss.')
tf.app.flags.DEFINE_float(
'gamma_weight', 1e-6,
'The coefficient for scaling the shared encoding similarity loss.')
tf.app.flags.DEFINE_float('pose_weight', 0.125,
'The coefficient for scaling the pose loss.')
tf.app.flags.DEFINE_float(
'weight_decay', 1e-6,
'The coefficient for the L2 regularization applied for all weights.')
tf.app.flags.DEFINE_integer(
'save_summaries_secs', 60,
'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
'save_interval_secs', 60,
'The frequency with which the model is saved, in seconds.')
tf.app.flags.DEFINE_integer(
'max_number_of_steps', None,
'The maximum number of gradient steps. Use None to train indefinitely.')
tf.app.flags.DEFINE_integer(
'domain_separation_startpoint', 1,
'The global step to add the domain separation losses.')
tf.app.flags.DEFINE_integer(
'bipartite_assignment_top_k', 3,
'The number of top-k matches to use in bipartite matching adaptation.')
tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
tf.app.flags.DEFINE_bool('use_separation', False,
'Use our domain separation model.')
tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
tf.app.flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
tf.app.flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
'The decoder to use.')
tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
'The encoder to use.')
################################################################################
# Flags that control the architecture and losses
################################################################################
tf.app.flags.DEFINE_string(
'similarity_loss', 'grl',
'The method to use for encouraging the common encoder codes to be '
'similar, one of "grl", "mmd", "corr".')
tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
'The name of the reconstruction loss.')
tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
'The basic tower building block.')
def provide_batch_fn():
""" The provide_batch function to use. """
return dataset_factory.provide_batch
def main(_):
model_params = {
'use_separation': FLAGS.use_separation,
'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
'layers_to_regularize': FLAGS.layers_to_regularize,
'alpha_weight': FLAGS.alpha_weight,
'beta_weight': FLAGS.beta_weight,
'gamma_weight': FLAGS.gamma_weight,
'pose_weight': FLAGS.pose_weight,
'recon_loss_name': FLAGS.recon_loss_name,
'decoder_name': FLAGS.decoder_name,
'encoder_name': FLAGS.encoder_name,
'weight_decay': FLAGS.weight_decay,
'batch_size': FLAGS.batch_size,
'use_logging': FLAGS.use_logging,
'ps_tasks': FLAGS.ps_tasks,
'task': FLAGS.task,
}
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Load the data.
source_images, source_labels = provide_batch_fn()(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
FLAGS.batch_size, FLAGS.num_preprocessing_threads)
target_images, target_labels = provide_batch_fn()(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
FLAGS.batch_size, FLAGS.num_preprocessing_threads)
# In the unsupervised case all the samples in the labeled
# domain are from the source domain.
domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
True)
# When using the semisupervised model we include labeled target data in
# the source labelled data.
if FLAGS.target_labeled_dataset != 'none':
# 1000 is the maximum number of labelled target samples that exists in
# the datasets.
target_semi_images, target_semi_labels = provide_batch_fn()(
FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
# Calculate the proportion of source domain samples in the semi-
# supervised setting, so that the proportion is set accordingly in the
# batches.
proportion = float(source_labels['num_train_samples']) / (
source_labels['num_train_samples'] +
target_semi_labels['num_train_samples'])
rnd_tensor = tf.random_uniform(
(target_semi_images.get_shape().as_list()[0],))
domain_selection_mask = rnd_tensor < proportion
source_images = tf.where(domain_selection_mask, source_images,
target_semi_images)
source_class_labels = tf.where(domain_selection_mask,
source_labels['classes'],
target_semi_labels['classes'])
if 'quaternions' in source_labels:
source_pose_labels = tf.where(domain_selection_mask,
source_labels['quaternions'],
target_semi_labels['quaternions'])
(source_images, source_class_labels, source_pose_labels,
domain_selection_mask) = tf.train.shuffle_batch(
[
source_images, source_class_labels, source_pose_labels,
domain_selection_mask
],
FLAGS.batch_size,
50000,
5000,
num_threads=1,
enqueue_many=True)
else:
(source_images, source_class_labels,
domain_selection_mask) = tf.train.shuffle_batch(
[source_images, source_class_labels, domain_selection_mask],
FLAGS.batch_size,
50000,
5000,
num_threads=1,
enqueue_many=True)
source_labels = {}
source_labels['classes'] = source_class_labels
if 'quaternions' in source_labels:
source_labels['quaternions'] = source_pose_labels
slim.get_or_create_global_step()
tf.summary.image('source_images', source_images, max_outputs=3)
tf.summary.image('target_images', target_images, max_outputs=3)
dsn.create_model(
source_images,
source_labels,
domain_selection_mask,
target_images,
target_labels,
FLAGS.similarity_loss,
model_params,
basic_tower_name=FLAGS.basic_tower)
# Configure the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.decay_steps,
FLAGS.decay_rate,
staircase=True,
name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('total_loss', tf.losses.get_total_loss())
opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
tf.logging.set_verbosity(tf.logging.INFO)
# Run training.
loss_tensor = slim.learning.create_train_op(
slim.losses.get_total_loss(),
opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
slim.learning.train(
train_op=loss_tensor,
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
number_of_steps=FLAGS.max_number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()