File size: 4,809 Bytes
687309c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# original vision transformer from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
# https://einops.rocks/pytorch-examples.html
from einops import rearrange
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# for each batch and each head, multiply each query position (i) with each key position (j), summing over the embedding dimension (d), etc
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class PositionalEncoding(nn.Module):
# https://discuss.pytorch.org/t/positional-encoding/175953
def __init__(self, d_model, max_len=500):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# alternatively adding sign and cos waves of increasing frequency
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
# not x = x + self.pe[:x.size(0), :] since
# x.size(0): batch size whereas x.size(1): length of sequence
x = x + self.pe[:, :x.size(1), :]
return x
class ViT(nn.Module):
"""
input_size: number of inputs
input_dim: number of channels in input
dim: Last dimension of output tensor after linear transformation nn.Linear(..., dim).
depth: Number of Transformer blocks.
heads: Number of heads in Multi-head Attention layer.
mlp_dim: Dimension of the MLP (FeedForward) layer.
dropout: Dropout rate.
emb_dropout: Embedding dropout rate.
pool: either cls token pooling or mean pooling
"""
# * to force keyword-only args
def __init__(self, *, input_dim, output_dim, dim, depth, heads, mlp_dim, pool = 'mean', dim_head = 64, dropout, emb_dropout):
super().__init__()
self.project = nn.Linear(input_dim, dim)
self.pos_encoder = PositionalEncoding(dim)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, output_dim)
)
self.tanh = torch.nn.Tanh()
def forward(self, x):
x = self.project(x)
b, n, _ = x.shape
x = self.pos_encoder(x)
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
return self.tanh(self.mlp_head(x)) |