Spaces:
Paused
Paused
# Copyright 2022 Google LLC | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Beam pipeline that generates Vimeo-90K (train or test) triplet TFRecords. | |
Vimeo-90K dataset is built upon 5,846 videos downloaded from vimeo.com. The list | |
of the original video links are available here: | |
https://github.com/anchen1011/toflow/blob/master/data/original_vimeo_links.txt. | |
Each video is further cropped into a fixed spatial size of (448 x 256) to create | |
89,000 video clips. | |
The Vimeo-90K dataset is designed for four video processing tasks. This script | |
creates the TFRecords of frame triplets for frame interpolation task. | |
Temporal frame interpolation triplet dataset: | |
- 73,171 triplets of size (448x256) extracted from 15K subsets of Vimeo-90K. | |
- The triplets are pre-split into (train,test) = (51313,3782) | |
- Download links: | |
Test-set: http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip | |
Train+test-set: http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip | |
For more information, see the arXiv paper, project page or the GitHub link. | |
@article{xue17toflow, | |
author = {Xue, Tianfan and | |
Chen, Baian and | |
Wu, Jiajun and | |
Wei, Donglai and | |
Freeman, William T}, | |
title = {Video Enhancement with Task-Oriented Flow}, | |
journal = {arXiv}, | |
year = {2017} | |
} | |
Project: http://toflow.csail.mit.edu/ | |
GitHub: https://github.com/anchen1011/toflow | |
Inputs to the script are (1) the directory to the downloaded and unzipped folder | |
(2) the filepath of the text-file that lists the subfolders of the triplets. | |
Output TFRecord is a tf.train.Example proto of each image triplet. | |
The feature_map takes the form: | |
feature_map { | |
'frame_0/encoded': | |
tf.io.FixedLenFeature((), tf.string, default_value=''), | |
'frame_0/format': | |
tf.io.FixedLenFeature((), tf.string, default_value='jpg'), | |
'frame_0/height': | |
tf.io.FixedLenFeature((), tf.int64, default_value=0), | |
'frame_0/width': | |
tf.io.FixedLenFeature((), tf.int64, default_value=0), | |
'frame_1/encoded': | |
tf.io.FixedLenFeature((), tf.string, default_value=''), | |
'frame_1/format': | |
tf.io.FixedLenFeature((), tf.string, default_value='jpg'), | |
'frame_1/height': | |
tf.io.FixedLenFeature((), tf.int64, default_value=0), | |
'frame_1/width': | |
tf.io.FixedLenFeature((), tf.int64, default_value=0), | |
'frame_2/encoded': | |
tf.io.FixedLenFeature((), tf.string, default_value=''), | |
'frame_2/format': | |
tf.io.FixedLenFeature((), tf.string, default_value='jpg'), | |
'frame_2/height': | |
tf.io.FixedLenFeature((), tf.int64, default_value=0), | |
'frame_2/width': | |
tf.io.FixedLenFeature((), tf.int64, default_value=0) | |
'path': | |
tf.io.FixedLenFeature((), tf.string, default_value='') | |
} | |
Usage example: | |
python3 -m frame_interpolation.datasets.create_vimeo90K_tfrecord \ | |
--input_dir=<root folder of vimeo90K dataset> \ | |
--input_triplet_list_filepath=<filepath of tri_{test|train}list.txt> \ | |
--output_tfrecord_filepath=<output tfrecord filepath> | |
""" | |
import os | |
from . import util | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import apache_beam as beam | |
import numpy as np | |
import tensorflow as tf | |
_INPUT_DIR = flags.DEFINE_string( | |
'input_dir', | |
default='/path/to/raw_vimeo_interp/sequences', | |
help='Path to the root directory of the vimeo frame interpolation dataset. ' | |
'We expect the data to have been downloaded and unzipped.\n' | |
'Folder structures:\n' | |
'| raw_vimeo_dataset/\n' | |
'| sequences/\n' | |
'| | 00001\n' | |
'| | | 0389/\n' | |
'| | | | im1.png\n' | |
'| | | | im2.png\n' | |
'| | | | im3.png\n' | |
'| | | ...\n' | |
'| | 00002/\n' | |
'| | ...\n' | |
'| readme.txt\n' | |
'| tri_trainlist.txt\n' | |
'| tri_testlist.txt \n') | |
_INTPUT_TRIPLET_LIST_FILEPATH = flags.DEFINE_string( | |
'input_triplet_list_filepath', | |
default='/path/to/raw_vimeo_dataset/tri_{test|train}list.txt', | |
help='Text file containing a list of sub-directories of input triplets.') | |
_OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string( | |
'output_tfrecord_filepath', | |
default=None, | |
help='Filepath to the output TFRecord file.') | |
_NUM_SHARDS = flags.DEFINE_integer('num_shards', | |
default=200, # set to 3 for vimeo_test, and 200 for vimeo_train. | |
help='Number of shards used for the output.') | |
# Image key -> basename for frame interpolator: start / middle / end frames. | |
_INTERPOLATOR_IMAGES_MAP = { | |
'frame_0': 'im1.png', | |
'frame_1': 'im2.png', | |
'frame_2': 'im3.png', | |
} | |
def main(unused_argv): | |
"""Creates and runs a Beam pipeline to write frame triplets as a TFRecord.""" | |
with tf.io.gfile.GFile(_INTPUT_TRIPLET_LIST_FILEPATH.value, 'r') as fid: | |
triplets_list = np.loadtxt(fid, dtype=str) | |
triplet_dicts = [] | |
for triplet in triplets_list: | |
triplet_dict = { | |
image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename) | |
for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items() | |
} | |
triplet_dicts.append(triplet_dict) | |
p = beam.Pipeline('DirectRunner') | |
(p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned | |
| 'GenerateSingleExample' >> beam.ParDo( | |
util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP)) | |
| 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord( | |
file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value, | |
num_shards=_NUM_SHARDS.value, | |
coder=beam.coders.BytesCoder())) | |
result = p.run() | |
result.wait_until_finish() | |
logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.', | |
_OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value)) | |
if __name__ == '__main__': | |
app.run(main) | |