{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "nUu3FjibHfQL", "outputId": "8f9e98eb-f627-49f9-dcae-47dfdde9cb1c" }, "outputs": [], "source": [ "!pip install keras==2.15.0 tensorflow==2.15.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3dlzV23EMp8e" }, "outputs": [], "source": [ "import os\n", "os.environ['KERAS_BACKEND'] = 'tensorflow'\n", "import pathlib\n", "import re" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vo6h7I0NNR2G" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import keras\n", "import numpy as np\n", "from keras import layers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FXa1ANO0NdPB", "outputId": "df1c87dd-c04f-469d-ec04-4b81e2c7917e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.15.0\n", "2.15.0\n", "tensorflow\n" ] } ], "source": [ "print(keras.__version__)\n", "print(tf.__version__)\n", "print(keras.backend.backend())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iGRrFrjo-VQL" }, "outputs": [], "source": [ "# Path to the images\n", "IMAGES_PATH = \"Flicker8k_Dataset\"\n", "\n", "# Desired image dimensions\n", "IMAGE_SIZE = (299, 299)\n", "\n", "# Vocabulary size\n", "VOCAB_SIZE = 10000\n", "\n", "# Fixed length allowed for any sequence\n", "SEQ_LENGTH = 25\n", "\n", "# Dimension for the image embeddings and token embeddings\n", "EMBED_DIM = 512\n", "\n", "# Per-layer units in the feed-forward network\n", "FF_DIM = 512\n", "\n", "# Other training parameters\n", "BATCH_SIZE = 64\n", "EPOCHS = 30\n", "AUTOTUNE = tf.data.AUTOTUNE" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "Yiy02CsNO12W", "outputId": "0f8b9a5e-c590-4b12-ec51-82741fadf4d9" }, "outputs": [], "source": [ "path = pathlib.Path(\".\")\n", "keras.utils.get_file(\n", " origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',\n", " cache_dir='.',\n", " cache_subdir=path,\n", " extract=True)\n", "keras.utils.get_file(\n", " origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',\n", " cache_dir='.',\n", " cache_subdir=path,\n", " extract=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ChgFWRLUzhJ" }, "outputs": [], "source": [ "dataset = pathlib.Path(path, \"Flickr8k.token.txt\").read_text(encoding='utf-8').splitlines()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Tp4TSNVZBlEt", "outputId": "42113a83-cc68-4ed3-a630-8aba8a2b20d6" }, "outputs": [], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qYtr55VMWTTy" }, "outputs": [], "source": [ "dataset = [line.split('\\t') for line in dataset]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9y18xIyfBovq", "outputId": "7896d304-5864-46d8-8b63-945dfebd3151" }, "outputs": [], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aMQbS8ueZzii" }, "outputs": [], "source": [ "dataset = [[os.path.join(IMAGES_PATH,fname.split('#')[0].strip()), caption] for (fname, caption) in dataset]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "l7w7CxPuZtMG", "outputId": "e3ca1665-277b-4c81-88d3-5c2c39ccd628" }, "outputs": [], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mkMvwXnZBFzD", "outputId": "774cd7f3-0ffc-4273-862d-feb4366c3e20" }, "outputs": [], "source": [ "for i in dataset:\n", " print(i)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqthDk6gZV3i" }, "outputs": [], "source": [ "caption_mapping = {}\n", "text_data = []\n", "X_en_data = []\n", "X_de_data = []\n", "Y_data = []" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yWbcI72w9Xv2" }, "outputs": [], "source": [ "for img_name, caption in dataset:\n", " if img_name.endswith(\"jpg\"):\n", " X_de_data.append(\" \" + caption.strip().replace(\".\", \"\"))\n", " Y_data.append(caption.strip().replace(\".\", \"\") + \" \")\n", " text_data.append(\" \" + caption.strip().replace(\".\", \"\") + \" \")\n", " X_en_data.append(img_name)\n", "\n", "\n", " if img_name in caption_mapping:\n", " caption_mapping[img_name].append(caption)\n", " else:\n", " caption_mapping[img_name] = [caption]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqLKJwdVSydw" }, "outputs": [], "source": [ "for i in X_de_data:\n", " if len(i) <= 2:\n", " print(\"Y\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6RDJi5j4_C_z", "outputId": "a6b9a4e5-94a3-40cc-98e6-8454e4e268ab" }, "outputs": [], "source": [ "print(X_en_data[0])\n", "print(X_de_data[0])\n", "print(Y_data[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L_l0R04eJZnQ" }, "outputs": [], "source": [ "train_size=0.8\n", "shuffle=True\n", "np.random.seed(42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f-EnvEgh8sOB", "outputId": "0cc36fe3-6cac-48ea-bcd4-9ace27dd5325" }, "outputs": [], "source": [ "zipped = list(zip(X_en_data, X_de_data, Y_data))\n", "np.random.shuffle(zipped)\n", "X_en_data, X_de_data, Y_data = zip(*zipped)\n", "print(X_en_data[0])\n", "print(X_de_data[0])\n", "print(Y_data[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z7EFBHhdPeLs" }, "outputs": [], "source": [ "train_size = int(len(X_en_data)*train_size)\n", "X_train_en = list(X_en_data[:train_size])\n", "X_train_de = list(X_de_data[:train_size])\n", "Y_train = list(Y_data[:train_size])\n", "X_valid_en = list(X_en_data[train_size:])\n", "X_valid_de = list(X_de_data[train_size:])\n", "Y_valid = list(Y_data[train_size:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VEACJzQ0Tccm", "outputId": "0c534618-1218-49df-ec73-54bdaf6c0452" }, "outputs": [], "source": [ "print(X_train_en[0])\n", "print(X_train_de[0])\n", "print(Y_train[0])\n", "print(X_valid_en[0])\n", "print(X_valid_de[0])\n", "print(Y_valid[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tZAf67OXAmxn" }, "outputs": [], "source": [ "strip_chars = \"!\\\"#$%&'()*+,-./:;=?@[\\]^_`{|}~\"\n", "def custom_standardization(input_string):\n", " lowercase = tf.strings.lower(input_string)\n", " return tf.strings.regex_replace(lowercase, f'{re.escape(strip_chars)}', '')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d6v0_3TuJZaG" }, "outputs": [], "source": [ "vectorization = keras.layers.TextVectorization(\n", " max_tokens=VOCAB_SIZE,\n", " output_mode=\"int\",\n", " output_sequence_length=SEQ_LENGTH,\n", " standardize=custom_standardization,\n", " )\n", "\n", "vectorization.adapt(text_data)\n", "vocab = np.array(vectorization.get_vocabulary())\n", "np.save('./artifacts/vocabulary.npy', vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WcXsq2LzNpda" }, "outputs": [], "source": [ "def decode_and_resize(img_path):\n", " img = tf.io.read_file(img_path)\n", " img = tf.image.decode_jpeg(img, channels=3)\n", " img = tf.image.resize(img, IMAGE_SIZE)\n", " img = tf.image.convert_image_dtype(img, tf.float32)\n", " return img\n", "\n", "\n", "def process_input(img_cap, y_captions):\n", " img_path, x_captions = img_cap\n", " return ((decode_and_resize(img_path), vectorization(x_captions)), vectorization(y_captions))\n", "\n", "\n", "def make_dataset(images, x_captions, y_captions):\n", " dataset = tf.data.Dataset.from_tensor_slices(((images, x_captions), y_captions))\n", " dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)\n", " dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)\n", "\n", " return dataset\n", "\n", "\n", "\n", "train_dataset = make_dataset(X_train_en, X_train_de, Y_train)\n", "\n", "valid_dataset = make_dataset(X_valid_en, X_valid_de, Y_valid)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CC_icKqG5xGg" }, "outputs": [], "source": [ "image_augmentation = keras.Sequential(\n", " [\n", " keras.layers.RandomFlip(\"horizontal\"),\n", " keras.layers.RandomRotation(0.2),\n", " keras.layers.RandomContrast(0.3),\n", " ]\n", ")\n", "@keras.saving.register_keras_serializable()\n", "def get_cnn_model():\n", " base_model = keras.applications.efficientnet.EfficientNetB0(\n", " input_shape=(*IMAGE_SIZE, 3),\n", " include_top=False,\n", " weights=\"imagenet\"\n", " )\n", " base_model.trainable = False\n", " base_model_out = base_model.output\n", " base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)\n", " cnn_model = keras.models.Model(base_model.input, base_model_out)\n", " return cnn_model\n", "\n", "@keras.saving.register_keras_serializable()\n", "class TransformerEncoderBlock(layers.Layer):\n", " def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n", " super().__init__(**kwargs)\n", " self.embed_dim = embed_dim\n", " self.dense_dim = dense_dim\n", " self.num_heads = num_heads\n", " self.attention_1 = layers.MultiHeadAttention(\n", " num_heads=num_heads, key_dim=embed_dim, dropout=0.0\n", " )\n", " self.layernorm_1 = layers.LayerNormalization()\n", " self.layernorm_2 = layers.LayerNormalization()\n", " self.dense_1 = layers.Dense(embed_dim, activation=\"relu\")\n", "\n", " def get_config(self):\n", " base_config = super().get_config()\n", " config = {\n", " \"embed_dim\": self.embed_dim,\n", " \"dense_dim\": self.dense_dim,\n", " \"num_heads\": self.num_heads,\n", " }\n", " return {**base_config, **config}\n", "\n", "\n", " def call(self, inputs, training):\n", " inputs = self.layernorm_1(inputs)\n", " inputs = self.dense_1(inputs)\n", "\n", " attention_output_1 = self.attention_1(\n", " query=inputs,\n", " value=inputs,\n", " key=inputs,\n", " training=training,\n", " )\n", " out_1 = self.layernorm_2(inputs + attention_output_1)\n", " return out_1\n", "\n", "@keras.saving.register_keras_serializable()\n", "class PositionalEmbedding(layers.Layer):\n", " def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):\n", " super().__init__(**kwargs)\n", " self.token_embeddings = layers.Embedding(\n", " input_dim=vocab_size, output_dim=embed_dim, mask_zero=True\n", " )\n", " self.position_embeddings = layers.Embedding(\n", " input_dim=sequence_length, output_dim=embed_dim\n", " )\n", " self.sequence_length = sequence_length\n", " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", "\n", " self.add = layers.Add()\n", "\n", " def get_config(self):\n", " base_config = super().get_config()\n", " config = {\n", " \"sequence_length\": self.sequence_length,\n", " \"vocab_size\": self.vocab_size,\n", " \"embed_dim\": self.embed_dim,\n", " }\n", " return {**base_config, **config}\n", "\n", " def call(self, seq):\n", " seq = self.token_embeddings(seq)\n", "\n", " x = tf.range(tf.shape(seq)[1])\n", " x = x[tf.newaxis, :]\n", " x = self.position_embeddings(x)\n", "\n", " return self.add([seq,x])\n", "\n", "@keras.saving.register_keras_serializable()\n", "class TransformerDecoderBlock(layers.Layer):\n", " def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):\n", " super().__init__(**kwargs)\n", " self.embed_dim = embed_dim\n", " self.ff_dim = ff_dim\n", " self.num_heads = num_heads\n", " self.attention_1 = layers.MultiHeadAttention(\n", " num_heads=num_heads, key_dim=embed_dim, dropout=0.1\n", " )\n", " self.attention_2 = layers.MultiHeadAttention(\n", " num_heads=num_heads, key_dim=embed_dim, dropout=0.1\n", " )\n", " self.ffn_layer_1 = layers.Dense(ff_dim, activation=\"relu\")\n", " self.ffn_layer_2 = layers.Dense(embed_dim)\n", "\n", " self.layernorm_1 = layers.LayerNormalization()\n", " self.layernorm_2 = layers.LayerNormalization()\n", " self.layernorm_3 = layers.LayerNormalization()\n", "\n", " self.embedding = PositionalEmbedding(\n", " embed_dim=EMBED_DIM,\n", " sequence_length=SEQ_LENGTH,\n", " vocab_size=VOCAB_SIZE,\n", " )\n", " self.out = layers.Dense(VOCAB_SIZE, activation=\"softmax\")\n", "\n", " self.dropout_1 = layers.Dropout(0.3)\n", " self.dropout_2 = layers.Dropout(0.5)\n", " self.supports_masking = True\n", "\n", " def get_config(self):\n", " base_config = super().get_config()\n", " config = {\n", " \"embed_dim\": self.embed_dim,\n", " \"ff_dim\": self.ff_dim,\n", " \"num_heads\": self.num_heads,\n", "\n", " }\n", " return {**base_config, **config}\n", "\n", "\n", "\n", " def call(self, inputs, encoder_outputs, training, mask=None):\n", " inputs = self.embedding(inputs)\n", "\n", " attention_output_1 = self.attention_1(\n", " query=inputs,\n", " value=inputs,\n", " key=inputs,\n", " training=training,\n", " use_causal_mask=True\n", " )\n", " out_1 = self.layernorm_1(inputs + attention_output_1)\n", "\n", " attention_output_2 = self.attention_2(\n", " query=out_1,\n", " value=encoder_outputs,\n", " key=encoder_outputs,\n", " training=training,\n", " )\n", " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", "\n", " ffn_out = self.ffn_layer_1(out_2)\n", " ffn_out = self.dropout_1(ffn_out, training=training)\n", " ffn_out = self.ffn_layer_2(ffn_out)\n", "\n", " ffn_out = self.layernorm_3(ffn_out + out_2, training=training)\n", " ffn_out = self.dropout_2(ffn_out, training=training)\n", " preds = self.out(ffn_out)\n", " return preds\n", "\n", "\n", "@keras.saving.register_keras_serializable()\n", "class ImageCaptioningModel(keras.Model):\n", " def __init__(\n", " self,\n", " cnn_model,\n", " encoder,\n", " decoder,\n", " image_aug=None,\n", " **kwargs\n", " ):\n", " super().__init__(**kwargs)\n", " self.cnn_model = cnn_model\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.image_aug = image_aug\n", "\n", " def get_config(self):\n", " base_config = super().get_config()\n", " config = {\n", " \"cnn_model\": self.cnn_model,\n", " \"encoder\": self.encoder,\n", " \"decoder\": self.decoder,\n", " \"image_aug\": self.image_aug,\n", " }\n", " return {**base_config, **config}\n", "\n", " @classmethod\n", " def from_config(cls, config):\n", " # Note that you can also use [`keras.saving.deserialize_keras_object`](/api/models/model_saving_apis/serialization_utils#deserializekerasobject-function) here\n", " config[\"cnn_model\"] = keras.saving.deserialize_keras_object(config[\"cnn_model\"])\n", " config[\"encoder\"] = keras.saving.deserialize_keras_object(config[\"encoder\"])\n", " config[\"decoder\"] = keras.saving.deserialize_keras_object(config[\"decoder\"])\n", " config[\"image_aug\"] = keras.saving.deserialize_keras_object(config[\"image_aug\"])\n", "\n", " # Instantiate the ImageCaptioningModel with the remaining configuration\n", " return cls(**config)\n", "\n", " def call(self, inputs, training):\n", " img, caption = inputs\n", " if self.image_aug:\n", " img = self.image_aug(img)\n", " img_embed = self.cnn_model(img)\n", " encoder_out = self.encoder(img_embed, training=training)\n", " pred = self.decoder(caption, encoder_out, training=training)\n", " return pred\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MGCXWbEY6tTn" }, "outputs": [], "source": [ "cnn_model = get_cnn_model()\n", "encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)\n", "decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)\n", "caption_model = ImageCaptioningModel(\n", " cnn_model=cnn_model,\n", " encoder=encoder,\n", " decoder=decoder,\n", " image_aug=image_augmentation,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 391 }, "id": "LtUx3PjMB6aJ", "outputId": "3cdc513a-321d-425b-a617-549e42fbf404" }, "outputs": [], "source": [ "\n", "early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)\n", "\n", "\n", "@keras.saving.register_keras_serializable()\n", "class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):\n", " def __init__(self, post_warmup_learning_rate, warmup_steps, **kwargs):\n", " super().__init__(**kwargs)\n", " self.post_warmup_learning_rate = post_warmup_learning_rate\n", " self.warmup_steps = warmup_steps\n", "\n", " def get_config(self):\n", " config = {\n", " \"post_warmup_learning_rate\": self.post_warmup_learning_rate,\n", " \"warmup_steps\": self.warmup_steps,\n", " }\n", " return config\n", "\n", " def __call__(self, step):\n", " global_step = tf.cast(step, tf.float32)\n", " warmup_steps = tf.cast(self.warmup_steps, tf.float32)\n", " warmup_progress = global_step / warmup_steps\n", " warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress\n", " return tf.cond(\n", " global_step < warmup_steps,\n", " lambda: warmup_learning_rate,\n", " lambda: self.post_warmup_learning_rate,\n", " )\n", "\n", "\n", "num_train_steps = len(train_dataset) * EPOCHS\n", "num_warmup_steps = num_train_steps // 15\n", "lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps)\n", "\n", "caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])\n", "\n", "caption_model.fit(\n", " train_dataset,\n", " epochs=EPOCHS,\n", " validation_data=valid_dataset,\n", " callbacks=[early_stopping],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M10k_8_gBKxz" }, "outputs": [], "source": [ "caption_model.save(\"caption_model.keras\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f1FD15MiBQSh" }, "outputs": [], "source": [ "loaded_model = keras.models.load_model(\"caption_model.keras\", compile=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ULoizN2kfR2W", "outputId": "fa3f0e8f-f1cc-4821-8f37-3977c6feb047" }, "outputs": [], "source": [ "caption_model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MUTEhm28fVdN", "outputId": "b020cfe6-b3e4-4c84-cf58-e59a233c6035" }, "outputs": [], "source": [ "loaded_model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "k4H_CsBUYBSi", "outputId": "590ecbca-1980-4456-89f1-ce0b2643506f" }, "outputs": [], "source": [ "caption_model.evaluate(valid_dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xzYtGvSPYA5H", "outputId": "5b103038-b02d-491a-d08c-c252682590fd" }, "outputs": [], "source": [ "loaded_model.evaluate(valid_dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UvRoJqZ-Xp7g" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 452 }, "id": "5edj0qS3YCTZ", "outputId": "fefb8277-944e-4a42-aa64-aef5df6ab92b" }, "outputs": [], "source": [ "\n", "\n", "vocab = vectorization.get_vocabulary()\n", "index_lookup = dict(zip(range(len(vocab)), vocab))\n", "max_decoded_sentence_length = SEQ_LENGTH - 1\n", "valid_images = list(X_train_en)\n", "\n", "\n", "def generate_caption():\n", " # Select a random image from the validation dataset\n", " sample_img = np.random.choice(valid_images)\n", "\n", " # Read the image from the disk\n", " sample_img = decode_and_resize(sample_img)\n", " img = sample_img.numpy().clip(0, 255).astype(np.uint8)\n", " plt.imshow(img)\n", " plt.show()\n", "\n", " # Pass the image to the CNN\n", " img = tf.expand_dims(sample_img, 0)\n", " img = caption_model.cnn_model(img)\n", "\n", " # Pass the image features to the Transformer encoder\n", " encoded_img = caption_model.encoder(img, training=False)\n", "\n", " # Generate the caption using the Transformer decoder\n", " decoded_caption = \" \"\n", " for i in range(max_decoded_sentence_length):\n", " tokenized_caption = vectorization([decoded_caption])\n", " mask = tf.math.not_equal(tokenized_caption, 0)\n", " predictions = caption_model.decoder(\n", " tokenized_caption, encoded_img, training=False, mask=mask\n", " )\n", " sampled_token_index = np.argmax(predictions[0, i, :])\n", " sampled_token = index_lookup[sampled_token_index]\n", " if sampled_token == \"\":\n", " break\n", " decoded_caption += \" \" + sampled_token\n", "\n", " decoded_caption = decoded_caption.replace(\" \", \"\")\n", " decoded_caption = decoded_caption.replace(\" \", \"\").strip()\n", " print(\"Predicted Caption: \", decoded_caption)\n", "\n", "\n", "# Check predictions for a few samples\n", "generate_caption()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 715 }, "id": "3zRF5hAEbOdm", "outputId": "85611bc9-70b9-4282-aac4-53d3af5c2c18" }, "outputs": [], "source": [ "\n", "\n", "vocab = vectorization.get_vocabulary()\n", "index_lookup = dict(zip(range(len(vocab)), vocab))\n", "max_decoded_sentence_length = SEQ_LENGTH - 1\n", "valid_images = list(X_train_en)\n", "\n", "\n", "def generate_caption():\n", " # Select a random image from the validation dataset\n", " sample_img = np.random.choice(valid_images)\n", "\n", " # Read the image from the disk\n", " sample_img = decode_and_resize(sample_img)\n", " img = sample_img.numpy().clip(0, 255).astype(np.uint8)\n", " plt.imshow(img)\n", " plt.show()\n", "\n", " # Pass the image to the CNN\n", " img = tf.expand_dims(sample_img, 0)\n", " img = loaded_model.cnn_model(img)\n", "\n", " # Pass the image features to the Transformer encoder\n", " encoded_img = loaded_model.encoder(img, training=False)\n", "\n", " # Generate the caption using the Transformer decoder\n", " decoded_caption = \" \"\n", " for i in range(max_decoded_sentence_length):\n", " tokenized_caption = vectorization([decoded_caption])\n", " mask = tf.math.not_equal(tokenized_caption, 0)\n", " predictions = loaded_model.decoder(\n", " tokenized_caption, encoded_img, training=False, mask=mask\n", " )\n", " sampled_token_index = np.argmax(predictions[0, i, :])\n", " sampled_token = index_lookup[sampled_token_index]\n", " if sampled_token == \"\":\n", " break\n", " decoded_caption += \" \" + sampled_token\n", "\n", " decoded_caption = decoded_caption.replace(\" \", \"\")\n", " decoded_caption = decoded_caption.replace(\" \", \"\").strip()\n", " print(\"Predicted Caption: \", decoded_caption)\n", "\n", "\n", "# Check predictions for a few samples\n", "generate_caption()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4n5iXcJwwB9-" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }