Spaces:
Running
Running
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Masked language model network.""" | |
# pylint: disable=g-classes-have-attributes | |
from __future__ import absolute_import | |
from __future__ import division | |
# from __future__ import google_type_annotations | |
from __future__ import print_function | |
import tensorflow as tf | |
from official.modeling import tf_utils | |
class MaskedLM(tf.keras.layers.Layer): | |
"""Masked language model network head for BERT modeling. | |
This network implements a masked language model based on the provided network. | |
It assumes that the network being passed has a "get_embedding_table()" method. | |
Arguments: | |
embedding_table: The embedding table of the targets. | |
activation: The activation, if any, for the dense layer. | |
initializer: The intializer for the dense layer. Defaults to a Glorot | |
uniform initializer. | |
output: The output style for this network. Can be either 'logits' or | |
'predictions'. | |
""" | |
def __init__(self, | |
embedding_table, | |
activation=None, | |
initializer='glorot_uniform', | |
output='logits', | |
name='cls/predictions', | |
**kwargs): | |
super(MaskedLM, self).__init__(name=name, **kwargs) | |
self.embedding_table = embedding_table | |
self.activation = activation | |
self.initializer = tf.keras.initializers.get(initializer) | |
if output not in ('predictions', 'logits'): | |
raise ValueError( | |
('Unknown `output` value "%s". `output` can be either "logits" or ' | |
'"predictions"') % output) | |
self._output_type = output | |
def build(self, input_shape): | |
self._vocab_size, hidden_size = self.embedding_table.shape | |
self.dense = tf.keras.layers.Dense( | |
hidden_size, | |
activation=self.activation, | |
kernel_initializer=self.initializer, | |
name='transform/dense') | |
self.layer_norm = tf.keras.layers.LayerNormalization( | |
axis=-1, epsilon=1e-12, name='transform/LayerNorm') | |
self.bias = self.add_weight( | |
'output_bias/bias', | |
shape=(self._vocab_size,), | |
initializer='zeros', | |
trainable=True) | |
super(MaskedLM, self).build(input_shape) | |
def call(self, sequence_data, masked_positions): | |
masked_lm_input = self._gather_indexes(sequence_data, masked_positions) | |
lm_data = self.dense(masked_lm_input) | |
lm_data = self.layer_norm(lm_data) | |
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) | |
logits = tf.nn.bias_add(lm_data, self.bias) | |
masked_positions_shape = tf_utils.get_shape_list( | |
masked_positions, name='masked_positions_tensor') | |
logits = tf.reshape(logits, | |
[-1, masked_positions_shape[1], self._vocab_size]) | |
if self._output_type == 'logits': | |
return logits | |
return tf.nn.log_softmax(logits) | |
def get_config(self): | |
raise NotImplementedError('MaskedLM cannot be directly serialized because ' | |
'it has variable sharing logic.') | |
def _gather_indexes(self, sequence_tensor, positions): | |
"""Gathers the vectors at the specific positions. | |
Args: | |
sequence_tensor: Sequence output of `BertModel` layer of shape | |
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of | |
hidden units of `BertModel` layer. | |
positions: Positions ids of tokens in sequence to mask for pretraining | |
of with dimension (batch_size, num_predictions) where | |
`num_predictions` is maximum number of tokens to mask out and predict | |
per each sequence. | |
Returns: | |
Masked out sequence tensor of shape (batch_size * num_predictions, | |
num_hidden). | |
""" | |
sequence_shape = tf_utils.get_shape_list( | |
sequence_tensor, name='sequence_output_tensor') | |
batch_size, seq_length, width = sequence_shape | |
flat_offsets = tf.reshape( | |
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) | |
flat_positions = tf.reshape(positions + flat_offsets, [-1]) | |
flat_sequence_tensor = tf.reshape(sequence_tensor, | |
[batch_size * seq_length, width]) | |
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) | |
return output_tensor | |