from transformers import T5Tokenizer from typing import Dict, List, Optional, Union import os import logging logger = logging.getLogger(__name__) class Byt5LangTokenizer(T5Tokenizer): """ Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц. Используется для модели vikp/surya_table """ def __init__( self, vocab_file=None, tokenizer_file=None, eos_token="", unk_token="", pad_token="", extra_ids=100, additional_special_tokens=None, sp_model_kwargs=None, **kwargs ): super().__init__( vocab_file=vocab_file, tokenizer_file=tokenizer_file, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, extra_ids=extra_ids, additional_special_tokens=additional_special_tokens, sp_model_kwargs=sp_model_kwargs, **kwargs ) # Создаем byte_decoder — важно для ByT5 self.byte_decoder = {i: bytes([i]) for i in range(256)} # Добавляем специальные токены self.special_tokens = { eos_token: self.convert_token_to_id(eos_token), unk_token: self.convert_token_to_id(unk_token), pad_token: self.convert_token_to_id(pad_token), } # Реализуем отсутствующие атрибуты self.special_tokens_encoder = self.special_tokens self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} @property def vocab_size(self): return 256 + self.num_special_tokens def get_vocab(self) -> Dict[str, int]: vocab = {chr(i): i for i in range(256)} vocab.update(self.special_tokens_encoder) return vocab def _tokenize(self, text: str) -> List[Union[int, str]]: return list(text.encode("utf-8")) def _convert_token_to_id(self, token: Union[str, int]) -> int: if isinstance(token, str): if token in self.special_tokens_encoder: return self.special_tokens_encoder[token] else: try: return ord(token) except TypeError: return token return token def _convert_id_to_token(self, index: int) -> Union[str, int]: if index in self.special_tokens_decoder: return self.special_tokens_decoder[index] else: return chr(index) def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str: decoded = b"" for token in tokens: if isinstance(token, int): decoded += bytes([token]) else: decoded += token.encode("utf-8") return decoded.decode("utf-8", errors="replace")