PyTorch
English
nanogpt
custom_code
Eval Results
rl-d20 / tokenizer_nanogpt.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
788c379 verified
import os
import pickle
import shutil
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.utils import HfHubHTTPError
from transformers import PreTrainedTokenizer
class _BaseNanoGPTTokenizer:
"""Lightweight wrapper used by the base (non-chat) checkpoints."""
special_tokens = {
"bos": "<|bos|>",
"user_start": "<|user_start|>",
"user_end": "<|user_end|>",
"assistant_start": "<|assistant_start|>",
"assistant_end": "<|assistant_end|>",
"python_start": "<|python_start|>",
"python_end": "<|python_end|>",
"output_start": "<|output_start|>",
"output_end": "<|output_end|>",
}
def __init__(self, enc):
self.enc = enc
self.bos_token_id = enc.encode_single_token(self.special_tokens["bos"])
@classmethod
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
pass
@classmethod
def _load_encoding(cls, pretrained_model_name_or_path, **kwargs):
subfolder = kwargs.get("subfolder")
base_path = (
os.path.join(pretrained_model_name_or_path, subfolder)
if subfolder
else pretrained_model_name_or_path
)
local_tok_path = os.path.join(base_path, "tokenizer.pkl")
if os.path.isfile(local_tok_path):
with open(local_tok_path, "rb") as f:
return pickle.load(f)
snapshot_kwargs = {k: kwargs[k] for k in kwargs if k in {
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"token",
"use_auth_token",
}}
token = snapshot_kwargs.pop("token", None)
if token is None:
token = snapshot_kwargs.pop("use_auth_token", None)
if token is not None:
snapshot_kwargs["token"] = token
snapshot_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_kwargs)
tok_path = os.path.join(snapshot_dir, subfolder, "tokenizer.pkl") if subfolder else os.path.join(snapshot_dir, "tokenizer.pkl")
if not os.path.isfile(tok_path):
try:
tok_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="tokenizer.pkl",
subfolder=subfolder,
**snapshot_kwargs,
)
except (HfHubHTTPError, OSError) as e:
raise ValueError(
f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. "
f"Make sure the path exists or the repo is accessible on the Hub."
) from e
with open(tok_path, "rb") as f:
return pickle.load(f)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
enc = cls._load_encoding(pretrained_model_name_or_path, **kwargs)
return cls(enc)
def encode(self, text, prepend=None):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
ids.insert(0, prepend_id)
return ids
def decode(self, ids):
return self.enc.decode(ids)
def get_bos_token_id(self):
return self.bos_token_id
def encode_special(self, token):
return self.enc.encode_single_token(token)
class NanoGPTTokenizer(_BaseNanoGPTTokenizer):
pass
class NanoGPTChatTokenizer(PreTrainedTokenizer):
"""Transformers-compatible tokenizer with chat helpers."""
vocab_files_names = {"vocab_file": "tokenizer.pkl"}
model_input_names = ["input_ids"]
_special_tokens = {
"bos": "<|bos|>",
"user_start": "<|user_start|>",
"user_end": "<|user_end|>",
"assistant_start": "<|assistant_start|>",
"assistant_end": "<|assistant_end|>",
"python_start": "<|python_start|>",
"python_end": "<|python_end|>",
"output_start": "<|output_start|>",
"output_end": "<|output_end|>",
}
def __init__(
self,
vocab_file: str,
bos_token: str = "<|bos|>",
eos_token: str = "<|assistant_end|>",
pad_token: Optional[str] = None,
**kwargs,
) -> None:
# Load encoding and build vocab mappings before parent init
with open(vocab_file, "rb") as f:
self.enc = pickle.load(f)
self.vocab_file = vocab_file
self.special_token_ids: Dict[str, int] = {
name: self.enc.encode_single_token(token)
for name, token in self._special_tokens.items()
}
self.bos_token_id = self.special_token_ids["bos"]
self.eos_token_id = self.special_token_ids["assistant_end"]
pad_token = pad_token or eos_token
self.pad_token_id = self.special_token_ids["assistant_end"]
self._build_vocabulary()
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
**kwargs,
)
additional_special_tokens = [
token
for key, token in self._special_tokens.items()
if token not in {bos_token, eos_token, pad_token}
]
if additional_special_tokens:
self.add_special_tokens({"additional_special_tokens": additional_special_tokens})
self.chat_template = kwargs.get("chat_template", getattr(self, "chat_template", None))
# ------------------------------------------------------------------
# Core tokenizer API
# ------------------------------------------------------------------
def _build_vocabulary(self) -> None:
id_to_token: Dict[int, str] = {}
token_to_id: Dict[str, int] = {}
for idx in range(self.enc.n_vocab):
token_bytes = self.enc.decode_single_token_bytes(idx)
token_str = token_bytes.decode("utf-8", errors="replace")
id_to_token[idx] = token_str
token_to_id[token_str] = idx
self._id_to_token = id_to_token
self._token_to_id = token_to_id
def get_vocab(self) -> Dict[str, int]:
return dict(self._token_to_id)
@property
def vocab_size(self) -> int: # type: ignore[override]
return self.enc.n_vocab
def _tokenize(self, text: str, **kwargs) -> List[str]:
ids = self.enc.encode_ordinary(text)
return [self._id_to_token[i] for i in ids]
def _convert_token_to_id(self, token: str) -> int:
if token in self._token_to_id:
return self._token_to_id[token]
raise KeyError(f"Token not found in vocabulary: {token}")
def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token[index]
def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override]
ids = [self._token_to_id[token] for token in tokens]
return self.enc.decode(ids)
def build_inputs_with_special_tokens( # type: ignore[override]
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
if token_ids_1 is not None:
return token_ids_0 + token_ids_1
return token_ids_0
def get_special_tokens_mask( # type: ignore[override]
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
all_ids = token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
return [1 if token in self.special_token_ids else 0 for token in all_ids]
def num_special_tokens_to_add(self, pair: bool = False) -> int: # type: ignore[override]
return 0
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> Tuple[str]: # type: ignore[override]
os.makedirs(save_directory, exist_ok=True)
filename = "tokenizer.pkl"
if filename_prefix is not None:
filename = f"{filename_prefix}-{filename}"
save_path = os.path.join(save_directory, filename)
shutil.copyfile(self.vocab_file, save_path)
return (save_path,)
# ------------------------------------------------------------------
# Chat helpers
# ------------------------------------------------------------------
def encode_special(self, token: str) -> int:
if token in self.special_token_ids:
return self.special_token_ids[token]
return self._token_to_id[token]
def _encode_text(self, text: str) -> List[int]:
return self.enc.encode_ordinary(text)
def _encode_python_block(self, token_id: int, content: str) -> List[int]:
tokens = [token_id]
tokens.extend(self._encode_text(content))
closing = {
self.special_token_ids["python_start"]: self.special_token_ids["python_end"],
self.special_token_ids["output_start"]: self.special_token_ids["output_end"],
}[token_id]
tokens.append(closing)
return tokens
def _encode_assistant_content(self, content) -> List[int]:
if isinstance(content, str):
return self._encode_text(content)
if isinstance(content, list):
tokens: List[int] = []
for part in content:
part_type = part.get("type", "text")
text = part.get("text", "")
if part_type == "text":
tokens.extend(self._encode_text(text))
elif part_type == "python":
tokens.extend(
self._encode_python_block(
self.special_token_ids["python_start"],
text,
)
)
elif part_type == "python_output":
tokens.extend(
self._encode_python_block(
self.special_token_ids["output_start"],
text,
)
)
else:
raise ValueError(f"Unknown assistant content part: {part_type}")
return tokens
raise ValueError(f"Unsupported assistant content type: {type(content)}")
def _render_conversation_ids(self, conversation: Sequence[Dict[str, object]]) -> List[int]:
if not conversation:
raise ValueError("Conversation must contain at least one message")
messages = list(conversation)
if messages[0]["role"] == "system":
if len(messages) < 2 or messages[1]["role"] != "user":
raise ValueError("System message must be followed by a user message")
merged = dict(messages[1])
merged["content"] = f"{messages[0]['content']}\n\n{messages[1]['content']}"
messages = [merged] + messages[2:]
ids: List[int] = [self.bos_token_id]
for idx, message in enumerate(messages):
expected_role = "user" if idx % 2 == 0 else "assistant"
role = message.get("role")
if role != expected_role:
raise ValueError(f"Expected role {expected_role}, received {role} at index {idx}")
content = message.get("content")
if expected_role == "user":
start = self.special_token_ids["user_start"]
end = self.special_token_ids["user_end"]
if not isinstance(content, str):
raise ValueError("User messages must contain string content")
ids.append(start)
ids.extend(self._encode_text(content))
ids.append(end)
else:
start = self.special_token_ids["assistant_start"]
end = self.special_token_ids["assistant_end"]
ids.append(start)
ids.extend(self._encode_assistant_content(content))
ids.append(end)
return ids
def apply_chat_template( # type: ignore[override]
self,
conversation,
tokenize: bool = False,
add_generation_prompt: bool = False,
return_tensors: Optional[str] = None,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
**kwargs,
):
if isinstance(conversation, dict) and "messages" in conversation:
messages = conversation["messages"]
else:
messages = conversation
token_ids = self._render_conversation_ids(messages)
if add_generation_prompt:
token_ids.append(self.special_token_ids["assistant_start"])
if tokenize:
if return_tensors is not None:
return self(
[token_ids],
add_special_tokens=False,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
**kwargs,
)
return token_ids
return self.decode(token_ids, skip_special_tokens=False)
def encode_chat_message(self, role: str, content: str) -> List[int]:
rendered = self.apply_chat_template(
[
{"role": role, "content": content},
],
tokenize=True,
add_generation_prompt=False,
)
return rendered