PyTorch
English
nanogpt
custom_code
Eval Results
burtenshaw HF Staff commited on
Commit
788c379
·
verified ·
1 Parent(s): da16226

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .configuration_nanogpt import NanoGPTConfig
2
+ from .modeling_nanogpt import NanoGPTModel, NanoGPTChat
3
+ from .tokenizer_nanogpt import NanoGPTTokenizer, NanoGPTChatTokenizer
4
+
5
+
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "nanogpt",
3
+ "architectures": [
4
+ "NanoGPTChat"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_nanogpt.NanoGPTConfig",
8
+ "AutoModel": "modeling_nanogpt.NanoGPTChat",
9
+ "AutoModelForCausalLM": "modeling_nanogpt.NanoGPTChat",
10
+ "AutoTokenizer": "tokenizer_nanogpt.NanoGPTChatTokenizer"
11
+ },
12
+ "bos_token": "<|bos|>",
13
+ "eos_token": "<|assistant_end|>",
14
+ "pad_token": "<|assistant_end|>",
15
+ "sequence_len": 2048,
16
+ "vocab_size": 65536,
17
+ "n_layer": 20,
18
+ "n_head": 10,
19
+ "n_kv_head": 10,
20
+ "n_embd": 1280,
21
+ "chat_template": "{% if messages[0]['role'] == 'system' %}<|bos|><|user_start|>{{ messages[0]['content'] }}\n\n{{ messages[1]['content'] }}<|user_end|>{% set messages = messages[2:] %}{% else %}<|bos|>{% endif %}{% for message in messages %}{% if loop.index0 % 2 == 0 %}<|user_start|>{{ message['content'] }}<|user_end|>{% else %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}"
22
+ }
configuration_nanogpt.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class NanoGPTConfig(PretrainedConfig):
5
+ model_type = "nanogpt"
6
+
7
+ def __init__(
8
+ self,
9
+ sequence_len: int = 1024,
10
+ vocab_size: int = 50304,
11
+ n_layer: int = 12,
12
+ n_head: int = 6,
13
+ n_kv_head: int = 6,
14
+ n_embd: int = 768,
15
+ bos_token_id: int = 0,
16
+ eos_token_id: int = 1,
17
+ pad_token_id: int = 1,
18
+ **kwargs,
19
+ ):
20
+ self.sequence_len = sequence_len
21
+ self.vocab_size = vocab_size
22
+ self.n_layer = n_layer
23
+ self.n_head = n_head
24
+ self.n_kv_head = n_kv_head
25
+ self.n_embd = n_embd
26
+ super().__init__(
27
+ bos_token_id=bos_token_id,
28
+ eos_token_id=eos_token_id,
29
+ pad_token_id=pad_token_id,
30
+ **kwargs,
31
+ )
32
+
33
+
34
+
d20/meta_000060.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000120.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000180.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000240.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000300.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000360.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000420.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/meta_000466.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "sequence_len": 2048,
4
+ "vocab_size": 65536,
5
+ "n_layer": 20,
6
+ "n_head": 10,
7
+ "n_kv_head": 10,
8
+ "n_embd": 1280
9
+ }
10
+ }
d20/model_000060.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba1199debccab5e267bceaf87e4dfc0ecc479ae920f8867b4700bbbd52200bd6
3
+ size 2076230219
d20/model_000120.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:809d694a727c414d173dca22ee02333b2eb2fee522fe0d1dabec21518224e2cc
3
+ size 2076230219
d20/model_000180.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:703f5049b2e3804a3e2cc55a3af444820ada278a42771d4ef5e679da80fa8a88
3
+ size 2076230219
d20/model_000240.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e58a748174610eaa914aaea9372a47f5c501af757bd7c306d28ac795539d7a68
3
+ size 2076230219
d20/model_000300.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a03c2d2143f33c412e502a75dad2e505da01d04aa61ee5d4bab2c2c6a99669d
3
+ size 2076230219
d20/model_000360.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dd3c35b34043ef1571ef81ca644d41e3e4f8aa722fffe47f4dd6a9eee9a5684
3
+ size 2076230219
d20/model_000420.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03a9c4557759a64c5ce1322107b3d57a295092d6ee62f994d1aec61fcb08d4e7
3
+ size 2076230219
d20/model_000466.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a7abb8f892b7aa004f3a54ac54988871e12d099996c371806708f5e9a0bea3c
3
+ size 2076230219
meta_000650.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "step": 650,
3
+ "val_loss": 1.0664211511611938,
4
+ "mmlu_acc": 0.3623046875,
5
+ "arc_easy_acc": 0.419921875,
6
+ "gsm8k_acc": 0.03125,
7
+ "humaneval_acc": 0.03125,
8
+ "model_config": {
9
+ "sequence_len": 2048,
10
+ "vocab_size": 65536,
11
+ "n_layer": 20,
12
+ "n_head": 10,
13
+ "n_kv_head": 10,
14
+ "n_embd": 1280,
15
+ "bos_token": "<|bos|>",
16
+ "eos_token": "<|assistant_end|>",
17
+ "pad_token": "<|assistant_end|>",
18
+ "chat_template": "{% if messages[0]['role'] == 'system' %}<|bos|><|user_start|>{{ messages[0]['content'] }}\n\n{{ messages[1]['content'] }}<|user_end|>{% set messages = messages[2:] %}{% else %}<|bos|>{% endif %}{% for message in messages %}{% if loop.index0 % 2 == 0 %}<|user_start|>{{ message['content'] }}<|user_end|>{% else %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}"
19
+ }
20
+ }
model_000650.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff2eee182e2aa396615d3b481ddf17884a7dbabb3caa47f66eede343135accff
3
+ size 2076230219
modeling_nanogpt.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import math
3
+ import os
4
+ import shutil
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from huggingface_hub import snapshot_download
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+
14
+ from .configuration_nanogpt import NanoGPTConfig
15
+
16
+
17
+ def _rms_norm(x: torch.Tensor) -> torch.Tensor:
18
+ return F.rms_norm(x, (x.size(-1),))
19
+
20
+
21
+ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
22
+ assert x.ndim == 4
23
+ d = x.shape[3] // 2
24
+ x1, x2 = x[..., :d], x[..., d:]
25
+ y1 = x1 * cos + x2 * sin
26
+ y2 = x1 * (-sin) + x2 * cos
27
+ out = torch.cat([y1, y2], 3)
28
+ return out.to(x.dtype)
29
+
30
+
31
+ def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
32
+ if n_rep == 1:
33
+ return x
34
+ bs, n_kv_heads, slen, head_dim = x.shape
35
+ return (
36
+ x[:, :, None, :, :]
37
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
38
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
39
+ )
40
+
41
+
42
+ class CausalSelfAttention(nn.Module):
43
+ def __init__(self, config: NanoGPTConfig, layer_idx: int):
44
+ super().__init__()
45
+ self.layer_idx = layer_idx
46
+ self.n_head = config.n_head
47
+ self.n_kv_head = config.n_kv_head
48
+ self.n_embd = config.n_embd
49
+ self.head_dim = self.n_embd // self.n_head
50
+ assert self.n_embd % self.n_head == 0
51
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
52
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
53
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
54
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
55
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
56
+
57
+ def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
58
+ B, T, C = x.size()
59
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
60
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
61
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
62
+ cos, sin = cos_sin
63
+ q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
64
+ q, k = _rms_norm(q), _rms_norm(k)
65
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
66
+ Tq = q.size(2)
67
+ Tk = k.size(2)
68
+ nrep = self.n_head // self.n_kv_head
69
+ k, v = _repeat_kv(k, nrep), _repeat_kv(v, nrep)
70
+ if Tq == Tk:
71
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
72
+ elif Tq == 1:
73
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
74
+ else:
75
+ attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
76
+ prefix_len = Tk - Tq
77
+ if prefix_len > 0:
78
+ attn_mask[:, :prefix_len] = True
79
+ attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
80
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
81
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
82
+ y = self.c_proj(y)
83
+ return y
84
+
85
+ def forward_with_cache(
86
+ self,
87
+ x: torch.Tensor,
88
+ cos_sin,
89
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ use_cache: bool = False,
92
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
93
+ B, T, _ = x.size()
94
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
95
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
96
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
97
+ cos, sin = cos_sin
98
+ q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
99
+ q, k = _rms_norm(q), _rms_norm(k)
100
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
101
+
102
+ if past_key_value is not None:
103
+ past_k, past_v = past_key_value
104
+ if past_k is not None and past_v is not None:
105
+ k = torch.cat([past_k, k], dim=2)
106
+ v = torch.cat([past_v, v], dim=2)
107
+
108
+ present = (k, v) if use_cache else None
109
+
110
+ Tq = q.size(2)
111
+ Tk = k.size(2)
112
+ nrep = self.n_head // self.n_kv_head
113
+ k_rep = _repeat_kv(k, nrep)
114
+ v_rep = _repeat_kv(v, nrep)
115
+
116
+ attn_mask = None
117
+ if attention_mask is not None:
118
+ attn_mask = attention_mask.to(dtype=torch.bool, device=q.device)
119
+ if attn_mask.dim() == 2:
120
+ attn_mask = attn_mask[:, None, None, :]
121
+ elif attn_mask.dim() == 4:
122
+ pass
123
+ else:
124
+ raise ValueError("Unsupported attention_mask dimensions")
125
+ if attn_mask.size(-1) != Tk:
126
+ attn_mask = torch.nn.functional.pad(attn_mask, (Tk - attn_mask.size(-1), 0))
127
+ attn_mask = (~attn_mask).to(dtype=q.dtype) * -1e4
128
+
129
+ if Tq == Tk:
130
+ y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=True)
131
+ else:
132
+ y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=False)
133
+
134
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
135
+ y = self.c_proj(y)
136
+ return y, present
137
+
138
+
139
+ class MLP(nn.Module):
140
+ def __init__(self, config: NanoGPTConfig):
141
+ super().__init__()
142
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
143
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ x = self.c_fc(x)
147
+ x = F.relu(x).square()
148
+ x = self.c_proj(x)
149
+ return x
150
+
151
+
152
+ class Block(nn.Module):
153
+ def __init__(self, config: NanoGPTConfig, layer_idx: int):
154
+ super().__init__()
155
+ self.attn = CausalSelfAttention(config, layer_idx)
156
+ self.mlp = MLP(config)
157
+
158
+ def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
159
+ x = x + self.attn(_rms_norm(x), cos_sin, kv_cache)
160
+ x = x + self.mlp(_rms_norm(x))
161
+ return x
162
+
163
+
164
+ class NanoGPTModel(PreTrainedModel):
165
+ config_class = NanoGPTConfig
166
+
167
+ _CANONICAL_WEIGHT_NAMES = (
168
+ "pytorch_model.bin",
169
+ "model.safetensors",
170
+ "model.ckpt.index",
171
+ "tf_model.h5",
172
+ "flax_model.msgpack",
173
+ )
174
+ _PT_PATTERN = "model_*.pt"
175
+
176
+ @classmethod
177
+ def _snapshot_kwargs(cls, source_kwargs: Dict) -> Dict:
178
+ keys = {
179
+ "cache_dir",
180
+ "force_download",
181
+ "local_files_only",
182
+ "proxies",
183
+ "resume_download",
184
+ "revision",
185
+ "token",
186
+ "use_auth_token",
187
+ }
188
+ return {k: source_kwargs[k] for k in keys if k in source_kwargs}
189
+
190
+ @classmethod
191
+ def _resolve_checkpoint_dir(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
192
+ if os.path.isdir(pretrained_model_name_or_path):
193
+ base_dir = pretrained_model_name_or_path
194
+ else:
195
+ snapshot_params = cls._snapshot_kwargs(kwargs)
196
+ token = snapshot_params.pop("token", None)
197
+ if token is None:
198
+ token = snapshot_params.pop("use_auth_token", None)
199
+ if token is not None:
200
+ snapshot_params["token"] = token
201
+ base_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_params)
202
+ if subfolder:
203
+ base_dir = os.path.join(base_dir, subfolder)
204
+ cls._ensure_canonical_weights(base_dir)
205
+ return base_dir
206
+
207
+ @classmethod
208
+ def _ensure_canonical_weights(cls, checkpoint_dir):
209
+ for name in cls._CANONICAL_WEIGHT_NAMES:
210
+ candidate = os.path.join(checkpoint_dir, name)
211
+ if os.path.isfile(candidate):
212
+ return candidate
213
+ pt_candidates = sorted(
214
+ glob.glob(os.path.join(checkpoint_dir, cls._PT_PATTERN)),
215
+ reverse=True,
216
+ )
217
+ if not pt_candidates:
218
+ raise FileNotFoundError(
219
+ f"No checkpoint weights found in {checkpoint_dir}. Expected one of {cls._CANONICAL_WEIGHT_NAMES} "
220
+ f"or files matching {cls._PT_PATTERN}."
221
+ )
222
+ source_path = pt_candidates[0]
223
+ target_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
224
+ if (
225
+ not os.path.isfile(target_path)
226
+ or os.path.getmtime(source_path) > os.path.getmtime(target_path)
227
+ ):
228
+ shutil.copyfile(source_path, target_path)
229
+ return target_path
230
+
231
+ def __init__(self, config: NanoGPTConfig):
232
+ super().__init__(config)
233
+ config.use_cache = getattr(config, "use_cache", True)
234
+ config.num_hidden_layers = config.n_layer
235
+ config.num_attention_heads = config.n_head
236
+ config.hidden_size = config.n_embd
237
+ self.transformer = nn.ModuleDict({
238
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
239
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
240
+ })
241
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
242
+ self.rotary_seq_len = config.sequence_len * 10
243
+ head_dim = config.n_embd // config.n_head
244
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
245
+ self.register_buffer("cos", cos, persistent=False)
246
+ self.register_buffer("sin", sin, persistent=False)
247
+ # ensure fp32 activations
248
+ self.transformer.wte.to(dtype=torch.bfloat16)
249
+
250
+ # following HF API expectations
251
+ self.post_init()
252
+
253
+ def _init_weights(self, module: nn.Module):
254
+ if isinstance(module, nn.Linear):
255
+ fan_out = module.weight.size(0)
256
+ fan_in = module.weight.size(1)
257
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
258
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
259
+ if module.bias is not None:
260
+ torch.nn.init.zeros_(module.bias)
261
+ elif isinstance(module, nn.Embedding):
262
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
263
+
264
+ def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None):
265
+ if device is None:
266
+ device = self.transformer.wte.weight.device
267
+ # Handle meta device case - use CPU as fallback
268
+ if device.type == 'meta':
269
+ device = torch.device('cpu')
270
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
271
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
272
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
273
+ freqs = torch.outer(t, inv_freq)
274
+ cos, sin = freqs.cos(), freqs.sin()
275
+ cos, sin = cos.bfloat16(), sin.bfloat16()
276
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
277
+ return cos, sin
278
+
279
+ def _apply_softcap(self, logits: torch.Tensor) -> torch.Tensor:
280
+ softcap = 15
281
+ return softcap * torch.tanh(logits / softcap)
282
+
283
+ def _forward_impl(self, idx: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
284
+ x = self.transformer.wte(idx)
285
+ x = x.float()
286
+ x = _rms_norm(x)
287
+ for block in self.transformer.h:
288
+ x = block(x, cos_sin, kv_cache)
289
+ x = _rms_norm(x)
290
+ logits = self.lm_head(x)
291
+ return self._apply_softcap(logits)
292
+
293
+ def forward(self, input_ids: torch.Tensor, labels=None, loss_reduction: str = 'mean', **kwargs):
294
+ idx = input_ids
295
+ B, T = idx.size()
296
+ T0 = 0
297
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
298
+ logits = self._forward_impl(idx, cos_sin, kv_cache=None)
299
+ loss = None
300
+ if labels is not None:
301
+ loss = F.cross_entropy(
302
+ logits.view(-1, logits.size(-1)),
303
+ labels.view(-1),
304
+ ignore_index=-1,
305
+ reduction=loss_reduction,
306
+ )
307
+ return {"loss": loss, "logits": logits}
308
+
309
+
310
+ class NanoGPTChat(NanoGPTModel):
311
+ """Chat-optimized variant with HF-friendly generate and support for KV cache."""
312
+
313
+ def __init__(self, config: NanoGPTConfig):
314
+ super().__init__(config)
315
+ self.use_cache = getattr(config, "use_cache", True)
316
+
317
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
318
+ if past_key_values is not None:
319
+ input_ids = input_ids[:, -1:]
320
+ return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
321
+
322
+ def _expand_past_length(self, past_key_values):
323
+ if not past_key_values:
324
+ return 0
325
+ past_k, _ = past_key_values[0]
326
+ if past_k is None:
327
+ return 0
328
+ return past_k.size(2)
329
+
330
+ def forward(
331
+ self,
332
+ input_ids: torch.Tensor,
333
+ attention_mask: Optional[torch.Tensor] = None,
334
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
335
+ use_cache: Optional[bool] = None,
336
+ labels: Optional[torch.Tensor] = None,
337
+ loss_reduction: str = "mean",
338
+ **kwargs,
339
+ ) -> CausalLMOutputWithPast:
340
+ idx = input_ids
341
+ B, T = idx.size()
342
+ use_cache = self.use_cache if use_cache is None else use_cache
343
+ past_length = self._expand_past_length(past_key_values)
344
+ T0 = past_length
345
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
346
+
347
+ x = self.transformer.wte(idx)
348
+ x = x.float()
349
+ x = _rms_norm(x)
350
+
351
+ presents = [] if use_cache else None
352
+ for layer_idx, block in enumerate(self.transformer.h):
353
+ past = None
354
+ if past_key_values is not None and past_key_values[layer_idx] is not None:
355
+ past = past_key_values[layer_idx]
356
+ attn_output, present = block.attn.forward_with_cache(
357
+ _rms_norm(x),
358
+ cos_sin,
359
+ past_key_value=past,
360
+ attention_mask=attention_mask,
361
+ use_cache=use_cache,
362
+ )
363
+ x = x + attn_output
364
+ x = x + block.mlp(_rms_norm(x))
365
+ if use_cache:
366
+ presents.append(present)
367
+
368
+ x = _rms_norm(x)
369
+ logits = self.lm_head(x)
370
+ loss = None
371
+ if labels is not None:
372
+ loss = F.cross_entropy(
373
+ logits.view(-1, logits.size(-1)),
374
+ labels.view(-1),
375
+ ignore_index=-1,
376
+ reduction=loss_reduction,
377
+ )
378
+
379
+ return CausalLMOutputWithPast(
380
+ loss=loss,
381
+ logits=logits,
382
+ past_key_values=presents,
383
+ )
384
+
385
+
386
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff2eee182e2aa396615d3b481ddf17884a7dbabb3caa47f66eede343135accff
3
+ size 2076230219
token_bytes.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1b6cdee5d02fe1018b2b1d2ae5b736be665f9c0e7d10c81dcf935e7efaf8cb5
3
+ size 263721
tokenizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8467414b90511a50c4dac438af25c075817e9d62d799a5ef613b186c977f5d1b
3
+ size 846518
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenizer_nanogpt.NanoGPTChatTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "tokenizer_class": "NanoGPTChatTokenizer",
9
+ "chat_template": "{% if messages[0]['role'] == 'system' %}<|bos|><|user_start|>{{ messages[0]['content'] }}\n\n{{ messages[1]['content'] }}<|user_end|>{% set messages = messages[2:] %}{% else %}<|bos|>{% endif %}{% for message in messages %}{% if loop.index0 % 2 == 0 %}<|user_start|>{{ message['content'] }}<|user_end|>{% else %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant_start|>{% endif %}"
10
+ }
tokenizer_nanogpt.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import shutil
4
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
5
+
6
+ from huggingface_hub import hf_hub_download, snapshot_download
7
+ from huggingface_hub.utils import HfHubHTTPError
8
+ from transformers import PreTrainedTokenizer
9
+
10
+
11
+ class _BaseNanoGPTTokenizer:
12
+ """Lightweight wrapper used by the base (non-chat) checkpoints."""
13
+
14
+ special_tokens = {
15
+ "bos": "<|bos|>",
16
+ "user_start": "<|user_start|>",
17
+ "user_end": "<|user_end|>",
18
+ "assistant_start": "<|assistant_start|>",
19
+ "assistant_end": "<|assistant_end|>",
20
+ "python_start": "<|python_start|>",
21
+ "python_end": "<|python_end|>",
22
+ "output_start": "<|output_start|>",
23
+ "output_end": "<|output_end|>",
24
+ }
25
+
26
+ def __init__(self, enc):
27
+ self.enc = enc
28
+ self.bos_token_id = enc.encode_single_token(self.special_tokens["bos"])
29
+
30
+ @classmethod
31
+ def register_for_auto_class(cls, auto_class="AutoTokenizer"):
32
+ pass
33
+
34
+ @classmethod
35
+ def _load_encoding(cls, pretrained_model_name_or_path, **kwargs):
36
+ subfolder = kwargs.get("subfolder")
37
+ base_path = (
38
+ os.path.join(pretrained_model_name_or_path, subfolder)
39
+ if subfolder
40
+ else pretrained_model_name_or_path
41
+ )
42
+ local_tok_path = os.path.join(base_path, "tokenizer.pkl")
43
+ if os.path.isfile(local_tok_path):
44
+ with open(local_tok_path, "rb") as f:
45
+ return pickle.load(f)
46
+
47
+ snapshot_kwargs = {k: kwargs[k] for k in kwargs if k in {
48
+ "cache_dir",
49
+ "force_download",
50
+ "local_files_only",
51
+ "proxies",
52
+ "resume_download",
53
+ "revision",
54
+ "token",
55
+ "use_auth_token",
56
+ }}
57
+ token = snapshot_kwargs.pop("token", None)
58
+ if token is None:
59
+ token = snapshot_kwargs.pop("use_auth_token", None)
60
+ if token is not None:
61
+ snapshot_kwargs["token"] = token
62
+
63
+ snapshot_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_kwargs)
64
+ tok_path = os.path.join(snapshot_dir, subfolder, "tokenizer.pkl") if subfolder else os.path.join(snapshot_dir, "tokenizer.pkl")
65
+ if not os.path.isfile(tok_path):
66
+ try:
67
+ tok_path = hf_hub_download(
68
+ repo_id=pretrained_model_name_or_path,
69
+ filename="tokenizer.pkl",
70
+ subfolder=subfolder,
71
+ **snapshot_kwargs,
72
+ )
73
+ except (HfHubHTTPError, OSError) as e:
74
+ raise ValueError(
75
+ f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. "
76
+ f"Make sure the path exists or the repo is accessible on the Hub."
77
+ ) from e
78
+ with open(tok_path, "rb") as f:
79
+ return pickle.load(f)
80
+
81
+ @classmethod
82
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
83
+ enc = cls._load_encoding(pretrained_model_name_or_path, **kwargs)
84
+ return cls(enc)
85
+
86
+ def encode(self, text, prepend=None):
87
+ ids = self.enc.encode_ordinary(text)
88
+ if prepend is not None:
89
+ prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
90
+ ids.insert(0, prepend_id)
91
+ return ids
92
+
93
+ def decode(self, ids):
94
+ return self.enc.decode(ids)
95
+
96
+ def get_bos_token_id(self):
97
+ return self.bos_token_id
98
+
99
+ def encode_special(self, token):
100
+ return self.enc.encode_single_token(token)
101
+
102
+
103
+ class NanoGPTTokenizer(_BaseNanoGPTTokenizer):
104
+ pass
105
+
106
+
107
+ class NanoGPTChatTokenizer(PreTrainedTokenizer):
108
+ """Transformers-compatible tokenizer with chat helpers."""
109
+
110
+ vocab_files_names = {"vocab_file": "tokenizer.pkl"}
111
+ model_input_names = ["input_ids"]
112
+
113
+ _special_tokens = {
114
+ "bos": "<|bos|>",
115
+ "user_start": "<|user_start|>",
116
+ "user_end": "<|user_end|>",
117
+ "assistant_start": "<|assistant_start|>",
118
+ "assistant_end": "<|assistant_end|>",
119
+ "python_start": "<|python_start|>",
120
+ "python_end": "<|python_end|>",
121
+ "output_start": "<|output_start|>",
122
+ "output_end": "<|output_end|>",
123
+ }
124
+
125
+ def __init__(
126
+ self,
127
+ vocab_file: str,
128
+ bos_token: str = "<|bos|>",
129
+ eos_token: str = "<|assistant_end|>",
130
+ pad_token: Optional[str] = None,
131
+ **kwargs,
132
+ ) -> None:
133
+ # Load encoding and build vocab mappings before parent init
134
+ with open(vocab_file, "rb") as f:
135
+ self.enc = pickle.load(f)
136
+ self.vocab_file = vocab_file
137
+
138
+ self.special_token_ids: Dict[str, int] = {
139
+ name: self.enc.encode_single_token(token)
140
+ for name, token in self._special_tokens.items()
141
+ }
142
+ self.bos_token_id = self.special_token_ids["bos"]
143
+ self.eos_token_id = self.special_token_ids["assistant_end"]
144
+ pad_token = pad_token or eos_token
145
+ self.pad_token_id = self.special_token_ids["assistant_end"]
146
+
147
+ self._build_vocabulary()
148
+
149
+ super().__init__(
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ pad_token=pad_token,
153
+ **kwargs,
154
+ )
155
+
156
+ additional_special_tokens = [
157
+ token
158
+ for key, token in self._special_tokens.items()
159
+ if token not in {bos_token, eos_token, pad_token}
160
+ ]
161
+ if additional_special_tokens:
162
+ self.add_special_tokens({"additional_special_tokens": additional_special_tokens})
163
+ self.chat_template = kwargs.get("chat_template", getattr(self, "chat_template", None))
164
+
165
+ # ------------------------------------------------------------------
166
+ # Core tokenizer API
167
+ # ------------------------------------------------------------------
168
+ def _build_vocabulary(self) -> None:
169
+ id_to_token: Dict[int, str] = {}
170
+ token_to_id: Dict[str, int] = {}
171
+ for idx in range(self.enc.n_vocab):
172
+ token_bytes = self.enc.decode_single_token_bytes(idx)
173
+ token_str = token_bytes.decode("utf-8", errors="replace")
174
+ id_to_token[idx] = token_str
175
+ token_to_id[token_str] = idx
176
+ self._id_to_token = id_to_token
177
+ self._token_to_id = token_to_id
178
+
179
+ def get_vocab(self) -> Dict[str, int]:
180
+ return dict(self._token_to_id)
181
+
182
+ @property
183
+ def vocab_size(self) -> int: # type: ignore[override]
184
+ return self.enc.n_vocab
185
+
186
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
187
+ ids = self.enc.encode_ordinary(text)
188
+ return [self._id_to_token[i] for i in ids]
189
+
190
+ def _convert_token_to_id(self, token: str) -> int:
191
+ if token in self._token_to_id:
192
+ return self._token_to_id[token]
193
+ raise KeyError(f"Token not found in vocabulary: {token}")
194
+
195
+ def _convert_id_to_token(self, index: int) -> str:
196
+ return self._id_to_token[index]
197
+
198
+ def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override]
199
+ ids = [self._token_to_id[token] for token in tokens]
200
+ return self.enc.decode(ids)
201
+
202
+ def build_inputs_with_special_tokens( # type: ignore[override]
203
+ self,
204
+ token_ids_0: List[int],
205
+ token_ids_1: Optional[List[int]] = None,
206
+ ) -> List[int]:
207
+ if token_ids_1 is not None:
208
+ return token_ids_0 + token_ids_1
209
+ return token_ids_0
210
+
211
+ def get_special_tokens_mask( # type: ignore[override]
212
+ self,
213
+ token_ids_0: List[int],
214
+ token_ids_1: Optional[List[int]] = None,
215
+ ) -> List[int]:
216
+ all_ids = token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
217
+ return [1 if token in self.special_token_ids else 0 for token in all_ids]
218
+
219
+ def num_special_tokens_to_add(self, pair: bool = False) -> int: # type: ignore[override]
220
+ return 0
221
+
222
+ def save_vocabulary(
223
+ self,
224
+ save_directory: str,
225
+ filename_prefix: Optional[str] = None,
226
+ ) -> Tuple[str]: # type: ignore[override]
227
+ os.makedirs(save_directory, exist_ok=True)
228
+ filename = "tokenizer.pkl"
229
+ if filename_prefix is not None:
230
+ filename = f"{filename_prefix}-{filename}"
231
+ save_path = os.path.join(save_directory, filename)
232
+ shutil.copyfile(self.vocab_file, save_path)
233
+ return (save_path,)
234
+
235
+ # ------------------------------------------------------------------
236
+ # Chat helpers
237
+ # ------------------------------------------------------------------
238
+ def encode_special(self, token: str) -> int:
239
+ if token in self.special_token_ids:
240
+ return self.special_token_ids[token]
241
+ return self._token_to_id[token]
242
+
243
+ def _encode_text(self, text: str) -> List[int]:
244
+ return self.enc.encode_ordinary(text)
245
+
246
+ def _encode_python_block(self, token_id: int, content: str) -> List[int]:
247
+ tokens = [token_id]
248
+ tokens.extend(self._encode_text(content))
249
+ closing = {
250
+ self.special_token_ids["python_start"]: self.special_token_ids["python_end"],
251
+ self.special_token_ids["output_start"]: self.special_token_ids["output_end"],
252
+ }[token_id]
253
+ tokens.append(closing)
254
+ return tokens
255
+
256
+ def _encode_assistant_content(self, content) -> List[int]:
257
+ if isinstance(content, str):
258
+ return self._encode_text(content)
259
+ if isinstance(content, list):
260
+ tokens: List[int] = []
261
+ for part in content:
262
+ part_type = part.get("type", "text")
263
+ text = part.get("text", "")
264
+ if part_type == "text":
265
+ tokens.extend(self._encode_text(text))
266
+ elif part_type == "python":
267
+ tokens.extend(
268
+ self._encode_python_block(
269
+ self.special_token_ids["python_start"],
270
+ text,
271
+ )
272
+ )
273
+ elif part_type == "python_output":
274
+ tokens.extend(
275
+ self._encode_python_block(
276
+ self.special_token_ids["output_start"],
277
+ text,
278
+ )
279
+ )
280
+ else:
281
+ raise ValueError(f"Unknown assistant content part: {part_type}")
282
+ return tokens
283
+ raise ValueError(f"Unsupported assistant content type: {type(content)}")
284
+
285
+ def _render_conversation_ids(self, conversation: Sequence[Dict[str, object]]) -> List[int]:
286
+ if not conversation:
287
+ raise ValueError("Conversation must contain at least one message")
288
+ messages = list(conversation)
289
+ if messages[0]["role"] == "system":
290
+ if len(messages) < 2 or messages[1]["role"] != "user":
291
+ raise ValueError("System message must be followed by a user message")
292
+ merged = dict(messages[1])
293
+ merged["content"] = f"{messages[0]['content']}\n\n{messages[1]['content']}"
294
+ messages = [merged] + messages[2:]
295
+ ids: List[int] = [self.bos_token_id]
296
+ for idx, message in enumerate(messages):
297
+ expected_role = "user" if idx % 2 == 0 else "assistant"
298
+ role = message.get("role")
299
+ if role != expected_role:
300
+ raise ValueError(f"Expected role {expected_role}, received {role} at index {idx}")
301
+ content = message.get("content")
302
+ if expected_role == "user":
303
+ start = self.special_token_ids["user_start"]
304
+ end = self.special_token_ids["user_end"]
305
+ if not isinstance(content, str):
306
+ raise ValueError("User messages must contain string content")
307
+ ids.append(start)
308
+ ids.extend(self._encode_text(content))
309
+ ids.append(end)
310
+ else:
311
+ start = self.special_token_ids["assistant_start"]
312
+ end = self.special_token_ids["assistant_end"]
313
+ ids.append(start)
314
+ ids.extend(self._encode_assistant_content(content))
315
+ ids.append(end)
316
+ return ids
317
+
318
+ def apply_chat_template( # type: ignore[override]
319
+ self,
320
+ conversation,
321
+ tokenize: bool = False,
322
+ add_generation_prompt: bool = False,
323
+ return_tensors: Optional[str] = None,
324
+ padding: bool = False,
325
+ truncation: bool = False,
326
+ max_length: Optional[int] = None,
327
+ **kwargs,
328
+ ):
329
+ if isinstance(conversation, dict) and "messages" in conversation:
330
+ messages = conversation["messages"]
331
+ else:
332
+ messages = conversation
333
+ token_ids = self._render_conversation_ids(messages)
334
+ if add_generation_prompt:
335
+ token_ids.append(self.special_token_ids["assistant_start"])
336
+ if tokenize:
337
+ if return_tensors is not None:
338
+ return self(
339
+ [token_ids],
340
+ add_special_tokens=False,
341
+ return_tensors=return_tensors,
342
+ padding=padding,
343
+ truncation=truncation,
344
+ max_length=max_length,
345
+ **kwargs,
346
+ )
347
+ return token_ids
348
+ return self.decode(token_ids, skip_special_tokens=False)
349
+
350
+ def encode_chat_message(self, role: str, content: str) -> List[int]:
351
+ rendered = self.apply_chat_template(
352
+ [
353
+ {"role": role, "content": content},
354
+ ],
355
+ tokenize=True,
356
+ add_generation_prompt=False,
357
+ )
358
+ return rendered
359
+
360
+
361
+
362
+