|
import os |
|
import pickle |
|
import shutil |
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
from huggingface_hub.utils import HfHubHTTPError |
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
class _BaseNanoGPTTokenizer: |
|
"""Lightweight wrapper used by the base (non-chat) checkpoints.""" |
|
|
|
special_tokens = { |
|
"bos": "<|bos|>", |
|
"user_start": "<|user_start|>", |
|
"user_end": "<|user_end|>", |
|
"assistant_start": "<|assistant_start|>", |
|
"assistant_end": "<|assistant_end|>", |
|
"python_start": "<|python_start|>", |
|
"python_end": "<|python_end|>", |
|
"output_start": "<|output_start|>", |
|
"output_end": "<|output_end|>", |
|
} |
|
|
|
def __init__(self, enc): |
|
self.enc = enc |
|
self.bos_token_id = enc.encode_single_token(self.special_tokens["bos"]) |
|
|
|
@classmethod |
|
def register_for_auto_class(cls, auto_class="AutoTokenizer"): |
|
pass |
|
|
|
@classmethod |
|
def _load_encoding(cls, pretrained_model_name_or_path, **kwargs): |
|
subfolder = kwargs.get("subfolder") |
|
base_path = ( |
|
os.path.join(pretrained_model_name_or_path, subfolder) |
|
if subfolder |
|
else pretrained_model_name_or_path |
|
) |
|
local_tok_path = os.path.join(base_path, "tokenizer.pkl") |
|
if os.path.isfile(local_tok_path): |
|
with open(local_tok_path, "rb") as f: |
|
return pickle.load(f) |
|
|
|
snapshot_kwargs = {k: kwargs[k] for k in kwargs if k in { |
|
"cache_dir", |
|
"force_download", |
|
"local_files_only", |
|
"proxies", |
|
"resume_download", |
|
"revision", |
|
"token", |
|
"use_auth_token", |
|
}} |
|
token = snapshot_kwargs.pop("token", None) |
|
if token is None: |
|
token = snapshot_kwargs.pop("use_auth_token", None) |
|
if token is not None: |
|
snapshot_kwargs["token"] = token |
|
|
|
snapshot_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_kwargs) |
|
tok_path = os.path.join(snapshot_dir, subfolder, "tokenizer.pkl") if subfolder else os.path.join(snapshot_dir, "tokenizer.pkl") |
|
if not os.path.isfile(tok_path): |
|
try: |
|
tok_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="tokenizer.pkl", |
|
subfolder=subfolder, |
|
**snapshot_kwargs, |
|
) |
|
except (HfHubHTTPError, OSError) as e: |
|
raise ValueError( |
|
f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. " |
|
f"Make sure the path exists or the repo is accessible on the Hub." |
|
) from e |
|
with open(tok_path, "rb") as f: |
|
return pickle.load(f) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
enc = cls._load_encoding(pretrained_model_name_or_path, **kwargs) |
|
return cls(enc) |
|
|
|
def encode(self, text, prepend=None): |
|
ids = self.enc.encode_ordinary(text) |
|
if prepend is not None: |
|
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) |
|
ids.insert(0, prepend_id) |
|
return ids |
|
|
|
def decode(self, ids): |
|
return self.enc.decode(ids) |
|
|
|
def get_bos_token_id(self): |
|
return self.bos_token_id |
|
|
|
def encode_special(self, token): |
|
return self.enc.encode_single_token(token) |
|
|
|
|
|
class NanoGPTTokenizer(_BaseNanoGPTTokenizer): |
|
pass |
|
|
|
|
|
class NanoGPTChatTokenizer(PreTrainedTokenizer): |
|
"""Transformers-compatible tokenizer with chat helpers.""" |
|
|
|
vocab_files_names = {"vocab_file": "tokenizer.pkl"} |
|
model_input_names = ["input_ids"] |
|
|
|
_special_tokens = { |
|
"bos": "<|bos|>", |
|
"user_start": "<|user_start|>", |
|
"user_end": "<|user_end|>", |
|
"assistant_start": "<|assistant_start|>", |
|
"assistant_end": "<|assistant_end|>", |
|
"python_start": "<|python_start|>", |
|
"python_end": "<|python_end|>", |
|
"output_start": "<|output_start|>", |
|
"output_end": "<|output_end|>", |
|
} |
|
|
|
def __init__( |
|
self, |
|
vocab_file: str, |
|
bos_token: str = "<|bos|>", |
|
eos_token: str = "<|assistant_end|>", |
|
pad_token: Optional[str] = None, |
|
**kwargs, |
|
) -> None: |
|
|
|
with open(vocab_file, "rb") as f: |
|
self.enc = pickle.load(f) |
|
self.vocab_file = vocab_file |
|
|
|
self.special_token_ids: Dict[str, int] = { |
|
name: self.enc.encode_single_token(token) |
|
for name, token in self._special_tokens.items() |
|
} |
|
self.bos_token_id = self.special_token_ids["bos"] |
|
self.eos_token_id = self.special_token_ids["assistant_end"] |
|
pad_token = pad_token or eos_token |
|
self.pad_token_id = self.special_token_ids["assistant_end"] |
|
|
|
self._build_vocabulary() |
|
|
|
super().__init__( |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
pad_token=pad_token, |
|
**kwargs, |
|
) |
|
|
|
additional_special_tokens = [ |
|
token |
|
for key, token in self._special_tokens.items() |
|
if token not in {bos_token, eos_token, pad_token} |
|
] |
|
if additional_special_tokens: |
|
self.add_special_tokens({"additional_special_tokens": additional_special_tokens}) |
|
self.chat_template = kwargs.get("chat_template", getattr(self, "chat_template", None)) |
|
|
|
|
|
|
|
|
|
def _build_vocabulary(self) -> None: |
|
id_to_token: Dict[int, str] = {} |
|
token_to_id: Dict[str, int] = {} |
|
for idx in range(self.enc.n_vocab): |
|
token_bytes = self.enc.decode_single_token_bytes(idx) |
|
token_str = token_bytes.decode("utf-8", errors="replace") |
|
id_to_token[idx] = token_str |
|
token_to_id[token_str] = idx |
|
self._id_to_token = id_to_token |
|
self._token_to_id = token_to_id |
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
return dict(self._token_to_id) |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return self.enc.n_vocab |
|
|
|
def _tokenize(self, text: str, **kwargs) -> List[str]: |
|
ids = self.enc.encode_ordinary(text) |
|
return [self._id_to_token[i] for i in ids] |
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
if token in self._token_to_id: |
|
return self._token_to_id[token] |
|
raise KeyError(f"Token not found in vocabulary: {token}") |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self._id_to_token[index] |
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
ids = [self._token_to_id[token] for token in tokens] |
|
return self.enc.decode(ids) |
|
|
|
def build_inputs_with_special_tokens( |
|
self, |
|
token_ids_0: List[int], |
|
token_ids_1: Optional[List[int]] = None, |
|
) -> List[int]: |
|
if token_ids_1 is not None: |
|
return token_ids_0 + token_ids_1 |
|
return token_ids_0 |
|
|
|
def get_special_tokens_mask( |
|
self, |
|
token_ids_0: List[int], |
|
token_ids_1: Optional[List[int]] = None, |
|
) -> List[int]: |
|
all_ids = token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1 |
|
return [1 if token in self.special_token_ids else 0 for token in all_ids] |
|
|
|
def num_special_tokens_to_add(self, pair: bool = False) -> int: |
|
return 0 |
|
|
|
def save_vocabulary( |
|
self, |
|
save_directory: str, |
|
filename_prefix: Optional[str] = None, |
|
) -> Tuple[str]: |
|
os.makedirs(save_directory, exist_ok=True) |
|
filename = "tokenizer.pkl" |
|
if filename_prefix is not None: |
|
filename = f"{filename_prefix}-{filename}" |
|
save_path = os.path.join(save_directory, filename) |
|
shutil.copyfile(self.vocab_file, save_path) |
|
return (save_path,) |
|
|
|
|
|
|
|
|
|
def encode_special(self, token: str) -> int: |
|
if token in self.special_token_ids: |
|
return self.special_token_ids[token] |
|
return self._token_to_id[token] |
|
|
|
def _encode_text(self, text: str) -> List[int]: |
|
return self.enc.encode_ordinary(text) |
|
|
|
def _encode_python_block(self, token_id: int, content: str) -> List[int]: |
|
tokens = [token_id] |
|
tokens.extend(self._encode_text(content)) |
|
closing = { |
|
self.special_token_ids["python_start"]: self.special_token_ids["python_end"], |
|
self.special_token_ids["output_start"]: self.special_token_ids["output_end"], |
|
}[token_id] |
|
tokens.append(closing) |
|
return tokens |
|
|
|
def _encode_assistant_content(self, content) -> List[int]: |
|
if isinstance(content, str): |
|
return self._encode_text(content) |
|
if isinstance(content, list): |
|
tokens: List[int] = [] |
|
for part in content: |
|
part_type = part.get("type", "text") |
|
text = part.get("text", "") |
|
if part_type == "text": |
|
tokens.extend(self._encode_text(text)) |
|
elif part_type == "python": |
|
tokens.extend( |
|
self._encode_python_block( |
|
self.special_token_ids["python_start"], |
|
text, |
|
) |
|
) |
|
elif part_type == "python_output": |
|
tokens.extend( |
|
self._encode_python_block( |
|
self.special_token_ids["output_start"], |
|
text, |
|
) |
|
) |
|
else: |
|
raise ValueError(f"Unknown assistant content part: {part_type}") |
|
return tokens |
|
raise ValueError(f"Unsupported assistant content type: {type(content)}") |
|
|
|
def _render_conversation_ids(self, conversation: Sequence[Dict[str, object]]) -> List[int]: |
|
if not conversation: |
|
raise ValueError("Conversation must contain at least one message") |
|
messages = list(conversation) |
|
if messages[0]["role"] == "system": |
|
if len(messages) < 2 or messages[1]["role"] != "user": |
|
raise ValueError("System message must be followed by a user message") |
|
merged = dict(messages[1]) |
|
merged["content"] = f"{messages[0]['content']}\n\n{messages[1]['content']}" |
|
messages = [merged] + messages[2:] |
|
ids: List[int] = [self.bos_token_id] |
|
for idx, message in enumerate(messages): |
|
expected_role = "user" if idx % 2 == 0 else "assistant" |
|
role = message.get("role") |
|
if role != expected_role: |
|
raise ValueError(f"Expected role {expected_role}, received {role} at index {idx}") |
|
content = message.get("content") |
|
if expected_role == "user": |
|
start = self.special_token_ids["user_start"] |
|
end = self.special_token_ids["user_end"] |
|
if not isinstance(content, str): |
|
raise ValueError("User messages must contain string content") |
|
ids.append(start) |
|
ids.extend(self._encode_text(content)) |
|
ids.append(end) |
|
else: |
|
start = self.special_token_ids["assistant_start"] |
|
end = self.special_token_ids["assistant_end"] |
|
ids.append(start) |
|
ids.extend(self._encode_assistant_content(content)) |
|
ids.append(end) |
|
return ids |
|
|
|
def apply_chat_template( |
|
self, |
|
conversation, |
|
tokenize: bool = False, |
|
add_generation_prompt: bool = False, |
|
return_tensors: Optional[str] = None, |
|
padding: bool = False, |
|
truncation: bool = False, |
|
max_length: Optional[int] = None, |
|
**kwargs, |
|
): |
|
if isinstance(conversation, dict) and "messages" in conversation: |
|
messages = conversation["messages"] |
|
else: |
|
messages = conversation |
|
token_ids = self._render_conversation_ids(messages) |
|
if add_generation_prompt: |
|
token_ids.append(self.special_token_ids["assistant_start"]) |
|
if tokenize: |
|
if return_tensors is not None: |
|
return self( |
|
[token_ids], |
|
add_special_tokens=False, |
|
return_tensors=return_tensors, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
**kwargs, |
|
) |
|
return token_ids |
|
return self.decode(token_ids, skip_special_tokens=False) |
|
|
|
def encode_chat_message(self, role: str, content: str) -> List[int]: |
|
rendered = self.apply_chat_template( |
|
[ |
|
{"role": role, "content": content}, |
|
], |
|
tokenize=True, |
|
add_generation_prompt=False, |
|
) |
|
return rendered |
|
|
|
|
|
|
|
|
|
|