|
import os |
|
import pickle |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import HfHubHTTPError |
|
|
|
|
|
class NanoGPTTokenizer: |
|
"""Lightweight wrapper over a tiktoken Encoding stored in tokenizer.pkl. |
|
|
|
Provides minimal encode/decode needed for inference and a from_pretrained |
|
constructor so it can be loaded via AutoTokenizer with trust_remote_code. |
|
""" |
|
|
|
def __init__(self, enc): |
|
self.enc = enc |
|
self.bos_token_id = enc.encode_single_token("<|bos|>") |
|
|
|
@classmethod |
|
def register_for_auto_class(cls, auto_class="AutoTokenizer"): |
|
"""Required for AutoTokenizer registration.""" |
|
pass |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
""" |
|
Load tokenizer from either: |
|
- Local directory path |
|
- Hugging Face Hub repo ID |
|
- Cached directory (handled automatically) |
|
""" |
|
|
|
local_tok_path = os.path.join(pretrained_model_name_or_path, "tokenizer.pkl") |
|
|
|
if os.path.isfile(local_tok_path): |
|
|
|
with open(local_tok_path, "rb") as f: |
|
enc = pickle.load(f) |
|
else: |
|
|
|
try: |
|
|
|
tok_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="tokenizer.pkl" |
|
) |
|
with open(tok_path, "rb") as f: |
|
enc = pickle.load(f) |
|
except (HfHubHTTPError, OSError) as e: |
|
raise ValueError( |
|
f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. " |
|
f"Make sure the path exists or the repo is accessible on the Hub." |
|
) from e |
|
|
|
return cls(enc) |
|
|
|
def encode(self, text, prepend=None): |
|
ids = self.enc.encode_ordinary(text) |
|
if prepend is not None: |
|
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) |
|
ids.insert(0, prepend_id) |
|
return ids |
|
|
|
def decode(self, ids): |
|
return self.enc.decode(ids) |
|
|
|
def get_bos_token_id(self): |
|
return self.bos_token_id |
|
|
|
|
|
|
|
|