{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This exercise wasn't exactly smooth sailing for me, but I did try to understand most of it. Will try to come back to this whenever I can" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# there no change change in the first several cells from last lecture\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt # for making figures\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# download the names.txt file from github\n", "!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# read in all the words\n", "words = open('names.txt', 'r').read().splitlines()\n", "# print(len(words))\n", "# print(max(len(w) for w in words))\n", "# print(words[:8])\n", "\n", "# build the vocabulary of characters and mappings to/from integers\n", "chars = sorted(list(set(''.join(words))))\n", "stoi = {s:i+1 for i,s in enumerate(chars)}\n", "stoi['.'] = 0\n", "itos = {i:s for s,i in stoi.items()}\n", "vocab_size = len(itos)\n", "# print(itos)\n", "# print(vocab_size)\n", "\n", "# build the dataset\n", "block_size = 3 # context length: how many characters do we take to predict the next one?\n", "\n", "def build_dataset(words):\n", " X, Y = [], []\n", "\n", " for w in words:\n", " context = [0] * block_size\n", " for ch in w + '.':\n", " ix = stoi[ch]\n", " X.append(context)\n", " Y.append(ix)\n", " context = context[1:] + [ix] # crop and append\n", "\n", " X = torch.tensor(X)\n", " Y = torch.tensor(Y)\n", " # print(X.shape, Y.shape)\n", " return X, Y\n", "\n", "import random\n", "random.seed(42)\n", "random.shuffle(words)\n", "n1 = int(0.8*len(words))\n", "n2 = int(0.9*len(words))\n", "\n", "Xtr, Ytr = build_dataset(words[:n1]) # 80%\n", "Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n", "Xte, Yte = build_dataset(words[n2:]) # 10%" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# utility function we will use later when comparing manual gradients to PyTorch gradients\n", "def cmp(s, dt, t):\n", " ex = torch.all(dt == t.grad).item()\n", " app = torch.allclose(dt, t.grad)\n", " maxdiff = (dt - t.grad).abs().max().item()\n", " print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4137\n" ] } ], "source": [ "n_embd = 10 # the dimensionality of the character embedding vectors\n", "n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n", "\n", "g = torch.Generator().manual_seed(2147483647) # for reproducibility\n", "C = torch.randn((vocab_size, n_embd), generator=g)\n", "# Layer 1\n", "W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n", "b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n", "# Layer 2\n", "W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n", "b2 = torch.randn(vocab_size, generator=g) * 0.1\n", "# BatchNorm parameters\n", "bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n", "bnbias = torch.randn((1, n_hidden))*0.1\n", "\n", "# Note: I am initializating many of these parameters in non-standard ways\n", "# because sometimes initializating with e.g. all zeros could mask an incorrect\n", "# implementation of the backward pass.\n", "\n", "parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n", "print(sum(p.nelement() for p in parameters)) # number of parameters in total\n", "for p in parameters:\n", " p.requires_grad = True" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "n = batch_size # a shorter variable also, for convenience\n", "# construct a minibatch\n", "ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n", "Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3.3479, grad_fn=)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n", "\n", "emb = C[Xb] # embed the characters into vectors\n", "embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n", "# Linear layer 1\n", "hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n", "# BatchNorm layer\n", "bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n", "bndiff = hprebn - bnmeani\n", "bndiff2 = bndiff**2\n", "bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n", "bnvar_inv = (bnvar + 1e-5)**-0.5\n", "bnraw = bndiff * bnvar_inv\n", "hpreact = bngain * bnraw + bnbias\n", "# Non-linearity\n", "h = torch.tanh(hpreact) # hidden layer\n", "# Linear layer 2\n", "logits = h @ W2 + b2 # output layer\n", "# cross entropy loss (same as F.cross_entropy(logits, Yb))\n", "logit_maxes = logits.max(1, keepdim=True).values\n", "norm_logits = logits - logit_maxes # subtract max for numerical stability\n", "counts = norm_logits.exp()\n", "counts_sum = counts.sum(1, keepdims=True)\n", "counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n", "probs = counts * counts_sum_inv\n", "logprobs = probs.log()\n", "loss = -logprobs[range(n), Yb].mean()\n", "\n", "# PyTorch backward pass\n", "for p in parameters:\n", " p.grad = None\n", "for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n", " norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n", " bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n", " embcat, emb]:\n", " t.retain_grad()\n", "loss.backward()\n", "loss" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "#The entire Exercise 1 implementation combined\n", "\n", "dlogprobs = torch.zeros_like(logprobs)\n", "dlogprobs[range(n), Yb] = -1.0/n\n", "dprobs = (1.0 / probs) * dlogprobs\n", "dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)\n", "dcounts = counts_sum_inv * dprobs\n", "dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv\n", "dcounts += torch.ones_like(counts) * dcounts_sum\n", "dnorm_logits = counts * dcounts\n", "dlogits = dnorm_logits.clone()\n", "dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)\n", "dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes\n", "dh = dlogits @ W2.T\n", "dW2 = h.T @ dlogits\n", "db2 = dlogits.sum(0)\n", "dhpreact = (1.0 - h**2) * dh\n", "dbngain = (bnraw * dhpreact).sum(0, keepdim=True)\n", "dbnraw = bngain * dhpreact\n", "dbnbias = dhpreact.sum(0, keepdim=True)\n", "dbndiff = bnvar_inv * dbnraw\n", "dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n", "dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n", "dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n", "dbndiff += (2*bndiff) * dbndiff2\n", "dhprebn = dbndiff.clone()\n", "dbnmeani = (-dbndiff).sum(0)\n", "dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n", "dembcat = dhprebn @ W1.T\n", "dW1 = embcat.T @ dhprebn\n", "db1 = dhprebn.sum(0)\n", "demb = dembcat.view(emb.shape)\n", "dC = torch.zeros_like(C)\n", "for k in range(Xb.shape[0]):\n", " for j in range(Xb.shape[1]):\n", " ix = Xb[k,j]\n", " dC[ix] += demb[k,j]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar boiler plate codes as done in the prev exercise and provided in the starter code^\n", "\n", "------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:36:38](https://youtu.be/q8SA3rM6ckI?si=Lo5Ly5jApvwIBfy9&t=6516) to 1:48:35 - Pen and Paper derivation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[1:48:36](https://youtu.be/q8SA3rM6ckI?si=Lo5Ly5jApvwIBfy9&t=6516) to - Implementation of the derivation in code" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "max diff: tensor(7.1526e-07, grad_fn=)\n" ] } ], "source": [ "# Exercise 3: backprop through batchnorm but all in one go\n", "# to complete this challenge look at the mathematical expression of the output of batchnorm,\n", "# take the derivative w.r.t. its input, simplify the expression, and just write it out\n", "# BatchNorm paper: https://arxiv.org/abs/1502.03167\n", "\n", "# forward pass\n", "\n", "# before:\n", "# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n", "# bndiff = hprebn - bnmeani\n", "# bndiff2 = bndiff**2\n", "# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n", "# bnvar_inv = (bnvar + 1e-5)**-0.5\n", "# bnraw = bndiff * bnvar_inv\n", "# hpreact = bngain * bnraw + bnbias\n", "\n", "# now:\n", "hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias\n", "print('max diff:', (hpreact_fast - hpreact).abs().max())" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hprebn | exact: False | approximate: True | maxdiff: 9.313225746154785e-10\n" ] } ], "source": [ "# backward pass\n", "\n", "# before we had:\n", "# dbnraw = bngain * dhpreact\n", "# dbndiff = bnvar_inv * dbnraw\n", "# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n", "# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n", "# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n", "# dbndiff += (2*bndiff) * dbndiff2\n", "# dhprebn = dbndiff.clone()\n", "# dbnmeani = (-dbndiff).sum(0)\n", "# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n", "\n", "# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)\n", "# (you'll also need to use some of the variables from the forward pass up above)\n", "\n", "#This is a direct implementation of what sensei did, as he said in the video this equation itself has a lot of breakdown steps to be considered\n", "#But this is what we come up with at the end\n", "dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))\n", "\n", "cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 2 }