Spaces:
Running
Running
# Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
# pylint: disable=line-too-long | |
"""Evaluation for Domain Separation Networks (DSNs).""" | |
# pylint: enable=line-too-long | |
import math | |
import numpy as np | |
from six.moves import xrange | |
import tensorflow as tf | |
from domain_adaptation.datasets import dataset_factory | |
from domain_adaptation.domain_separation import losses | |
from domain_adaptation.domain_separation import models | |
slim = tf.contrib.slim | |
FLAGS = tf.app.flags.FLAGS | |
tf.app.flags.DEFINE_integer('batch_size', 32, | |
'The number of images in each batch.') | |
tf.app.flags.DEFINE_string('master', '', | |
'BNS name of the TensorFlow master to use.') | |
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/', | |
'Directory where the model was written to.') | |
tf.app.flags.DEFINE_string( | |
'eval_dir', '/tmp/da/', | |
'Directory where we should write the tf summaries to.') | |
tf.app.flags.DEFINE_string('dataset_dir', None, | |
'The directory where the dataset files are stored.') | |
tf.app.flags.DEFINE_string('dataset', 'mnist_m', | |
'Which dataset to test on: "mnist", "mnist_m".') | |
tf.app.flags.DEFINE_string('split', 'valid', | |
'Which portion to test on: "valid", "test".') | |
tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.') | |
tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist', | |
'The basic tower building block.') | |
tf.app.flags.DEFINE_bool('enable_precision_recall', False, | |
'If True, precision and recall for each class will ' | |
'be added to the metrics.') | |
tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.') | |
def quaternion_metric(predictions, labels): | |
params = {'batch_size': FLAGS.batch_size, 'use_logging': False} | |
logcost = losses.log_quaternion_loss_batch(predictions, labels, params) | |
return slim.metrics.streaming_mean(logcost) | |
def angle_diff(true_q, pred_q): | |
angles = 2 * ( | |
180.0 / | |
np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1))) | |
return angles | |
def provide_batch_fn(): | |
""" The provide_batch function to use. """ | |
return dataset_factory.provide_batch | |
def main(_): | |
g = tf.Graph() | |
with g.as_default(): | |
# Load the data. | |
images, labels = provide_batch_fn()( | |
FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4) | |
num_classes = labels['classes'].get_shape().as_list()[1] | |
tf.summary.image('eval_images', images, max_outputs=3) | |
# Define the model: | |
with tf.variable_scope('towers'): | |
basic_tower = getattr(models, FLAGS.basic_tower) | |
predictions, endpoints = basic_tower( | |
images, | |
num_classes=num_classes, | |
is_training=False, | |
batch_norm_params=None) | |
metric_names_to_values = {} | |
# Define the metrics: | |
if 'quaternions' in labels: # Also have to evaluate pose estimation! | |
quaternion_loss = quaternion_metric(labels['quaternions'], | |
endpoints['quaternion_pred']) | |
angle_errors, = tf.py_func( | |
angle_diff, [labels['quaternions'], endpoints['quaternion_pred']], | |
[tf.float32]) | |
metric_names_to_values[ | |
'Angular mean error'] = slim.metrics.streaming_mean(angle_errors) | |
metric_names_to_values['Quaternion Loss'] = quaternion_loss | |
accuracy = tf.contrib.metrics.streaming_accuracy( | |
tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1)) | |
predictions = tf.argmax(predictions, 1) | |
labels = tf.argmax(labels['classes'], 1) | |
metric_names_to_values['Accuracy'] = accuracy | |
if FLAGS.enable_precision_recall: | |
for i in xrange(num_classes): | |
index_map = tf.one_hot(i, depth=num_classes) | |
name = 'PR/Precision_{}'.format(i) | |
metric_names_to_values[name] = slim.metrics.streaming_precision( | |
tf.gather(index_map, predictions), tf.gather(index_map, labels)) | |
name = 'PR/Recall_{}'.format(i) | |
metric_names_to_values[name] = slim.metrics.streaming_recall( | |
tf.gather(index_map, predictions), tf.gather(index_map, labels)) | |
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map( | |
metric_names_to_values) | |
# Create the summary ops such that they also print out to std output: | |
summary_ops = [] | |
for metric_name, metric_value in names_to_values.iteritems(): | |
op = tf.summary.scalar(metric_name, metric_value) | |
op = tf.Print(op, [metric_value], metric_name) | |
summary_ops.append(op) | |
# This ensures that we make a single pass over all of the data. | |
num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) | |
# Setup the global step. | |
slim.get_or_create_global_step() | |
slim.evaluation.evaluation_loop( | |
FLAGS.master, | |
checkpoint_dir=FLAGS.checkpoint_dir, | |
logdir=FLAGS.eval_dir, | |
num_evals=num_batches, | |
eval_op=names_to_updates.values(), | |
summary_op=tf.summary.merge(summary_ops)) | |
if __name__ == '__main__': | |
tf.app.run() | |