Spaces:
Paused
Paused
# Copyright 2022 Google LLC | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Converts TF2 training checkpoint to a saved model. | |
The model must match the checkpoint, so the gin config must be given. | |
Usage example: | |
python3 -m frame_interpolation.training.build_saved_model_cli \ | |
--gin_config <filepath of the gin config the training session was based> \ | |
--base_folder <base folder of training sessions> \ | |
--label <the name of the run> | |
This will produce a saved model into: <base_folder>/<label>/saved_model | |
""" | |
import os | |
from typing import Sequence | |
from . import model_lib | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import gin.tf | |
import tensorflow as tf | |
tf.get_logger().setLevel('ERROR') | |
_GIN_CONFIG = flags.DEFINE_string( | |
name='gin_config', | |
default='config.gin', | |
help='Gin config file, saved in the training session <root folder>.') | |
_LABEL = flags.DEFINE_string( | |
name='label', | |
default=None, | |
required=True, | |
help='Descriptive label for the training session.') | |
_BASE_FOLDER = flags.DEFINE_string( | |
name='base_folder', | |
default=None, | |
help='Path to all training sessions.') | |
_MODE = flags.DEFINE_enum( | |
name='mode', | |
default=None, | |
enum_values=['cpu', 'gpu', 'tpu'], | |
help='Distributed strategy approach.') | |
def _build_saved_model(checkpoint_path: str, config_files: Sequence[str], | |
output_model_path: str): | |
"""Builds a saved model based on the checkpoint directory.""" | |
gin.parse_config_files_and_bindings( | |
config_files=config_files, | |
bindings=None, | |
skip_unknown=True) | |
model = model_lib.create_model() | |
checkpoint = tf.train.Checkpoint(model=model) | |
checkpoint_file = tf.train.latest_checkpoint(checkpoint_path) | |
try: | |
logging.info('Restoring from %s', checkpoint_file) | |
status = checkpoint.restore(checkpoint_file) | |
status.assert_existing_objects_matched() | |
status.expect_partial() | |
model.save(output_model_path) | |
except (tf.errors.NotFoundError, AssertionError) as err: | |
logging.info('Failed to restore checkpoint from %s. Error:\n%s', | |
checkpoint_file, err) | |
def main(argv): | |
if len(argv) > 1: | |
raise app.UsageError('Too many command-line arguments.') | |
checkpoint_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train') | |
if not tf.io.gfile.exists(_GIN_CONFIG.value): | |
config_file = os.path.join(_BASE_FOLDER.value, _LABEL.value, | |
_GIN_CONFIG.value) | |
else: | |
config_file = _GIN_CONFIG.value | |
output_model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, | |
'saved_model') | |
_build_saved_model( | |
checkpoint_path=checkpoint_path, | |
config_files=[config_file], | |
output_model_path=output_model_path) | |
logging.info('The saved model stored into %s/.', output_model_path) | |
if __name__ == '__main__': | |
app.run(main) | |