File size: 2,387 Bytes
210c84c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import pickle
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError


class NanoGPTTokenizer:
    """Lightweight wrapper over a tiktoken Encoding stored in tokenizer.pkl.

    Provides minimal encode/decode needed for inference and a from_pretrained
    constructor so it can be loaded via AutoTokenizer with trust_remote_code.
    """

    def __init__(self, enc):
        self.enc = enc
        self.bos_token_id = enc.encode_single_token("<|bos|>")

    @classmethod
    def register_for_auto_class(cls, auto_class="AutoTokenizer"):
        """Required for AutoTokenizer registration."""
        pass

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        """
        Load tokenizer from either:
        - Local directory path
        - Hugging Face Hub repo ID
        - Cached directory (handled automatically)
        """
        # First, try to load from local path
        local_tok_path = os.path.join(pretrained_model_name_or_path, "tokenizer.pkl")

        if os.path.isfile(local_tok_path):
            # Local file exists, load it directly
            with open(local_tok_path, "rb") as f:
                enc = pickle.load(f)
        else:
            # Try to download from Hugging Face Hub
            try:
                # This handles cache automatically and returns the cached file path
                tok_path = hf_hub_download(
                    repo_id=pretrained_model_name_or_path,
                    filename="tokenizer.pkl"
                )
                with open(tok_path, "rb") as f:
                    enc = pickle.load(f)
            except (HfHubHTTPError, OSError) as e:
                raise ValueError(
                    f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. "
                    f"Make sure the path exists or the repo is accessible on the Hub."
                ) from e

        return cls(enc)

    def encode(self, text, prepend=None):
        ids = self.enc.encode_ordinary(text)
        if prepend is not None:
            prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
            ids.insert(0, prepend_id)
        return ids

    def decode(self, ids):
        return self.enc.decode(ids)

    def get_bos_token_id(self):
        return self.bos_token_id