Spaces:
Running
Running
File size: 4,648 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../../')))
import numpy as np
import tensorflow as tf
from config import get_config
from model_n2nmn.assembler import Assembler
from model_n2nmn.model import Model
from util.data_reader import DataReader
from util.data_reader import SampleBuilder
from util.misc import prepare_dirs_and_logger
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('snapshot_name', '00001000', 'snapshot file name')
def main(_):
config = prepare_dirs_and_logger(config_raw)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
config.rng = rng
config.module_names = ['_key_find', '_key_filter', '_val_desc', '<eos>']
config.gt_layout_tokens = ['_key_find', '_key_filter', '_val_desc', '<eos>']
assembler = Assembler(config)
sample_builder = SampleBuilder(config)
config = sample_builder.config # update T_encoder according to data
data_test = sample_builder.data_all['test']
data_reader_test = DataReader(
config, data_test, assembler, shuffle=False, one_pass=True)
num_vocab_txt = len(sample_builder.dict_all)
num_vocab_nmn = len(assembler.module_names)
num_choices = len(sample_builder.dict_all)
# Network inputs
text_seq_batch = tf.placeholder(tf.int32, [None, None])
seq_len_batch = tf.placeholder(tf.int32, [None])
# The model
model = Model(
config,
sample_builder.kb,
text_seq_batch,
seq_len_batch,
num_vocab_txt=num_vocab_txt,
num_vocab_nmn=num_vocab_nmn,
EOS_idx=assembler.EOS_idx,
num_choices=num_choices,
decoder_sampling=False)
compiler = model.compiler
scores = model.scores
sess = tf.Session()
sess.run(tf.global_variables_initializer())
snapshot_file = os.path.join(config.model_dir, FLAGS.snapshot_name)
tf.logging.info('Snapshot file: %s' % snapshot_file)
snapshot_saver = tf.train.Saver()
snapshot_saver.restore(sess, snapshot_file)
# Evaluation metrics
num_questions = len(data_test.Y)
tf.logging.info('# of test questions: %d' % num_questions)
answer_correct = 0
layout_correct = 0
layout_valid = 0
for batch in data_reader_test.batches():
# set up input and output tensors
h = sess.partial_run_setup(
fetches=[model.predicted_tokens, scores],
feeds=[text_seq_batch, seq_len_batch, compiler.loom_input_tensor])
# Part 1: Generate module layout
tokens = sess.partial_run(
h,
fetches=model.predicted_tokens,
feed_dict={
text_seq_batch: batch['input_seq_batch'],
seq_len_batch: batch['seq_len_batch']
})
# Compute accuracy of the predicted layout
gt_tokens = batch['gt_layout_batch']
layout_correct += np.sum(
np.all(
np.logical_or(tokens == gt_tokens, gt_tokens == assembler.EOS_idx),
axis=0))
# Assemble the layout tokens into network structure
expr_list, expr_validity_array = assembler.assemble(tokens)
layout_valid += np.sum(expr_validity_array)
labels = batch['ans_label_batch']
# Build TensorFlow Fold input for NMN
expr_feed = compiler.build_feed_dict(expr_list)
# Part 2: Run NMN and learning steps
scores_val = sess.partial_run(h, scores, feed_dict=expr_feed)
# Compute accuracy
predictions = np.argmax(scores_val, axis=1)
answer_correct += np.sum(
np.logical_and(expr_validity_array, predictions == labels))
answer_accuracy = answer_correct * 1.0 / num_questions
layout_accuracy = layout_correct * 1.0 / num_questions
layout_validity = layout_valid * 1.0 / num_questions
tf.logging.info('test answer accuracy = %f, '
'test layout accuracy = %f, '
'test layout validity = %f' %
(answer_accuracy, layout_accuracy, layout_validity))
if __name__ == '__main__':
config_raw, unparsed = get_config()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|