|
from __future__ import annotations |
|
|
|
import json |
|
from typing import TYPE_CHECKING, Any |
|
|
|
if TYPE_CHECKING: |
|
from tokenizers import Tokenizer |
|
|
|
_FORBIDDEN_PRETOKENIZERS = ( |
|
"WhiteSpace", |
|
"WhitespaceSplit", |
|
"BertPreTokenizer", |
|
"CharDelimiterSplit", |
|
"Punctuation", |
|
"Split", |
|
"UnicodeScripts", |
|
) |
|
_BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False} |
|
|
|
|
|
def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None: |
|
"""Fixes a single pretokenizer to allow multiword units.""" |
|
if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS: |
|
return None |
|
if pre_tokenizer["type"] == "ByteLevel": |
|
pre_tokenizer["add_prefix_space"] = True |
|
pre_tokenizer["use_regex"] = False |
|
if pre_tokenizer["type"] == "Metaspace": |
|
pre_tokenizer["split"] = False |
|
pre_tokenizer["prepend_scheme"] = "always" |
|
|
|
return pre_tokenizer |
|
|
|
|
|
def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer: |
|
"""Fixes a single pretokenizer to allow multiword units.""" |
|
tokenizer_json = json.loads(tokenizer.to_str()) |
|
pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None) |
|
|
|
if pre_tokenizer_json is None: |
|
pre_tokenizer_json = _BASIC_METASPACE |
|
|
|
elif pre_tokenizer_json["type"] == "Sequence": |
|
new_pretokenizers = [] |
|
for single_pretokenizer in pre_tokenizer_json["pretokenizers"]: |
|
new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer) |
|
if new_pretokenizer is not None: |
|
new_pretokenizers.append(new_pretokenizer) |
|
|
|
if new_pretokenizers: |
|
pre_tokenizer_json["pretokenizers"] = new_pretokenizers |
|
else: |
|
pre_tokenizer_json = _BASIC_METASPACE |
|
|
|
pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE |
|
tokenizer_json["pre_tokenizer"] = pre_tokenizer_json |
|
|
|
return tokenizer.from_str(json.dumps(tokenizer_json)) |
|
|