Spaces:
Sleeping
Sleeping
File size: 6,843 Bytes
c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 c6e5e78 561a912 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
import math
# RMSNorm is a normalization technique that normalizes the input by dividing by the square root of the variance plus a small number to prevent division by zero
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5): # the number of features/dimensions/embeddings in the input, eps is a small number to prevent division by zero
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # weight is a learnable parameter that scales the input
self.eps = eps
def forward(self, x):
norm = x.pow(2).mean(-1, keepdim=True).sqrt() + self.eps # compute the norm of the input
return x / norm * self.weight # normalize the input by dividing by the norm and scale it by the weight parameter
# RotaryEmbedding is a technique that rotates the input by a learnable angle
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, device=None): # dim is the number of features/dimensions/embeddings in the input, base is a base number for the frequency, device is the device to store the buffer
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) # compute the inverse frequency
self.register_buffer("inv_freq", inv_freq) # register the inverse frequency as a buffer
def forward(self, x, seq_len):
seq_len = seq_len.to(x.device) # convert seq_len to the device of the input
t = torch.arange(seq_len, device=x.device) # create a tensor of the sequence length
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # compute the frequency by taking the dot product of the sequence length and the inverse frequency
emb = torch.cat((freqs, freqs), dim=-1) # concatenate the frequency with itself
return emb
class LlamaMLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) # create the gate projection layer with the input dimension and the hidden dimension
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) # create the up projection layer with the input dimension and the hidden dimension
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) # create the down projection layer with the hidden dimension and the output dimension
self.act_fn = nn.SiLU() # create the activation function
def forward(self, x):
gated = self.gate_proj(x) # apply the gate projection to the input
hidden = self.up_proj(x) # apply the up projection to the input
return self.down_proj(self.act_fn(gated * hidden)) # apply the activation function to the gated and hidden values and then apply the down projection
class LlamaAttention(nn.Module):
def __init__(self, dim, num_heads=8,max_seq_len=2048):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
self.register_buffer("bias", torch.tril(torch.ones(max_seq_len, max_seq_len)).view(1, 1, max_seq_len, max_seq_len))
def forward(self, x):
batch_size, seq_len, dim = x.size() # [batch_size, seq_len, dim] -> [4, 128, 576]
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Split heads
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores = scores.masked_fill(self.bias[:, :, :seq_len, :seq_len] == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
context = torch.matmul(attention, v)
# Combine heads
context = context.transpose(1, 2).reshape(batch_size, seq_len, dim)
return self.o_proj(context)
class LlamaDecoderLayer(nn.Module):
def __init__(self, dim, hidden_dim, num_heads,max_position_embeddings):
super().__init__()
self.self_attn = LlamaAttention(dim, num_heads,max_position_embeddings)
self.mlp = LlamaMLP(dim, hidden_dim)
self.input_layernorm = LlamaRMSNorm(dim)
self.post_attention_layernorm = LlamaRMSNorm(dim)
def forward(self, x):
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x)
x = x + residual
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = x + residual
return x
class LlamaModel(nn.Module):
def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads,max_position_embeddings):
super().__init__()
self.embed_tokens = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([
LlamaDecoderLayer(dim, hidden_dim, num_heads,max_position_embeddings) for _ in range(num_layers)
])
self.norm = LlamaRMSNorm(dim)
self.rotary_emb = LlamaRotaryEmbedding(dim)
self.vocab_size = vocab_size
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.wte = nn.Embedding(self.vocab_size, self.dim)
self.wpe = nn.Embedding(self.max_position_embeddings, self.dim)
def forward(self, tokens):
B, T = tokens.size()
assert T <= self.max_position_embeddings, f"Cannot forward sequence of length {T}, block size is only {self.max_position_embeddings}"
pos = torch.arange(0, T, dtype=torch.long, device=tokens.device) # shape (T)
pos_emb = self.wpe(pos) # position embeddings of shape (T, n_embd)
tok_emb = self.wte(tokens) # token embeddings of shape (B, T, n_embd)
x = tok_emb + pos_emb
for layer in self.layers:
x = layer(x)
return self.norm(x)
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
vocab_size = config.vocab_size
dim = config.hidden_size
num_layers = config.num_layers
hidden_dim = config.intermediate_size
num_heads = config.num_attention_heads
max_position_embeddings = config.max_position_embeddings
self.model = LlamaModel(vocab_size, dim, num_layers, hidden_dim, num_heads,max_position_embeddings)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
def forward(self, x):
x = self.model(x)
return self.lm_head(x) |