Spaces:
Running
Running
# Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Domain Adaptation Loss Functions. | |
The following domain adaptation loss functions are defined: | |
- Maximum Mean Discrepancy (MMD). | |
Relevant paper: | |
Gretton, Arthur, et al., | |
"A kernel two-sample test." | |
The Journal of Machine Learning Research, 2012 | |
- Correlation Loss on a batch. | |
""" | |
from functools import partial | |
import tensorflow as tf | |
import grl_op_grads # pylint: disable=unused-import | |
import grl_op_shapes # pylint: disable=unused-import | |
import grl_ops | |
import utils | |
slim = tf.contrib.slim | |
################################################################################ | |
# SIMILARITY LOSS | |
################################################################################ | |
def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix): | |
r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y. | |
Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of | |
the distributions of x and y. Here we use the kernel two sample estimate | |
using the empirical mean of the two distributions. | |
MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2 | |
= \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }, | |
where K = <\phi(x), \phi(y)>, | |
is the desired kernel function, in this case a radial basis kernel. | |
Args: | |
x: a tensor of shape [num_samples, num_features] | |
y: a tensor of shape [num_samples, num_features] | |
kernel: a function which computes the kernel in MMD. Defaults to the | |
GaussianKernelMatrix. | |
Returns: | |
a scalar denoting the squared maximum mean discrepancy loss. | |
""" | |
with tf.name_scope('MaximumMeanDiscrepancy'): | |
# \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) } | |
cost = tf.reduce_mean(kernel(x, x)) | |
cost += tf.reduce_mean(kernel(y, y)) | |
cost -= 2 * tf.reduce_mean(kernel(x, y)) | |
# We do not allow the loss to become negative. | |
cost = tf.where(cost > 0, cost, 0, name='value') | |
return cost | |
def mmd_loss(source_samples, target_samples, weight, scope=None): | |
"""Adds a similarity loss term, the MMD between two representations. | |
This Maximum Mean Discrepancy (MMD) loss is calculated with a number of | |
different Gaussian kernels. | |
Args: | |
source_samples: a tensor of shape [num_samples, num_features]. | |
target_samples: a tensor of shape [num_samples, num_features]. | |
weight: the weight of the MMD loss. | |
scope: optional name scope for summary tags. | |
Returns: | |
a scalar tensor representing the MMD loss value. | |
""" | |
sigmas = [ | |
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, | |
1e3, 1e4, 1e5, 1e6 | |
] | |
gaussian_kernel = partial( | |
utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas)) | |
loss_value = maximum_mean_discrepancy( | |
source_samples, target_samples, kernel=gaussian_kernel) | |
loss_value = tf.maximum(1e-4, loss_value) * weight | |
assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value]) | |
with tf.control_dependencies([assert_op]): | |
tag = 'MMD Loss' | |
if scope: | |
tag = scope + tag | |
tf.summary.scalar(tag, loss_value) | |
tf.losses.add_loss(loss_value) | |
return loss_value | |
def correlation_loss(source_samples, target_samples, weight, scope=None): | |
"""Adds a similarity loss term, the correlation between two representations. | |
Args: | |
source_samples: a tensor of shape [num_samples, num_features] | |
target_samples: a tensor of shape [num_samples, num_features] | |
weight: a scalar weight for the loss. | |
scope: optional name scope for summary tags. | |
Returns: | |
a scalar tensor representing the correlation loss value. | |
""" | |
with tf.name_scope('corr_loss'): | |
source_samples -= tf.reduce_mean(source_samples, 0) | |
target_samples -= tf.reduce_mean(target_samples, 0) | |
source_samples = tf.nn.l2_normalize(source_samples, 1) | |
target_samples = tf.nn.l2_normalize(target_samples, 1) | |
source_cov = tf.matmul(tf.transpose(source_samples), source_samples) | |
target_cov = tf.matmul(tf.transpose(target_samples), target_samples) | |
corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight | |
assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss]) | |
with tf.control_dependencies([assert_op]): | |
tag = 'Correlation Loss' | |
if scope: | |
tag = scope + tag | |
tf.summary.scalar(tag, corr_loss) | |
tf.losses.add_loss(corr_loss) | |
return corr_loss | |
def dann_loss(source_samples, target_samples, weight, scope=None): | |
"""Adds the domain adversarial (DANN) loss. | |
Args: | |
source_samples: a tensor of shape [num_samples, num_features]. | |
target_samples: a tensor of shape [num_samples, num_features]. | |
weight: the weight of the loss. | |
scope: optional name scope for summary tags. | |
Returns: | |
a scalar tensor representing the correlation loss value. | |
""" | |
with tf.variable_scope('dann'): | |
batch_size = tf.shape(source_samples)[0] | |
samples = tf.concat(axis=0, values=[source_samples, target_samples]) | |
samples = slim.flatten(samples) | |
domain_selection_mask = tf.concat( | |
axis=0, values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))]) | |
# Perform the gradient reversal and be careful with the shape. | |
grl = grl_ops.gradient_reversal(samples) | |
grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1])) | |
grl = slim.fully_connected(grl, 100, scope='fc1') | |
logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2') | |
domain_predictions = tf.sigmoid(logits) | |
domain_loss = tf.losses.log_loss( | |
domain_selection_mask, domain_predictions, weights=weight) | |
domain_accuracy = utils.accuracy( | |
tf.round(domain_predictions), domain_selection_mask) | |
assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss]) | |
with tf.control_dependencies([assert_op]): | |
tag_loss = 'losses/domain_loss' | |
tag_accuracy = 'losses/domain_accuracy' | |
if scope: | |
tag_loss = scope + tag_loss | |
tag_accuracy = scope + tag_accuracy | |
tf.summary.scalar(tag_loss, domain_loss) | |
tf.summary.scalar(tag_accuracy, domain_accuracy) | |
return domain_loss | |
################################################################################ | |
# DIFFERENCE LOSS | |
################################################################################ | |
def difference_loss(private_samples, shared_samples, weight=1.0, name=''): | |
"""Adds the difference loss between the private and shared representations. | |
Args: | |
private_samples: a tensor of shape [num_samples, num_features]. | |
shared_samples: a tensor of shape [num_samples, num_features]. | |
weight: the weight of the incoherence loss. | |
name: the name of the tf summary. | |
""" | |
private_samples -= tf.reduce_mean(private_samples, 0) | |
shared_samples -= tf.reduce_mean(shared_samples, 0) | |
private_samples = tf.nn.l2_normalize(private_samples, 1) | |
shared_samples = tf.nn.l2_normalize(shared_samples, 1) | |
correlation_matrix = tf.matmul( | |
private_samples, shared_samples, transpose_a=True) | |
cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight | |
cost = tf.where(cost > 0, cost, 0, name='value') | |
tf.summary.scalar('losses/Difference Loss {}'.format(name), | |
cost) | |
assert_op = tf.Assert(tf.is_finite(cost), [cost]) | |
with tf.control_dependencies([assert_op]): | |
tf.losses.add_loss(cost) | |
################################################################################ | |
# TASK LOSS | |
################################################################################ | |
def log_quaternion_loss_batch(predictions, labels, params): | |
"""A helper function to compute the error between quaternions. | |
Args: | |
predictions: A Tensor of size [batch_size, 4]. | |
labels: A Tensor of size [batch_size, 4]. | |
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. | |
Returns: | |
A Tensor of size [batch_size], denoting the error between the quaternions. | |
""" | |
use_logging = params['use_logging'] | |
assertions = [] | |
if use_logging: | |
assertions.append( | |
tf.Assert( | |
tf.reduce_all( | |
tf.less( | |
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1), | |
1e-4)), | |
['The l2 norm of each prediction quaternion vector should be 1.'])) | |
assertions.append( | |
tf.Assert( | |
tf.reduce_all( | |
tf.less( | |
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)), | |
['The l2 norm of each label quaternion vector should be 1.'])) | |
with tf.control_dependencies(assertions): | |
product = tf.multiply(predictions, labels) | |
internal_dot_products = tf.reduce_sum(product, [1]) | |
if use_logging: | |
internal_dot_products = tf.Print( | |
internal_dot_products, | |
[internal_dot_products, tf.shape(internal_dot_products)], | |
'internal_dot_products:') | |
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products)) | |
return logcost | |
def log_quaternion_loss(predictions, labels, params): | |
"""A helper function to compute the mean error between batches of quaternions. | |
The caller is expected to add the loss to the graph. | |
Args: | |
predictions: A Tensor of size [batch_size, 4]. | |
labels: A Tensor of size [batch_size, 4]. | |
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. | |
Returns: | |
A Tensor of size 1, denoting the mean error between batches of quaternions. | |
""" | |
use_logging = params['use_logging'] | |
logcost = log_quaternion_loss_batch(predictions, labels, params) | |
logcost = tf.reduce_sum(logcost, [0]) | |
batch_size = params['batch_size'] | |
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss') | |
if use_logging: | |
logcost = tf.Print( | |
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print') | |
return logcost | |