Spaces:
Running
Running
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utility helpers for Bert2Bert.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
# from __future__ import google_type_annotations | |
from __future__ import print_function | |
from absl import logging | |
import tensorflow as tf | |
from typing import Optional, Text | |
from official.modeling.hyperparams import params_dict | |
from official.nlp.bert import configs | |
from official.nlp.nhnet import configs as nhnet_configs | |
def get_bert_config_from_params( | |
params: params_dict.ParamsDict) -> configs.BertConfig: | |
"""Converts a BertConfig to ParamsDict.""" | |
return configs.BertConfig.from_dict(params.as_dict()) | |
def get_test_params(cls=nhnet_configs.BERT2BERTConfig): | |
return cls.from_args(**nhnet_configs.UNITTEST_CONFIG) | |
# pylint: disable=protected-access | |
def encoder_common_layers(transformer_block): | |
return [ | |
transformer_block._attention_layer, | |
transformer_block._attention_layer_norm, | |
transformer_block._intermediate_dense, transformer_block._output_dense, | |
transformer_block._output_layer_norm | |
] | |
# pylint: enable=protected-access | |
def initialize_bert2bert_from_pretrained_bert( | |
bert_encoder: tf.keras.layers.Layer, | |
bert_decoder: tf.keras.layers.Layer, | |
init_checkpoint: Optional[Text] = None) -> None: | |
"""Helper function to initialze Bert2Bert from Bert pretrained checkpoint.""" | |
ckpt = tf.train.Checkpoint(model=bert_encoder) | |
logging.info( | |
"Checkpoint file %s found and restoring from " | |
"initial checkpoint for core model.", init_checkpoint) | |
status = ckpt.restore(init_checkpoint) | |
# Expects the bert model is a subset of checkpoint as pooling layer is | |
# not used. | |
status.assert_existing_objects_matched() | |
logging.info("Loading from checkpoint file completed.") | |
# Saves a checkpoint with transformer layers. | |
encoder_layers = [] | |
for transformer_block in bert_encoder.transformer_layers: | |
encoder_layers.extend(encoder_common_layers(transformer_block)) | |
# Restores from the checkpoint with encoder layers. | |
decoder_layers_to_initialize = [] | |
for decoder_block in bert_decoder.decoder.layers: | |
decoder_layers_to_initialize.extend( | |
decoder_block.common_layers_with_encoder()) | |
if len(decoder_layers_to_initialize) != len(encoder_layers): | |
raise ValueError( | |
"Source encoder layers with %d objects does not match destination " | |
"decoder layers with %d objects." % | |
(len(decoder_layers_to_initialize), len(encoder_layers))) | |
for dest_layer, source_layer in zip(decoder_layers_to_initialize, | |
encoder_layers): | |
try: | |
dest_layer.set_weights(source_layer.get_weights()) | |
except ValueError as e: | |
logging.error( | |
"dest_layer: %s failed to set weights from " | |
"source_layer: %s as %s", dest_layer.name, source_layer.name, str(e)) | |