File size: 4,435 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Eval Cross Convolutional Model."""
import io
import os
import sys
import time

import numpy as np
from six.moves import xrange
import tensorflow as tf

import model as cross_conv_model
import reader

FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
tf.flags.DEFINE_string('data_filepattern',
                       'est',
                       'training data file pattern.')
tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
tf.flags.DEFINE_float('scale', 10.0,
                      'Scale the image after norm_scale and move the diff '
                      'to the positive realm.')
tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
tf.flags.DEFINE_integer('eval_batch_count', 100,
                        'Average the result this number of examples.')
tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')

slim = tf.contrib.slim


def _Eval():
  params = dict()
  params['batch_size'] = FLAGS.batch_size
  params['seq_len'] = FLAGS.sequence_length
  params['image_size'] = FLAGS.image_size
  params['is_training'] = False
  params['norm_scale'] = FLAGS.norm_scale
  params['scale'] = FLAGS.scale
  params['l2_loss'] = FLAGS.l2_loss
  params['reconstr_loss'] = FLAGS.reconstr_loss
  params['kl_loss'] = FLAGS.kl_loss

  eval_dir = os.path.join(FLAGS.log_root, 'eval')

  images = reader.ReadInput(
      FLAGS.data_filepattern, shuffle=False, params=params)
  images *= params['scale']
  # Increase the value makes training much faster.
  image_diff_list = reader.SequenceToImageAndDiff(images)
  model = cross_conv_model.CrossConvModel(image_diff_list, params)
  model.Build()

  summary_writer = tf.summary.FileWriter(eval_dir)
  saver = tf.train.Saver()
  sess = tf.Session('', config=tf.ConfigProto(allow_soft_placement=True))
  tf.train.start_queue_runners(sess)

  while True:
    time.sleep(60)
    try:
      ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
    except tf.errors.OutOfRangeError as e:
      sys.stderr.write('Cannot restore checkpoint: %s\n' % e)
      continue
    if not (ckpt_state and ckpt_state.model_checkpoint_path):
      sys.stderr.write('No model to eval yet at %s\n' % FLAGS.log_root)
      continue
    sys.stderr.write('Loading checkpoint %s\n' %
                     ckpt_state.model_checkpoint_path)
    saver.restore(sess, ckpt_state.model_checkpoint_path)
    # Use the empirical distribution of z from training set.
    if not tf.gfile.Exists(os.path.join(FLAGS.log_root, 'z_mean.npy')):
      sys.stderr.write('No z at %s\n' % FLAGS.log_root)
      continue

    with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy')) as f:
      sample_z_mean = np.load(io.BytesIO(f.read()))
    with tf.gfile.Open(
        os.path.join(FLAGS.log_root, 'z_stddev_log.npy')) as f:
      sample_z_stddev_log = np.load(io.BytesIO(f.read()))

    total_loss = 0.0
    for _ in xrange(FLAGS.eval_batch_count):
      loss_val, total_steps, summaries = sess.run(
          [model.loss, model.global_step, model.summary_op],
          feed_dict={model.z_mean: sample_z_mean,
                     model.z_stddev_log: sample_z_stddev_log})
      total_loss += loss_val

    summary_writer.add_summary(summaries, total_steps)
    sys.stderr.write('steps: %d, loss: %f\n' %
                     (total_steps, total_loss / FLAGS.eval_batch_count))


def main(_):
  _Eval()


if __name__ == '__main__':
  tf.app.run()