Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Run NHNet model training and eval.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
from absl import app | |
from absl import flags | |
from absl import logging | |
from six.moves import zip | |
import tensorflow as tf | |
from official.modeling.hyperparams import params_dict | |
from official.nlp.nhnet import evaluation | |
from official.nlp.nhnet import input_pipeline | |
from official.nlp.nhnet import models | |
from official.nlp.nhnet import optimizer | |
from official.nlp.transformer import metrics as transformer_metrics | |
from official.utils.misc import distribution_utils | |
from official.utils.misc import keras_utils | |
FLAGS = flags.FLAGS | |
def define_flags(): | |
"""Defines command line flags used by NHNet trainer.""" | |
## Required parameters | |
flags.DEFINE_enum("mode", "train", ["train", "eval", "train_and_eval"], | |
"Execution mode.") | |
flags.DEFINE_string("train_file_pattern", "", "Train file pattern.") | |
flags.DEFINE_string("eval_file_pattern", "", "Eval file pattern.") | |
flags.DEFINE_string( | |
"model_dir", None, | |
"The output directory where the model checkpoints will be written.") | |
# Model training specific flags. | |
flags.DEFINE_enum( | |
"distribution_strategy", "mirrored", ["tpu", "mirrored"], | |
"Distribution Strategy type to use for training. `tpu` uses TPUStrategy " | |
"for running on TPUs, `mirrored` uses GPUs with single host.") | |
flags.DEFINE_string("tpu", "", "TPU address to connect to.") | |
flags.DEFINE_string( | |
"init_checkpoint", None, | |
"Initial checkpoint (usually from a pre-trained BERT model).") | |
flags.DEFINE_integer("train_steps", 100000, "Max train steps") | |
flags.DEFINE_integer("eval_steps", 32, "Number of eval steps per run.") | |
flags.DEFINE_integer("eval_timeout", 3000, "Timeout waiting for checkpoints.") | |
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") | |
flags.DEFINE_integer("eval_batch_size", 4, "Total batch size for evaluation.") | |
flags.DEFINE_integer( | |
"steps_per_loop", 1000, | |
"Number of steps per graph-mode loop. Only training step " | |
"happens inside the loop.") | |
flags.DEFINE_integer("checkpoint_interval", 2000, "Checkpointing interval.") | |
flags.DEFINE_integer("len_title", 15, "Title length.") | |
flags.DEFINE_integer("len_passage", 200, "Passage length.") | |
flags.DEFINE_integer("num_encoder_layers", 12, | |
"Number of hidden layers of encoder.") | |
flags.DEFINE_integer("num_decoder_layers", 12, | |
"Number of hidden layers of decoder.") | |
flags.DEFINE_string("model_type", "nhnet", | |
"Model type to choose a model configuration.") | |
flags.DEFINE_integer( | |
"num_nhnet_articles", 5, | |
"Maximum number of articles in NHNet, only used when model_type=nhnet") | |
flags.DEFINE_string( | |
"params_override", | |
default=None, | |
help=("a YAML/JSON string or a YAML file which specifies additional " | |
"overrides over the default parameters")) | |
# pylint: disable=protected-access | |
class Trainer(tf.keras.Model): | |
"""A training only model.""" | |
def __init__(self, model, params): | |
super(Trainer, self).__init__() | |
self.model = model | |
self.params = params | |
self._num_replicas_in_sync = tf.distribute.get_strategy( | |
).num_replicas_in_sync | |
def call(self, inputs, mode="train"): | |
return self.model(inputs, mode) | |
def train_step(self, inputs): | |
"""The logic for one training step.""" | |
with tf.GradientTape() as tape: | |
logits, _, _ = self(inputs, mode="train", training=True) | |
targets = models.remove_sos_from_seq(inputs["target_ids"], | |
self.params.pad_token_id) | |
loss = transformer_metrics.transformer_loss(logits, targets, | |
self.params.label_smoothing, | |
self.params.vocab_size) | |
# Scales the loss, which results in using the average loss across all | |
# of the replicas for backprop. | |
scaled_loss = loss / self._num_replicas_in_sync | |
tvars = self.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
self.optimizer.apply_gradients(list(zip(grads, tvars))) | |
return { | |
"training_loss": loss, | |
"learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32) | |
} | |
def train(params, strategy, dataset=None): | |
"""Runs training.""" | |
if not dataset: | |
dataset = input_pipeline.get_input_dataset( | |
FLAGS.train_file_pattern, | |
FLAGS.train_batch_size, | |
params, | |
is_training=True, | |
strategy=strategy) | |
with strategy.scope(): | |
model = models.create_model( | |
FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint) | |
opt = optimizer.create_optimizer(params) | |
trainer = Trainer(model, params) | |
model.global_step = opt.iterations | |
trainer.compile( | |
optimizer=opt, | |
experimental_steps_per_execution=FLAGS.steps_per_loop) | |
summary_dir = os.path.join(FLAGS.model_dir, "summaries") | |
summary_callback = tf.keras.callbacks.TensorBoard( | |
summary_dir, update_freq=max(100, FLAGS.steps_per_loop)) | |
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
directory=FLAGS.model_dir, | |
max_to_keep=10, | |
step_counter=model.global_step, | |
checkpoint_interval=FLAGS.checkpoint_interval) | |
if checkpoint_manager.restore_or_initialize(): | |
logging.info("Training restored from the checkpoints in: %s", | |
FLAGS.model_dir) | |
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) | |
# Trains the model. | |
steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval) | |
epochs = FLAGS.train_steps // steps_per_epoch | |
history = trainer.fit( | |
x=dataset, | |
steps_per_epoch=steps_per_epoch, | |
epochs=epochs, | |
callbacks=[summary_callback, checkpoint_callback], | |
verbose=2) | |
train_hist = history.history | |
# Gets final loss from training. | |
stats = dict(training_loss=float(train_hist["training_loss"][-1])) | |
return stats | |
def run(): | |
"""Runs NHNet using Keras APIs.""" | |
strategy = distribution_utils.get_distribution_strategy( | |
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu) | |
if strategy: | |
logging.info("***** Number of cores used : %d", | |
strategy.num_replicas_in_sync) | |
params = models.get_model_params(FLAGS.model_type) | |
params = params_dict.override_params_dict( | |
params, FLAGS.params_override, is_strict=True) | |
params.override( | |
{ | |
"len_title": | |
FLAGS.len_title, | |
"len_passage": | |
FLAGS.len_passage, | |
"num_hidden_layers": | |
FLAGS.num_encoder_layers, | |
"num_decoder_layers": | |
FLAGS.num_decoder_layers, | |
"passage_list": | |
[chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)], | |
}, | |
is_strict=False) | |
stats = {} | |
if "train" in FLAGS.mode: | |
stats = train(params, strategy) | |
if "eval" in FLAGS.mode: | |
timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout | |
# Uses padded decoding for TPU. Always uses cache. | |
padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy) | |
params.override({ | |
"padded_decode": padded_decode, | |
}, is_strict=False) | |
stats = evaluation.continuous_eval( | |
strategy, | |
params, | |
model_type=FLAGS.model_type, | |
eval_file_pattern=FLAGS.eval_file_pattern, | |
batch_size=FLAGS.eval_batch_size, | |
eval_steps=FLAGS.eval_steps, | |
model_dir=FLAGS.model_dir, | |
timeout=timeout) | |
return stats | |
def main(_): | |
stats = run() | |
if stats: | |
logging.info("Stats:\n%s", stats) | |
if __name__ == "__main__": | |
define_flags() | |
app.run(main) | |