Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Base head class. | |
All the different kinds of prediction heads in different models will inherit | |
from this class. What is in common between all head classes is that they have a | |
`predict` function that receives `features` as its first argument. | |
How to add a new prediction head to an existing meta architecture? | |
For example, how can we add a `3d shape` prediction head to Mask RCNN? | |
We have to take the following steps to add a new prediction head to an | |
existing meta arch: | |
(a) Add a class for predicting the head. This class should inherit from the | |
`Head` class below and have a `predict` function that receives the features | |
and predicts the output. The output is always a tf.float32 tensor. | |
(b) Add the head to the meta architecture. For example in case of Mask RCNN, | |
go to box_predictor_builder and put in the logic for adding the new head to the | |
Mask RCNN box predictor. | |
(c) Add the logic for computing the loss for the new head. | |
(d) Add the necessary metrics for the new head. | |
(e) (optional) Add visualization for the new head. | |
""" | |
from abc import abstractmethod | |
import tensorflow.compat.v1 as tf | |
class Head(object): | |
"""Mask RCNN head base class.""" | |
def __init__(self): | |
"""Constructor.""" | |
pass | |
def predict(self, features, num_predictions_per_location): | |
"""Returns the head's predictions. | |
Args: | |
features: A float tensor of features. | |
num_predictions_per_location: Int containing number of predictions per | |
location. | |
Returns: | |
A tf.float32 tensor. | |
""" | |
pass | |
class KerasHead(tf.keras.Model): | |
"""Keras head base class.""" | |
def call(self, features): | |
"""The Keras model call will delegate to the `_predict` method.""" | |
return self._predict(features) | |
def _predict(self, features): | |
"""Returns the head's predictions. | |
Args: | |
features: A float tensor of features. | |
Returns: | |
A tf.float32 tensor. | |
""" | |
pass | |