Upload folder using huggingface_hub
Browse files- __init__.py +5 -0
- config.json +22 -0
- configuration_nanogpt.py +34 -0
- d20/meta_000060.json +10 -0
- d20/meta_000120.json +10 -0
- d20/meta_000180.json +10 -0
- d20/meta_000240.json +10 -0
- d20/meta_000300.json +10 -0
- d20/meta_000360.json +10 -0
- d20/meta_000420.json +10 -0
- d20/meta_000466.json +10 -0
- d20/model_000060.pt +3 -0
- d20/model_000120.pt +3 -0
- d20/model_000180.pt +3 -0
- d20/model_000240.pt +3 -0
- d20/model_000300.pt +3 -0
- d20/model_000360.pt +3 -0
- d20/model_000420.pt +3 -0
- d20/model_000466.pt +3 -0
- meta_000650.json +20 -0
- model_000650.pt +3 -0
- modeling_nanogpt.py +386 -0
- pytorch_model.bin +3 -0
- token_bytes.pt +3 -0
- tokenizer.pkl +3 -0
- tokenizer_config.json +10 -0
- tokenizer_nanogpt.py +362 -0
__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration_nanogpt import NanoGPTConfig
|
2 |
+
from .modeling_nanogpt import NanoGPTModel, NanoGPTChat
|
3 |
+
from .tokenizer_nanogpt import NanoGPTTokenizer, NanoGPTChatTokenizer
|
4 |
+
|
5 |
+
|
config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "nanogpt",
|
3 |
+
"architectures": [
|
4 |
+
"NanoGPTChat"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_nanogpt.NanoGPTConfig",
|
8 |
+
"AutoModel": "modeling_nanogpt.NanoGPTChat",
|
9 |
+
"AutoModelForCausalLM": "modeling_nanogpt.NanoGPTChat",
|
10 |
+
"AutoTokenizer": "tokenizer_nanogpt.NanoGPTChatTokenizer"
|
11 |
+
},
|
12 |
+
"bos_token": "<|bos|>",
|
13 |
+
"eos_token": "<|assistant_end|>",
|
14 |
+
"pad_token": "<|assistant_end|>",
|
15 |
+
"sequence_len": 2048,
|
16 |
+
"vocab_size": 65536,
|
17 |
+
"n_layer": 20,
|
18 |
+
"n_head": 10,
|
19 |
+
"n_kv_head": 10,
|
20 |
+
"n_embd": 1280,
|
21 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}<|bos|><|user_start|>{{ messages[0]['content'] }}\n\n{{ messages[1]['content'] }}<|user_end|>{% set messages = messages[2:] %}{% else %}<|bos|>{% endif %}{% for message in messages %}{% if loop.index0 % 2 == 0 %}<|user_start|>{{ message['content'] }}<|user_end|>{% else %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}"
|
22 |
+
}
|
configuration_nanogpt.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class NanoGPTConfig(PretrainedConfig):
|
5 |
+
model_type = "nanogpt"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
sequence_len: int = 1024,
|
10 |
+
vocab_size: int = 50304,
|
11 |
+
n_layer: int = 12,
|
12 |
+
n_head: int = 6,
|
13 |
+
n_kv_head: int = 6,
|
14 |
+
n_embd: int = 768,
|
15 |
+
bos_token_id: int = 0,
|
16 |
+
eos_token_id: int = 1,
|
17 |
+
pad_token_id: int = 1,
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
self.sequence_len = sequence_len
|
21 |
+
self.vocab_size = vocab_size
|
22 |
+
self.n_layer = n_layer
|
23 |
+
self.n_head = n_head
|
24 |
+
self.n_kv_head = n_kv_head
|
25 |
+
self.n_embd = n_embd
|
26 |
+
super().__init__(
|
27 |
+
bos_token_id=bos_token_id,
|
28 |
+
eos_token_id=eos_token_id,
|
29 |
+
pad_token_id=pad_token_id,
|
30 |
+
**kwargs,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
|
d20/meta_000060.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000120.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000180.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000240.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000300.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000360.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000420.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/meta_000466.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_config": {
|
3 |
+
"sequence_len": 2048,
|
4 |
+
"vocab_size": 65536,
|
5 |
+
"n_layer": 20,
|
6 |
+
"n_head": 10,
|
7 |
+
"n_kv_head": 10,
|
8 |
+
"n_embd": 1280
|
9 |
+
}
|
10 |
+
}
|
d20/model_000060.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba1199debccab5e267bceaf87e4dfc0ecc479ae920f8867b4700bbbd52200bd6
|
3 |
+
size 2076230219
|
d20/model_000120.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:809d694a727c414d173dca22ee02333b2eb2fee522fe0d1dabec21518224e2cc
|
3 |
+
size 2076230219
|
d20/model_000180.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:703f5049b2e3804a3e2cc55a3af444820ada278a42771d4ef5e679da80fa8a88
|
3 |
+
size 2076230219
|
d20/model_000240.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e58a748174610eaa914aaea9372a47f5c501af757bd7c306d28ac795539d7a68
|
3 |
+
size 2076230219
|
d20/model_000300.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a03c2d2143f33c412e502a75dad2e505da01d04aa61ee5d4bab2c2c6a99669d
|
3 |
+
size 2076230219
|
d20/model_000360.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dd3c35b34043ef1571ef81ca644d41e3e4f8aa722fffe47f4dd6a9eee9a5684
|
3 |
+
size 2076230219
|
d20/model_000420.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03a9c4557759a64c5ce1322107b3d57a295092d6ee62f994d1aec61fcb08d4e7
|
3 |
+
size 2076230219
|
d20/model_000466.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a7abb8f892b7aa004f3a54ac54988871e12d099996c371806708f5e9a0bea3c
|
3 |
+
size 2076230219
|
meta_000650.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"step": 650,
|
3 |
+
"val_loss": 1.0664211511611938,
|
4 |
+
"mmlu_acc": 0.3623046875,
|
5 |
+
"arc_easy_acc": 0.419921875,
|
6 |
+
"gsm8k_acc": 0.03125,
|
7 |
+
"humaneval_acc": 0.03125,
|
8 |
+
"model_config": {
|
9 |
+
"sequence_len": 2048,
|
10 |
+
"vocab_size": 65536,
|
11 |
+
"n_layer": 20,
|
12 |
+
"n_head": 10,
|
13 |
+
"n_kv_head": 10,
|
14 |
+
"n_embd": 1280,
|
15 |
+
"bos_token": "<|bos|>",
|
16 |
+
"eos_token": "<|assistant_end|>",
|
17 |
+
"pad_token": "<|assistant_end|>",
|
18 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}<|bos|><|user_start|>{{ messages[0]['content'] }}\n\n{{ messages[1]['content'] }}<|user_end|>{% set messages = messages[2:] %}{% else %}<|bos|>{% endif %}{% for message in messages %}{% if loop.index0 % 2 == 0 %}<|user_start|>{{ message['content'] }}<|user_end|>{% else %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}"
|
19 |
+
}
|
20 |
+
}
|
model_000650.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff2eee182e2aa396615d3b481ddf17884a7dbabb3caa47f66eede343135accff
|
3 |
+
size 2076230219
|
modeling_nanogpt.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from typing import Dict, List, Optional, Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
from transformers import PreTrainedModel
|
12 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
13 |
+
|
14 |
+
from .configuration_nanogpt import NanoGPTConfig
|
15 |
+
|
16 |
+
|
17 |
+
def _rms_norm(x: torch.Tensor) -> torch.Tensor:
|
18 |
+
return F.rms_norm(x, (x.size(-1),))
|
19 |
+
|
20 |
+
|
21 |
+
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
22 |
+
assert x.ndim == 4
|
23 |
+
d = x.shape[3] // 2
|
24 |
+
x1, x2 = x[..., :d], x[..., d:]
|
25 |
+
y1 = x1 * cos + x2 * sin
|
26 |
+
y2 = x1 * (-sin) + x2 * cos
|
27 |
+
out = torch.cat([y1, y2], 3)
|
28 |
+
return out.to(x.dtype)
|
29 |
+
|
30 |
+
|
31 |
+
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
32 |
+
if n_rep == 1:
|
33 |
+
return x
|
34 |
+
bs, n_kv_heads, slen, head_dim = x.shape
|
35 |
+
return (
|
36 |
+
x[:, :, None, :, :]
|
37 |
+
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
38 |
+
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
class CausalSelfAttention(nn.Module):
|
43 |
+
def __init__(self, config: NanoGPTConfig, layer_idx: int):
|
44 |
+
super().__init__()
|
45 |
+
self.layer_idx = layer_idx
|
46 |
+
self.n_head = config.n_head
|
47 |
+
self.n_kv_head = config.n_kv_head
|
48 |
+
self.n_embd = config.n_embd
|
49 |
+
self.head_dim = self.n_embd // self.n_head
|
50 |
+
assert self.n_embd % self.n_head == 0
|
51 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
52 |
+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
53 |
+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
54 |
+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
55 |
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
|
58 |
+
B, T, C = x.size()
|
59 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
60 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
61 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
62 |
+
cos, sin = cos_sin
|
63 |
+
q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
|
64 |
+
q, k = _rms_norm(q), _rms_norm(k)
|
65 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
66 |
+
Tq = q.size(2)
|
67 |
+
Tk = k.size(2)
|
68 |
+
nrep = self.n_head // self.n_kv_head
|
69 |
+
k, v = _repeat_kv(k, nrep), _repeat_kv(v, nrep)
|
70 |
+
if Tq == Tk:
|
71 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
72 |
+
elif Tq == 1:
|
73 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
74 |
+
else:
|
75 |
+
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
|
76 |
+
prefix_len = Tk - Tq
|
77 |
+
if prefix_len > 0:
|
78 |
+
attn_mask[:, :prefix_len] = True
|
79 |
+
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
80 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
81 |
+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
82 |
+
y = self.c_proj(y)
|
83 |
+
return y
|
84 |
+
|
85 |
+
def forward_with_cache(
|
86 |
+
self,
|
87 |
+
x: torch.Tensor,
|
88 |
+
cos_sin,
|
89 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
91 |
+
use_cache: bool = False,
|
92 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
93 |
+
B, T, _ = x.size()
|
94 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
95 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
96 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
97 |
+
cos, sin = cos_sin
|
98 |
+
q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
|
99 |
+
q, k = _rms_norm(q), _rms_norm(k)
|
100 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
101 |
+
|
102 |
+
if past_key_value is not None:
|
103 |
+
past_k, past_v = past_key_value
|
104 |
+
if past_k is not None and past_v is not None:
|
105 |
+
k = torch.cat([past_k, k], dim=2)
|
106 |
+
v = torch.cat([past_v, v], dim=2)
|
107 |
+
|
108 |
+
present = (k, v) if use_cache else None
|
109 |
+
|
110 |
+
Tq = q.size(2)
|
111 |
+
Tk = k.size(2)
|
112 |
+
nrep = self.n_head // self.n_kv_head
|
113 |
+
k_rep = _repeat_kv(k, nrep)
|
114 |
+
v_rep = _repeat_kv(v, nrep)
|
115 |
+
|
116 |
+
attn_mask = None
|
117 |
+
if attention_mask is not None:
|
118 |
+
attn_mask = attention_mask.to(dtype=torch.bool, device=q.device)
|
119 |
+
if attn_mask.dim() == 2:
|
120 |
+
attn_mask = attn_mask[:, None, None, :]
|
121 |
+
elif attn_mask.dim() == 4:
|
122 |
+
pass
|
123 |
+
else:
|
124 |
+
raise ValueError("Unsupported attention_mask dimensions")
|
125 |
+
if attn_mask.size(-1) != Tk:
|
126 |
+
attn_mask = torch.nn.functional.pad(attn_mask, (Tk - attn_mask.size(-1), 0))
|
127 |
+
attn_mask = (~attn_mask).to(dtype=q.dtype) * -1e4
|
128 |
+
|
129 |
+
if Tq == Tk:
|
130 |
+
y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=True)
|
131 |
+
else:
|
132 |
+
y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=False)
|
133 |
+
|
134 |
+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
135 |
+
y = self.c_proj(y)
|
136 |
+
return y, present
|
137 |
+
|
138 |
+
|
139 |
+
class MLP(nn.Module):
|
140 |
+
def __init__(self, config: NanoGPTConfig):
|
141 |
+
super().__init__()
|
142 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
143 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
144 |
+
|
145 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
146 |
+
x = self.c_fc(x)
|
147 |
+
x = F.relu(x).square()
|
148 |
+
x = self.c_proj(x)
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class Block(nn.Module):
|
153 |
+
def __init__(self, config: NanoGPTConfig, layer_idx: int):
|
154 |
+
super().__init__()
|
155 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
156 |
+
self.mlp = MLP(config)
|
157 |
+
|
158 |
+
def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
|
159 |
+
x = x + self.attn(_rms_norm(x), cos_sin, kv_cache)
|
160 |
+
x = x + self.mlp(_rms_norm(x))
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
class NanoGPTModel(PreTrainedModel):
|
165 |
+
config_class = NanoGPTConfig
|
166 |
+
|
167 |
+
_CANONICAL_WEIGHT_NAMES = (
|
168 |
+
"pytorch_model.bin",
|
169 |
+
"model.safetensors",
|
170 |
+
"model.ckpt.index",
|
171 |
+
"tf_model.h5",
|
172 |
+
"flax_model.msgpack",
|
173 |
+
)
|
174 |
+
_PT_PATTERN = "model_*.pt"
|
175 |
+
|
176 |
+
@classmethod
|
177 |
+
def _snapshot_kwargs(cls, source_kwargs: Dict) -> Dict:
|
178 |
+
keys = {
|
179 |
+
"cache_dir",
|
180 |
+
"force_download",
|
181 |
+
"local_files_only",
|
182 |
+
"proxies",
|
183 |
+
"resume_download",
|
184 |
+
"revision",
|
185 |
+
"token",
|
186 |
+
"use_auth_token",
|
187 |
+
}
|
188 |
+
return {k: source_kwargs[k] for k in keys if k in source_kwargs}
|
189 |
+
|
190 |
+
@classmethod
|
191 |
+
def _resolve_checkpoint_dir(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
|
192 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
193 |
+
base_dir = pretrained_model_name_or_path
|
194 |
+
else:
|
195 |
+
snapshot_params = cls._snapshot_kwargs(kwargs)
|
196 |
+
token = snapshot_params.pop("token", None)
|
197 |
+
if token is None:
|
198 |
+
token = snapshot_params.pop("use_auth_token", None)
|
199 |
+
if token is not None:
|
200 |
+
snapshot_params["token"] = token
|
201 |
+
base_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_params)
|
202 |
+
if subfolder:
|
203 |
+
base_dir = os.path.join(base_dir, subfolder)
|
204 |
+
cls._ensure_canonical_weights(base_dir)
|
205 |
+
return base_dir
|
206 |
+
|
207 |
+
@classmethod
|
208 |
+
def _ensure_canonical_weights(cls, checkpoint_dir):
|
209 |
+
for name in cls._CANONICAL_WEIGHT_NAMES:
|
210 |
+
candidate = os.path.join(checkpoint_dir, name)
|
211 |
+
if os.path.isfile(candidate):
|
212 |
+
return candidate
|
213 |
+
pt_candidates = sorted(
|
214 |
+
glob.glob(os.path.join(checkpoint_dir, cls._PT_PATTERN)),
|
215 |
+
reverse=True,
|
216 |
+
)
|
217 |
+
if not pt_candidates:
|
218 |
+
raise FileNotFoundError(
|
219 |
+
f"No checkpoint weights found in {checkpoint_dir}. Expected one of {cls._CANONICAL_WEIGHT_NAMES} "
|
220 |
+
f"or files matching {cls._PT_PATTERN}."
|
221 |
+
)
|
222 |
+
source_path = pt_candidates[0]
|
223 |
+
target_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
|
224 |
+
if (
|
225 |
+
not os.path.isfile(target_path)
|
226 |
+
or os.path.getmtime(source_path) > os.path.getmtime(target_path)
|
227 |
+
):
|
228 |
+
shutil.copyfile(source_path, target_path)
|
229 |
+
return target_path
|
230 |
+
|
231 |
+
def __init__(self, config: NanoGPTConfig):
|
232 |
+
super().__init__(config)
|
233 |
+
config.use_cache = getattr(config, "use_cache", True)
|
234 |
+
config.num_hidden_layers = config.n_layer
|
235 |
+
config.num_attention_heads = config.n_head
|
236 |
+
config.hidden_size = config.n_embd
|
237 |
+
self.transformer = nn.ModuleDict({
|
238 |
+
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
239 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
240 |
+
})
|
241 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
242 |
+
self.rotary_seq_len = config.sequence_len * 10
|
243 |
+
head_dim = config.n_embd // config.n_head
|
244 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
245 |
+
self.register_buffer("cos", cos, persistent=False)
|
246 |
+
self.register_buffer("sin", sin, persistent=False)
|
247 |
+
# ensure fp32 activations
|
248 |
+
self.transformer.wte.to(dtype=torch.bfloat16)
|
249 |
+
|
250 |
+
# following HF API expectations
|
251 |
+
self.post_init()
|
252 |
+
|
253 |
+
def _init_weights(self, module: nn.Module):
|
254 |
+
if isinstance(module, nn.Linear):
|
255 |
+
fan_out = module.weight.size(0)
|
256 |
+
fan_in = module.weight.size(1)
|
257 |
+
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
258 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
259 |
+
if module.bias is not None:
|
260 |
+
torch.nn.init.zeros_(module.bias)
|
261 |
+
elif isinstance(module, nn.Embedding):
|
262 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
263 |
+
|
264 |
+
def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None):
|
265 |
+
if device is None:
|
266 |
+
device = self.transformer.wte.weight.device
|
267 |
+
# Handle meta device case - use CPU as fallback
|
268 |
+
if device.type == 'meta':
|
269 |
+
device = torch.device('cpu')
|
270 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
271 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
272 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
273 |
+
freqs = torch.outer(t, inv_freq)
|
274 |
+
cos, sin = freqs.cos(), freqs.sin()
|
275 |
+
cos, sin = cos.bfloat16(), sin.bfloat16()
|
276 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
|
277 |
+
return cos, sin
|
278 |
+
|
279 |
+
def _apply_softcap(self, logits: torch.Tensor) -> torch.Tensor:
|
280 |
+
softcap = 15
|
281 |
+
return softcap * torch.tanh(logits / softcap)
|
282 |
+
|
283 |
+
def _forward_impl(self, idx: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
|
284 |
+
x = self.transformer.wte(idx)
|
285 |
+
x = x.float()
|
286 |
+
x = _rms_norm(x)
|
287 |
+
for block in self.transformer.h:
|
288 |
+
x = block(x, cos_sin, kv_cache)
|
289 |
+
x = _rms_norm(x)
|
290 |
+
logits = self.lm_head(x)
|
291 |
+
return self._apply_softcap(logits)
|
292 |
+
|
293 |
+
def forward(self, input_ids: torch.Tensor, labels=None, loss_reduction: str = 'mean', **kwargs):
|
294 |
+
idx = input_ids
|
295 |
+
B, T = idx.size()
|
296 |
+
T0 = 0
|
297 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
|
298 |
+
logits = self._forward_impl(idx, cos_sin, kv_cache=None)
|
299 |
+
loss = None
|
300 |
+
if labels is not None:
|
301 |
+
loss = F.cross_entropy(
|
302 |
+
logits.view(-1, logits.size(-1)),
|
303 |
+
labels.view(-1),
|
304 |
+
ignore_index=-1,
|
305 |
+
reduction=loss_reduction,
|
306 |
+
)
|
307 |
+
return {"loss": loss, "logits": logits}
|
308 |
+
|
309 |
+
|
310 |
+
class NanoGPTChat(NanoGPTModel):
|
311 |
+
"""Chat-optimized variant with HF-friendly generate and support for KV cache."""
|
312 |
+
|
313 |
+
def __init__(self, config: NanoGPTConfig):
|
314 |
+
super().__init__(config)
|
315 |
+
self.use_cache = getattr(config, "use_cache", True)
|
316 |
+
|
317 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
318 |
+
if past_key_values is not None:
|
319 |
+
input_ids = input_ids[:, -1:]
|
320 |
+
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
|
321 |
+
|
322 |
+
def _expand_past_length(self, past_key_values):
|
323 |
+
if not past_key_values:
|
324 |
+
return 0
|
325 |
+
past_k, _ = past_key_values[0]
|
326 |
+
if past_k is None:
|
327 |
+
return 0
|
328 |
+
return past_k.size(2)
|
329 |
+
|
330 |
+
def forward(
|
331 |
+
self,
|
332 |
+
input_ids: torch.Tensor,
|
333 |
+
attention_mask: Optional[torch.Tensor] = None,
|
334 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
335 |
+
use_cache: Optional[bool] = None,
|
336 |
+
labels: Optional[torch.Tensor] = None,
|
337 |
+
loss_reduction: str = "mean",
|
338 |
+
**kwargs,
|
339 |
+
) -> CausalLMOutputWithPast:
|
340 |
+
idx = input_ids
|
341 |
+
B, T = idx.size()
|
342 |
+
use_cache = self.use_cache if use_cache is None else use_cache
|
343 |
+
past_length = self._expand_past_length(past_key_values)
|
344 |
+
T0 = past_length
|
345 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
|
346 |
+
|
347 |
+
x = self.transformer.wte(idx)
|
348 |
+
x = x.float()
|
349 |
+
x = _rms_norm(x)
|
350 |
+
|
351 |
+
presents = [] if use_cache else None
|
352 |
+
for layer_idx, block in enumerate(self.transformer.h):
|
353 |
+
past = None
|
354 |
+
if past_key_values is not None and past_key_values[layer_idx] is not None:
|
355 |
+
past = past_key_values[layer_idx]
|
356 |
+
attn_output, present = block.attn.forward_with_cache(
|
357 |
+
_rms_norm(x),
|
358 |
+
cos_sin,
|
359 |
+
past_key_value=past,
|
360 |
+
attention_mask=attention_mask,
|
361 |
+
use_cache=use_cache,
|
362 |
+
)
|
363 |
+
x = x + attn_output
|
364 |
+
x = x + block.mlp(_rms_norm(x))
|
365 |
+
if use_cache:
|
366 |
+
presents.append(present)
|
367 |
+
|
368 |
+
x = _rms_norm(x)
|
369 |
+
logits = self.lm_head(x)
|
370 |
+
loss = None
|
371 |
+
if labels is not None:
|
372 |
+
loss = F.cross_entropy(
|
373 |
+
logits.view(-1, logits.size(-1)),
|
374 |
+
labels.view(-1),
|
375 |
+
ignore_index=-1,
|
376 |
+
reduction=loss_reduction,
|
377 |
+
)
|
378 |
+
|
379 |
+
return CausalLMOutputWithPast(
|
380 |
+
loss=loss,
|
381 |
+
logits=logits,
|
382 |
+
past_key_values=presents,
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff2eee182e2aa396615d3b481ddf17884a7dbabb3caa47f66eede343135accff
|
3 |
+
size 2076230219
|
token_bytes.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1b6cdee5d02fe1018b2b1d2ae5b736be665f9c0e7d10c81dcf935e7efaf8cb5
|
3 |
+
size 263721
|
tokenizer.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8467414b90511a50c4dac438af25c075817e9d62d799a5ef613b186c977f5d1b
|
3 |
+
size 846518
|
tokenizer_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoTokenizer": [
|
4 |
+
"tokenizer_nanogpt.NanoGPTChatTokenizer",
|
5 |
+
null
|
6 |
+
]
|
7 |
+
},
|
8 |
+
"tokenizer_class": "NanoGPTChatTokenizer",
|
9 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}<|bos|><|user_start|>{{ messages[0]['content'] }}\n\n{{ messages[1]['content'] }}<|user_end|>{% set messages = messages[2:] %}{% else %}<|bos|>{% endif %}{% for message in messages %}{% if loop.index0 % 2 == 0 %}<|user_start|>{{ message['content'] }}<|user_end|>{% else %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant_start|>{% endif %}"
|
10 |
+
}
|
tokenizer_nanogpt.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import shutil
|
4 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
5 |
+
|
6 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
7 |
+
from huggingface_hub.utils import HfHubHTTPError
|
8 |
+
from transformers import PreTrainedTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
class _BaseNanoGPTTokenizer:
|
12 |
+
"""Lightweight wrapper used by the base (non-chat) checkpoints."""
|
13 |
+
|
14 |
+
special_tokens = {
|
15 |
+
"bos": "<|bos|>",
|
16 |
+
"user_start": "<|user_start|>",
|
17 |
+
"user_end": "<|user_end|>",
|
18 |
+
"assistant_start": "<|assistant_start|>",
|
19 |
+
"assistant_end": "<|assistant_end|>",
|
20 |
+
"python_start": "<|python_start|>",
|
21 |
+
"python_end": "<|python_end|>",
|
22 |
+
"output_start": "<|output_start|>",
|
23 |
+
"output_end": "<|output_end|>",
|
24 |
+
}
|
25 |
+
|
26 |
+
def __init__(self, enc):
|
27 |
+
self.enc = enc
|
28 |
+
self.bos_token_id = enc.encode_single_token(self.special_tokens["bos"])
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
|
32 |
+
pass
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def _load_encoding(cls, pretrained_model_name_or_path, **kwargs):
|
36 |
+
subfolder = kwargs.get("subfolder")
|
37 |
+
base_path = (
|
38 |
+
os.path.join(pretrained_model_name_or_path, subfolder)
|
39 |
+
if subfolder
|
40 |
+
else pretrained_model_name_or_path
|
41 |
+
)
|
42 |
+
local_tok_path = os.path.join(base_path, "tokenizer.pkl")
|
43 |
+
if os.path.isfile(local_tok_path):
|
44 |
+
with open(local_tok_path, "rb") as f:
|
45 |
+
return pickle.load(f)
|
46 |
+
|
47 |
+
snapshot_kwargs = {k: kwargs[k] for k in kwargs if k in {
|
48 |
+
"cache_dir",
|
49 |
+
"force_download",
|
50 |
+
"local_files_only",
|
51 |
+
"proxies",
|
52 |
+
"resume_download",
|
53 |
+
"revision",
|
54 |
+
"token",
|
55 |
+
"use_auth_token",
|
56 |
+
}}
|
57 |
+
token = snapshot_kwargs.pop("token", None)
|
58 |
+
if token is None:
|
59 |
+
token = snapshot_kwargs.pop("use_auth_token", None)
|
60 |
+
if token is not None:
|
61 |
+
snapshot_kwargs["token"] = token
|
62 |
+
|
63 |
+
snapshot_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_kwargs)
|
64 |
+
tok_path = os.path.join(snapshot_dir, subfolder, "tokenizer.pkl") if subfolder else os.path.join(snapshot_dir, "tokenizer.pkl")
|
65 |
+
if not os.path.isfile(tok_path):
|
66 |
+
try:
|
67 |
+
tok_path = hf_hub_download(
|
68 |
+
repo_id=pretrained_model_name_or_path,
|
69 |
+
filename="tokenizer.pkl",
|
70 |
+
subfolder=subfolder,
|
71 |
+
**snapshot_kwargs,
|
72 |
+
)
|
73 |
+
except (HfHubHTTPError, OSError) as e:
|
74 |
+
raise ValueError(
|
75 |
+
f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. "
|
76 |
+
f"Make sure the path exists or the repo is accessible on the Hub."
|
77 |
+
) from e
|
78 |
+
with open(tok_path, "rb") as f:
|
79 |
+
return pickle.load(f)
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
83 |
+
enc = cls._load_encoding(pretrained_model_name_or_path, **kwargs)
|
84 |
+
return cls(enc)
|
85 |
+
|
86 |
+
def encode(self, text, prepend=None):
|
87 |
+
ids = self.enc.encode_ordinary(text)
|
88 |
+
if prepend is not None:
|
89 |
+
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
|
90 |
+
ids.insert(0, prepend_id)
|
91 |
+
return ids
|
92 |
+
|
93 |
+
def decode(self, ids):
|
94 |
+
return self.enc.decode(ids)
|
95 |
+
|
96 |
+
def get_bos_token_id(self):
|
97 |
+
return self.bos_token_id
|
98 |
+
|
99 |
+
def encode_special(self, token):
|
100 |
+
return self.enc.encode_single_token(token)
|
101 |
+
|
102 |
+
|
103 |
+
class NanoGPTTokenizer(_BaseNanoGPTTokenizer):
|
104 |
+
pass
|
105 |
+
|
106 |
+
|
107 |
+
class NanoGPTChatTokenizer(PreTrainedTokenizer):
|
108 |
+
"""Transformers-compatible tokenizer with chat helpers."""
|
109 |
+
|
110 |
+
vocab_files_names = {"vocab_file": "tokenizer.pkl"}
|
111 |
+
model_input_names = ["input_ids"]
|
112 |
+
|
113 |
+
_special_tokens = {
|
114 |
+
"bos": "<|bos|>",
|
115 |
+
"user_start": "<|user_start|>",
|
116 |
+
"user_end": "<|user_end|>",
|
117 |
+
"assistant_start": "<|assistant_start|>",
|
118 |
+
"assistant_end": "<|assistant_end|>",
|
119 |
+
"python_start": "<|python_start|>",
|
120 |
+
"python_end": "<|python_end|>",
|
121 |
+
"output_start": "<|output_start|>",
|
122 |
+
"output_end": "<|output_end|>",
|
123 |
+
}
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
vocab_file: str,
|
128 |
+
bos_token: str = "<|bos|>",
|
129 |
+
eos_token: str = "<|assistant_end|>",
|
130 |
+
pad_token: Optional[str] = None,
|
131 |
+
**kwargs,
|
132 |
+
) -> None:
|
133 |
+
# Load encoding and build vocab mappings before parent init
|
134 |
+
with open(vocab_file, "rb") as f:
|
135 |
+
self.enc = pickle.load(f)
|
136 |
+
self.vocab_file = vocab_file
|
137 |
+
|
138 |
+
self.special_token_ids: Dict[str, int] = {
|
139 |
+
name: self.enc.encode_single_token(token)
|
140 |
+
for name, token in self._special_tokens.items()
|
141 |
+
}
|
142 |
+
self.bos_token_id = self.special_token_ids["bos"]
|
143 |
+
self.eos_token_id = self.special_token_ids["assistant_end"]
|
144 |
+
pad_token = pad_token or eos_token
|
145 |
+
self.pad_token_id = self.special_token_ids["assistant_end"]
|
146 |
+
|
147 |
+
self._build_vocabulary()
|
148 |
+
|
149 |
+
super().__init__(
|
150 |
+
bos_token=bos_token,
|
151 |
+
eos_token=eos_token,
|
152 |
+
pad_token=pad_token,
|
153 |
+
**kwargs,
|
154 |
+
)
|
155 |
+
|
156 |
+
additional_special_tokens = [
|
157 |
+
token
|
158 |
+
for key, token in self._special_tokens.items()
|
159 |
+
if token not in {bos_token, eos_token, pad_token}
|
160 |
+
]
|
161 |
+
if additional_special_tokens:
|
162 |
+
self.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
163 |
+
self.chat_template = kwargs.get("chat_template", getattr(self, "chat_template", None))
|
164 |
+
|
165 |
+
# ------------------------------------------------------------------
|
166 |
+
# Core tokenizer API
|
167 |
+
# ------------------------------------------------------------------
|
168 |
+
def _build_vocabulary(self) -> None:
|
169 |
+
id_to_token: Dict[int, str] = {}
|
170 |
+
token_to_id: Dict[str, int] = {}
|
171 |
+
for idx in range(self.enc.n_vocab):
|
172 |
+
token_bytes = self.enc.decode_single_token_bytes(idx)
|
173 |
+
token_str = token_bytes.decode("utf-8", errors="replace")
|
174 |
+
id_to_token[idx] = token_str
|
175 |
+
token_to_id[token_str] = idx
|
176 |
+
self._id_to_token = id_to_token
|
177 |
+
self._token_to_id = token_to_id
|
178 |
+
|
179 |
+
def get_vocab(self) -> Dict[str, int]:
|
180 |
+
return dict(self._token_to_id)
|
181 |
+
|
182 |
+
@property
|
183 |
+
def vocab_size(self) -> int: # type: ignore[override]
|
184 |
+
return self.enc.n_vocab
|
185 |
+
|
186 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
187 |
+
ids = self.enc.encode_ordinary(text)
|
188 |
+
return [self._id_to_token[i] for i in ids]
|
189 |
+
|
190 |
+
def _convert_token_to_id(self, token: str) -> int:
|
191 |
+
if token in self._token_to_id:
|
192 |
+
return self._token_to_id[token]
|
193 |
+
raise KeyError(f"Token not found in vocabulary: {token}")
|
194 |
+
|
195 |
+
def _convert_id_to_token(self, index: int) -> str:
|
196 |
+
return self._id_to_token[index]
|
197 |
+
|
198 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override]
|
199 |
+
ids = [self._token_to_id[token] for token in tokens]
|
200 |
+
return self.enc.decode(ids)
|
201 |
+
|
202 |
+
def build_inputs_with_special_tokens( # type: ignore[override]
|
203 |
+
self,
|
204 |
+
token_ids_0: List[int],
|
205 |
+
token_ids_1: Optional[List[int]] = None,
|
206 |
+
) -> List[int]:
|
207 |
+
if token_ids_1 is not None:
|
208 |
+
return token_ids_0 + token_ids_1
|
209 |
+
return token_ids_0
|
210 |
+
|
211 |
+
def get_special_tokens_mask( # type: ignore[override]
|
212 |
+
self,
|
213 |
+
token_ids_0: List[int],
|
214 |
+
token_ids_1: Optional[List[int]] = None,
|
215 |
+
) -> List[int]:
|
216 |
+
all_ids = token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
|
217 |
+
return [1 if token in self.special_token_ids else 0 for token in all_ids]
|
218 |
+
|
219 |
+
def num_special_tokens_to_add(self, pair: bool = False) -> int: # type: ignore[override]
|
220 |
+
return 0
|
221 |
+
|
222 |
+
def save_vocabulary(
|
223 |
+
self,
|
224 |
+
save_directory: str,
|
225 |
+
filename_prefix: Optional[str] = None,
|
226 |
+
) -> Tuple[str]: # type: ignore[override]
|
227 |
+
os.makedirs(save_directory, exist_ok=True)
|
228 |
+
filename = "tokenizer.pkl"
|
229 |
+
if filename_prefix is not None:
|
230 |
+
filename = f"{filename_prefix}-{filename}"
|
231 |
+
save_path = os.path.join(save_directory, filename)
|
232 |
+
shutil.copyfile(self.vocab_file, save_path)
|
233 |
+
return (save_path,)
|
234 |
+
|
235 |
+
# ------------------------------------------------------------------
|
236 |
+
# Chat helpers
|
237 |
+
# ------------------------------------------------------------------
|
238 |
+
def encode_special(self, token: str) -> int:
|
239 |
+
if token in self.special_token_ids:
|
240 |
+
return self.special_token_ids[token]
|
241 |
+
return self._token_to_id[token]
|
242 |
+
|
243 |
+
def _encode_text(self, text: str) -> List[int]:
|
244 |
+
return self.enc.encode_ordinary(text)
|
245 |
+
|
246 |
+
def _encode_python_block(self, token_id: int, content: str) -> List[int]:
|
247 |
+
tokens = [token_id]
|
248 |
+
tokens.extend(self._encode_text(content))
|
249 |
+
closing = {
|
250 |
+
self.special_token_ids["python_start"]: self.special_token_ids["python_end"],
|
251 |
+
self.special_token_ids["output_start"]: self.special_token_ids["output_end"],
|
252 |
+
}[token_id]
|
253 |
+
tokens.append(closing)
|
254 |
+
return tokens
|
255 |
+
|
256 |
+
def _encode_assistant_content(self, content) -> List[int]:
|
257 |
+
if isinstance(content, str):
|
258 |
+
return self._encode_text(content)
|
259 |
+
if isinstance(content, list):
|
260 |
+
tokens: List[int] = []
|
261 |
+
for part in content:
|
262 |
+
part_type = part.get("type", "text")
|
263 |
+
text = part.get("text", "")
|
264 |
+
if part_type == "text":
|
265 |
+
tokens.extend(self._encode_text(text))
|
266 |
+
elif part_type == "python":
|
267 |
+
tokens.extend(
|
268 |
+
self._encode_python_block(
|
269 |
+
self.special_token_ids["python_start"],
|
270 |
+
text,
|
271 |
+
)
|
272 |
+
)
|
273 |
+
elif part_type == "python_output":
|
274 |
+
tokens.extend(
|
275 |
+
self._encode_python_block(
|
276 |
+
self.special_token_ids["output_start"],
|
277 |
+
text,
|
278 |
+
)
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
raise ValueError(f"Unknown assistant content part: {part_type}")
|
282 |
+
return tokens
|
283 |
+
raise ValueError(f"Unsupported assistant content type: {type(content)}")
|
284 |
+
|
285 |
+
def _render_conversation_ids(self, conversation: Sequence[Dict[str, object]]) -> List[int]:
|
286 |
+
if not conversation:
|
287 |
+
raise ValueError("Conversation must contain at least one message")
|
288 |
+
messages = list(conversation)
|
289 |
+
if messages[0]["role"] == "system":
|
290 |
+
if len(messages) < 2 or messages[1]["role"] != "user":
|
291 |
+
raise ValueError("System message must be followed by a user message")
|
292 |
+
merged = dict(messages[1])
|
293 |
+
merged["content"] = f"{messages[0]['content']}\n\n{messages[1]['content']}"
|
294 |
+
messages = [merged] + messages[2:]
|
295 |
+
ids: List[int] = [self.bos_token_id]
|
296 |
+
for idx, message in enumerate(messages):
|
297 |
+
expected_role = "user" if idx % 2 == 0 else "assistant"
|
298 |
+
role = message.get("role")
|
299 |
+
if role != expected_role:
|
300 |
+
raise ValueError(f"Expected role {expected_role}, received {role} at index {idx}")
|
301 |
+
content = message.get("content")
|
302 |
+
if expected_role == "user":
|
303 |
+
start = self.special_token_ids["user_start"]
|
304 |
+
end = self.special_token_ids["user_end"]
|
305 |
+
if not isinstance(content, str):
|
306 |
+
raise ValueError("User messages must contain string content")
|
307 |
+
ids.append(start)
|
308 |
+
ids.extend(self._encode_text(content))
|
309 |
+
ids.append(end)
|
310 |
+
else:
|
311 |
+
start = self.special_token_ids["assistant_start"]
|
312 |
+
end = self.special_token_ids["assistant_end"]
|
313 |
+
ids.append(start)
|
314 |
+
ids.extend(self._encode_assistant_content(content))
|
315 |
+
ids.append(end)
|
316 |
+
return ids
|
317 |
+
|
318 |
+
def apply_chat_template( # type: ignore[override]
|
319 |
+
self,
|
320 |
+
conversation,
|
321 |
+
tokenize: bool = False,
|
322 |
+
add_generation_prompt: bool = False,
|
323 |
+
return_tensors: Optional[str] = None,
|
324 |
+
padding: bool = False,
|
325 |
+
truncation: bool = False,
|
326 |
+
max_length: Optional[int] = None,
|
327 |
+
**kwargs,
|
328 |
+
):
|
329 |
+
if isinstance(conversation, dict) and "messages" in conversation:
|
330 |
+
messages = conversation["messages"]
|
331 |
+
else:
|
332 |
+
messages = conversation
|
333 |
+
token_ids = self._render_conversation_ids(messages)
|
334 |
+
if add_generation_prompt:
|
335 |
+
token_ids.append(self.special_token_ids["assistant_start"])
|
336 |
+
if tokenize:
|
337 |
+
if return_tensors is not None:
|
338 |
+
return self(
|
339 |
+
[token_ids],
|
340 |
+
add_special_tokens=False,
|
341 |
+
return_tensors=return_tensors,
|
342 |
+
padding=padding,
|
343 |
+
truncation=truncation,
|
344 |
+
max_length=max_length,
|
345 |
+
**kwargs,
|
346 |
+
)
|
347 |
+
return token_ids
|
348 |
+
return self.decode(token_ids, skip_special_tokens=False)
|
349 |
+
|
350 |
+
def encode_chat_message(self, role: str, content: str) -> List[int]:
|
351 |
+
rendered = self.apply_chat_template(
|
352 |
+
[
|
353 |
+
{"role": role, "content": content},
|
354 |
+
],
|
355 |
+
tokenize=True,
|
356 |
+
add_generation_prompt=False,
|
357 |
+
)
|
358 |
+
return rendered
|
359 |
+
|
360 |
+
|
361 |
+
|
362 |
+
|