Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Main entry to train and evaluate DeepSpeech model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
# pylint: disable=g-bad-import-order | |
from absl import app as absl_app | |
from absl import flags | |
import tensorflow as tf | |
# pylint: enable=g-bad-import-order | |
import data.dataset as dataset | |
import decoder | |
import deep_speech_model | |
from official.utils.flags import core as flags_core | |
from official.utils.misc import distribution_utils | |
from official.utils.misc import model_helpers | |
# Default vocabulary file | |
_VOCABULARY_FILE = os.path.join( | |
os.path.dirname(__file__), "data/vocabulary.txt") | |
# Evaluation metrics | |
_WER_KEY = "WER" | |
_CER_KEY = "CER" | |
def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length): | |
"""Computes the time_steps/ctc_input_length after convolution. | |
Suppose that the original feature contains two parts: | |
1) Real spectrogram signals, spanning input_length steps. | |
2) Padded part with all 0s. | |
The total length of those two parts is denoted as max_time_steps, which is | |
the padded length of the current batch. After convolution layers, the time | |
steps of a spectrogram feature will be decreased. As we know the percentage | |
of its original length within the entire length, we can compute the time steps | |
for the signal after conv as follows (using ctc_input_length to denote): | |
ctc_input_length = (input_length / max_time_steps) * output_length_of_conv. | |
This length is then fed into ctc loss function to compute loss. | |
Args: | |
max_time_steps: max_time_steps for the batch, after padding. | |
ctc_time_steps: number of timesteps after convolution. | |
input_length: actual length of the original spectrogram, without padding. | |
Returns: | |
the ctc_input_length after convolution layer. | |
""" | |
ctc_input_length = tf.to_float(tf.multiply( | |
input_length, ctc_time_steps)) | |
return tf.to_int32(tf.floordiv( | |
ctc_input_length, tf.to_float(max_time_steps))) | |
def ctc_loss(label_length, ctc_input_length, labels, logits): | |
"""Computes the ctc loss for the current batch of predictions.""" | |
label_length = tf.to_int32(tf.squeeze(label_length)) | |
ctc_input_length = tf.to_int32(tf.squeeze(ctc_input_length)) | |
sparse_labels = tf.to_int32( | |
tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length)) | |
y_pred = tf.log(tf.transpose( | |
logits, perm=[1, 0, 2]) + tf.keras.backend.epsilon()) | |
return tf.expand_dims( | |
tf.nn.ctc_loss(labels=sparse_labels, inputs=y_pred, | |
sequence_length=ctc_input_length), | |
axis=1) | |
def evaluate_model(estimator, speech_labels, entries, input_fn_eval): | |
"""Evaluate the model performance using WER anc CER as metrics. | |
WER: Word Error Rate | |
CER: Character Error Rate | |
Args: | |
estimator: estimator to evaluate. | |
speech_labels: a string specifying all the character in the vocabulary. | |
entries: a list of data entries (audio_file, file_size, transcript) for the | |
given dataset. | |
input_fn_eval: data input function for evaluation. | |
Returns: | |
Evaluation result containing 'wer' and 'cer' as two metrics. | |
""" | |
# Get predictions | |
predictions = estimator.predict(input_fn=input_fn_eval) | |
# Get probabilities of each predicted class | |
probs = [pred["probabilities"] for pred in predictions] | |
num_of_examples = len(probs) | |
targets = [entry[2] for entry in entries] # The ground truth transcript | |
total_wer, total_cer = 0, 0 | |
greedy_decoder = decoder.DeepSpeechDecoder(speech_labels) | |
for i in range(num_of_examples): | |
# Decode string. | |
decoded_str = greedy_decoder.decode(probs[i]) | |
# Compute CER. | |
total_cer += greedy_decoder.cer(decoded_str, targets[i]) / float( | |
len(targets[i])) | |
# Compute WER. | |
total_wer += greedy_decoder.wer(decoded_str, targets[i]) / float( | |
len(targets[i].split())) | |
# Get mean value | |
total_cer /= num_of_examples | |
total_wer /= num_of_examples | |
global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP) | |
eval_results = { | |
_WER_KEY: total_wer, | |
_CER_KEY: total_cer, | |
tf.GraphKeys.GLOBAL_STEP: global_step, | |
} | |
return eval_results | |
def model_fn(features, labels, mode, params): | |
"""Define model function for deep speech model. | |
Args: | |
features: a dictionary of input_data features. It includes the data | |
input_length, label_length and the spectrogram features. | |
labels: a list of labels for the input data. | |
mode: current estimator mode; should be one of | |
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`. | |
params: a dict of hyper parameters to be passed to model_fn. | |
Returns: | |
EstimatorSpec parameterized according to the input params and the | |
current mode. | |
""" | |
num_classes = params["num_classes"] | |
input_length = features["input_length"] | |
label_length = features["label_length"] | |
features = features["features"] | |
# Create DeepSpeech2 model. | |
model = deep_speech_model.DeepSpeech2( | |
flags_obj.rnn_hidden_layers, flags_obj.rnn_type, | |
flags_obj.is_bidirectional, flags_obj.rnn_hidden_size, | |
num_classes, flags_obj.use_bias) | |
if mode == tf.estimator.ModeKeys.PREDICT: | |
logits = model(features, training=False) | |
predictions = { | |
"classes": tf.argmax(logits, axis=2), | |
"probabilities": tf.nn.softmax(logits), | |
"logits": logits | |
} | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
predictions=predictions) | |
# In training mode. | |
logits = model(features, training=True) | |
probs = tf.nn.softmax(logits) | |
ctc_input_length = compute_length_after_conv( | |
tf.shape(features)[1], tf.shape(probs)[1], input_length) | |
# Compute CTC loss | |
loss = tf.reduce_mean(ctc_loss( | |
label_length, ctc_input_length, labels, probs)) | |
optimizer = tf.train.AdamOptimizer(learning_rate=flags_obj.learning_rate) | |
global_step = tf.train.get_or_create_global_step() | |
minimize_op = optimizer.minimize(loss, global_step=global_step) | |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
# Create the train_op that groups both minimize_ops and update_ops | |
train_op = tf.group(minimize_op, update_ops) | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
loss=loss, | |
train_op=train_op) | |
def generate_dataset(data_dir): | |
"""Generate a speech dataset.""" | |
audio_conf = dataset.AudioConfig(sample_rate=flags_obj.sample_rate, | |
window_ms=flags_obj.window_ms, | |
stride_ms=flags_obj.stride_ms, | |
normalize=True) | |
train_data_conf = dataset.DatasetConfig( | |
audio_conf, | |
data_dir, | |
flags_obj.vocabulary_file, | |
flags_obj.sortagrad | |
) | |
speech_dataset = dataset.DeepSpeechDataset(train_data_conf) | |
return speech_dataset | |
def per_device_batch_size(batch_size, num_gpus): | |
"""For multi-gpu, batch-size must be a multiple of the number of GPUs. | |
Note that distribution strategy handles this automatically when used with | |
Keras. For using with Estimator, we need to get per GPU batch. | |
Args: | |
batch_size: Global batch size to be divided among devices. This should be | |
equal to num_gpus times the single-GPU batch_size for multi-gpu training. | |
num_gpus: How many GPUs are used with DistributionStrategies. | |
Returns: | |
Batch size per device. | |
Raises: | |
ValueError: if batch_size is not divisible by number of devices | |
""" | |
if num_gpus <= 1: | |
return batch_size | |
remainder = batch_size % num_gpus | |
if remainder: | |
err = ('When running with multiple GPUs, batch size ' | |
'must be a multiple of the number of available GPUs. Found {} ' | |
'GPUs with a batch size of {}; try --batch_size={} instead.' | |
).format(num_gpus, batch_size, batch_size - remainder) | |
raise ValueError(err) | |
return int(batch_size / num_gpus) | |
def run_deep_speech(_): | |
"""Run deep speech training and eval loop.""" | |
tf.set_random_seed(flags_obj.seed) | |
# Data preprocessing | |
tf.logging.info("Data preprocessing...") | |
train_speech_dataset = generate_dataset(flags_obj.train_data_dir) | |
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir) | |
# Number of label classes. Label string is "[a-z]' -" | |
num_classes = len(train_speech_dataset.speech_labels) | |
# Use distribution strategy for multi-gpu training | |
num_gpus = flags_core.get_num_gpus(flags_obj) | |
distribution_strategy = distribution_utils.get_distribution_strategy(num_gpus=num_gpus) | |
run_config = tf.estimator.RunConfig( | |
train_distribute=distribution_strategy) | |
estimator = tf.estimator.Estimator( | |
model_fn=model_fn, | |
model_dir=flags_obj.model_dir, | |
config=run_config, | |
params={ | |
"num_classes": num_classes, | |
} | |
) | |
# Benchmark logging | |
run_params = { | |
"batch_size": flags_obj.batch_size, | |
"train_epochs": flags_obj.train_epochs, | |
"rnn_hidden_size": flags_obj.rnn_hidden_size, | |
"rnn_hidden_layers": flags_obj.rnn_hidden_layers, | |
"rnn_type": flags_obj.rnn_type, | |
"is_bidirectional": flags_obj.is_bidirectional, | |
"use_bias": flags_obj.use_bias | |
} | |
per_replica_batch_size = per_device_batch_size(flags_obj.batch_size, num_gpus) | |
def input_fn_train(): | |
return dataset.input_fn( | |
per_replica_batch_size, train_speech_dataset) | |
def input_fn_eval(): | |
return dataset.input_fn( | |
per_replica_batch_size, eval_speech_dataset) | |
total_training_cycle = (flags_obj.train_epochs // | |
flags_obj.epochs_between_evals) | |
for cycle_index in range(total_training_cycle): | |
tf.logging.info("Starting a training cycle: %d/%d", | |
cycle_index + 1, total_training_cycle) | |
# Perform batch_wise dataset shuffling | |
train_speech_dataset.entries = dataset.batch_wise_dataset_shuffle( | |
train_speech_dataset.entries, cycle_index, flags_obj.sortagrad, | |
flags_obj.batch_size) | |
estimator.train(input_fn=input_fn_train) | |
# Evaluation | |
tf.logging.info("Starting to evaluate...") | |
eval_results = evaluate_model( | |
estimator, eval_speech_dataset.speech_labels, | |
eval_speech_dataset.entries, input_fn_eval) | |
# Log the WER and CER results. | |
benchmark_logger.log_evaluation_result(eval_results) | |
tf.logging.info( | |
"Iteration {}: WER = {:.2f}, CER = {:.2f}".format( | |
cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY])) | |
# If some evaluation threshold is met | |
if model_helpers.past_stop_threshold( | |
flags_obj.wer_threshold, eval_results[_WER_KEY]): | |
break | |
def define_deep_speech_flags(): | |
"""Add flags for run_deep_speech.""" | |
# Add common flags | |
flags_core.define_base( | |
data_dir=False, # we use train_data_dir and eval_data_dir instead | |
export_dir=True, | |
train_epochs=True, | |
hooks=True, | |
num_gpu=True, | |
epochs_between_evals=True | |
) | |
flags_core.define_performance( | |
num_parallel_calls=False, | |
inter_op=False, | |
intra_op=False, | |
synthetic_data=False, | |
max_train_steps=False, | |
dtype=False | |
) | |
flags_core.define_benchmark() | |
flags.adopt_module_key_flags(flags_core) | |
flags_core.set_defaults( | |
model_dir="/tmp/deep_speech_model/", | |
export_dir="/tmp/deep_speech_saved_model/", | |
train_epochs=10, | |
batch_size=128, | |
hooks="") | |
# Deep speech flags | |
flags.DEFINE_integer( | |
name="seed", default=1, | |
help=flags_core.help_wrap("The random seed.")) | |
flags.DEFINE_string( | |
name="train_data_dir", | |
default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean.csv", | |
help=flags_core.help_wrap("The csv file path of train dataset.")) | |
flags.DEFINE_string( | |
name="eval_data_dir", | |
default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean.csv", | |
help=flags_core.help_wrap("The csv file path of evaluation dataset.")) | |
flags.DEFINE_bool( | |
name="sortagrad", default=True, | |
help=flags_core.help_wrap( | |
"If true, sort examples by audio length and perform no " | |
"batch_wise shuffling for the first epoch.")) | |
flags.DEFINE_integer( | |
name="sample_rate", default=16000, | |
help=flags_core.help_wrap("The sample rate for audio.")) | |
flags.DEFINE_integer( | |
name="window_ms", default=20, | |
help=flags_core.help_wrap("The frame length for spectrogram.")) | |
flags.DEFINE_integer( | |
name="stride_ms", default=10, | |
help=flags_core.help_wrap("The frame step.")) | |
flags.DEFINE_string( | |
name="vocabulary_file", default=_VOCABULARY_FILE, | |
help=flags_core.help_wrap("The file path of vocabulary file.")) | |
# RNN related flags | |
flags.DEFINE_integer( | |
name="rnn_hidden_size", default=800, | |
help=flags_core.help_wrap("The hidden size of RNNs.")) | |
flags.DEFINE_integer( | |
name="rnn_hidden_layers", default=5, | |
help=flags_core.help_wrap("The number of RNN layers.")) | |
flags.DEFINE_bool( | |
name="use_bias", default=True, | |
help=flags_core.help_wrap("Use bias in the last fully-connected layer")) | |
flags.DEFINE_bool( | |
name="is_bidirectional", default=True, | |
help=flags_core.help_wrap("If rnn unit is bidirectional")) | |
flags.DEFINE_enum( | |
name="rnn_type", default="gru", | |
enum_values=deep_speech_model.SUPPORTED_RNNS.keys(), | |
case_sensitive=False, | |
help=flags_core.help_wrap("Type of RNN cell.")) | |
# Training related flags | |
flags.DEFINE_float( | |
name="learning_rate", default=5e-4, | |
help=flags_core.help_wrap("The initial learning rate.")) | |
# Evaluation metrics threshold | |
flags.DEFINE_float( | |
name="wer_threshold", default=None, | |
help=flags_core.help_wrap( | |
"If passed, training will stop when the evaluation metric WER is " | |
"greater than or equal to wer_threshold. For libri speech dataset " | |
"the desired wer_threshold is 0.23 which is the result achieved by " | |
"MLPerf implementation.")) | |
def main(_): | |
run_deep_speech(flags_obj) | |
if __name__ == "__main__": | |
tf.logging.set_verbosity(tf.logging.INFO) | |
define_deep_speech_flags() | |
flags_obj = flags.FLAGS | |
absl_app.run(main) | |