from __future__ import annotations import json from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from tokenizers import Tokenizer _FORBIDDEN_PRETOKENIZERS = ( "WhiteSpace", "WhitespaceSplit", "BertPreTokenizer", "CharDelimiterSplit", "Punctuation", "Split", "UnicodeScripts", ) _BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False} def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None: """Fixes a single pretokenizer to allow multiword units.""" if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS: return None if pre_tokenizer["type"] == "ByteLevel": pre_tokenizer["add_prefix_space"] = True pre_tokenizer["use_regex"] = False if pre_tokenizer["type"] == "Metaspace": pre_tokenizer["split"] = False pre_tokenizer["prepend_scheme"] = "always" return pre_tokenizer def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer: """Fixes a single pretokenizer to allow multiword units.""" tokenizer_json = json.loads(tokenizer.to_str()) pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None) if pre_tokenizer_json is None: pre_tokenizer_json = _BASIC_METASPACE elif pre_tokenizer_json["type"] == "Sequence": new_pretokenizers = [] for single_pretokenizer in pre_tokenizer_json["pretokenizers"]: new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer) if new_pretokenizer is not None: new_pretokenizers.append(new_pretokenizer) if new_pretokenizers: pre_tokenizer_json["pretokenizers"] = new_pretokenizers else: pre_tokenizer_json = _BASIC_METASPACE pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE tokenizer_json["pre_tokenizer"] = pre_tokenizer_json return tokenizer.from_str(json.dumps(tokenizer_json))