loocorez commited on
Commit
210c84c
·
verified ·
1 Parent(s): a9ae1bb

Upload tokenizer_nanogpt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenizer_nanogpt.py +70 -0
tokenizer_nanogpt.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from huggingface_hub import hf_hub_download
4
+ from huggingface_hub.utils import HfHubHTTPError
5
+
6
+
7
+ class NanoGPTTokenizer:
8
+ """Lightweight wrapper over a tiktoken Encoding stored in tokenizer.pkl.
9
+
10
+ Provides minimal encode/decode needed for inference and a from_pretrained
11
+ constructor so it can be loaded via AutoTokenizer with trust_remote_code.
12
+ """
13
+
14
+ def __init__(self, enc):
15
+ self.enc = enc
16
+ self.bos_token_id = enc.encode_single_token("<|bos|>")
17
+
18
+ @classmethod
19
+ def register_for_auto_class(cls, auto_class="AutoTokenizer"):
20
+ """Required for AutoTokenizer registration."""
21
+ pass
22
+
23
+ @classmethod
24
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
25
+ """
26
+ Load tokenizer from either:
27
+ - Local directory path
28
+ - Hugging Face Hub repo ID
29
+ - Cached directory (handled automatically)
30
+ """
31
+ # First, try to load from local path
32
+ local_tok_path = os.path.join(pretrained_model_name_or_path, "tokenizer.pkl")
33
+
34
+ if os.path.isfile(local_tok_path):
35
+ # Local file exists, load it directly
36
+ with open(local_tok_path, "rb") as f:
37
+ enc = pickle.load(f)
38
+ else:
39
+ # Try to download from Hugging Face Hub
40
+ try:
41
+ # This handles cache automatically and returns the cached file path
42
+ tok_path = hf_hub_download(
43
+ repo_id=pretrained_model_name_or_path,
44
+ filename="tokenizer.pkl"
45
+ )
46
+ with open(tok_path, "rb") as f:
47
+ enc = pickle.load(f)
48
+ except (HfHubHTTPError, OSError) as e:
49
+ raise ValueError(
50
+ f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. "
51
+ f"Make sure the path exists or the repo is accessible on the Hub."
52
+ ) from e
53
+
54
+ return cls(enc)
55
+
56
+ def encode(self, text, prepend=None):
57
+ ids = self.enc.encode_ordinary(text)
58
+ if prepend is not None:
59
+ prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
60
+ ids.insert(0, prepend_id)
61
+ return ids
62
+
63
+ def decode(self, ids):
64
+ return self.enc.decode(ids)
65
+
66
+ def get_bos_token_id(self):
67
+ return self.bos_token_id
68
+
69
+
70
+