Spaces:
Paused
Paused
# Copyright 2022 Google LLC | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A library for instantiating the model for training frame interpolation. | |
All models are expected to use three inputs: input image batches 'x0' and 'x1' | |
and 'time', the fractional time where the output should be generated. | |
The models are expected to output the prediction as a dictionary that contains | |
at least the predicted image batch as 'image' plus optional data for debug, | |
analysis or custom losses. | |
""" | |
import gin.tf | |
from ..models.film_net import interpolator as film_net_interpolator | |
from ..models.film_net import options as film_net_options | |
import tensorflow as tf | |
def create_model(name: str) -> tf.keras.Model: | |
"""Creates the frame interpolation model based on given model name.""" | |
if name == 'film_net': | |
return _create_film_net_model() # pylint: disable=no-value-for-parameter | |
else: | |
raise ValueError(f'Model {name} not implemented.') | |
def _create_film_net_model() -> tf.keras.Model: | |
"""Creates the film_net interpolator.""" | |
# Options are gin-configured in the Options class directly. | |
options = film_net_options.Options() | |
x0 = tf.keras.Input( | |
shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0') | |
x1 = tf.keras.Input( | |
shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1') | |
time = tf.keras.Input( | |
shape=(1,), batch_size=None, dtype=tf.float32, name='time') | |
return film_net_interpolator.create_model(x0, x1, time, options) | |