Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Processes crawled content from news URLs by generating tfrecords.""" | |
import os | |
from absl import app | |
from absl import flags | |
from official.nlp.nhnet import raw_data_processor | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string("crawled_articles", "/tmp/nhnet/", | |
"Folder path to the crawled articles using news-please.") | |
flags.DEFINE_string("vocab", None, "Filepath of the BERT vocabulary.") | |
flags.DEFINE_bool("do_lower_case", True, | |
"Whether the vocabulary is uncased or not.") | |
flags.DEFINE_integer("len_title", 15, | |
"Maximum number of tokens in story headline.") | |
flags.DEFINE_integer("len_passage", 200, | |
"Maximum number of tokens in article passage.") | |
flags.DEFINE_integer("max_num_articles", 5, | |
"Maximum number of articles in a story.") | |
flags.DEFINE_bool("include_article_title_in_passage", False, | |
"Whether to include article title in article passage.") | |
flags.DEFINE_string("data_folder", None, | |
"Folder path to the downloaded data folder (output).") | |
flags.DEFINE_integer("num_tfrecords_shards", 20, | |
"Number of shards for train/valid/test.") | |
def transform_as_tfrecords(data_processor, filename): | |
"""Transforms story from json to tfrecord (sharded). | |
Args: | |
data_processor: Instance of RawDataProcessor. | |
filename: 'train', 'valid', or 'test'. | |
""" | |
print("Transforming json to tfrecord for %s..." % filename) | |
story_filepath = os.path.join(FLAGS.data_folder, filename + ".json") | |
output_folder = os.path.join(FLAGS.data_folder, "processed") | |
os.makedirs(output_folder, exist_ok=True) | |
output_filepaths = [] | |
for i in range(FLAGS.num_tfrecords_shards): | |
output_filepaths.append( | |
os.path.join( | |
output_folder, "%s.tfrecord-%.5d-of-%.5d" % | |
(filename, i, FLAGS.num_tfrecords_shards))) | |
(total_num_examples, | |
generated_num_examples) = data_processor.generate_examples( | |
story_filepath, output_filepaths) | |
print("For %s, %d examples have been generated from %d stories in json." % | |
(filename, generated_num_examples, total_num_examples)) | |
def main(_): | |
if not FLAGS.data_folder: | |
raise ValueError("data_folder must be set as the downloaded folder path.") | |
if not FLAGS.vocab: | |
raise ValueError("vocab must be set as the filepath of BERT vocabulary.") | |
data_processor = raw_data_processor.RawDataProcessor( | |
vocab=FLAGS.vocab, | |
do_lower_case=FLAGS.do_lower_case, | |
len_title=FLAGS.len_title, | |
len_passage=FLAGS.len_passage, | |
max_num_articles=FLAGS.max_num_articles, | |
include_article_title_in_passage=FLAGS.include_article_title_in_passage, | |
include_text_snippet_in_example=True) | |
print("Loading crawled articles...") | |
num_articles = data_processor.read_crawled_articles(FLAGS.crawled_articles) | |
print("Total number of articles loaded: %d" % num_articles) | |
print() | |
transform_as_tfrecords(data_processor, "train") | |
transform_as_tfrecords(data_processor, "valid") | |
transform_as_tfrecords(data_processor, "test") | |
if __name__ == "__main__": | |
app.run(main) | |