Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Multi-channel Attention.""" | |
# pylint: disable=g-classes-have-attributes | |
from __future__ import absolute_import | |
from __future__ import division | |
# from __future__ import google_type_annotations | |
from __future__ import print_function | |
import math | |
import tensorflow as tf | |
from official.modeling import tf_utils | |
from official.nlp.modeling.layers import attention | |
from official.nlp.modeling.layers import dense_einsum | |
from official.nlp.modeling.layers import masked_softmax | |
class VotingAttention(tf.keras.layers.Layer): | |
"""Voting Attention layer. | |
Arguments: | |
num_heads: the number of attention heads. | |
head_size: per-head hidden size. | |
kernel_initializer: Initializer for dense layer kernels. | |
bias_initializer: Initializer for dense layer biases. | |
kernel_regularizer: Regularizer for dense layer kernels. | |
bias_regularizer: Regularizer for dense layer biases. | |
activity_regularizer: Regularizer for dense layer activity. | |
kernel_constraint: Constraint for dense layer kernels. | |
bias_constraint: Constraint for dense layer kernels. | |
""" | |
def __init__(self, | |
num_heads, | |
head_size, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros", | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
**kwargs): | |
super(VotingAttention, self).__init__(**kwargs) | |
self._num_heads = num_heads | |
self._head_size = head_size | |
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) | |
self._bias_initializer = tf.keras.initializers.get(bias_initializer) | |
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) | |
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) | |
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) | |
self._bias_constraint = tf.keras.constraints.get(bias_constraint) | |
def build(self, unused_input_shapes): | |
self._query_dense = dense_einsum.DenseEinsum( | |
output_shape=(self._num_heads, self._head_size), | |
kernel_initializer=self._kernel_initializer, | |
bias_initializer=self._bias_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer, | |
activity_regularizer=self._activity_regularizer, | |
kernel_constraint=self._kernel_constraint, | |
bias_constraint=self._bias_constraint, | |
dtype=self.dtype, | |
name="encdocatt_query") | |
self._key_dense = dense_einsum.DenseEinsum( | |
output_shape=(self._num_heads, self._head_size), | |
kernel_initializer=self._kernel_initializer, | |
bias_initializer=self._bias_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer, | |
activity_regularizer=self._activity_regularizer, | |
kernel_constraint=self._kernel_constraint, | |
bias_constraint=self._bias_constraint, | |
dtype=self.dtype, | |
name="encdocatt_key") | |
super(VotingAttention, self).build(unused_input_shapes) | |
def call(self, encoder_outputs, doc_attention_mask): | |
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1] | |
cls_embeddings = encoder_outputs[:, :, 0, :] | |
key = self._key_dense(cls_embeddings) | |
query = self._query_dense(cls_embeddings) | |
doc_attention_mask = tf.cast(doc_attention_mask, tf.float32) | |
key = tf.einsum("BANH,BA->BANH", key, doc_attention_mask) | |
query = tf.einsum("BANH,BA->BANH", query, doc_attention_mask) | |
attention_matrix = tf.einsum("BXNH,BYNH->BNXY", query, key) | |
mask = tf.ones([num_docs, num_docs]) | |
mask = tf.linalg.set_diag(mask, tf.zeros(num_docs)) | |
attention_matrix = tf.einsum("BNXY,XY->BNXY", attention_matrix, mask) | |
doc_attention_probs = tf.einsum("BNAY->BNA", attention_matrix) | |
doc_attention_probs = tf.einsum("BNA->BA", doc_attention_probs) | |
infadder = (1.0 - doc_attention_mask) * -100000.0 | |
return tf.nn.softmax(doc_attention_probs + infadder) | |
class MultiChannelAttention(attention.MultiHeadAttention): | |
"""Multi-channel Attention layer. | |
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple | |
cross-attention target sequences. | |
""" | |
def build(self, input_shape): | |
super(MultiChannelAttention, self).build(input_shape) | |
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) | |
def call(self, inputs, attention_mask=None): | |
from_tensor = inputs[0] | |
to_tensor = inputs[1] | |
doc_attention_probs = inputs[2] | |
# Scalar dimensions referenced here: | |
# B = batch size (number of stories) | |
# A = num_docs (number of docs) | |
# F = `from_tensor` sequence length | |
# T = `to_tensor` sequence length | |
# N = `num_attention_heads` | |
# H = `size_per_head` | |
# `query_tensor` = [B, F, N ,H] | |
query_tensor = self._query_dense(from_tensor) | |
# `key_tensor` = [B, A, T, N, H] | |
key_tensor = self._key_dense(to_tensor) | |
# `value_tensor` = [B, A, T, N, H] | |
value_tensor = self._value_dense(to_tensor) | |
# Take the dot product between "query" and "key" to get the raw | |
# attention scores. | |
attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor) | |
attention_scores = tf.multiply(attention_scores, | |
1.0 / math.sqrt(float(self._key_size))) | |
# Normalize the attention scores to probabilities. | |
# `attention_probs` = [B, A, N, F, T] | |
attention_probs = self._masked_softmax(attention_scores, attention_mask) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self._dropout_layer(attention_probs) | |
# `context_layer` = [B, F, N, H] | |
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, | |
value_tensor) | |
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, | |
context_layer) | |
attention_output = self._output_dense(attention_output) | |
return attention_output | |