Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
import numpy as np | |
# the number of attention input to each module | |
_module_input_num = { | |
'_key_find': 0, | |
'_key_filter': 1, | |
'_val_desc': 1} | |
_module_output_type = { | |
'_key_find': 'att', | |
'_key_filter': 'att', | |
'_val_desc': 'ans' | |
} | |
INVALID_EXPR = 'INVALID_EXPR' | |
class Assembler: | |
def __init__(self, config): | |
# read the module list, and record the index of each module and <eos> | |
self.module_names = config.module_names | |
# find the index of <eos> | |
for n_s in range(len(self.module_names)): | |
if self.module_names[n_s] == '<eos>': | |
self.EOS_idx = n_s | |
break | |
# build a dictionary from module name to token index | |
self.name2idx_dict = { | |
name: n_s | |
for n_s, name in enumerate(self.module_names) | |
} | |
def module_list2tokens(self, module_list, max_len=None): | |
layout_tokens = [self.name2idx_dict[name] for name in module_list] | |
if max_len is not None: | |
if len(module_list) >= max_len: | |
raise ValueError('Not enough time steps to add <eos>') | |
layout_tokens += [self.EOS_idx] * (max_len - len(module_list)) | |
return layout_tokens | |
def _layout_tokens2str(self, layout_tokens): | |
return ' '.join([self.module_names[idx] for idx in layout_tokens]) | |
def _invalid_expr(self, layout_tokens, error_str): | |
return { | |
'module': INVALID_EXPR, | |
'expr_str': self._layout_tokens2str(layout_tokens), | |
'error': error_str | |
} | |
def _assemble_layout_tokens(self, layout_tokens, batch_idx): | |
# Every module takes a time_idx as the index from LSTM hidden states | |
# (even if it doesn't need it, like _and), and different arity of | |
# attention inputs. The output type can be either attention or answer | |
# | |
# The final assembled expression for each instance is as follows: | |
# expr_type := | |
# {'module': '_find', 'output_type': 'att', 'time_idx': idx} | |
# | {'module': '_relocate', 'output_type': 'att', 'time_idx': idx, | |
# 'inputs_0': <expr_type>} | |
# | {'module': '_and', 'output_type': 'att', 'time_idx': idx, | |
# 'inputs_0': <expr_type>, 'inputs_1': <expr_type>)} | |
# | {'module': '_describe', 'output_type': 'ans', 'time_idx': idx, | |
# 'inputs_0': <expr_type>} | |
# | {'module': INVALID_EXPR, 'expr_str': '...', 'error': '...', | |
# 'assembly_loss': <float32>} (for invalid expressions) | |
# | |
# A valid layout must contain <eos>. Assembly fails if it doesn't. | |
if not np.any(layout_tokens == self.EOS_idx): | |
return self._invalid_expr(layout_tokens, 'cannot find <eos>') | |
# Decoding Reverse Polish Notation with a stack | |
decoding_stack = [] | |
for t in range(len(layout_tokens)): | |
# decode a module/operation | |
module_idx = layout_tokens[t] | |
if module_idx == self.EOS_idx: | |
break | |
module_name = self.module_names[module_idx] | |
expr = { | |
'module': module_name, | |
'output_type': _module_output_type[module_name], | |
'time_idx': t, | |
'batch_idx': batch_idx | |
} | |
input_num = _module_input_num[module_name] | |
# Check if there are enough input in the stack | |
if len(decoding_stack) < input_num: | |
# Invalid expression. Not enough input. | |
return self._invalid_expr(layout_tokens, | |
'not enough input for ' + module_name) | |
# Get the input from stack | |
for n_input in range(input_num - 1, -1, -1): | |
stack_top = decoding_stack.pop() | |
if stack_top['output_type'] != 'att': | |
# Invalid expression. Input must be attention | |
return self._invalid_expr(layout_tokens, | |
'input incompatible for ' + module_name) | |
expr['input_%d' % n_input] = stack_top | |
decoding_stack.append(expr) | |
# After decoding the reverse polish expression, there should be exactly | |
# one expression in the stack | |
if len(decoding_stack) != 1: | |
return self._invalid_expr( | |
layout_tokens, | |
'final stack size not equal to 1 (%d remains)' % len(decoding_stack)) | |
result = decoding_stack[0] | |
# The result type should be answer, not attention | |
if result['output_type'] != 'ans': | |
return self._invalid_expr(layout_tokens, | |
'result type must be ans, not att') | |
return result | |
def assemble(self, layout_tokens_batch): | |
# layout_tokens_batch is a numpy array with shape [max_dec_len, batch_size], | |
# containing module tokens and <eos>, in Reverse Polish Notation. | |
_, batch_size = layout_tokens_batch.shape | |
expr_list = [ | |
self._assemble_layout_tokens(layout_tokens_batch[:, batch_i], batch_i) | |
for batch_i in range(batch_size) | |
] | |
expr_validity = np.array( | |
[expr['module'] != INVALID_EXPR for expr in expr_list], np.bool) | |
return expr_list, expr_validity | |