|
import tensorflow as tf |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class Augment(tf.keras.layers.Layer): |
|
def __init__(self, seed=42): |
|
super().__init__() |
|
|
|
self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed) |
|
self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed) |
|
|
|
def call(self, inputs, labels): |
|
inputs = self.augment_inputs(inputs) |
|
labels = self.augment_labels(labels) |
|
return inputs, labels |
|
|
|
def load_image(datapoint): |
|
input_image = tf.image.resize(datapoint['image'], (128, 128)) |
|
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128), method = tf.image.ResizeMethod.NEAREST_NEIGHBOR) |
|
input_image, input_mask = normalize_image(input_image, input_mask) |
|
return input_image, input_mask |
|
|
|
def normalize_image(input_image, input_mask): |
|
input_image = tf.cast(input_image, tf.float32) / 255.0 |
|
input_mask -= 1 |
|
return input_image, input_mask |
|
|
|
def create_mask(pred_mask): |
|
pred_mask = tf.math.argmax(pred_mask, axis=-1) |
|
pred_mask = pred_mask[..., tf.newaxis] |
|
return pred_mask[0] |
|
|
|
def U_net_model(output_channels:int, down_stack, up_stack): |
|
inputs = tf.keras.layers.Input(shape=[128, 128, 3]) |
|
skips = down_stack(inputs) |
|
outputs = skips[-1] |
|
skips = reversed(skips[:-1]) |
|
for up, skip in zip(up_stack, skips): |
|
outputs = up(outputs) |
|
concatenate = tf.keras.layers.Concatenate() |
|
outputs = concatenate([outputs, skip]) |
|
last = tf.keras.layers.Conv2DTranspose(filters=output_channels, kernel_size=3, strides=2, padding='same') |
|
outputs = last(outputs) |
|
return tf.keras.Model(inputs=inputs, outputs=outputs) |
|
|
|
def display(display_list): |
|
plt.figure(figsize=(15, 15)) |
|
titles = ['Input Image', 'Predicted Mask'] |
|
for i in range(len(display_list)): |
|
plt.subplot(1, len(display_list), i+1) |
|
plt.title(titles[i]) |
|
plt.imshow(tf.keras.utils.array_to_img(display_list[i])) |
|
plt.axis('off') |
|
plt.show() |
|
|
|
def show_predictions(image_url, model): |
|
image = tf.keras.utils.get_file(origin=image_url) |
|
image = tf.keras.utils.load_img(image) |
|
image = tf.keras.utils.img_to_array(image) |
|
image = tf.image.resize(image, (128,128)) |
|
image = tf.cast(image, tf.float32) / 255.0 |
|
image = tf.expand_dims(image, axis=0) |
|
pred_mask = model.predict(image) |
|
display([image[0], create_mask(pred_mask)]) |