# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import sys sys.path.append(os.path.abspath(os.path.join(__file__, '../../'))) import numpy as np import tensorflow as tf from config import get_config from model_n2nmn.assembler import Assembler from model_n2nmn.model import Model from util.data_reader import DataReader from util.data_reader import SampleBuilder from util.misc import prepare_dirs_and_logger FLAGS = tf.flags.FLAGS tf.flags.DEFINE_string('snapshot_name', '00001000', 'snapshot file name') def main(_): config = prepare_dirs_and_logger(config_raw) rng = np.random.RandomState(config.random_seed) tf.set_random_seed(config.random_seed) config.rng = rng config.module_names = ['_key_find', '_key_filter', '_val_desc', ''] config.gt_layout_tokens = ['_key_find', '_key_filter', '_val_desc', ''] assembler = Assembler(config) sample_builder = SampleBuilder(config) config = sample_builder.config # update T_encoder according to data data_test = sample_builder.data_all['test'] data_reader_test = DataReader( config, data_test, assembler, shuffle=False, one_pass=True) num_vocab_txt = len(sample_builder.dict_all) num_vocab_nmn = len(assembler.module_names) num_choices = len(sample_builder.dict_all) # Network inputs text_seq_batch = tf.placeholder(tf.int32, [None, None]) seq_len_batch = tf.placeholder(tf.int32, [None]) # The model model = Model( config, sample_builder.kb, text_seq_batch, seq_len_batch, num_vocab_txt=num_vocab_txt, num_vocab_nmn=num_vocab_nmn, EOS_idx=assembler.EOS_idx, num_choices=num_choices, decoder_sampling=False) compiler = model.compiler scores = model.scores sess = tf.Session() sess.run(tf.global_variables_initializer()) snapshot_file = os.path.join(config.model_dir, FLAGS.snapshot_name) tf.logging.info('Snapshot file: %s' % snapshot_file) snapshot_saver = tf.train.Saver() snapshot_saver.restore(sess, snapshot_file) # Evaluation metrics num_questions = len(data_test.Y) tf.logging.info('# of test questions: %d' % num_questions) answer_correct = 0 layout_correct = 0 layout_valid = 0 for batch in data_reader_test.batches(): # set up input and output tensors h = sess.partial_run_setup( fetches=[model.predicted_tokens, scores], feeds=[text_seq_batch, seq_len_batch, compiler.loom_input_tensor]) # Part 1: Generate module layout tokens = sess.partial_run( h, fetches=model.predicted_tokens, feed_dict={ text_seq_batch: batch['input_seq_batch'], seq_len_batch: batch['seq_len_batch'] }) # Compute accuracy of the predicted layout gt_tokens = batch['gt_layout_batch'] layout_correct += np.sum( np.all( np.logical_or(tokens == gt_tokens, gt_tokens == assembler.EOS_idx), axis=0)) # Assemble the layout tokens into network structure expr_list, expr_validity_array = assembler.assemble(tokens) layout_valid += np.sum(expr_validity_array) labels = batch['ans_label_batch'] # Build TensorFlow Fold input for NMN expr_feed = compiler.build_feed_dict(expr_list) # Part 2: Run NMN and learning steps scores_val = sess.partial_run(h, scores, feed_dict=expr_feed) # Compute accuracy predictions = np.argmax(scores_val, axis=1) answer_correct += np.sum( np.logical_and(expr_validity_array, predictions == labels)) answer_accuracy = answer_correct * 1.0 / num_questions layout_accuracy = layout_correct * 1.0 / num_questions layout_validity = layout_valid * 1.0 / num_questions tf.logging.info('test answer accuracy = %f, ' 'test layout accuracy = %f, ' 'test layout validity = %f' % (answer_accuracy, layout_accuracy, layout_validity)) if __name__ == '__main__': config_raw, unparsed = get_config() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)