Spaces:
Running
Running
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Input pipelines.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow.compat.v2 as tf | |
def decode_record(record, name_to_features): | |
"""Decodes a record to a TensorFlow example.""" | |
example = tf.io.parse_single_example(record, name_to_features) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in list(example.keys()): | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = t | |
return example | |
def process_singledoc_dataset(dataset, batch_size, params): | |
"""Parses and batches single-doc dataset.""" | |
name_to_features = { | |
"input_ids_a": tf.io.FixedLenFeature([params.len_title], tf.int64), | |
"input_ids_b": tf.io.FixedLenFeature([params.len_passage], tf.int64), | |
"input_mask_b": tf.io.FixedLenFeature([params.len_passage], tf.int64), | |
"segment_ids_b": tf.io.FixedLenFeature([params.len_passage], tf.int64), | |
} | |
decode_fn = lambda record: decode_record(record, name_to_features) | |
dataset = dataset.map( | |
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
def _select_data_from_record(record): | |
"""Filter out features to use for pretraining.""" | |
return { | |
"input_ids": record["input_ids_b"], | |
"input_mask": record["input_mask_b"], | |
"segment_ids": record["segment_ids_b"], | |
"target_ids": record["input_ids_a"], | |
} | |
dataset = dataset.map( | |
_select_data_from_record, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
dataset = dataset.batch(batch_size, drop_remainder=True) | |
return dataset | |
def decode_sparse_record(record, name_to_features): | |
"""Decodes a sparse record to a TensorFlow example.""" | |
example = tf.io.parse_single_example(record, name_to_features) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in list(example.keys()): | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = tf.sparse.to_dense(t) | |
return example | |
def _filter_max_length(example, max_title_length=256): | |
"""Indicates whether the example's length is lower than the maximum length.""" | |
return tf.size(example["targets"]) <= max_title_length | |
def process_singledoc_transformer_dataset(dataset, batch_size, params): | |
"""Parses, batches and pads single-doc dataset.""" | |
name_to_features = { | |
"inputs": tf.io.VarLenFeature(tf.int64), | |
"targets": tf.io.VarLenFeature(tf.int64), | |
} | |
decode_fn = lambda record: decode_sparse_record(record, name_to_features) | |
dataset = dataset.map( | |
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
def _select_data_from_record(record): | |
"""Filter out features to use for pretraining.""" | |
input_ids = record["inputs"][:params.len_passage] | |
target_ids = record["targets"] | |
input_mask = tf.ones_like(input_ids) | |
segment_ids = tf.zeros_like(input_ids) | |
return { | |
"input_ids": input_ids, | |
"input_mask": input_mask, | |
"segment_ids": segment_ids, | |
"target_ids": target_ids, | |
} | |
dataset = dataset.filter(lambda x: _filter_max_length(x, params.len_title)) | |
dataset = dataset.map( | |
_select_data_from_record, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
dataset = dataset.padded_batch( | |
batch_size, { | |
"input_ids": [params.len_passage], | |
"input_mask": [params.len_passage], | |
"segment_ids": [params.len_passage], | |
"target_ids": [params.len_title], | |
}, | |
padding_values={ | |
"input_ids": params.pad_token_id, | |
"input_mask": 0, | |
"segment_ids": 0, | |
"target_ids": params.pad_token_id, | |
}, | |
drop_remainder=True) | |
return dataset | |
def multidoc_parse_spec(params, training=True): | |
"""Gets the mutli-doc tf.Example parsing spec.""" | |
len_p = params.len_passage | |
name_to_features = {} | |
feature_list = ["input_ids", "input_mask", "segment_ids"] | |
for idx in params.passage_list: | |
for feature in feature_list: | |
name_to_features["%s_%s" % (feature, idx)] = tf.io.FixedLenFeature( | |
[len_p], tf.int64) | |
if training: | |
# Cluster title. | |
name_to_features["input_ids_a"] = tf.io.FixedLenFeature([params.len_title], | |
tf.int64) | |
return name_to_features, feature_list | |
def process_multidoc_dataset(dataset, batch_size, params): | |
"""Parses, organizes and batches multi-doc dataset.""" | |
name_to_features, feature_list = multidoc_parse_spec(params) | |
decode_fn = lambda record: decode_record(record, name_to_features) | |
dataset = dataset.map( | |
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
def _select_data_from_record(record): | |
"""Filter out features to use for pretraining.""" | |
features = {"target_ids": record["input_ids_a"]} | |
for feature in feature_list: | |
tensors = [record["%s_%s" % (feature, i)] for i in params.passage_list] | |
features[feature] = tf.stack(tensors) | |
return features | |
dataset = dataset.map( | |
_select_data_from_record, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
dataset = dataset.batch(batch_size, drop_remainder=True) | |
return dataset | |
def create_dataset(file_paths, | |
batch_size, | |
params, | |
is_training=True, | |
input_pipeline_context=None): | |
"""Creates input dataset from (tf)records files for pretraining.""" | |
dataset = tf.data.Dataset.list_files(file_paths, shuffle=is_training) | |
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: | |
if not is_training or params.input_sharding: | |
dataset = dataset.shard(input_pipeline_context.num_input_pipelines, | |
input_pipeline_context.input_pipeline_id) | |
if is_training: | |
dataset = dataset.repeat() | |
# We set shuffle buffer to exactly match total number of | |
# training files to ensure that training data is well shuffled. | |
dataset = dataset.shuffle(len(file_paths)) | |
# In parallel, create tf record dataset for each train files. | |
# cycle_length = 8 means that up to 8 files will be read and deserialized in | |
# parallel. You may want to increase this number if you have a large number of | |
# CPU cores. | |
dataset = dataset.interleave( | |
tf.data.TFRecordDataset, | |
cycle_length=8, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
if is_training: | |
dataset = dataset.shuffle(100) | |
if params.get("multi_channel_cross_attention", value=False): | |
dataset = process_multidoc_dataset(dataset, batch_size, params) | |
else: | |
if not params.input_data_not_padded: | |
dataset = process_singledoc_dataset(dataset, batch_size, params) | |
else: | |
dataset = process_singledoc_transformer_dataset(dataset, batch_size, | |
params) | |
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) | |
return dataset | |
def get_input_dataset(input_file_pattern, | |
batch_size, | |
params, | |
is_training, | |
strategy=None): | |
"""Returns input dataset from input file string.""" | |
# When using TPU pods, we need to clone dataset across | |
# workers and need to pass in function that returns the dataset rather | |
# than passing dataset instance itself. | |
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) | |
if use_dataset_fn: | |
if batch_size % strategy.num_replicas_in_sync != 0: | |
raise ValueError( | |
"Batch size must be divisible by number of replicas : {}".format( | |
strategy.num_replicas_in_sync)) | |
# As auto rebatching is not supported in | |
# `experimental_distribute_datasets_from_function()` API, which is | |
# required when cloning dataset to multiple workers in eager mode, | |
# we use per-replica batch size. | |
batch_size = int(batch_size / strategy.num_replicas_in_sync) | |
def _dataset_fn(ctx=None): | |
"""Returns tf.data.Dataset for distributed BERT pretraining.""" | |
input_files = [] | |
for input_pattern in input_file_pattern.split(","): | |
input_files.extend(tf.io.gfile.glob(input_pattern)) | |
return create_dataset( | |
input_files, | |
batch_size, | |
params, | |
is_training=is_training, | |
input_pipeline_context=ctx) | |
if use_dataset_fn: | |
return strategy.experimental_distribute_datasets_from_function(_dataset_fn) | |
else: | |
return strategy.experimental_distribute_dataset(_dataset_fn()) | |