|
from __future__ import annotations |
|
|
|
import json |
|
import logging |
|
from typing import TYPE_CHECKING, Any, cast |
|
|
|
from tokenizers import Tokenizer |
|
from transformers import PreTrainedTokenizerFast |
|
|
|
from distiller.model2vec.tokenizer.datamodels import Token |
|
from distiller.model2vec.tokenizer.model import process_tokenizer |
|
from distiller.model2vec.tokenizer.normalizer import replace_normalizer |
|
from distiller.model2vec.tokenizer.pretokenizer import replace_pretokenizer |
|
|
|
if TYPE_CHECKING: |
|
import re |
|
|
|
from tokenizers.normalizers import Normalizer |
|
from tokenizers.pre_tokenizers import ( |
|
PreTokenizer, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
_DEFAULT_POST_PROCESSOR_TEMPLATE = { |
|
"type": "TemplateProcessing", |
|
"single": [{"Sequence": {"id": "A", "type_id": 0}}], |
|
"pair": [{"Sequence": {"id": "A", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 0}}], |
|
"special_tokens": {}, |
|
} |
|
|
|
|
|
def _remap_added_tokens( |
|
special_tokens: list[dict[str, Any]], |
|
vocabulary: list[str], |
|
) -> list[dict[str, Any]]: |
|
""" |
|
Remap special tokens in the tokenizer. |
|
|
|
This function updates the special tokens in the tokenizer based on a mapping provided. |
|
It also ensures that the special tokens are present in the vocabulary. |
|
|
|
:param special_tokens: The special tokens to remap. |
|
:param vocabulary: The vocabulary as a list of tokens. |
|
:return: The updated special tokens. |
|
""" |
|
|
|
special_tokens = [{**x} for x in special_tokens] |
|
for token in special_tokens: |
|
token["id"] = vocabulary.index(token["content"]) |
|
|
|
return special_tokens |
|
|
|
|
|
def replace_vocabulary( |
|
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None |
|
) -> Tokenizer: |
|
"""Replace the vocabulary of a tokenizer with a new one.""" |
|
tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str()) |
|
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"] |
|
|
|
pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary] |
|
|
|
|
|
added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens) |
|
added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens) |
|
|
|
|
|
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}] |
|
tokenizer_json = process_tokenizer( |
|
tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None |
|
) |
|
|
|
|
|
tokenizer_json["added_tokens"] = _remap_added_tokens( |
|
special_tokens=tokenizer_json["added_tokens"], |
|
vocabulary=pre_tokenized_tokens, |
|
) |
|
tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE |
|
|
|
return Tokenizer.from_str(json.dumps(tokenizer_json)) |
|
|
|
|
|
def _rename_added_token( |
|
form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str] |
|
) -> list[dict[str, Any]]: |
|
"""Rename added tokens in the tokenizer.""" |
|
if form is None: |
|
return added_tokens |
|
|
|
idx = vocabulary.index(form) |
|
added_token = [x for x in added_tokens if x["content"] == form] |
|
if added_token: |
|
added_token[0]["id"] = idx |
|
added_token[0]["content"] = new_form |
|
vocabulary[idx] = new_form |
|
|
|
return added_tokens |
|
|
|
|
|
def clean_and_create_vocabulary( |
|
tokenizer: PreTrainedTokenizerFast, |
|
vocabulary: list[str], |
|
token_remove_regex: re.Pattern | None, |
|
) -> tuple[list[Token], Tokenizer]: |
|
"""Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary.""" |
|
seen_tokens = set() |
|
post_normalize_seen_tokens = set() |
|
n_empty = 0 |
|
n_duplicates = 0 |
|
|
|
backend_tokenizer = tokenizer.backend_tokenizer |
|
|
|
|
|
internal_vocab: dict[str, int] = tokenizer.get_vocab() |
|
internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])] |
|
|
|
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex) |
|
|
|
backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str()) |
|
backend_tokenizer = replace_normalizer(backend_tokenizer) |
|
|
|
internal_tokens_set = {token.form for token in cleaned_vocabulary} |
|
|
|
normalizer: Normalizer | None = backend_tokenizer.normalizer |
|
for token in vocabulary: |
|
if normalizer is not None: |
|
token = cast("str", normalizer.normalize_str(token)) |
|
|
|
if not token: |
|
n_empty += 1 |
|
continue |
|
|
|
pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer |
|
normalized_token = token |
|
if pre_tokenizer is not None: |
|
normalized_token = _normalize_vocabulary_token( |
|
token=token, |
|
pre_tokenizer=pre_tokenizer, |
|
) |
|
|
|
|
|
|
|
|
|
if normalized_token in seen_tokens or normalized_token in internal_tokens_set: |
|
n_duplicates += 1 |
|
continue |
|
|
|
|
|
seen_tokens.add(normalized_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not normalized_token.startswith(("▁", "Ġ")): |
|
normalized_token = normalized_token.replace(" ", "▁") |
|
normalized_token = f"▁{normalized_token}" |
|
else: |
|
normalized_token = normalized_token.replace(" ", normalized_token[0]) |
|
|
|
if normalized_token in post_normalize_seen_tokens: |
|
n_duplicates += 1 |
|
continue |
|
|
|
post_normalize_seen_tokens.add(normalized_token) |
|
|
|
cleaned_vocabulary.append( |
|
Token(form=token, normalized_form=normalized_token, is_subword=False, is_internal=False) |
|
) |
|
|
|
if n_duplicates: |
|
logger.warning(f"Removed {n_duplicates} duplicate tokens.") |
|
if n_empty: |
|
logger.warning(f"Removed {n_empty} empty tokens.") |
|
|
|
return cleaned_vocabulary, replace_pretokenizer(backend_tokenizer) |
|
|
|
|
|
def _process_internal_tokens( |
|
tokenizer: PreTrainedTokenizerFast, |
|
backend_tokenizer: Tokenizer, |
|
internal_tokens: list[str], |
|
token_remove_regex: re.Pattern | None, |
|
) -> list[Token]: |
|
"""Clean internal tokens.""" |
|
|
|
pad_token: str | None = tokenizer.special_tokens_map.get("pad_token") |
|
unk_token: str | None = tokenizer.special_tokens_map.get("unk_token") |
|
|
|
added_tokens_to_keep: set[str] = {x for x in (pad_token, unk_token) if x is not None} |
|
added_tokens_to_remove = set(tokenizer.added_tokens_encoder) - added_tokens_to_keep |
|
cleaned_internal_tokens: list[Token] = [] |
|
|
|
|
|
encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False) |
|
first_token, second_token, *_ = encoded.tokens |
|
|
|
|
|
|
|
a_index = None if "a" not in first_token else first_token.index("a") |
|
word_prefix = first_token[:a_index] |
|
is_byte_prefix = word_prefix == "Ġ" |
|
second_token = encoded.tokens[1] |
|
|
|
|
|
|
|
a_index = None if "a" not in second_token else second_token.index("a") |
|
subword_prefix = second_token[:a_index] |
|
|
|
pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer |
|
|
|
for token in internal_tokens: |
|
|
|
if token_object := _create_single_internal_token( |
|
token=token, |
|
subword_prefix=subword_prefix, |
|
word_prefix=word_prefix, |
|
pre_tokenizer=pre_tokenizer, |
|
is_byte_prefix=is_byte_prefix, |
|
token_remove_regex=token_remove_regex, |
|
added_tokens_to_keep=added_tokens_to_keep, |
|
added_tokens_to_remove=added_tokens_to_remove, |
|
): |
|
cleaned_internal_tokens.append(token_object) |
|
|
|
if len(cleaned_internal_tokens) != len(internal_tokens): |
|
logger.info( |
|
f"Removed {len(internal_tokens) - len(cleaned_internal_tokens)} internal tokens from the vocabulary." |
|
) |
|
|
|
return cleaned_internal_tokens |
|
|
|
|
|
def _create_single_internal_token( |
|
token: str, |
|
subword_prefix: str, |
|
word_prefix: str, |
|
pre_tokenizer: PreTokenizer | None, |
|
is_byte_prefix: bool, |
|
token_remove_regex: re.Pattern | None, |
|
added_tokens_to_keep: set[str], |
|
added_tokens_to_remove: set[str], |
|
) -> Token | None: |
|
"""Create a token object from a string.""" |
|
if token in added_tokens_to_remove: |
|
|
|
return None |
|
if token in added_tokens_to_keep: |
|
|
|
return Token(form=token, normalized_form=token, is_subword=False, is_internal=True) |
|
if token_remove_regex and token_remove_regex.match(token): |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_subword = False |
|
if subword_prefix: |
|
is_subword = bool(token.startswith(subword_prefix)) |
|
if word_prefix: |
|
is_subword = not bool(token.startswith(word_prefix)) |
|
|
|
|
|
if pre_tokenizer is not None and not is_byte_prefix: |
|
|
|
|
|
|
|
if (subword_prefix and not is_subword) or (word_prefix and is_subword): |
|
|
|
if len(pre_tokenizer.pre_tokenize_str(token)) > 1: |
|
return None |
|
|
|
|
|
normalized_form = _create_normalized_form(token, subword_prefix, word_prefix, is_byte_prefix, is_subword) |
|
|
|
return Token(form=token, normalized_form=normalized_form, is_subword=is_subword, is_internal=True) |
|
|
|
|
|
def _create_normalized_form( |
|
token: str, subword_prefix: str, word_prefix: str, is_byte_prefix: bool, is_subword: bool |
|
) -> str: |
|
"""Turn an internal token string into a normalized form.""" |
|
|
|
if is_byte_prefix: |
|
return token |
|
|
|
if is_subword: |
|
return token.removeprefix(subword_prefix) |
|
|
|
return f"▁{token.removeprefix(word_prefix)}" |
|
|
|
|
|
def turn_tokens_into_ids( |
|
tokens: list[Token], tokenizer: PreTrainedTokenizerFast, unk_token: str | None |
|
) -> list[list[int]]: |
|
""" |
|
Convert a list of Token objects to their corresponding token ID sequences. |
|
|
|
:param tokens: List of Token objects to convert |
|
:param tokenizer: The tokenizer to use for converting tokens to IDs |
|
:param unk_token: The string form of the unk token. |
|
:return: List of token IDs corresponding to the input tokens |
|
""" |
|
unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token) |
|
prefix, suffix = find_eos_bos(tokenizer) |
|
|
|
token_ids: list[list[int]] = [] |
|
for token in tokens: |
|
if token.is_internal: |
|
|
|
|
|
token_id: int = cast("int", tokenizer.convert_tokens_to_ids(token.form)) or 0 |
|
|
|
if unk_id is not None and token_id == unk_id and token.form != unk_token: |
|
logger.warning(f"Token {token.form} was set to unk. This is wrong.") |
|
token_ids.append([*prefix, token_id, *suffix]) |
|
else: |
|
token_ids.append(tokenizer.encode(token.form)) |
|
|
|
return token_ids |
|
|
|
|
|
def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]: |
|
"""Finds the eos and bos tokens for a tokenizer.""" |
|
|
|
encoding = tokenizer.encode("a", add_special_tokens=True) |
|
if len(encoding) != 3: |
|
a_encoded = tokenizer.encode("a", add_special_tokens=False) |
|
if len(a_encoded) != 1: |
|
msg = f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'" |
|
raise ValueError( |
|
msg |
|
) |
|
a_idx = encoding.index(a_encoded[0]) |
|
prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :] |
|
else: |
|
prefix, suffix = encoding[:1], encoding[2:] |
|
return prefix, suffix |
|
|
|
|
|
def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str: |
|
"""Normalize a token that is not in the initial token vocabulary.""" |
|
|
|
prefixed_token = f" {token}" |
|
pretokenized_tokens: tuple[str, ...] |
|
pretokenized_tokens, offsets = zip(*pre_tokenizer.pre_tokenize_str(prefixed_token), strict=False) |
|
|
|
new_token = [pretokenized_tokens[0]] |
|
|
|
for t, (s, _) in zip(pretokenized_tokens[1:], offsets[1:], strict=False): |
|
|
|
if t.startswith("▁"): |
|
new_token.append(t) |
|
|
|
|
|
|
|
|
|
|
|
elif prefixed_token[s - 1] == " ": |
|
new_token.append(f" {t}") |
|
else: |
|
new_token.append(t) |
|
return "".join(new_token) |
|
|
|
|
|
|
|
def create_tokenizer( |
|
tokenizer: PreTrainedTokenizerFast, |
|
vocabulary: list[str], |
|
token_remove_regex: re.Pattern | None = None, |
|
) -> PreTrainedTokenizerFast: |
|
""" |
|
Create a tokenizer by adding tokens to the vocabulary. |
|
|
|
This function turns any tokenizer into a supertoken tokenizer. It does the following: |
|
1. Turns the tokenizer model into a unigram model. |
|
2. Adds a new pretokenizer, splitting on punctuation. |
|
3. Adds all tokens in vocabulary to the model. |
|
4. Removes any internal tokens that conform to the regex. |
|
|
|
:param tokenizer: The tokenizer to use. |
|
:param vocabulary: The vocabulary to use. |
|
:param token_remove_regex: The regex to use to remove tokens from the vocabulary. |
|
:return: The created tokenizer. |
|
""" |
|
unk_token = cast("str | None", tokenizer.special_tokens_map.get("unk_token")) |
|
pad_token = cast("str | None", tokenizer.special_tokens_map.get("pad_token")) |
|
cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex) |
|
new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token) |
|
|
|
return PreTrainedTokenizerFast(tokenizer_object=new_tokenizer) |
|
|