{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rToK0Tku8PPn" }, "source": [ "## makemore: becoming a backprop ninja" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "ChBbac4y8PPq" }, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt # for making figures\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "klmu3ZG08PPr" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32033\n", "15\n", "['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']\n" ] } ], "source": [ "# read in all the words\n", "words = open('names.txt', 'r').read().splitlines()\n", "print(len(words))\n", "print(max(len(w) for w in words))\n", "print(words[:8])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "BCQomLE_8PPs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}\n", "27\n" ] } ], "source": [ "# build the vocabulary of characters and mappings to/from integers\n", "chars = sorted(list(set(''.join(words))))\n", "stoi = {s:i+1 for i,s in enumerate(chars)}\n", "stoi['.'] = 0\n", "itos = {i:s for s,i in stoi.items()}\n", "vocab_size = len(itos)\n", "print(itos)\n", "print(vocab_size)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "V_zt2QHr8PPs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([182625, 3]) torch.Size([182625])\n", "torch.Size([22655, 3]) torch.Size([22655])\n", "torch.Size([22866, 3]) torch.Size([22866])\n" ] } ], "source": [ "# build the dataset\n", "block_size = 3 # context length: how many characters do we take to predict the next one?\n", "\n", "def build_dataset(words):\n", " X, Y = [], []\n", "\n", " for w in words:\n", " context = [0] * block_size\n", " for ch in w + '.':\n", " ix = stoi[ch]\n", " X.append(context)\n", " Y.append(ix)\n", " context = context[1:] + [ix] # crop and append\n", "\n", " X = torch.tensor(X)\n", " Y = torch.tensor(Y)\n", " print(X.shape, Y.shape)\n", " return X, Y\n", "\n", "import random\n", "random.seed(42)\n", "random.shuffle(words)\n", "n1 = int(0.8*len(words))\n", "n2 = int(0.9*len(words))\n", "\n", "Xtr, Ytr = build_dataset(words[:n1]) # 80%\n", "Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n", "Xte, Yte = build_dataset(words[n2:]) # 10%" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "eg20-vsg8PPt" }, "outputs": [], "source": [ "# ok biolerplate done, now we get to the action:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "MJPU8HT08PPu" }, "outputs": [], "source": [ "# utility function we will use later when comparing manual gradients to PyTorch gradients\n", "def cmp(s, dt, t):\n", " ex = torch.all(dt == t.grad).item()\n", " app = torch.allclose(dt, t.grad)\n", " maxdiff = (dt - t.grad).abs().max().item()\n", " print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "ZlFLjQyT8PPu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4137\n" ] } ], "source": [ "n_embd = 10 # the dimensionality of the character embedding vectors\n", "n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n", "\n", "g = torch.Generator().manual_seed(2147483647) # for reproducibility\n", "C = torch.randn((vocab_size, n_embd), generator=g)\n", "# Layer 1\n", "W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n", "b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n", "# Layer 2\n", "W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n", "b2 = torch.randn(vocab_size, generator=g) * 0.1\n", "# BatchNorm parameters\n", "bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n", "bnbias = torch.randn((1, n_hidden))*0.1\n", "\n", "# Note: I am initializating many of these parameters in non-standard ways\n", "# because sometimes initializating with e.g. all zeros could mask an incorrect\n", "# implementation of the backward pass.\n", "\n", "parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n", "print(sum(p.nelement() for p in parameters)) # number of parameters in total\n", "for p in parameters:\n", " p.requires_grad = True" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "QY-y96Y48PPv" }, "outputs": [], "source": [ "batch_size = 32\n", "n = batch_size # a shorter variable also, for convenience\n", "# construct a minibatch\n", "ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n", "Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "8ofj1s6d8PPv" }, "outputs": [ { "data": { "text/plain": [ "tensor(3.3221, grad_fn=)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n", "\n", "emb = C[Xb] # embed the characters into vectors\n", "embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n", "# Linear layer 1\n", "hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n", "# BatchNorm layer\n", "bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n", "bndiff = hprebn - bnmeani\n", "bndiff2 = bndiff**2\n", "bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n", "bnvar_inv = (bnvar + 1e-5)**-0.5\n", "bnraw = bndiff * bnvar_inv\n", "hpreact = bngain * bnraw + bnbias\n", "# Non-linearity\n", "h = torch.tanh(hpreact) # hidden layer\n", "# Linear layer 2\n", "logits = h @ W2 + b2 # output layer\n", "# cross entropy loss (same as F.cross_entropy(logits, Yb))\n", "logit_maxes = logits.max(1, keepdim=True).values\n", "norm_logits = logits - logit_maxes # subtract max for numerical stability\n", "counts = norm_logits.exp()\n", "counts_sum = counts.sum(1, keepdims=True) #DONE\n", "counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact... #DONE\n", "probs = counts * counts_sum_inv #DONE\n", "logprobs = probs.log() #DONE\n", "loss = -logprobs[range(n), Yb].mean() #DONE\n", "\n", "# PyTorch backward pass\n", "for p in parameters:\n", " p.grad = None\n", "for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n", " norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n", " bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n", " embcat, emb]:\n", " t.retain_grad()\n", "loss.backward()\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---------\n", "\n", "### **EXERCISE 1**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[13:01](https://www.youtube.com/watch?v=q8SA3rM6ckI&t=781s) to [19:05](https://youtu.be/q8SA3rM6ckI?si=mm8M8feWFToF4STA&t=1145) `cmp('logprobs', dlogprobs, logprobs)`" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 27])\n" ] }, { "data": { "text/plain": [ "tensor([-4.0562, -3.0820, -3.6629, -3.2621, -4.1229, -3.4201, -3.2428, -3.9554,\n", " -3.1259, -4.2500, -3.1582, -1.6256, -2.8483, -2.9654, -2.9990, -3.1882,\n", " -3.9132, -3.0643, -3.5065, -3.5153, -2.8832, -3.0837, -4.2941, -4.0007,\n", " -3.4440, -2.9220, -3.1386, -3.8946, -2.6488, -3.5292, -3.3408, -3.1560],\n", " grad_fn=)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(logprobs.shape)\n", "logprobs[range(n), Yb]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32])\n" ] }, { "data": { "text/plain": [ "tensor([ 8, 14, 15, 22, 0, 19, 9, 14, 5, 1, 20, 3, 8, 14, 12, 0, 11, 0,\n", " 26, 9, 25, 0, 1, 1, 7, 18, 9, 3, 5, 9, 0, 18])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(Yb.shape)\n", "Yb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#simple breakdown\n", "#now here we know there are 32 examples, for explaination lets assume we only have 3 in total i.e. a,b,c\n", "\n", "#loss = - (a + b + c) / 3 ==> so we are basically doing the mean calculation here\n", "#loss = - (1/3a + 1/3b + 1/3c) ==> same equation\n", "#so now, when we take the derivative wrt a\n", "#dloss/da = -1/3 ==>where 3 is the number of elements we consider, so we can also say that it is -1/n, therefore\n", "#dloss/dn = -1/n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "logprobs | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "#So based on our calculation above\n", "dlogprobs = torch.zeros_like(logprobs) #same as torch.zeros((32, 27)) as we need the same shape as logprobs. So instead of hardcoding the values we did this\n", "dlogprobs[range(n), Yb] = -1.0/n #as we need to do it for each of the elements in the array\n", "\n", "#Now, lets check\n", "cmp('logprobs', dlogprobs, logprobs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[19:06](https://youtu.be/q8SA3rM6ckI?si=mO61nJLwtQpxsjju&t=1146) to [20:55](https://youtu.be/q8SA3rM6ckI?si=fgJsPGOCdJIIRYC9&t=1255) `cmp('probs', dprobs, probs)`" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "probs | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "dprobs = (1.0/probs) * dlogprobs #we had to take the derivative of logprobs, which was 1/x --> d/dx(log(x)) = 1/x \n", "#then we multiplied it with dlogprobs (the one we calculated before this for the chainrule)\n", "\n", "cmp('probs', dprobs, probs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[20:56](https://youtu.be/q8SA3rM6ckI?si=sNM67lNSfsmUke2Y&t=1256) to [26:21](https://youtu.be/q8SA3rM6ckI?si=5MWGHdf1v-72g5ib&t=1581) `cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)`" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "counts_sum_inv | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# probs = counts * counts_sum_inv, now here before we do the multiplication, take a look at the matrix dimensions using `.shape`\n", "# You would see that `counts` would have 3x3 and `counts_sum_inv` will have 3x1\n", "# So before the backpropagation calculation, there is 'broadcasting' happening where the value of b is been replicated/broadcasted multiple time across the matrix\n", "\n", "# Example\n", "# c = a * b\n", "# a[3x3] * b[3x1] ---->\n", "# a[1,1]*b1 + a[1,2]*b1 + a[1,3]*b1\n", "# a[2,1]*b2 + a[2,2]*b2 + a[2,3]*b2\n", "# a[3,1]*b3 + a[3,2]*b3 + a[2,3]*b3\n", "# ====> c[3x3]\n", "\n", "# The point of this is just to show that there are two things happening internally: The broadcasting and then the backpropagation\n", "\n", "# (first case) The derivative of c wrt b will be a\n", "# So, here just `counts` will remain -> then `dprobs` is multiplied because chain rule.\n", "# Finally, in order to make `dcounts_sum_inv` the same dimension as `counts_sum_inv` we sum all of them by 1 and also keepdims as true\n", "\n", "dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True) # So this is our final manually calcualted equation\n", "\n", "cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[26:26](https://youtu.be/q8SA3rM6ckI?si=TBwv2QkGmkp-d8JR&t=1586) to [27:56](https://youtu.be/q8SA3rM6ckI?si=awbZx9fZ_-WB_q5M&t=1676) first contribution of `counts`" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Here we have to calculate the second half of `dcounts` i.e. (Second case) The derivative of c wrt a will be b\n", "\n", "dcounts = counts_sum_inv * dprobs\n", "\n", "#but we cant compare it yet as `counts` is later depended on top again as well, which we will check" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[27:57](https://youtu.be/q8SA3rM6ckI?si=APAFn28Pf8HVpbM3&t=1677) to [28:59](https://youtu.be/q8SA3rM6ckI?si=O5ja7cEm2xS_yuzN&t=1740) `cmp('counts_sum', dcounts_sum, counts_sum)`" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "counts_sum | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# counts_sum_inv = counts_sum**-1\n", "\n", "# Okay so for this, the derivative of x^-1 is -(x^-2)\n", "\n", "dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv #Remember for this its the one before the `26:26 to 27:56 first contribution of counts` section\n", "\n", "cmp('counts_sum', dcounts_sum, counts_sum)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[29:00](https://youtu.be/q8SA3rM6ckI?si=UsxgAcBfiU5GAHaz&t=1740) to [32:26](https://youtu.be/q8SA3rM6ckI?si=nsXvTD-8RWvUAubq&t=1947) `cmp('counts', dcounts, counts)`" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "counts | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# counts_sum = counts.sum(1, keepdims=True)\n", "\n", "# Now here we know the shape of counts_sum is 32 by 1 and the shape of counts is 32 by 27. So we need to broadcast counts_sum 27 times\n", "# We are dirctly using a PyTorch function where it keeps adding numbers from `counts`\n", "\n", "dcounts += torch.ones_like(counts) * dcounts_sum #Also here we are adding `dcounts` as remember this is the second iteration of it, we had calculated one more value of it at the top\n", "\n", "cmp('counts', dcounts, counts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[32:27](https://youtu.be/q8SA3rM6ckI?si=nsXvTD-8RWvUAubq&t=1947) to [33:13](https://youtu.be/q8SA3rM6ckI?si=Ydk-b_pmKybrrnxe&t=1994) `cmp('norm_logits', dnorm_logits, norm_logits)`" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "norm_logits | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# counts = norm_logits.exp()\n", "\n", "# Now here, the derivative of `norm_logits.exp()`, now the derivate of e^x is (famously) e^x, so its just `norm_logits.exp()` itself\n", "# so we can also just write it as `counts` directly as it holds that value\n", "\n", "dnorm_logits = counts * dcounts\n", "\n", "cmp('norm_logits', dnorm_logits, norm_logits)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[33:14](https://youtu.be/q8SA3rM6ckI?si=GIbBvHKGbW0RvlWf&t=1994) to [36:20](https://youtu.be/q8SA3rM6ckI?si=LGenDRNCeOVsWIkY&t=2180) `cmp('logit_maxes', dlogit_maxes, logit_maxes)`" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "logit_maxes | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# norm_logits = logits - logit_maxes\n", "\n", "# Now here if you would look at the shape of all these variables, you would notice that there is internal broadcasting happening here (logit_maxes)\n", "\n", "dlogits = dnorm_logits.clone()\n", "dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True) #WILL HAVE TO REWATCH THIS PART AGAIN, DIDN'T COMPLETELY GET IT\n", "\n", "cmp('logit_maxes', dlogit_maxes, logit_maxes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[38:27](https://youtu.be/q8SA3rM6ckI?si=sVCg29V84Ua56x3H&t=2307) to [41:44](https://youtu.be/q8SA3rM6ckI?si=yHhzlWlaR9J4VBo_&t=2504) `cmp('logits', dlogits, logits)`" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "logits | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# logit_maxes = logits.max(1, keepdim=True).values\n", "\n", "# Here, this step is similar to that of the first one in `dlogprobs` where we used torch.zeros_like() function\n", "# So we are doing another alternative way of doing that\n", "\n", "dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes #Just remember the += here as we already have one dlogits above\n", "\n", "cmp('logits', dlogits, logits)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[41:45](https://youtu.be/q8SA3rM6ckI?si=wJvhK8v1Hj2sEhc6&t=2505) to [53:25](https://youtu.be/q8SA3rM6ckI?si=xg15htmnJE03afh5&t=3216) `cmp('h', dh, h)`, `cmp('W2', dW2, W2)` and `cmp('b2', db2, b2)` - Bckpropagation through a linear layer\n", "\n", "( Till [49:56](https://youtu.be/q8SA3rM6ckI?si=nX-tCDJWXFHTgqi3&t=2996) had theoritical proofs on the matrix multiplication )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # Linear layer 2\n", "# logits = h @ W2 + b2 # output layer\n", "\n", "# in `b2` broadcasting is happening" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32, 27]),\n", " torch.Size([32, 64]),\n", " torch.Size([64, 27]),\n", " torch.Size([27]))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Need these for understanding the matrix mulitplication why we are multiplying with what\n", "dlogits.shape, h.shape, W2.shape, b2.shape" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "h | exact: True | approximate: True | maxdiff: 0.0\n", "W2 | exact: True | approximate: True | maxdiff: 0.0\n", "b2 | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# watch the last few minutes, probably from 51 to see how he broke down this based on the matrix sizes\n", "dh = dlogits @ W2.T\n", "dW2 = h.T @ dlogits\n", "db2 = dlogits.sum(0)\n", "\n", "cmp('h', dh, h)\n", "cmp('W2', dW2, W2)\n", "cmp('b2', db2, b2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[53:37](https://youtu.be/q8SA3rM6ckI?si=xASEEmeuBmpZwd6B&t=3217) to 55:12 `cmp('hpreact', dhpreact, hpreact)`" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hpreact | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# h = torch.tanh(hpreact) # hidden layer\n", "\n", "dhpreact = (1.0 - h**2)*dh #we saw that the derivative of tanh is also (1-a^2) where a was the external variable `a`, not the input `z` to tanh i.e. a = tanh(z)\n", "\n", "cmp('hpreact', dhpreact, hpreact)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[55:13](https://youtu.be/q8SA3rM6ckI?si=7v0ZQ9alRi52gD9s&t=3313) to 59:38 `cmp('bngain', dbngain, bngain)`" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32, 64]),\n", " torch.Size([1, 64]),\n", " torch.Size([1, 64]),\n", " torch.Size([32, 64]))" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bnraw.shape, bngain.shape, bnbias.shape, dhpreact.shape" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bngain | exact: True | approximate: True | maxdiff: 0.0\n", "bnbias | exact: True | approximate: True | maxdiff: 0.0\n", "bnraw | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# hpreact = bngain * bnraw + bnbias\n", "\n", "dbngain = (bnraw * dhpreact).sum(0, keepdim=True) #because dbraw and dhpreact are 32by64, but dbngain expects 1by64 (we also keep the dimension)\n", "dbnraw = (bngain * dhpreact)\n", "dbnbias = (dhpreact).sum(0, keepdim=True) #because dhpreact is 32by64 but the dbnbias expects 1by64 (we also keep the dimension)\n", "\n", "cmp('bngain', dbngain, bngain)\n", "cmp('bnbias', dbnbias, bnbias)\n", "cmp('bnraw', dbnraw, bnraw)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[59:40](https://youtu.be/q8SA3rM6ckI?si=RNb8T5WGla37958Q&t=3580) to 1:04:1 `cmp('bnvar_inv', dbnvar_inv, bnvar_inv)`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# From here we are working on the batch norm layer\n", "# the code has been spread out and broken down to different parts (based on the equations on the \"bottom right corner box\" in the paper for batch norm - See prev lecture) inorder to perform manual backprop more easily" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32, 64]), torch.Size([32, 64]), torch.Size([1, 64]))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bnraw.shape, bndiff.shape, bnvar_inv.shape" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bnvar_inv | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# bnraw = bndiff * bnvar_inv\n", "\n", "dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n", "dbndiff = bnvar_inv * dbnraw #We will come back to this in 1:12:43 - (1)\n", "\n", "cmp('bnvar_inv', dbnvar_inv, bnvar_inv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:04:15](https://youtu.be/q8SA3rM6ckI?si=Mj6mc99YFmqYxo_l&t=3855) to 1:05:16 `cmp('bnvar', dbnvar, bnvar)`" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bnvar | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# bnvar_inv = (bnvar + 1e-5)**-0.5\n", "#This is a direct equation of derivative of x^n so the output should be n*x^n-1\n", "\n", "dbnvar = (-0.5 * ((bnvar + 1e-5) ** (-1.5))) * dbnvar_inv\n", "\n", "cmp('bnvar', dbnvar, bnvar)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:05:17](https://youtu.be/q8SA3rM6ckI?si=vjAXVF6w3BoZMC04&t=3917) to 1:09:01 - Why he implemented the bessel's correction (as there seem to be some problem/issue in the paper. Using Bias during training time and Unbiased during testing). But we prefer to use Unbiased during both training and testing and that is what we went ahead with." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:09:02](https://youtu.be/q8SA3rM6ckI?si=WxOg7f0S-mqLiZfD&t=4142) to 1:12:42 `cmp('bndiff2', dbndiff2, bndiff2)`" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bndiff2 | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True)\n", "\n", "dbndiff2 = 1/(n-1) * torch.ones_like(bndiff2) * dbnvar\n", "\n", "cmp('bndiff2', dbndiff2, bndiff2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:12:43](https://youtu.be/q8SA3rM6ckI?si=HkT46KjpcZoit33H&t=4363) to 1:13:58 `cmp('bndiff', dbndiff, bndiff)`" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bndiff | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# bndiff2 = bndiff**2\n", "\n", "dbndiff += 2*bndiff * dbndiff2 #This is the (2)nd occurance of dbndiff - 59:40 so, we add it here\n", "\n", "cmp('bndiff', dbndiff, bndiff)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:13:59](https://youtu.be/q8SA3rM6ckI?si=t03BQ_sro2n6X0a2&t=4439) to 1:18:35 `cmp('bnmeani', dbnmeani, bnmeani)` and `cmp('hprebn', dhprebn, hprebn)`" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bnmeani | exact: True | approximate: True | maxdiff: 0.0\n", "hprebn | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "## Please go thorugh this one again, i didnt completely get it\n", "\n", "# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n", "# bndiff = hprebn - bnmeani\n", "\n", "dhprebn = dbndiff.clone() #we are making a copy of it\n", "dbnmeani = (-dbndiff).sum(0)\n", "\n", "dhprebn += (1.0/n)*(torch.ones_like(hprebn) * dbnmeani)\n", "\n", "cmp('bnmeani', dbnmeani, bnmeani)\n", "cmp('hprebn', dhprebn, hprebn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:18:36](https://youtu.be/q8SA3rM6ckI?si=j_uFOOB3AsbrkbwM&t=4716) to 1:20:34 `cmp('embcat', dembcat, embcat)`, `cmp('W1', dW1, W1)` and `cmp('b1', db1, b1)`" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32, 64]),\n", " torch.Size([32, 30]),\n", " torch.Size([30, 64]),\n", " torch.Size([64]))" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hprebn.shape, embcat.shape, W1.shape, b1.shape" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "embcat | exact: True | approximate: True | maxdiff: 0.0\n", "W1 | exact: True | approximate: True | maxdiff: 0.0\n", "b1 | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "# Forward pass: hprebn = embcat @ W1 + b1\n", "\n", "dembcat = dhprebn @ W1.T\n", "dW1 = embcat.T @ dhprebn\n", "db1 = dhprebn.sum(0)\n", "\n", "cmp('embcat', dembcat, embcat)\n", "cmp('W1', dW1, W1)\n", "cmp('b1', db1, b1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:20:35](https://youtu.be/q8SA3rM6ckI?si=F8arFi8ee8a9eAvv&t=4835) to 1:21:58 `cmp('emb', demb, emb)`" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "emb | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "## Please rewatch this as well\n", "\n", "# embcat = emb.view(emb.shape[0], -1)\n", "\n", "demb = dembcat.view(emb.shape)\n", "\n", "cmp('emb', demb, emb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:21:59](https://youtu.be/q8SA3rM6ckI?si=cPimgFWzBgjrkpAr&t=4919) to `cmp('C', dC, C)`" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "C | exact: True | approximate: True | maxdiff: 0.0\n" ] } ], "source": [ "## Please rewatch this as well\n", "# emb = C[Xb]\n", "\n", "dC = torch.zeros_like(C)\n", "for k in range(Xb.shape[0]):\n", " for j in range(Xb.shape[1]):\n", " ix = Xb[k,j]\n", " dC[ix] += demb[k,j]\n", "\n", "cmp('C', dC, C)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we are done with the first exercise!!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mO-8aqxK8PPw" }, "outputs": [], "source": [ "# Exercise 1: backprop through the whole thing manually,\n", "# backpropagating through exactly all of the variables\n", "# as they are defined in the forward pass above, one by one\n", "\n", "# -----------------\n", "# YOUR CODE HERE :)\n", "# -----------------\n", "\n", "# cmp('logprobs', dlogprobs, logprobs)\n", "# cmp('probs', dprobs, probs)\n", "# cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n", "# cmp('counts_sum', dcounts_sum, counts_sum)\n", "# cmp('counts', dcounts, counts)\n", "# cmp('norm_logits', dnorm_logits, norm_logits)\n", "# cmp('logit_maxes', dlogit_maxes, logit_maxes)\n", "# cmp('logits', dlogits, logits)\n", "# cmp('h', dh, h)\n", "# cmp('W2', dW2, W2)\n", "# cmp('b2', db2, b2)\n", "# cmp('hpreact', dhpreact, hpreact)\n", "# cmp('bngain', dbngain, bngain)\n", "# cmp('bnbias', dbnbias, bnbias)\n", "# cmp('bnraw', dbnraw, bnraw)\n", "# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n", "# cmp('bnvar', dbnvar, bnvar)\n", "# cmp('bndiff2', dbndiff2, bndiff2)\n", "# cmp('bndiff', dbndiff, bndiff)\n", "# cmp('bnmeani', dbnmeani, bnmeani)\n", "# cmp('hprebn', dhprebn, hprebn)\n", "# cmp('embcat', dembcat, embcat)\n", "# cmp('W1', dW1, W1)\n", "# cmp('b1', db1, b1)\n", "# cmp('emb', demb, emb)\n", "# cmp('C', dC, C)" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 0 }