Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# ============================================================================== | |
"""Model using memory component. | |
The model embeds images using a standard CNN architecture. | |
These embeddings are used as keys to the memory component, | |
which returns nearest neighbors. | |
""" | |
import tensorflow as tf | |
import memory | |
FLAGS = tf.flags.FLAGS | |
class BasicClassifier(object): | |
def __init__(self, output_dim): | |
self.output_dim = output_dim | |
def core_builder(self, memory_val, x, y): | |
del x, y | |
y_pred = memory_val | |
loss = 0.0 | |
return loss, y_pred | |
class LeNet(object): | |
"""Standard CNN architecture.""" | |
def __init__(self, image_size, num_channels, hidden_dim): | |
self.image_size = image_size | |
self.num_channels = num_channels | |
self.hidden_dim = hidden_dim | |
self.matrix_init = tf.truncated_normal_initializer(stddev=0.1) | |
self.vector_init = tf.constant_initializer(0.0) | |
def core_builder(self, x): | |
"""Embeds x using standard CNN architecture. | |
Args: | |
x: Batch of images as a 2-d Tensor [batch_size, -1]. | |
Returns: | |
A 2-d Tensor [batch_size, hidden_dim] of embedded images. | |
""" | |
ch1 = 32 * 2 # number of channels in 1st layer | |
ch2 = 64 * 2 # number of channels in 2nd layer | |
conv1_weights = tf.get_variable('conv1_w', | |
[3, 3, self.num_channels, ch1], | |
initializer=self.matrix_init) | |
conv1_biases = tf.get_variable('conv1_b', [ch1], | |
initializer=self.vector_init) | |
conv1a_weights = tf.get_variable('conv1a_w', | |
[3, 3, ch1, ch1], | |
initializer=self.matrix_init) | |
conv1a_biases = tf.get_variable('conv1a_b', [ch1], | |
initializer=self.vector_init) | |
conv2_weights = tf.get_variable('conv2_w', [3, 3, ch1, ch2], | |
initializer=self.matrix_init) | |
conv2_biases = tf.get_variable('conv2_b', [ch2], | |
initializer=self.vector_init) | |
conv2a_weights = tf.get_variable('conv2a_w', [3, 3, ch2, ch2], | |
initializer=self.matrix_init) | |
conv2a_biases = tf.get_variable('conv2a_b', [ch2], | |
initializer=self.vector_init) | |
# fully connected | |
fc1_weights = tf.get_variable( | |
'fc1_w', [self.image_size // 4 * self.image_size // 4 * ch2, | |
self.hidden_dim], initializer=self.matrix_init) | |
fc1_biases = tf.get_variable('fc1_b', [self.hidden_dim], | |
initializer=self.vector_init) | |
# define model | |
x = tf.reshape(x, | |
[-1, self.image_size, self.image_size, self.num_channels]) | |
batch_size = tf.shape(x)[0] | |
conv1 = tf.nn.conv2d(x, conv1_weights, | |
strides=[1, 1, 1, 1], padding='SAME') | |
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases)) | |
conv1 = tf.nn.conv2d(relu1, conv1a_weights, | |
strides=[1, 1, 1, 1], padding='SAME') | |
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1a_biases)) | |
pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], | |
strides=[1, 2, 2, 1], padding='SAME') | |
conv2 = tf.nn.conv2d(pool1, conv2_weights, | |
strides=[1, 1, 1, 1], padding='SAME') | |
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases)) | |
conv2 = tf.nn.conv2d(relu2, conv2a_weights, | |
strides=[1, 1, 1, 1], padding='SAME') | |
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2a_biases)) | |
pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], | |
strides=[1, 2, 2, 1], padding='SAME') | |
reshape = tf.reshape(pool2, [batch_size, -1]) | |
hidden = tf.matmul(reshape, fc1_weights) + fc1_biases | |
return hidden | |
class Model(object): | |
"""Model for coordinating between CNN embedder and Memory module.""" | |
def __init__(self, input_dim, output_dim, rep_dim, memory_size, vocab_size, | |
learning_rate=0.0001, use_lsh=False): | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.rep_dim = rep_dim | |
self.memory_size = memory_size | |
self.vocab_size = vocab_size | |
self.learning_rate = learning_rate | |
self.use_lsh = use_lsh | |
self.embedder = self.get_embedder() | |
self.memory = self.get_memory() | |
self.classifier = self.get_classifier() | |
self.global_step = tf.train.get_or_create_global_step() | |
def get_embedder(self): | |
return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim) | |
def get_memory(self): | |
cls = memory.LSHMemory if self.use_lsh else memory.Memory | |
return cls(self.rep_dim, self.memory_size, self.vocab_size) | |
def get_classifier(self): | |
return BasicClassifier(self.output_dim) | |
def core_builder(self, x, y, keep_prob, use_recent_idx=True): | |
embeddings = self.embedder.core_builder(x) | |
if keep_prob < 1.0: | |
embeddings = tf.nn.dropout(embeddings, keep_prob) | |
memory_val, _, teacher_loss = self.memory.query( | |
embeddings, y, use_recent_idx=use_recent_idx) | |
loss, y_pred = self.classifier.core_builder(memory_val, x, y) | |
return loss + teacher_loss, y_pred | |
def train(self, x, y): | |
loss, _ = self.core_builder(x, y, keep_prob=0.3) | |
gradient_ops = self.training_ops(loss) | |
return loss, gradient_ops | |
def eval(self, x, y): | |
_, y_preds = self.core_builder(x, y, keep_prob=1.0, | |
use_recent_idx=False) | |
return y_preds | |
def get_xy_placeholders(self): | |
return (tf.placeholder(tf.float32, [None, self.input_dim]), | |
tf.placeholder(tf.int32, [None])) | |
def setup(self): | |
"""Sets up all components of the computation graph.""" | |
self.x, self.y = self.get_xy_placeholders() | |
# This context creates variables | |
with tf.variable_scope('core', reuse=None): | |
self.loss, self.gradient_ops = self.train(self.x, self.y) | |
# And this one re-uses them (thus the `reuse=True`) | |
with tf.variable_scope('core', reuse=True): | |
self.y_preds = self.eval(self.x, self.y) | |
def training_ops(self, loss): | |
opt = self.get_optimizer() | |
params = tf.trainable_variables() | |
gradients = tf.gradients(loss, params) | |
clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0) | |
return opt.apply_gradients(zip(clipped_gradients, params), | |
global_step=self.global_step) | |
def get_optimizer(self): | |
return tf.train.AdamOptimizer(learning_rate=self.learning_rate, | |
epsilon=1e-4) | |
def one_step(self, sess, x, y): | |
outputs = [self.loss, self.gradient_ops] | |
return sess.run(outputs, feed_dict={self.x: x, self.y: y}) | |
def episode_step(self, sess, x, y, clear_memory=False): | |
"""Performs training steps on episodic input. | |
Args: | |
sess: A Tensorflow Session. | |
x: A list of batches of images defining the episode. | |
y: A list of batches of labels corresponding to x. | |
clear_memory: Whether to clear the memory before the episode. | |
Returns: | |
List of losses the same length as the episode. | |
""" | |
outputs = [self.loss, self.gradient_ops] | |
if clear_memory: | |
self.clear_memory(sess) | |
losses = [] | |
for xx, yy in zip(x, y): | |
out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy}) | |
loss = out[0] | |
losses.append(loss) | |
return losses | |
def predict(self, sess, x, y=None): | |
"""Predict the labels on a single batch of examples. | |
Args: | |
sess: A Tensorflow Session. | |
x: A batch of images. | |
y: The labels for the images in x. | |
This allows for updating the memory. | |
Returns: | |
Predicted y. | |
""" | |
# Storing current memory state to restore it after prediction | |
mem_keys, mem_vals, mem_age, _ = self.memory.get() | |
cur_memory = ( | |
tf.identity(mem_keys), | |
tf.identity(mem_vals), | |
tf.identity(mem_age), | |
None, | |
) | |
outputs = [self.y_preds] | |
if y is None: | |
ret = sess.run(outputs, feed_dict={self.x: x}) | |
else: | |
ret = sess.run(outputs, feed_dict={self.x: x, self.y: y}) | |
# Restoring memory state | |
self.memory.set(*cur_memory) | |
return ret | |
def episode_predict(self, sess, x, y, clear_memory=False): | |
"""Predict the labels on an episode of examples. | |
Args: | |
sess: A Tensorflow Session. | |
x: A list of batches of images. | |
y: A list of labels for the images in x. | |
This allows for updating the memory. | |
clear_memory: Whether to clear the memory before the episode. | |
Returns: | |
List of predicted y. | |
""" | |
# Storing current memory state to restore it after prediction | |
mem_keys, mem_vals, mem_age, _ = self.memory.get() | |
cur_memory = ( | |
tf.identity(mem_keys), | |
tf.identity(mem_vals), | |
tf.identity(mem_age), | |
None, | |
) | |
if clear_memory: | |
self.clear_memory(sess) | |
outputs = [self.y_preds] | |
y_preds = [] | |
for xx, yy in zip(x, y): | |
out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy}) | |
y_pred = out[0] | |
y_preds.append(y_pred) | |
# Restoring memory state | |
self.memory.set(*cur_memory) | |
return y_preds | |
def clear_memory(self, sess): | |
sess.run([self.memory.clear()]) | |