Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,695 Bytes
f79db70 ab79902 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 49abc70 f79db70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import logging
import json
import re
import torch
from pathlib import Path
from unicodedata import category
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# Special tokens
SOT = "[START]"
EOT = "[STOP]"
UNK = "[UNK]"
SPACE = "[SPACE]"
SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
logger = logging.getLogger(__name__)
class EnTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str):
text_tokens = self.encode(text)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode( self, txt: str, verbose=False):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
txt = txt.replace(' ', SPACE)
code = self.tokenizer.encode(txt)
ids = code.ids
return ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt: str = self.tokenizer.decode(seq,
skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
txt = txt.replace(UNK, '')
return txt
# Model repository
REPO_ID = "ResembleAI/chatterbox"
# Global instances for optional dependencies
_kakasi = None
_dicta = None
def is_kanji(c: str) -> bool:
"""Check if character is kanji."""
return 19968 <= ord(c) <= 40959
def is_katakana(c: str) -> bool:
"""Check if character is katakana."""
return 12449 <= ord(c) <= 12538
def hiragana_normalize(text: str) -> str:
"""Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
global _kakasi
try:
if _kakasi is None:
import pykakasi
_kakasi = pykakasi.kakasi()
result = _kakasi.convert(text)
out = []
for r in result:
inp = r['orig']
hira = r["hira"]
# Any kanji in the phrase
if any([is_kanji(c) for c in inp]):
if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
hira = " " + hira
out.append(hira)
# All katakana
elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
out.append(r['orig'])
else:
out.append(inp)
normalized_text = "".join(out)
# Decompose Japanese characters for tokenizer compatibility
import unicodedata
normalized_text = unicodedata.normalize('NFKD', normalized_text)
return normalized_text
except ImportError:
logger.warning("pykakasi not available - Japanese text processing skipped")
return text
def add_hebrew_diacritics(text: str) -> str:
"""Hebrew text normalization: adds diacritics to Hebrew text."""
global _dicta
try:
if _dicta is None:
from dicta_onnx import Dicta
_dicta = Dicta()
return _dicta.add_diacritics(text)
except ImportError:
logger.warning("dicta_onnx not available - Hebrew text processing skipped")
return text
except Exception as e:
logger.warning(f"Hebrew diacritization failed: {e}")
return text
def korean_normalize(text: str) -> str:
"""Korean text normalization: decompose syllables into Jamo for tokenization."""
def decompose_hangul(char):
"""Decompose Korean syllable into Jamo components."""
if not ('\uac00' <= char <= '\ud7af'):
return char
# Hangul decomposition formula
base = ord(char) - 0xAC00
initial = chr(0x1100 + base // (21 * 28))
medial = chr(0x1161 + (base % (21 * 28)) // 28)
final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
return initial + medial + final
# Decompose syllables and normalize punctuation
result = ''.join(decompose_hangul(char) for char in text)
return result.strip()
class ChineseCangjieConverter:
"""Converts Chinese characters to Cangjie codes for tokenization."""
def __init__(self, model_dir=None):
self.word2cj = {}
self.cj2word = {}
self.segmenter = None
self._load_cangjie_mapping(model_dir)
self._init_segmenter()
def _load_cangjie_mapping(self, model_dir=None):
"""Load Cangjie mapping from HuggingFace model repository."""
try:
cangjie_file = hf_hub_download(
repo_id=REPO_ID,
filename="Cangjie5_TC.json",
cache_dir=model_dir
)
with open(cangjie_file, "r", encoding="utf-8") as fp:
data = json.load(fp)
for entry in data:
word, code = entry.split("\t")[:2]
self.word2cj[word] = code
if code not in self.cj2word:
self.cj2word[code] = [word]
else:
self.cj2word[code].append(word)
except Exception as e:
logger.warning(f"Could not load Cangjie mapping: {e}")
def _init_segmenter(self):
"""Initialize pkuseg segmenter."""
try:
from pkuseg import pkuseg
self.segmenter = pkuseg()
except ImportError:
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
self.segmenter = None
def _cangjie_encode(self, glyph: str):
"""Encode a single Chinese glyph to Cangjie code."""
normed_glyph = glyph
code = self.word2cj.get(normed_glyph, None)
if code is None: # e.g. Japanese hiragana
return None
index = self.cj2word[code].index(normed_glyph)
index = str(index) if index > 0 else ""
return code + str(index)
def __call__(self, text):
"""Convert Chinese characters in text to Cangjie tokens."""
output = []
if self.segmenter is not None:
segmented_words = self.segmenter.cut(text)
full_text = " ".join(segmented_words)
else:
full_text = text
for t in full_text:
if category(t) == "Lo":
cangjie = self._cangjie_encode(t)
if cangjie is None:
output.append(t)
continue
code = []
for c in cangjie:
code.append(f"[cj_{c}]")
code.append("[cj_.]")
code = "".join(code)
output.append(code)
else:
output.append(t)
return "".join(output)
class MTLTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
model_dir = Path(vocab_file_path).parent
self.cangjie_converter = ChineseCangjieConverter(model_dir)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str, language_id: str = None):
text_tokens = self.encode(text, language_id=language_id)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str, language_id: str = None):
# Language-specific text processing
if language_id == 'zh':
txt = self.cangjie_converter(txt)
elif language_id == 'ja':
txt = hiragana_normalize(txt)
elif language_id == 'he':
txt = add_hebrew_diacritics(txt)
elif language_id == 'ko':
txt = korean_normalize(txt)
# Prepend language token
if language_id:
txt = f"[{language_id.lower()}]{txt}"
txt = txt.replace(' ', SPACE)
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
return txt
|