import os import json import pickle import argparse from collections import Counter, defaultdict from typing import List, Dict, Set, Optional, Tuple import re import unicodedata class TechnicalTokenizer: """ Custom tokenizer optimized for technical content and conversations """ def __init__(self, vocab_size: int = 32000, min_freq: int = 2): self.vocab_size = vocab_size self.min_freq = min_freq self.special_tokens = { '': 0, '': 1, '': 2, '': 3, '': 4, '': 5, '': 6, '<|endoftext|>': 7, '<|newline|>': 8, '<|tab|>': 9, '<|code|>': 10, '<|/code|>': 11, '<|math|>': 12, '<|/math|>': 13 } self.vocab = {} self.id_to_token = {} self.token_frequencies = Counter() self.bpe_merges = [] self.bpe_cache = {} self.code_pattern = re.compile(r'```[\s\S]*?```|`[^`]+`') self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+') self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b') self.number_pattern = re.compile(r'\b\d+\.?\d*\b') self.technical_terms = { 'function', 'variable', 'array', 'object', 'class', 'method', 'parameter', 'return', 'import', 'export', 'async', 'await', 'promise', 'callback', 'algorithm', 'datatype', 'boolean', 'integer', 'string', 'float', 'javascript', 'python', 'java', 'cpp', 'html', 'css', 'sql', 'api', 'json', 'xml', 'http', 'https', 'rest', 'graphql', 'equation', 'formula', 'theorem', 'proof', 'hypothesis', 'derivative', 'integral', 'matrix', 'vector', 'polynomial', 'probability', 'statistics', 'correlation', 'regression', 'neural', 'network', 'model', 'training', 'validation', 'test', 'accuracy', 'precision', 'recall', 'f1score', 'loss', 'gradient', 'backpropagation', 'forward', 'layer', 'neuron', 'weight', 'bias', 'transformer', 'attention', 'embedding', 'tokenization', 'database', 'server', 'client', 'protocol', 'encryption', 'security', 'authentication', 'authorization', 'deployment', 'docker', 'kubernetes', 'microservice', 'architecture', 'scalability', 'performance' } self._init_vocab() def _init_vocab(self): self.vocab = self.special_tokens.copy() self.id_to_token = {v: k for k, v in self.special_tokens.items()} def normalize_text(self, text: str) -> str: text = re.sub(r'\r\n|\r', '\n', text) text = re.sub(r'\t', '<|tab|>', text) text = unicodedata.normalize('NFKC', text) code_blocks = [] def replace_code(match): code_blocks.append(match.group()) return f'<|code|>CODE_BLOCK_{len(code_blocks)-1}<|/code|>' text = self.code_pattern.sub(replace_code, text) text = self.url_pattern.sub('', text) text = self.email_pattern.sub('', text) for i, code_block in enumerate(code_blocks): text = text.replace(f'<|code|>CODE_BLOCK_{i}<|/code|>', code_block) return text def pre_tokenize(self, text: str) -> List[str]: text = self.normalize_text(text) text = re.sub(r'<\|system\|>', ' ', text) text = re.sub(r'<\|user\|>', ' ', text) text = re.sub(r'<\|assistant\|>', ' ', text) text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text) tokens = re.findall(r''' <[^>]+>| # Special tokens \b\w+@\w+\.\w+\b| # Email-like patterns https?://\S+| # URLs ```[\s\S]*?```| # Code blocks `[^`]+`| # Inline code \b\d+\.?\d*\b| # Numbers \b[a-zA-Z]+(?:'[a-z]*)?| # Words with optional apostrophes [^\w\s] # Punctuation ''', text, re.VERBOSE) return [token.strip() for token in tokens if token.strip()] def get_pairs(self, word_freqs: Dict[Tuple[str, ...], int]) -> Counter: pairs = Counter() for word, freq in word_freqs.items(): if len(word) < 2: continue for i in range(len(word) - 1): pair = (word[i], word[i + 1]) pairs[pair] += freq return pairs def merge_symbols(self, pair: Tuple[str, str], word_freqs: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]: new_word_freqs = {} bigram = pair for word, freq in word_freqs.items(): new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and (word[i], word[i + 1]) == bigram: new_word.append(word[i] + word[i + 1]) i += 2 else: new_word.append(word[i]) i += 1 new_word_freqs[tuple(new_word)] = freq return new_word_freqs def train_bpe(self, texts: List[str]) -> None: print("Training BPE tokenizer...") word_freqs = Counter() for i, text in enumerate(texts): if i % 10000 == 0: print(f"Processing text {i}/{len(texts)}") tokens = self.pre_tokenize(text) for token in tokens: char_seq = tuple(token) if len(char_seq) > 0: word_freqs[char_seq] += 1 print(f"Found {len(word_freqs)} unique word patterns") word_freqs = {word: freq for word, freq in word_freqs.items() if freq >= self.min_freq} for term in self.technical_terms: if (term,) in word_freqs: word_freqs[(term,)] *= 10 all_chars = set() for word in word_freqs: all_chars.update(word) for char in sorted(all_chars): if char not in self.vocab: self.vocab[char] = len(self.vocab) self.id_to_token[len(self.id_to_token)] = char target_vocab_size = self.vocab_size - len(self.special_tokens) num_merges = target_vocab_size - len(self.vocab) for i in range(num_merges): if i % 1000 == 0: print(f"BPE merge {i}/{num_merges}") pairs = self.get_pairs(word_freqs) if not pairs: break best_pair = pairs.most_common(1)[0][0] word_freqs = self.merge_symbols(best_pair, word_freqs) merged_token = best_pair[0] + best_pair[1] if merged_token not in self.vocab: self.vocab[merged_token] = len(self.vocab) self.id_to_token[len(self.id_to_token)] = merged_token self.bpe_merges.append(best_pair) print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}") for word, freq in word_freqs.items(): for token in word: self.token_frequencies[token] += freq def apply_bpe(self, word: str) -> List[str]: if word in self.bpe_cache: return self.bpe_cache[word] tokens = list(word) for merge in self.bpe_merges: i = 0 while i < len(tokens) - 1: if tokens[i] == merge[0] and tokens[i + 1] == merge[1]: tokens = tokens[:i] + [merge[0] + merge[1]] + tokens[i + 2:] else: i += 1 self.bpe_cache[word] = tokens return tokens def tokenize(self, text: str) -> List[str]: pre_tokens = self.pre_tokenize(text) final_tokens = [] for token in pre_tokens: if token in self.special_tokens or token in self.vocab: final_tokens.append(token) else: bpe_tokens = self.apply_bpe(token) final_tokens.extend(bpe_tokens) return final_tokens def encode_ids(self, text: str, add_special_tokens: bool = True) -> List[int]: tokens = self.tokenize(text) if add_special_tokens: tokens = [''] + tokens + [''] ids = [] for token in tokens: ids.append(self.vocab.get(token, self.vocab[''])) return ids def decode_ids(self, ids: List[int], skip_special_tokens: bool = True) -> str: tokens = [] for id in ids: token = self.id_to_token.get(id, '') if skip_special_tokens and token in self.special_tokens: continue tokens.append(token) text = ''.join(tokens) text = text.replace('<|tab|>', '\t') text = text.replace('<|newline|>', '\n') return text def save(self, save_dir: str): os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, 'vocab.json'), 'w', encoding='utf-8') as f: json.dump(self.vocab, f, indent=2, ensure_ascii=False) with open(os.path.join(save_dir, 'merges.txt'), 'w', encoding='utf-8') as f: for merge in self.bpe_merges: f.write(f"{merge[0]} {merge[1]}\n") config = { 'vocab_size': self.vocab_size, 'min_freq': self.min_freq, 'special_tokens': self.special_tokens, 'technical_terms': list(self.technical_terms) } with open(os.path.join(save_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as f: json.dump(config, f, indent=2, ensure_ascii=False) with open(os.path.join(save_dir, 'token_frequencies.pkl'), 'wb') as f: pickle.dump(dict(self.token_frequencies), f) print(f"Tokenizer saved to {save_dir}") def load(self, save_dir: str): with open(os.path.join(save_dir, 'vocab.json'), 'r', encoding='utf-8') as f: self.vocab = json.load(f) self.id_to_token = {v: k for k, v in self.vocab.items()} with open(os.path.join(save_dir, 'merges.txt'), 'r', encoding='utf-8') as f: self.bpe_merges = [tuple(line.strip().split()) for line in f if line.strip()] config_file = os.path.join(save_dir, 'tokenizer_config.json') if os.path.exists(config_file): with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) self.vocab_size = config.get('vocab_size', self.vocab_size) self.min_freq = config.get('min_freq', self.min_freq) if 'technical_terms' in config: self.technical_terms = set(config['technical_terms']) freq_file = os.path.join(save_dir, 'token_frequencies.pkl') if os.path.exists(freq_file): with open(freq_file, 'rb') as f: self.token_frequencies = Counter(pickle.load(f)) self.bpe_cache = {} print(f"Tokenizer loaded from {save_dir}") print(f"Vocabulary size: {len(self.vocab)}") print(f"Number of BPE merges: {len(self.bpe_merges)}") def get_vocab_size(self) -> int: return len(self.vocab) def get_token_frequency(self, token: str) -> int: return self.token_frequencies.get(token, 0) def analyze_tokenization(self, text: str): tokens = self.tokenize(text) ids = self.encode_ids(text, add_special_tokens=False) print(f"Original text: {text}") print(f"Tokens: {tokens}") print(f"Token IDs: {ids}") print(f"Number of tokens: {len(tokens)}") print(f"Compression ratio: {len(text.split())/len(tokens):.2f}") return tokens, ids class ConversationDataset: """Dataset class for handling conversation data with the custom tokenizer""" def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 512): self.data_file = data_file self.tokenizer = tokenizer self.max_length = max_length self.conversations = [] self.load_conversations() def load_conversations(self): print(f"Loading conversations from {self.data_file}") if self.data_file.endswith('.jsonl'): self.load_jsonl() else: self.load_text() print(f"Loaded {len(self.conversations)} conversations") def load_jsonl(self): with open(self.data_file, 'r', encoding='utf-8') as f: for line in f: try: conv = json.loads(line.strip()) messages = conv.get("messages", []) if not messages: continue text_parts = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "").strip() if not content: continue if role == "system": continue elif role == "user": text_parts.append(f" {content}") elif role == "assistant": text_parts.append(f" {content}") if len(text_parts) >= 2: conversation_text = " ".join(text_parts) + " <|endoftext|>" self.conversations.append(conversation_text) except json.JSONDecodeError: continue def load_text(self): with open(self.data_file, 'r', encoding='utf-8') as f: content = f.read() conversations = content.split('<|endoftext|>\n') for conv in conversations: conv = conv.strip() if conv: self.conversations.append(conv + " <|endoftext|>") def get_tokenized_conversations(self, include_stats=False): tokenized = [] stats = {'total_tokens': 0, 'truncated': 0, 'avg_length': 0} for conv in self.conversations: tokens = self.tokenizer.encode_ids(conv) if len(tokens) > self.max_length: tokens = tokens[:self.max_length] stats['truncated'] += 1 tokenized.append(tokens) stats['total_tokens'] += len(tokens) if tokenized: stats['avg_length'] = stats['total_tokens'] / len(tokenized) if include_stats: return tokenized, stats return tokenized def create_training_examples(self, stride: int = None): if stride is None: stride = self.max_length // 2 examples = [] for conv in self.conversations: tokens = self.tokenizer.encode_ids(conv) if len(tokens) <= self.max_length: examples.append(tokens) else: for i in range(0, len(tokens), stride): window = tokens[i:i + self.max_length] if len(window) >= 32: examples.append(window) return examples def train_tokenizer_from_files(file_paths: List[str], vocab_size: int = 32000, min_freq: int = 2, output_dir: str = "tokenizer", max_texts: int = None): print(f"Training tokenizer with vocab_size={vocab_size}") print(f"Input files: {file_paths}") all_texts = [] for file_path in file_paths: print(f"Loading {file_path}...") if file_path.endswith('.jsonl'): with open(file_path, 'r', encoding='utf-8') as f: for line in f: try: conv = json.loads(line.strip()) messages = conv.get("messages", []) text_parts = [] for msg in messages: content = msg.get("content", "").strip() if content: text_parts.append(content) if text_parts: all_texts.append(" ".join(text_parts)) except json.JSONDecodeError: continue else: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() chunks = content.split('\n\n') for chunk in chunks: if chunk.strip(): all_texts.append(chunk.strip()) print(f"Loaded {len(all_texts)} texts") if max_texts and len(all_texts) > max_texts: import random random.shuffle(all_texts) all_texts = all_texts[:max_texts] print(f"Limited to {len(all_texts)} texts") tokenizer = TechnicalTokenizer(vocab_size=vocab_size, min_freq=min_freq) tokenizer.train_bpe(all_texts) tokenizer.save(output_dir) print("\nTesting tokenization on sample texts:") test_texts = [ "Hello, how can I help you with your Python programming question?", "The neural network has 3 hidden layers with ReLU activation functions.", "```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```", "The derivative of x^2 is 2x, and the integral is (x^3)/3 + C." ] for text in test_texts: tokenizer.analyze_tokenization(text) print() return tokenizer def main(): parser = argparse.ArgumentParser(description="Train custom tokenizer for technical content") parser.add_argument("--input_files", nargs='+', help="Input text/jsonl files") parser.add_argument("--output_dir", default="tokenizer", help="Output directory for tokenizer") parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size") parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency") parser.add_argument("--max_texts", type=int, help="Maximum number of texts to use for training") parser.add_argument("--test_file", help="Test file for analyzing tokenization") parser.add_argument("--load_tokenizer", help="Load existing tokenizer from directory") args = parser.parse_args() default_input_file = "/kaggle/input/gpt-based-slm-dataset/slm_training_complete.jsonl" default_text_file = "/kaggle/working/text_data/training_data_chat.txt" if not args.input_files and not args.load_tokenizer: if os.path.exists(default_input_file): args.input_files = [default_input_file] print(f"No arguments provided, using default input file: {default_input_file}") elif os.path.exists(default_text_file): args.input_files = [default_text_file] print(f"No arguments provided, using default text file: {default_text_file}") else: parser.error("No input files or tokenizer directory provided, and default files not found. " "Please specify --input_files or --load_tokenizer.") if args.load_tokenizer: tokenizer = TechnicalTokenizer() tokenizer.load(args.load_tokenizer) if args.test_file: print(f"\nTesting on {args.test_file}") dataset = ConversationDataset(args.test_file, tokenizer) tokenized, stats = dataset.get_tokenized_conversations(include_stats=True) print(f"Dataset statistics:") print(f" Total conversations: {len(tokenized)}") print(f" Total tokens: {stats['total_tokens']:,}") print(f" Average tokens per conversation: {stats['avg_length']:.1f}") print(f" Conversations truncated: {stats['truncated']}") else: tokenizer = train_tokenizer_from_files( file_paths=args.input_files, vocab_size=args.vocab_size, min_freq=args.min_freq, output_dir=args.output_dir, max_texts=args.max_texts ) if args.test_file: print(f"\nTesting on {args.test_file}") dataset = ConversationDataset(args.test_file, tokenizer) tokenized, stats = dataset.get_tokenized_conversations(include_stats=True) print(f"Dataset statistics:") print(f" Total conversations: {len(tokenized)}") print(f" Total tokens: {stats['total_tokens']:,}") print(f" Average tokens per conversation: {stats['avg_length']:.1f}") print(f" Conversations truncated: {stats['truncated']}") if __name__ == "__main__": main()