MixtureofRecursionwithRouter / custom_tokenizer.py
Girinath11's picture
Create custom_tokenizer.py
613f2bb verified
import os
import json
import pickle
import argparse
from collections import Counter, defaultdict
from typing import List, Dict, Set, Optional, Tuple
import re
import unicodedata
class TechnicalTokenizer:
"""
Custom tokenizer optimized for technical content and conversations
"""
def __init__(self, vocab_size: int = 32000, min_freq: int = 2):
self.vocab_size = vocab_size
self.min_freq = min_freq
self.special_tokens = {
'<pad>': 0,
'<unk>': 1,
'<bos>': 2,
'<eos>': 3,
'<system>': 4,
'<user>': 5,
'<assistant>': 6,
'<|endoftext|>': 7,
'<|newline|>': 8,
'<|tab|>': 9,
'<|code|>': 10,
'<|/code|>': 11,
'<|math|>': 12,
'<|/math|>': 13
}
self.vocab = {}
self.id_to_token = {}
self.token_frequencies = Counter()
self.bpe_merges = []
self.bpe_cache = {}
self.code_pattern = re.compile(r'```[\s\S]*?```|`[^`]+`')
self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
self.number_pattern = re.compile(r'\b\d+\.?\d*\b')
self.technical_terms = {
'function', 'variable', 'array', 'object', 'class', 'method', 'parameter',
'return', 'import', 'export', 'async', 'await', 'promise', 'callback',
'algorithm', 'datatype', 'boolean', 'integer', 'string', 'float',
'javascript', 'python', 'java', 'cpp', 'html', 'css', 'sql',
'api', 'json', 'xml', 'http', 'https', 'rest', 'graphql',
'equation', 'formula', 'theorem', 'proof', 'hypothesis',
'derivative', 'integral', 'matrix', 'vector', 'polynomial',
'probability', 'statistics', 'correlation', 'regression',
'neural', 'network', 'model', 'training', 'validation', 'test',
'accuracy', 'precision', 'recall', 'f1score', 'loss', 'gradient',
'backpropagation', 'forward', 'layer', 'neuron', 'weight', 'bias',
'transformer', 'attention', 'embedding', 'tokenization',
'database', 'server', 'client', 'protocol', 'encryption', 'security',
'authentication', 'authorization', 'deployment', 'docker', 'kubernetes',
'microservice', 'architecture', 'scalability', 'performance'
}
self._init_vocab()
def _init_vocab(self):
self.vocab = self.special_tokens.copy()
self.id_to_token = {v: k for k, v in self.special_tokens.items()}
def normalize_text(self, text: str) -> str:
text = re.sub(r'\r\n|\r', '\n', text)
text = re.sub(r'\t', '<|tab|>', text)
text = unicodedata.normalize('NFKC', text)
code_blocks = []
def replace_code(match):
code_blocks.append(match.group())
return f'<|code|>CODE_BLOCK_{len(code_blocks)-1}<|/code|>'
text = self.code_pattern.sub(replace_code, text)
text = self.url_pattern.sub('<URL>', text)
text = self.email_pattern.sub('<EMAIL>', text)
for i, code_block in enumerate(code_blocks):
text = text.replace(f'<|code|>CODE_BLOCK_{i}<|/code|>', code_block)
return text
def pre_tokenize(self, text: str) -> List[str]:
text = self.normalize_text(text)
text = re.sub(r'<\|system\|>', ' <system> ', text)
text = re.sub(r'<\|user\|>', ' <user> ', text)
text = re.sub(r'<\|assistant\|>', ' <assistant> ', text)
text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text)
tokens = re.findall(r'''
<[^>]+>| # Special tokens
\b\w+@\w+\.\w+\b| # Email-like patterns
https?://\S+| # URLs
```[\s\S]*?```| # Code blocks
`[^`]+`| # Inline code
\b\d+\.?\d*\b| # Numbers
\b[a-zA-Z]+(?:'[a-z]*)?| # Words with optional apostrophes
[^\w\s] # Punctuation
''', text, re.VERBOSE)
return [token.strip() for token in tokens if token.strip()]
def get_pairs(self, word_freqs: Dict[Tuple[str, ...], int]) -> Counter:
pairs = Counter()
for word, freq in word_freqs.items():
if len(word) < 2:
continue
for i in range(len(word) - 1):
pair = (word[i], word[i + 1])
pairs[pair] += freq
return pairs
def merge_symbols(self, pair: Tuple[str, str], word_freqs: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
new_word_freqs = {}
bigram = pair
for word, freq in word_freqs.items():
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and (word[i], word[i + 1]) == bigram:
new_word.append(word[i] + word[i + 1])
i += 2
else:
new_word.append(word[i])
i += 1
new_word_freqs[tuple(new_word)] = freq
return new_word_freqs
def train_bpe(self, texts: List[str]) -> None:
print("Training BPE tokenizer...")
word_freqs = Counter()
for i, text in enumerate(texts):
if i % 10000 == 0:
print(f"Processing text {i}/{len(texts)}")
tokens = self.pre_tokenize(text)
for token in tokens:
char_seq = tuple(token)
if len(char_seq) > 0:
word_freqs[char_seq] += 1
print(f"Found {len(word_freqs)} unique word patterns")
word_freqs = {word: freq for word, freq in word_freqs.items() if freq >= self.min_freq}
for term in self.technical_terms:
if (term,) in word_freqs:
word_freqs[(term,)] *= 10
all_chars = set()
for word in word_freqs:
all_chars.update(word)
for char in sorted(all_chars):
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
self.id_to_token[len(self.id_to_token)] = char
target_vocab_size = self.vocab_size - len(self.special_tokens)
num_merges = target_vocab_size - len(self.vocab)
for i in range(num_merges):
if i % 1000 == 0:
print(f"BPE merge {i}/{num_merges}")
pairs = self.get_pairs(word_freqs)
if not pairs:
break
best_pair = pairs.most_common(1)[0][0]
word_freqs = self.merge_symbols(best_pair, word_freqs)
merged_token = best_pair[0] + best_pair[1]
if merged_token not in self.vocab:
self.vocab[merged_token] = len(self.vocab)
self.id_to_token[len(self.id_to_token)] = merged_token
self.bpe_merges.append(best_pair)
print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}")
for word, freq in word_freqs.items():
for token in word:
self.token_frequencies[token] += freq
def apply_bpe(self, word: str) -> List[str]:
if word in self.bpe_cache:
return self.bpe_cache[word]
tokens = list(word)
for merge in self.bpe_merges:
i = 0
while i < len(tokens) - 1:
if tokens[i] == merge[0] and tokens[i + 1] == merge[1]:
tokens = tokens[:i] + [merge[0] + merge[1]] + tokens[i + 2:]
else:
i += 1
self.bpe_cache[word] = tokens
return tokens
def tokenize(self, text: str) -> List[str]:
pre_tokens = self.pre_tokenize(text)
final_tokens = []
for token in pre_tokens:
if token in self.special_tokens or token in self.vocab:
final_tokens.append(token)
else:
bpe_tokens = self.apply_bpe(token)
final_tokens.extend(bpe_tokens)
return final_tokens
def encode_ids(self, text: str, add_special_tokens: bool = True) -> List[int]:
tokens = self.tokenize(text)
if add_special_tokens:
tokens = ['<bos>'] + tokens + ['<eos>']
ids = []
for token in tokens:
ids.append(self.vocab.get(token, self.vocab['<unk>']))
return ids
def decode_ids(self, ids: List[int], skip_special_tokens: bool = True) -> str:
tokens = []
for id in ids:
token = self.id_to_token.get(id, '<unk>')
if skip_special_tokens and token in self.special_tokens:
continue
tokens.append(token)
text = ''.join(tokens)
text = text.replace('<|tab|>', '\t')
text = text.replace('<|newline|>', '\n')
return text
def save(self, save_dir: str):
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, 'vocab.json'), 'w', encoding='utf-8') as f:
json.dump(self.vocab, f, indent=2, ensure_ascii=False)
with open(os.path.join(save_dir, 'merges.txt'), 'w', encoding='utf-8') as f:
for merge in self.bpe_merges:
f.write(f"{merge[0]} {merge[1]}\n")
config = {
'vocab_size': self.vocab_size,
'min_freq': self.min_freq,
'special_tokens': self.special_tokens,
'technical_terms': list(self.technical_terms)
}
with open(os.path.join(save_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
with open(os.path.join(save_dir, 'token_frequencies.pkl'), 'wb') as f:
pickle.dump(dict(self.token_frequencies), f)
print(f"Tokenizer saved to {save_dir}")
def load(self, save_dir: str):
with open(os.path.join(save_dir, 'vocab.json'), 'r', encoding='utf-8') as f:
self.vocab = json.load(f)
self.id_to_token = {v: k for k, v in self.vocab.items()}
with open(os.path.join(save_dir, 'merges.txt'), 'r', encoding='utf-8') as f:
self.bpe_merges = [tuple(line.strip().split()) for line in f if line.strip()]
config_file = os.path.join(save_dir, 'tokenizer_config.json')
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
self.vocab_size = config.get('vocab_size', self.vocab_size)
self.min_freq = config.get('min_freq', self.min_freq)
if 'technical_terms' in config:
self.technical_terms = set(config['technical_terms'])
freq_file = os.path.join(save_dir, 'token_frequencies.pkl')
if os.path.exists(freq_file):
with open(freq_file, 'rb') as f:
self.token_frequencies = Counter(pickle.load(f))
self.bpe_cache = {}
print(f"Tokenizer loaded from {save_dir}")
print(f"Vocabulary size: {len(self.vocab)}")
print(f"Number of BPE merges: {len(self.bpe_merges)}")
def get_vocab_size(self) -> int:
return len(self.vocab)
def get_token_frequency(self, token: str) -> int:
return self.token_frequencies.get(token, 0)
def analyze_tokenization(self, text: str):
tokens = self.tokenize(text)
ids = self.encode_ids(text, add_special_tokens=False)
print(f"Original text: {text}")
print(f"Tokens: {tokens}")
print(f"Token IDs: {ids}")
print(f"Number of tokens: {len(tokens)}")
print(f"Compression ratio: {len(text.split())/len(tokens):.2f}")
return tokens, ids
class ConversationDataset:
"""Dataset class for handling conversation data with the custom tokenizer"""
def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 512):
self.data_file = data_file
self.tokenizer = tokenizer
self.max_length = max_length
self.conversations = []
self.load_conversations()
def load_conversations(self):
print(f"Loading conversations from {self.data_file}")
if self.data_file.endswith('.jsonl'):
self.load_jsonl()
else:
self.load_text()
print(f"Loaded {len(self.conversations)} conversations")
def load_jsonl(self):
with open(self.data_file, 'r', encoding='utf-8') as f:
for line in f:
try:
conv = json.loads(line.strip())
messages = conv.get("messages", [])
if not messages:
continue
text_parts = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "").strip()
if not content:
continue
if role == "system":
continue
elif role == "user":
text_parts.append(f"<user> {content}")
elif role == "assistant":
text_parts.append(f"<assistant> {content}")
if len(text_parts) >= 2:
conversation_text = " ".join(text_parts) + " <|endoftext|>"
self.conversations.append(conversation_text)
except json.JSONDecodeError:
continue
def load_text(self):
with open(self.data_file, 'r', encoding='utf-8') as f:
content = f.read()
conversations = content.split('<|endoftext|>\n')
for conv in conversations:
conv = conv.strip()
if conv:
self.conversations.append(conv + " <|endoftext|>")
def get_tokenized_conversations(self, include_stats=False):
tokenized = []
stats = {'total_tokens': 0, 'truncated': 0, 'avg_length': 0}
for conv in self.conversations:
tokens = self.tokenizer.encode_ids(conv)
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
stats['truncated'] += 1
tokenized.append(tokens)
stats['total_tokens'] += len(tokens)
if tokenized:
stats['avg_length'] = stats['total_tokens'] / len(tokenized)
if include_stats:
return tokenized, stats
return tokenized
def create_training_examples(self, stride: int = None):
if stride is None:
stride = self.max_length // 2
examples = []
for conv in self.conversations:
tokens = self.tokenizer.encode_ids(conv)
if len(tokens) <= self.max_length:
examples.append(tokens)
else:
for i in range(0, len(tokens), stride):
window = tokens[i:i + self.max_length]
if len(window) >= 32:
examples.append(window)
return examples
def train_tokenizer_from_files(file_paths: List[str],
vocab_size: int = 32000,
min_freq: int = 2,
output_dir: str = "tokenizer",
max_texts: int = None):
print(f"Training tokenizer with vocab_size={vocab_size}")
print(f"Input files: {file_paths}")
all_texts = []
for file_path in file_paths:
print(f"Loading {file_path}...")
if file_path.endswith('.jsonl'):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
conv = json.loads(line.strip())
messages = conv.get("messages", [])
text_parts = []
for msg in messages:
content = msg.get("content", "").strip()
if content:
text_parts.append(content)
if text_parts:
all_texts.append(" ".join(text_parts))
except json.JSONDecodeError:
continue
else:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
chunks = content.split('\n\n')
for chunk in chunks:
if chunk.strip():
all_texts.append(chunk.strip())
print(f"Loaded {len(all_texts)} texts")
if max_texts and len(all_texts) > max_texts:
import random
random.shuffle(all_texts)
all_texts = all_texts[:max_texts]
print(f"Limited to {len(all_texts)} texts")
tokenizer = TechnicalTokenizer(vocab_size=vocab_size, min_freq=min_freq)
tokenizer.train_bpe(all_texts)
tokenizer.save(output_dir)
print("\nTesting tokenization on sample texts:")
test_texts = [
"Hello, how can I help you with your Python programming question?",
"The neural network has 3 hidden layers with ReLU activation functions.",
"```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```",
"The derivative of x^2 is 2x, and the integral is (x^3)/3 + C."
]
for text in test_texts:
tokenizer.analyze_tokenization(text)
print()
return tokenizer
def main():
parser = argparse.ArgumentParser(description="Train custom tokenizer for technical content")
parser.add_argument("--input_files", nargs='+', help="Input text/jsonl files")
parser.add_argument("--output_dir", default="tokenizer", help="Output directory for tokenizer")
parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size")
parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency")
parser.add_argument("--max_texts", type=int, help="Maximum number of texts to use for training")
parser.add_argument("--test_file", help="Test file for analyzing tokenization")
parser.add_argument("--load_tokenizer", help="Load existing tokenizer from directory")
args = parser.parse_args()
default_input_file = "/kaggle/input/gpt-based-slm-dataset/slm_training_complete.jsonl"
default_text_file = "/kaggle/working/text_data/training_data_chat.txt"
if not args.input_files and not args.load_tokenizer:
if os.path.exists(default_input_file):
args.input_files = [default_input_file]
print(f"No arguments provided, using default input file: {default_input_file}")
elif os.path.exists(default_text_file):
args.input_files = [default_text_file]
print(f"No arguments provided, using default text file: {default_text_file}")
else:
parser.error("No input files or tokenizer directory provided, and default files not found. "
"Please specify --input_files or --load_tokenizer.")
if args.load_tokenizer:
tokenizer = TechnicalTokenizer()
tokenizer.load(args.load_tokenizer)
if args.test_file:
print(f"\nTesting on {args.test_file}")
dataset = ConversationDataset(args.test_file, tokenizer)
tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
print(f"Dataset statistics:")
print(f" Total conversations: {len(tokenized)}")
print(f" Total tokens: {stats['total_tokens']:,}")
print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
print(f" Conversations truncated: {stats['truncated']}")
else:
tokenizer = train_tokenizer_from_files(
file_paths=args.input_files,
vocab_size=args.vocab_size,
min_freq=args.min_freq,
output_dir=args.output_dir,
max_texts=args.max_texts
)
if args.test_file:
print(f"\nTesting on {args.test_file}")
dataset = ConversationDataset(args.test_file, tokenizer)
tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
print(f"Dataset statistics:")
print(f" Total conversations: {len(tokenized)}")
print(f" Total tokens: {stats['total_tokens']:,}")
print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
print(f" Conversations truncated: {stats['truncated']}")
if __name__ == "__main__":
main()