CaroTTS-DE / char_tokenizers.py
Warholt's picture
Add app, track models via lfs
a1e382b
import logging
from typing import List
import unicodedata
from abc import ABC, abstractmethod
def normalize_unicode_text(text: str) -> str:
if not unicodedata.is_normalized("NFC", text):
text = unicodedata.normalize("NFC", text)
return text
def any_locale_text_preprocessing(text: str) -> str:
res = []
for c in normalize_unicode_text(text):
if c in ['’']:
res.append("'")
else:
res.append(c)
return ''.join(res)
class BaseTokenizer(ABC):
PAD, BLANK, OOV = '<pad>', '<blank>', '<oov>'
def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None):
"""Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens.
Args:
tokens: List of tokens.
pad: Pad token as string.
blank: Blank token as string.
oov: OOV token as string.
sep: Separation token as string.
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
if None then no blank in labels.
"""
super().__init__()
tokens = list(tokens)
# TODO @xueyang: in general, IDs of pad, sil, blank, and oov are preserved ahead instead of dynamically
# assigned according to the number of tokens. The downside of using dynamical assignment leads to different
# IDs for each.
self.pad, tokens = len(tokens), tokens + [pad] # Padding
if add_blank_at is not None:
self.blank, tokens = len(tokens), tokens + [blank] # Reserved for blank from asr-model
else:
# use add_blank_at=None only for ASR where blank is added automatically, disable blank here
self.blank = None
self.oov, tokens = len(tokens), tokens + [oov] # Out Of Vocabulary
if add_blank_at == "last":
tokens[-1], tokens[-2] = tokens[-2], tokens[-1]
self.oov, self.blank = self.blank, self.oov
self.tokens = tokens
self.sep = sep
self._util_ids = {self.pad, self.blank, self.oov}
self._token2id = {l: i for i, l in enumerate(tokens)}
self._id2token = tokens
def __call__(self, text: str) -> List[int]:
return self.encode(text)
@abstractmethod
def encode(self, text: str) -> List[int]:
"""Turns str text into int tokens."""
pass
def decode(self, tokens: List[int]) -> str:
"""Turns ints tokens into str text."""
return self.sep.join(self._id2token[t] for t in tokens if t not in self._util_ids)
class GermanCharsTokenizer(BaseTokenizer):
_PUNCT_LIST = ['!', '"', '(', ')', ',', '-', '.', '/', ':', ';', '?', '[', ']', '{', '}', '«', '»', '‒', '–', '—', '‘', '‚', '“', '„', '‹', '›']
_CHARSET_STR = 'ABCDEFGHIJKLMNOPQRSTUVWXYZÄÖÜẞabcdefghijklmnopqrstuvwxyzäöüß'
PUNCT_LIST = (
',', '.', '!', '?', '-',
':', ';', '/', '"', '(',
')', '[', ']', '{', '}',
)
def __init__(
self,
chars=_CHARSET_STR,
punct=True,
apostrophe=True,
add_blank_at=None,
pad_with_space=True,
non_default_punct_list=_PUNCT_LIST,
text_preprocessing_func=any_locale_text_preprocessing,
):
tokens = []
self.space, tokens = len(tokens), tokens + [' '] # Space
tokens.extend(chars)
if apostrophe:
tokens.append("'") # Apostrophe for saving "don't" and "Joe's"
if punct:
if non_default_punct_list is not None:
self.PUNCT_LIST = non_default_punct_list
tokens.extend(self.PUNCT_LIST)
super().__init__(tokens, add_blank_at=add_blank_at)
self.punct = punct
self.pad_with_space = pad_with_space
self.text_preprocessing_func = text_preprocessing_func
def encode(self, text):
"""See base class."""
cs, space, tokens = [], self.tokens[self.space], set(self.tokens)
text = self.text_preprocessing_func(text)
for c in text:
# Add a whitespace if the current char is a whitespace while the previous char is not a whitespace.
if c == space and len(cs) > 0 and cs[-1] != space:
cs.append(c)
# Add the current char that is an alphanumeric or an apostrophe.
elif (c.isalnum() or c == "'") and c in tokens:
cs.append(c)
# Add a punctuation that has a single char.
elif (c in self.PUNCT_LIST) and self.punct:
cs.append(c)
# Warn about unknown char
elif c != space:
logging.warning(f"Text: [{text}] contains unknown char: [{c}]. Symbol will be skipped.")
# Remove trailing spaces
if cs:
while cs[-1] == space:
cs.pop()
if self.pad_with_space:
cs = [space] + cs + [space]
return [self._token2id[p] for p in cs]