import os import pickle import shutil from typing import Dict, Iterable, List, Optional, Sequence, Tuple from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub.utils import HfHubHTTPError from transformers import PreTrainedTokenizer class _BaseNanoGPTTokenizer: """Lightweight wrapper used by the base (non-chat) checkpoints.""" special_tokens = { "bos": "<|bos|>", "user_start": "<|user_start|>", "user_end": "<|user_end|>", "assistant_start": "<|assistant_start|>", "assistant_end": "<|assistant_end|>", "python_start": "<|python_start|>", "python_end": "<|python_end|>", "output_start": "<|output_start|>", "output_end": "<|output_end|>", } def __init__(self, enc): self.enc = enc self.bos_token_id = enc.encode_single_token(self.special_tokens["bos"]) @classmethod def register_for_auto_class(cls, auto_class="AutoTokenizer"): pass @classmethod def _load_encoding(cls, pretrained_model_name_or_path, **kwargs): subfolder = kwargs.get("subfolder") base_path = ( os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path ) local_tok_path = os.path.join(base_path, "tokenizer.pkl") if os.path.isfile(local_tok_path): with open(local_tok_path, "rb") as f: return pickle.load(f) snapshot_kwargs = {k: kwargs[k] for k in kwargs if k in { "cache_dir", "force_download", "local_files_only", "proxies", "resume_download", "revision", "token", "use_auth_token", }} token = snapshot_kwargs.pop("token", None) if token is None: token = snapshot_kwargs.pop("use_auth_token", None) if token is not None: snapshot_kwargs["token"] = token snapshot_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_kwargs) tok_path = os.path.join(snapshot_dir, subfolder, "tokenizer.pkl") if subfolder else os.path.join(snapshot_dir, "tokenizer.pkl") if not os.path.isfile(tok_path): try: tok_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="tokenizer.pkl", subfolder=subfolder, **snapshot_kwargs, ) except (HfHubHTTPError, OSError) as e: raise ValueError( f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. " f"Make sure the path exists or the repo is accessible on the Hub." ) from e with open(tok_path, "rb") as f: return pickle.load(f) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): enc = cls._load_encoding(pretrained_model_name_or_path, **kwargs) return cls(enc) def encode(self, text, prepend=None): ids = self.enc.encode_ordinary(text) if prepend is not None: prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) ids.insert(0, prepend_id) return ids def decode(self, ids): return self.enc.decode(ids) def get_bos_token_id(self): return self.bos_token_id def encode_special(self, token): return self.enc.encode_single_token(token) class NanoGPTTokenizer(_BaseNanoGPTTokenizer): pass class NanoGPTChatTokenizer(PreTrainedTokenizer): """Transformers-compatible tokenizer with chat helpers.""" vocab_files_names = {"vocab_file": "tokenizer.pkl"} model_input_names = ["input_ids"] _special_tokens = { "bos": "<|bos|>", "user_start": "<|user_start|>", "user_end": "<|user_end|>", "assistant_start": "<|assistant_start|>", "assistant_end": "<|assistant_end|>", "python_start": "<|python_start|>", "python_end": "<|python_end|>", "output_start": "<|output_start|>", "output_end": "<|output_end|>", } def __init__( self, vocab_file: str, bos_token: str = "<|bos|>", eos_token: str = "<|assistant_end|>", pad_token: Optional[str] = None, **kwargs, ) -> None: # Load encoding and build vocab mappings before parent init with open(vocab_file, "rb") as f: self.enc = pickle.load(f) self.vocab_file = vocab_file self.special_token_ids: Dict[str, int] = { name: self.enc.encode_single_token(token) for name, token in self._special_tokens.items() } self.bos_token_id = self.special_token_ids["bos"] self.eos_token_id = self.special_token_ids["assistant_end"] pad_token = pad_token or eos_token self.pad_token_id = self.special_token_ids["assistant_end"] self._build_vocabulary() super().__init__( bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, **kwargs, ) additional_special_tokens = [ token for key, token in self._special_tokens.items() if token not in {bos_token, eos_token, pad_token} ] if additional_special_tokens: self.add_special_tokens({"additional_special_tokens": additional_special_tokens}) self.chat_template = kwargs.get("chat_template", getattr(self, "chat_template", None)) # ------------------------------------------------------------------ # Core tokenizer API # ------------------------------------------------------------------ def _build_vocabulary(self) -> None: id_to_token: Dict[int, str] = {} token_to_id: Dict[str, int] = {} for idx in range(self.enc.n_vocab): token_bytes = self.enc.decode_single_token_bytes(idx) token_str = token_bytes.decode("utf-8", errors="replace") id_to_token[idx] = token_str token_to_id[token_str] = idx self._id_to_token = id_to_token self._token_to_id = token_to_id def get_vocab(self) -> Dict[str, int]: return dict(self._token_to_id) @property def vocab_size(self) -> int: # type: ignore[override] return self.enc.n_vocab def _tokenize(self, text: str, **kwargs) -> List[str]: ids = self.enc.encode_ordinary(text) return [self._id_to_token[i] for i in ids] def _convert_token_to_id(self, token: str) -> int: if token in self._token_to_id: return self._token_to_id[token] raise KeyError(f"Token not found in vocabulary: {token}") def _convert_id_to_token(self, index: int) -> str: return self._id_to_token[index] def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override] ids = [self._token_to_id[token] for token in tokens] return self.enc.decode(ids) def build_inputs_with_special_tokens( # type: ignore[override] self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: if token_ids_1 is not None: return token_ids_0 + token_ids_1 return token_ids_0 def get_special_tokens_mask( # type: ignore[override] self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: all_ids = token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1 return [1 if token in self.special_token_ids else 0 for token in all_ids] def num_special_tokens_to_add(self, pair: bool = False) -> int: # type: ignore[override] return 0 def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None, ) -> Tuple[str]: # type: ignore[override] os.makedirs(save_directory, exist_ok=True) filename = "tokenizer.pkl" if filename_prefix is not None: filename = f"{filename_prefix}-{filename}" save_path = os.path.join(save_directory, filename) shutil.copyfile(self.vocab_file, save_path) return (save_path,) # ------------------------------------------------------------------ # Chat helpers # ------------------------------------------------------------------ def encode_special(self, token: str) -> int: if token in self.special_token_ids: return self.special_token_ids[token] return self._token_to_id[token] def _encode_text(self, text: str) -> List[int]: return self.enc.encode_ordinary(text) def _encode_python_block(self, token_id: int, content: str) -> List[int]: tokens = [token_id] tokens.extend(self._encode_text(content)) closing = { self.special_token_ids["python_start"]: self.special_token_ids["python_end"], self.special_token_ids["output_start"]: self.special_token_ids["output_end"], }[token_id] tokens.append(closing) return tokens def _encode_assistant_content(self, content) -> List[int]: if isinstance(content, str): return self._encode_text(content) if isinstance(content, list): tokens: List[int] = [] for part in content: part_type = part.get("type", "text") text = part.get("text", "") if part_type == "text": tokens.extend(self._encode_text(text)) elif part_type == "python": tokens.extend( self._encode_python_block( self.special_token_ids["python_start"], text, ) ) elif part_type == "python_output": tokens.extend( self._encode_python_block( self.special_token_ids["output_start"], text, ) ) else: raise ValueError(f"Unknown assistant content part: {part_type}") return tokens raise ValueError(f"Unsupported assistant content type: {type(content)}") def _render_conversation_ids(self, conversation: Sequence[Dict[str, object]]) -> List[int]: if not conversation: raise ValueError("Conversation must contain at least one message") messages = list(conversation) if messages[0]["role"] == "system": if len(messages) < 2 or messages[1]["role"] != "user": raise ValueError("System message must be followed by a user message") merged = dict(messages[1]) merged["content"] = f"{messages[0]['content']}\n\n{messages[1]['content']}" messages = [merged] + messages[2:] ids: List[int] = [self.bos_token_id] for idx, message in enumerate(messages): expected_role = "user" if idx % 2 == 0 else "assistant" role = message.get("role") if role != expected_role: raise ValueError(f"Expected role {expected_role}, received {role} at index {idx}") content = message.get("content") if expected_role == "user": start = self.special_token_ids["user_start"] end = self.special_token_ids["user_end"] if not isinstance(content, str): raise ValueError("User messages must contain string content") ids.append(start) ids.extend(self._encode_text(content)) ids.append(end) else: start = self.special_token_ids["assistant_start"] end = self.special_token_ids["assistant_end"] ids.append(start) ids.extend(self._encode_assistant_content(content)) ids.append(end) return ids def apply_chat_template( # type: ignore[override] self, conversation, tokenize: bool = False, add_generation_prompt: bool = False, return_tensors: Optional[str] = None, padding: bool = False, truncation: bool = False, max_length: Optional[int] = None, **kwargs, ): if isinstance(conversation, dict) and "messages" in conversation: messages = conversation["messages"] else: messages = conversation token_ids = self._render_conversation_ids(messages) if add_generation_prompt: token_ids.append(self.special_token_ids["assistant_start"]) if tokenize: if return_tensors is not None: return self( [token_ids], add_special_tokens=False, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, **kwargs, ) return token_ids return self.decode(token_ids, skip_special_tokens=False) def encode_chat_message(self, role: str, content: str) -> List[int]: rendered = self.apply_chat_template( [ {"role": role, "content": content}, ], tokenize=True, add_generation_prompt=False, ) return rendered