from __future__ import annotations import json import logging from typing import TYPE_CHECKING, Any, cast from tokenizers import Tokenizer from transformers import PreTrainedTokenizerFast from distiller.model2vec.tokenizer.datamodels import Token from distiller.model2vec.tokenizer.model import process_tokenizer from distiller.model2vec.tokenizer.normalizer import replace_normalizer from distiller.model2vec.tokenizer.pretokenizer import replace_pretokenizer if TYPE_CHECKING: import re from tokenizers.normalizers import Normalizer from tokenizers.pre_tokenizers import ( PreTokenizer, ) logger = logging.getLogger(__name__) _DEFAULT_POST_PROCESSOR_TEMPLATE = { "type": "TemplateProcessing", "single": [{"Sequence": {"id": "A", "type_id": 0}}], "pair": [{"Sequence": {"id": "A", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 0}}], "special_tokens": {}, } def _remap_added_tokens( special_tokens: list[dict[str, Any]], vocabulary: list[str], ) -> list[dict[str, Any]]: """ Remap special tokens in the tokenizer. This function updates the special tokens in the tokenizer based on a mapping provided. It also ensures that the special tokens are present in the vocabulary. :param special_tokens: The special tokens to remap. :param vocabulary: The vocabulary as a list of tokens. :return: The updated special tokens. """ # Deepcopy special_tokens = [{**x} for x in special_tokens] for token in special_tokens: token["id"] = vocabulary.index(token["content"]) return special_tokens def replace_vocabulary( tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None ) -> Tokenizer: """Replace the vocabulary of a tokenizer with a new one.""" tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str()) added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"] pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary] # We need to remove the added tokens but keep [UNK] and [PAD] tokens. added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens) added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens) # Remove old added tokens from added tokens tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}] tokenizer_json = process_tokenizer( tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None ) # Remap special tokens tokenizer_json["added_tokens"] = _remap_added_tokens( special_tokens=tokenizer_json["added_tokens"], vocabulary=pre_tokenized_tokens, ) tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE return Tokenizer.from_str(json.dumps(tokenizer_json)) def _rename_added_token( form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str] ) -> list[dict[str, Any]]: """Rename added tokens in the tokenizer.""" if form is None: return added_tokens idx = vocabulary.index(form) added_token = [x for x in added_tokens if x["content"] == form] if added_token: added_token[0]["id"] = idx added_token[0]["content"] = new_form vocabulary[idx] = new_form return added_tokens def clean_and_create_vocabulary( tokenizer: PreTrainedTokenizerFast, vocabulary: list[str], token_remove_regex: re.Pattern | None, ) -> tuple[list[Token], Tokenizer]: """Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary.""" seen_tokens = set() post_normalize_seen_tokens = set() n_empty = 0 n_duplicates = 0 backend_tokenizer = tokenizer.backend_tokenizer # Make a base list of tokens. internal_vocab: dict[str, int] = tokenizer.get_vocab() internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])] cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex) # Copy the backend tokenizer to avoid modifying the original. backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str()) backend_tokenizer = replace_normalizer(backend_tokenizer) internal_tokens_set = {token.form for token in cleaned_vocabulary} normalizer: Normalizer | None = backend_tokenizer.normalizer for token in vocabulary: if normalizer is not None: token = cast("str", normalizer.normalize_str(token)) if not token: n_empty += 1 continue pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer normalized_token = token if pre_tokenizer is not None: normalized_token = _normalize_vocabulary_token( token=token, pre_tokenizer=pre_tokenizer, ) # We need to check whether the pretokenized token is in the vocabulary. # But we need to return the original token, because that will be tokenized # again by the tokenizer during featurization. if normalized_token in seen_tokens or normalized_token in internal_tokens_set: n_duplicates += 1 continue # Add the possibly pretokenized token to seen seen_tokens.add(normalized_token) # After checking the token exists, we need to normalize it into the token # it will become. For byte tokens, this means we don't do anything. For # other types of tokens, we will insert a metaspace. # In the case of multiword tokens, we replace any spaces with the metaspace # or byte prefix token. if not normalized_token.startswith(("▁", "Ġ")): normalized_token = normalized_token.replace(" ", "▁") normalized_token = f"▁{normalized_token}" else: normalized_token = normalized_token.replace(" ", normalized_token[0]) if normalized_token in post_normalize_seen_tokens: n_duplicates += 1 continue post_normalize_seen_tokens.add(normalized_token) # Add the original string to the vocabulary. cleaned_vocabulary.append( Token(form=token, normalized_form=normalized_token, is_subword=False, is_internal=False) ) if n_duplicates: logger.warning(f"Removed {n_duplicates} duplicate tokens.") if n_empty: logger.warning(f"Removed {n_empty} empty tokens.") return cleaned_vocabulary, replace_pretokenizer(backend_tokenizer) def _process_internal_tokens( tokenizer: PreTrainedTokenizerFast, backend_tokenizer: Tokenizer, internal_tokens: list[str], token_remove_regex: re.Pattern | None, ) -> list[Token]: """Clean internal tokens.""" # Get the pad and unk token from the tokenizer. pad_token: str | None = tokenizer.special_tokens_map.get("pad_token") # type: ignore[assignment] unk_token: str | None = tokenizer.special_tokens_map.get("unk_token") # type: ignore[assignment] # Empty set if no pad or unk token is set. added_tokens_to_keep: set[str] = {x for x in (pad_token, unk_token) if x is not None} added_tokens_to_remove = set(tokenizer.added_tokens_encoder) - added_tokens_to_keep cleaned_internal_tokens: list[Token] = [] # Figure out whether token is a subword or not. encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False) first_token, second_token, *_ = encoded.tokens # Isolate the prefix. We can't do first_token[0] because we don't know # how long the prefix is. # e.g., "Ġaaaa" -> "Ġ" a_index = None if "a" not in first_token else first_token.index("a") word_prefix = first_token[:a_index] is_byte_prefix = word_prefix == "Ġ" second_token = encoded.tokens[1] # The second token is the first subword token. # If a tokenizer uses subwords, this token will have been prefixed. # We don't know how long the prefix is. a_index = None if "a" not in second_token else second_token.index("a") subword_prefix = second_token[:a_index] pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer for token in internal_tokens: # Create the token objects. If this returns None, it was unsucessful for some reason. if token_object := _create_single_internal_token( token=token, subword_prefix=subword_prefix, word_prefix=word_prefix, pre_tokenizer=pre_tokenizer, is_byte_prefix=is_byte_prefix, token_remove_regex=token_remove_regex, added_tokens_to_keep=added_tokens_to_keep, added_tokens_to_remove=added_tokens_to_remove, ): cleaned_internal_tokens.append(token_object) if len(cleaned_internal_tokens) != len(internal_tokens): logger.info( f"Removed {len(internal_tokens) - len(cleaned_internal_tokens)} internal tokens from the vocabulary." ) return cleaned_internal_tokens def _create_single_internal_token( token: str, subword_prefix: str, word_prefix: str, pre_tokenizer: PreTokenizer | None, is_byte_prefix: bool, token_remove_regex: re.Pattern | None, added_tokens_to_keep: set[str], added_tokens_to_remove: set[str], ) -> Token | None: """Create a token object from a string.""" if token in added_tokens_to_remove: # We remove any tokens that are added tokens that aren't [UNK] or [PAD]. return None if token in added_tokens_to_keep: # Don't put added tokens through the regular motions. return Token(form=token, normalized_form=token, is_subword=False, is_internal=True) if token_remove_regex and token_remove_regex.match(token): # If the regex matches, remove the token. return None # A token is a subword if there is a subword prefix and the word # starts with a subword prefix, or if there is a WORD prefix, and the word # does not start with this prefix. For metaspace tokenizers, for example: # "doghouse" -> ["_dog", "house"] # So we can only tell that "house" is a subword by knowing that it is not prefixed # and word-initial tokens are. is_subword = False if subword_prefix: is_subword = bool(token.startswith(subword_prefix)) if word_prefix: is_subword = not bool(token.startswith(word_prefix)) # Byte prefixed tokenizers don't need to be checked. if pre_tokenizer is not None and not is_byte_prefix: # We need to check the thing without prefixes. If we have a word prefix, # we need to check tokens that have are subwords. Other way around for subword # prefixes. if (subword_prefix and not is_subword) or (word_prefix and is_subword): # If this is True, the token is unreachable, even though it is a subword token. if len(pre_tokenizer.pre_tokenize_str(token)) > 1: return None # Turn a token into a normalized form for later processing. normalized_form = _create_normalized_form(token, subword_prefix, word_prefix, is_byte_prefix, is_subword) return Token(form=token, normalized_form=normalized_form, is_subword=is_subword, is_internal=True) def _create_normalized_form( token: str, subword_prefix: str, word_prefix: str, is_byte_prefix: bool, is_subword: bool ) -> str: """Turn an internal token string into a normalized form.""" # We don't need to check byte prefixed strings. if is_byte_prefix: return token # We need to check if the token is a subword or not and remove the prefix. if is_subword: return token.removeprefix(subword_prefix) # If the token is not a subword, we need to remove the word prefix, and add metaspace. return f"▁{token.removeprefix(word_prefix)}" def turn_tokens_into_ids( tokens: list[Token], tokenizer: PreTrainedTokenizerFast, unk_token: str | None ) -> list[list[int]]: """ Convert a list of Token objects to their corresponding token ID sequences. :param tokens: List of Token objects to convert :param tokenizer: The tokenizer to use for converting tokens to IDs :param unk_token: The string form of the unk token. :return: List of token IDs corresponding to the input tokens """ unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token) prefix, suffix = find_eos_bos(tokenizer) token_ids: list[list[int]] = [] for token in tokens: if token.is_internal: # Careful. Any incorrect tokens will just get `[UNK]``, so this could go horribly wrong # Cast because return type is wrong. token_id: int = cast("int", tokenizer.convert_tokens_to_ids(token.form)) or 0 # Explicitly check and warn if `unk_id` appears, but don't crash. if unk_id is not None and token_id == unk_id and token.form != unk_token: logger.warning(f"Token {token.form} was set to unk. This is wrong.") token_ids.append([*prefix, token_id, *suffix]) else: token_ids.append(tokenizer.encode(token.form)) return token_ids def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]: """Finds the eos and bos tokens for a tokenizer.""" # Little bit complicated, because not all tokenizers have eos and bos tokens. encoding = tokenizer.encode("a", add_special_tokens=True) if len(encoding) != 3: a_encoded = tokenizer.encode("a", add_special_tokens=False) if len(a_encoded) != 1: msg = f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'" raise ValueError( msg ) a_idx = encoding.index(a_encoded[0]) prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :] else: prefix, suffix = encoding[:1], encoding[2:] return prefix, suffix def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str: """Normalize a token that is not in the initial token vocabulary.""" # Add prefix space for byte tokenizers. prefixed_token = f" {token}" pretokenized_tokens: tuple[str, ...] pretokenized_tokens, offsets = zip(*pre_tokenizer.pre_tokenize_str(prefixed_token), strict=False) # The first item is always the start of the token. new_token = [pretokenized_tokens[0]] # Loop over the subtokens and offsets. for t, (s, _) in zip(pretokenized_tokens[1:], offsets[1:], strict=False): # Do not prefix the token with a space if it starts with a metaspace. if t.startswith("▁"): new_token.append(t) # If the character before the subtoken is a space, we have a # multiword token. e.g., "room for the moon", which is split into # ["room", "for", "the", "moon"]. # If it doesn't have a space, it is part of a complex multiword token, # e.g., "chat-gpt", which is split into ["chat", "-", "gpt"]. elif prefixed_token[s - 1] == " ": new_token.append(f" {t}") else: new_token.append(t) return "".join(new_token) def create_tokenizer( tokenizer: PreTrainedTokenizerFast, vocabulary: list[str], token_remove_regex: re.Pattern | None = None, ) -> PreTrainedTokenizerFast: """ Create a tokenizer by adding tokens to the vocabulary. This function turns any tokenizer into a supertoken tokenizer. It does the following: 1. Turns the tokenizer model into a unigram model. 2. Adds a new pretokenizer, splitting on punctuation. 3. Adds all tokens in vocabulary to the model. 4. Removes any internal tokens that conform to the regex. :param tokenizer: The tokenizer to use. :param vocabulary: The vocabulary to use. :param token_remove_regex: The regex to use to remove tokens from the vocabulary. :return: The created tokenizer. """ unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token")) pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token")) cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex) new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token) return PreTrainedTokenizerFast(tokenizer_object=new_tokenizer)