Spaces:
Running
Running
import keras.backend as K | |
import tensorflow as tf | |
from keras.layers import * | |
def resize_images_bilinear(X, height_factor=1, width_factor=1, target_height=None, target_width=None, data_format='default'): | |
'''Resizes the images contained in a 4D tensor of shape | |
- [batch, channels, height, width] (for 'channels_first' data_format) | |
- [batch, height, width, channels] (for 'channels_last' data_format) | |
by a factor of (height_factor, width_factor). Both factors should be | |
positive integers. | |
''' | |
if data_format == 'default': | |
data_format = K.image_data_format() | |
if data_format == 'channels_first': | |
original_shape = K.int_shape(X) | |
if target_height and target_width: | |
new_shape = tf.constant(np.array((target_height, target_width)).astype('int32')) | |
else: | |
new_shape = tf.shape(X)[2:] | |
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) | |
X = K.permute_dimensions(X, [0, 2, 3, 1]) | |
X = tf.image.resize_bilinear(X, new_shape) | |
X = K.permute_dimensions(X, [0, 3, 1, 2]) | |
if target_height and target_width: | |
X.set_shape((None, None, target_height, target_width)) | |
else: | |
X.set_shape((None, None, original_shape[2] * height_factor, original_shape[3] * width_factor)) | |
return X | |
elif data_format == 'channels_last': | |
original_shape = K.int_shape(X) | |
if target_height and target_width: | |
new_shape = tf.constant(np.array((target_height, target_width)).astype('int32')) | |
else: | |
new_shape = tf.shape(X)[1:3] | |
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) | |
X = tf.image.resize_bilinear(X, new_shape) | |
if target_height and target_width: | |
X.set_shape((None, target_height, target_width, None)) | |
else: | |
X.set_shape((None, original_shape[1] * height_factor, original_shape[2] * width_factor, None)) | |
return X | |
else: | |
raise Exception('Invalid data_format: ' + data_format) | |
class BilinearUpSampling2D(Layer): | |
def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs): | |
if data_format == 'default': | |
data_format = K.image_data_format() | |
self.size = tuple(size) | |
if target_size is not None: | |
self.target_size = tuple(target_size) | |
else: | |
self.target_size = None | |
assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}' | |
self.data_format = data_format | |
self.input_spec = [InputSpec(ndim=4)] | |
super(BilinearUpSampling2D, self).__init__(**kwargs) | |
def compute_output_shape(self, input_shape): | |
if self.data_format == 'channels_first': | |
width = int(self.size[0] * input_shape[2] if input_shape[2] is not None else None) | |
height = int(self.size[1] * input_shape[3] if input_shape[3] is not None else None) | |
if self.target_size is not None: | |
width = self.target_size[0] | |
height = self.target_size[1] | |
return (input_shape[0], | |
input_shape[1], | |
width, | |
height) | |
elif self.data_format == 'channels_last': | |
width = int(self.size[0] * input_shape[1] if input_shape[1] is not None else None) | |
height = int(self.size[1] * input_shape[2] if input_shape[2] is not None else None) | |
if self.target_size is not None: | |
width = self.target_size[0] | |
height = self.target_size[1] | |
return (input_shape[0], | |
width, | |
height, | |
input_shape[3]) | |
else: | |
raise Exception('Invalid data_format: ' + self.data_format) | |
def call(self, x, mask=None): | |
if self.target_size is not None: | |
return resize_images_bilinear(x, target_height=self.target_size[0], target_width=self.target_size[1], data_format=self.data_format) | |
else: | |
return resize_images_bilinear(x, height_factor=self.size[0], width_factor=self.size[1], data_format=self.data_format) | |
def get_config(self): | |
config = {'size': self.size, 'target_size': self.target_size} | |
base_config = super(BilinearUpSampling2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |