{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "bD0MZfilv6Cc" }, "outputs": [], "source": [ "import numpy as np\n", "def task_seperate(x,y):\n", "\n", " indx0 = np.where(y==0)[0]\n", " y0 = y[indx0]\n", " x0 = x[indx0,:,:,:]\n", "\n", " indx1 = np.where(y==1)[0]\n", " y1 = y[indx1]\n", " x1 = x[indx1,:,:,:]\n", "\n", " indx2 = np.where(y==2)[0]\n", " y2 = y[indx2]\n", " x2 = x[indx2,:,:,:]\n", "\n", " indx3 = np.where(y==3)[0]\n", " y3 = y[indx3]\n", " x3 = x[indx3,:,:,:]\n", "\n", " indx4 = np.where(y==4)[0]\n", " y4 = y[indx4]\n", " x4 = x[indx4,:,:,:]\n", "\n", " indx5 = np.where(y==5)[0]\n", " y5 = y[indx5]\n", " x5 = x[indx5,:,:,:]\n", "\n", " y_task1 = np.concatenate((y0,y1),axis=0)\n", " x_task1 = np.concatenate((x0,x1),axis=0)\n", "\n", " y_task2 = np.concatenate((y2,y3),axis=0)\n", " x_task2 = np.concatenate((x2,x3),axis=0)\n", "\n", " y_task3 = np.concatenate((y4,y5),axis=0)\n", " x_task3 = np.concatenate((x4,x5),axis=0)\n", "\n", " Y = [y_task1, y_task2, y_task3]\n", " X = [x_task1, x_task2, x_task3]\n", "\n", " return X,Y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Q2qu1tupEhIf" }, "outputs": [], "source": [ "def compile_model(model, learning_rate, extra_losses=None):\n", " def custom_loss(y_true, y_pred):\n", " loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)\n", " if extra_losses is not None:\n", " for fn in extra_losses:\n", " loss += fn(model)\n", "\n", " return loss\n", "\n", " model.compile(\n", " loss=custom_loss,\n", " optimizer=keras.optimizers.Adam(learning_rate=learning_rate),\n", " metrics=[\"accuracy\"]\n", " )\n", "\n", "def report(model, epoch, validation_datasets, batch_size):\n", " result = []\n", " for inputs, labels in validation_datasets:\n", " _, accuracy = model.evaluate(inputs, labels, verbose=0,\n", " batch_size=batch_size)\n", " result.append(\"{:.2f}\".format(accuracy * 100))\n", "\n", " # Add 1: assuming that we report after training has finished for this epoch.\n", " print(epoch + 1, \"\\t\", \"\\t\".join(result))\n", "\n", "def train_epoch(model, train_data, batch_size,\n", " gradient_mask=None, incdet_threshold=None):\n", " \"\"\"Need a custom training loop for when we modify the gradients.\"\"\"\n", " dataset = tf.data.Dataset.from_tensor_slices(train_data)\n", " dataset = dataset.shuffle(len(train_data[0])).batch(batch_size)\n", "\n", " for inputs, labels in dataset:\n", " with tf.GradientTape() as tape:\n", " outputs = model(inputs)\n", " loss = model.compiled_loss(labels, outputs)\n", "\n", " gradients = tape.gradient(loss, model.trainable_weights)\n", "\n", " model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3GlD4yJOvr9n" }, "outputs": [], "source": [ "def fisher_matrix(model, dataset, samples):\n", " \"\"\"\n", " Compute the Fisher matrix, representing the importance of each weight in the\n", " model. This is approximated using the variance of the gradient of each\n", " weight, for some number of samples from the dataset.\n", "\n", " :param model: Model whose Fisher matrix is to be computed.\n", " :param dataset: Dataset which the model has been trained on, but which will\n", " not be seen in the future. Formatted as (inputs, labels).\n", " :param samples: Number of samples to take from the dataset. More samples\n", " gives a better approximation of the true variance.\n", " :return: The main diagonal of the Fisher matrix, shaped to match the weights\n", " returned by `model.trainable_weights`.\n", " \"\"\"\n", " inputs, labels = dataset\n", " weights = model.trainable_weights\n", " variance = [tf.zeros_like(tensor) for tensor in weights]\n", "\n", " for _ in range(samples):\n", " # Select a random element from the dataset.\n", " index = np.random.randint(len(inputs))\n", " data = inputs[index]\n", "\n", " # When extracting from the array we lost a dimension so put it back.\n", " data = tf.expand_dims(data, axis=0)\n", "\n", " # Collect gradients.\n", " with tf.GradientTape() as tape:\n", " output = model(data)\n", " log_likelihood = tf.math.log(output)\n", "\n", " gradients = tape.gradient(log_likelihood, weights)\n", "\n", " # If the model has converged, we can assume that the current weights\n", " # are the mean, and each gradient we see is a deviation. The variance is\n", " # the average of the square of this deviation.\n", " variance = [var + (grad ** 2) for var, grad in zip(variance, gradients)]\n", "\n", " fisher_diagonal = [tensor / samples for tensor in variance]\n", " return fisher_diagonal\n", "\n", "\n", "def ewc_loss(lam, model, dataset, samples):\n", " \"\"\"\n", " Generate a loss function which will penalise divergence from the current\n", " state. It is assumed that the model achieves good accuracy on `dataset`,\n", " and we want to preserve this behaviour.\n", "\n", " The penalty is scaled according to how important each weight is for the\n", " given dataset, and `lam` (lambda) applies equally to all weights.\n", "\n", " :param lam: Weight of this cost function compared to the other losses.\n", " :param model: Model optimised for the given dataset.\n", " :param dataset: NumPy arrays (inputs, labels).\n", " :param samples: Number of samples of dataset to take when estimating\n", " importance of weights. More samples improves estimates.\n", " :return: A loss function.\n", " \"\"\"\n", " optimal_weights = deepcopy(model.trainable_weights)\n", " fisher_diagonal = fisher_matrix(model, dataset, samples)\n", "\n", " def loss_fn(new_model):\n", " # We're computing:\n", " # sum [(lambda / 2) * F * (current weights - optimal weights)^2]\n", " loss = 0\n", " current = new_model.trainable_weights\n", " for f, c, o in zip(fisher_diagonal, current, optimal_weights):\n", " loss += tf.reduce_sum(f * ((c - o) ** 2))\n", "\n", " return loss * (lam / 2)\n", "\n", " return loss_fn\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "collapsed": true, "id": "K8M29Gfrtwwe", "outputId": "97a40ad6-b9ed-432d-c889-e313b0c64bb1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", "11490434/11490434 [==============================] - 0s 0us/step\n", "Model Trained on Task 0\n", "67/67 [==============================] - 0s 3ms/step - loss: 0.0013 - accuracy: 0.9995\n", "Test Accuracy on Task 0 = 1.00\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "WARNING:tensorflow:5 out of the last 1585 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Model Trained on Task 1\n", "67/67 [==============================] - 1s 2ms/step - loss: 4.1902 - accuracy: 0.1225\n", "Test Accuracy on Task 0 = 0.12\n", "64/64 [==============================] - 0s 3ms/step - loss: 0.0728 - accuracy: 0.9799\n", "Test Accuracy on Task 1 = 0.98\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "WARNING:tensorflow:5 out of the last 1513 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Model Trained on Task 2\n", "67/67 [==============================] - 0s 3ms/step - loss: 7.8403 - accuracy: 0.0473\n", "Test Accuracy on Task 0 = 0.05\n", "64/64 [==============================] - 0s 2ms/step - loss: 6.0971 - accuracy: 0.0034\n", "Test Accuracy on Task 1 = 0.00\n", "59/59 [==============================] - 0s 2ms/step - loss: 0.3602 - accuracy: 0.9536\n", "Test Accuracy on Task 2 = 0.95\n" ] } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from keras.datasets import mnist\n", "from copy import deepcopy\n", "\n", "# Hyperparameters\n", "learning_rate = 0.001\n", "epochs = 2\n", "lambda_ewc = 10 # Importance of past tasks\n", "\n", "# Load MNIST dataset\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "x_train = x_train.reshape(-1, 28, 28, 1).astype(\"float32\") / 255.0\n", "x_test = x_test.reshape(-1, 28, 28, 1).astype(\"float32\") / 255.0\n", "\n", "x_train_task, y_train_task = task_seperate(x_train,y_train)\n", "x_test_task, y_test_task = task_seperate(x_test,y_test)\n", "\n", "# Define model (replace with your desired architecture)\n", "model = keras.Sequential([\n", " keras.layers.Flatten(input_shape=(28, 28, 1)),\n", " keras.layers.Dense(128, activation=\"relu\"),\n", " keras.layers.Dense(10, activation=\"softmax\")\n", "])\n", "\n", "# Compile model with Adam optimizer\n", "compile_model(model, learning_rate)\n", "\n", "regularisers = []\n", "\n", "for task in range(3):\n", " inputs = x_train_task[task]\n", " labels = y_train_task[task]\n", "\n", " for epoch in range(epochs):\n", " train_epoch(model, (inputs, labels), batch_size=64)\n", " valid_sets = [(x_test_task[task], y_test_task[task])]\n", "\n", " print(f\"Model Trained on Task {task}\")\n", "\n", "\n", " for iTask in range(task+1):\n", " test_loss, test_acc = model.evaluate(x_test_task[iTask], y_test_task[iTask])\n", " print(f\"Test Accuracy on Task {iTask} = {test_acc:.2f}\")\n", "\n", " loss_fn = ewc_loss(lambda_ewc, model, (inputs, labels), x_train_task[task].shape[0])\n", " regularisers.append(loss_fn)\n", " compile_model(model, learning_rate, extra_losses=regularisers)\n", "\n" ] } ], "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }