# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines the various loss functions in use by the PIXELDA model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow as tf slim = tf.contrib.slim def add_domain_classifier_losses(end_points, hparams): """Adds losses related to the domain-classifier. Args: end_points: A map of network end point names to `Tensors`. hparams: The hyperparameters struct. Returns: loss: A `Tensor` representing the total task-classifier loss. """ if hparams.domain_loss_weight == 0: tf.logging.info( 'Domain classifier loss weight is 0, so not creating losses.') return 0 # The domain prediction loss is minimized with respect to the domain # classifier features only. Its aim is to predict the domain of the images. # Note: 1 = 'real image' label, 0 = 'fake image' label transferred_domain_loss = tf.losses.sigmoid_cross_entropy( multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']), logits=end_points['transferred_domain_logits']) tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss) target_domain_loss = tf.losses.sigmoid_cross_entropy( multi_class_labels=tf.ones_like(end_points['target_domain_logits']), logits=end_points['target_domain_logits']) tf.summary.scalar('Domain_loss_target', target_domain_loss) # Compute the total domain loss: total_domain_loss = transferred_domain_loss + target_domain_loss total_domain_loss *= hparams.domain_loss_weight tf.summary.scalar('Domain_loss_total', total_domain_loss) return total_domain_loss def log_quaternion_loss_batch(predictions, labels, params): """A helper function to compute the error between quaternions. Args: predictions: A Tensor of size [batch_size, 4]. labels: A Tensor of size [batch_size, 4]. params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. Returns: A Tensor of size [batch_size], denoting the error between the quaternions. """ use_logging = params['use_logging'] assertions = [] if use_logging: assertions.append( tf.Assert( tf.reduce_all( tf.less( tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1), 1e-4)), ['The l2 norm of each prediction quaternion vector should be 1.'])) assertions.append( tf.Assert( tf.reduce_all( tf.less( tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)), ['The l2 norm of each label quaternion vector should be 1.'])) with tf.control_dependencies(assertions): product = tf.multiply(predictions, labels) internal_dot_products = tf.reduce_sum(product, [1]) if use_logging: internal_dot_products = tf.Print(internal_dot_products, [ internal_dot_products, tf.shape(internal_dot_products) ], 'internal_dot_products:') logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products)) return logcost def log_quaternion_loss(predictions, labels, params): """A helper function to compute the mean error between batches of quaternions. The caller is expected to add the loss to the graph. Args: predictions: A Tensor of size [batch_size, 4]. labels: A Tensor of size [batch_size, 4]. params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. Returns: A Tensor of size 1, denoting the mean error between batches of quaternions. """ use_logging = params['use_logging'] logcost = log_quaternion_loss_batch(predictions, labels, params) logcost = tf.reduce_sum(logcost, [0]) batch_size = params['batch_size'] logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss') if use_logging: logcost = tf.Print( logcost, [logcost], '[logcost]', name='log_quaternion_loss_print') return logcost def _quaternion_loss(labels, predictions, weight, batch_size, domain, add_summaries): """Creates a Quaternion Loss. Args: labels: The true quaternions. predictions: The predicted quaternions. weight: A scalar weight. batch_size: The size of the batches. domain: The name of the domain from which the labels were taken. add_summaries: Whether or not to add summaries for the losses. Returns: A `Tensor` representing the loss. """ assert domain in ['Source', 'Transferred'] params = {'use_logging': False, 'batch_size': batch_size} loss = weight * log_quaternion_loss(labels, predictions, params) if add_summaries: assert_op = tf.Assert(tf.is_finite(loss), [loss]) with tf.control_dependencies([assert_op]): tf.summary.histogram( 'Log_Quaternion_Loss_%s' % domain, loss, collections='losses') tf.summary.scalar( 'Task_Quaternion_Loss_%s' % domain, loss, collections='losses') return loss def _add_task_specific_losses(end_points, source_labels, num_classes, hparams, add_summaries=False): """Adds losses related to the task-classifier. Args: end_points: A map of network end point names to `Tensors`. source_labels: A dictionary of output labels to `Tensors`. num_classes: The number of classes used by the classifier. hparams: The hyperparameters struct. add_summaries: Whether or not to add the summaries. Returns: loss: A `Tensor` representing the total task-classifier loss. """ # TODO(ddohan): Make sure the l2 regularization is added to the loss one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes) total_loss = 0 if 'source_task_logits' in end_points: loss = tf.losses.softmax_cross_entropy( onehot_labels=one_hot_labels, logits=end_points['source_task_logits'], weights=hparams.source_task_loss_weight) if add_summaries: tf.summary.scalar('Task_Classifier_Loss_Source', loss) total_loss += loss if 'transferred_task_logits' in end_points: loss = tf.losses.softmax_cross_entropy( onehot_labels=one_hot_labels, logits=end_points['transferred_task_logits'], weights=hparams.transferred_task_loss_weight) if add_summaries: tf.summary.scalar('Task_Classifier_Loss_Transferred', loss) total_loss += loss ######################### # Pose specific losses. # ######################### if 'quaternion' in source_labels: total_loss += _quaternion_loss( source_labels['quaternion'], end_points['source_quaternion'], hparams.source_pose_weight, hparams.batch_size, 'Source', add_summaries) total_loss += _quaternion_loss( source_labels['quaternion'], end_points['transferred_quaternion'], hparams.transferred_pose_weight, hparams.batch_size, 'Transferred', add_summaries) if add_summaries: tf.summary.scalar('Task_Loss_Total', total_loss) return total_loss def _transferred_similarity_loss(reconstructions, source_images, weight=1.0, method='mse', max_diff=0.4, name='similarity'): """Computes a loss encouraging similarity between source and transferred. Args: reconstructions: A `Tensor` of shape [batch_size, height, width, channels] source_images: A `Tensor` of shape [batch_size, height, width, channels]. weight: Multiple similarity loss by this weight before returning method: One of: mpse = Mean Pairwise Squared Error mse = Mean Squared Error hinged_mse = Computes the mean squared error using squared differences greater than hparams.transferred_similarity_max_diff hinged_mae = Computes the mean absolute error using absolute differences greater than hparams.transferred_similarity_max_diff. max_diff: Maximum unpenalized difference for hinged losses name: Identifying name to use for creating summaries Returns: A `Tensor` representing the transferred similarity loss. Raises: ValueError: if `method` is not recognized. """ if weight == 0: return 0 source_channels = source_images.shape.as_list()[-1] reconstruction_channels = reconstructions.shape.as_list()[-1] # Convert grayscale source to RGB if target is RGB if source_channels == 1 and reconstruction_channels != 1: source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels]) if reconstruction_channels == 1 and source_channels != 1: reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels]) if method == 'mpse': reconstruction_similarity_loss_fn = ( tf.contrib.losses.mean_pairwise_squared_error) elif method == 'masked_mpse': def masked_mpse(predictions, labels, weight): """Masked mpse assuming we have a depth to create a mask from.""" assert labels.shape.as_list()[-1] == 4 mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99)) mask = tf.tile(mask, [1, 1, 1, 4]) predictions *= mask labels *= mask tf.image_summary('masked_pred', predictions) tf.image_summary('masked_label', labels) return tf.contrib.losses.mean_pairwise_squared_error( predictions, labels, weight) reconstruction_similarity_loss_fn = masked_mpse elif method == 'mse': reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error elif method == 'hinged_mse': def hinged_mse(predictions, labels, weight): diffs = tf.square(predictions - labels) diffs = tf.maximum(0.0, diffs - max_diff) return tf.reduce_mean(diffs) * weight reconstruction_similarity_loss_fn = hinged_mse elif method == 'hinged_mae': def hinged_mae(predictions, labels, weight): diffs = tf.abs(predictions - labels) diffs = tf.maximum(0.0, diffs - max_diff) return tf.reduce_mean(diffs) * weight reconstruction_similarity_loss_fn = hinged_mae else: raise ValueError('Unknown reconstruction loss %s' % method) reconstruction_similarity_loss = reconstruction_similarity_loss_fn( reconstructions, source_images, weight) name = '%s_Similarity_(%s)' % (name, method) tf.summary.scalar(name, reconstruction_similarity_loss) return reconstruction_similarity_loss def g_step_loss(source_images, source_labels, end_points, hparams, num_classes): """Configures the loss function which runs during the g-step. Args: source_images: A `Tensor` of shape [batch_size, height, width, channels]. source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys are 'class' and 'quaternion'. end_points: A map of the network end points. hparams: The hyperparameters struct. num_classes: Number of classes for classifier loss Returns: A `Tensor` representing a loss function. Raises: ValueError: if hparams.transferred_similarity_loss_weight is non-zero but hparams.transferred_similarity_loss is invalid. """ generator_loss = 0 ################################################################ # Adds a loss which encourages the discriminator probabilities # # to be high (near one). ################################################################ # As per the GAN paper, maximize the log probs, instead of minimizing # log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is # the same thing. style_transfer_loss = tf.losses.sigmoid_cross_entropy( logits=end_points['transferred_domain_logits'], multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']), weights=hparams.style_transfer_loss_weight) tf.summary.scalar('Style_transfer_loss', style_transfer_loss) generator_loss += style_transfer_loss # Optimizes the style transfer network to produce transferred images similar # to the source images. generator_loss += _transferred_similarity_loss( end_points['transferred_images'], source_images, weight=hparams.transferred_similarity_loss_weight, method=hparams.transferred_similarity_loss, name='transferred_similarity') # Optimizes the style transfer network to maximize classification accuracy. if source_labels is not None and hparams.task_tower_in_g_step: generator_loss += _add_task_specific_losses( end_points, source_labels, num_classes, hparams) * hparams.task_loss_in_g_weight return generator_loss def d_step_loss(end_points, source_labels, num_classes, hparams): """Configures the losses during the D-Step. Note that during the D-step, the model optimizes both the domain (binary) classifier and the task classifier. Args: end_points: A map of the network end points. source_labels: A dictionary of output labels to `Tensors`. num_classes: The number of classes used by the classifier. hparams: The hyperparameters struct. Returns: A `Tensor` representing the value of the D-step loss. """ domain_classifier_loss = add_domain_classifier_losses(end_points, hparams) task_classifier_loss = 0 if source_labels is not None: task_classifier_loss = _add_task_specific_losses( end_points, source_labels, num_classes, hparams, add_summaries=True) return domain_classifier_loss + task_classifier_loss