|
import os |
|
import json |
|
import pickle |
|
import argparse |
|
from collections import Counter, defaultdict |
|
from typing import List, Dict, Set, Optional, Tuple |
|
import re |
|
import unicodedata |
|
class TechnicalTokenizer: |
|
""" |
|
Custom tokenizer optimized for technical content and conversations |
|
""" |
|
def __init__(self, vocab_size: int = 32000, min_freq: int = 2): |
|
self.vocab_size = vocab_size |
|
self.min_freq = min_freq |
|
self.special_tokens = { |
|
'<pad>': 0, |
|
'<unk>': 1, |
|
'<bos>': 2, |
|
'<eos>': 3, |
|
'<system>': 4, |
|
'<user>': 5, |
|
'<assistant>': 6, |
|
'<|endoftext|>': 7, |
|
'<|newline|>': 8, |
|
'<|tab|>': 9, |
|
'<|code|>': 10, |
|
'<|/code|>': 11, |
|
'<|math|>': 12, |
|
'<|/math|>': 13 |
|
} |
|
self.vocab = {} |
|
self.id_to_token = {} |
|
self.token_frequencies = Counter() |
|
self.bpe_merges = [] |
|
self.bpe_cache = {} |
|
self.code_pattern = re.compile(r'```[\s\S]*?```|`[^`]+`') |
|
self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+') |
|
self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b') |
|
self.number_pattern = re.compile(r'\b\d+\.?\d*\b') |
|
self.technical_terms = { |
|
'function', 'variable', 'array', 'object', 'class', 'method', 'parameter', |
|
'return', 'import', 'export', 'async', 'await', 'promise', 'callback', |
|
'algorithm', 'datatype', 'boolean', 'integer', 'string', 'float', |
|
'javascript', 'python', 'java', 'cpp', 'html', 'css', 'sql', |
|
'api', 'json', 'xml', 'http', 'https', 'rest', 'graphql', |
|
'equation', 'formula', 'theorem', 'proof', 'hypothesis', |
|
'derivative', 'integral', 'matrix', 'vector', 'polynomial', |
|
'probability', 'statistics', 'correlation', 'regression', |
|
'neural', 'network', 'model', 'training', 'validation', 'test', |
|
'accuracy', 'precision', 'recall', 'f1score', 'loss', 'gradient', |
|
'backpropagation', 'forward', 'layer', 'neuron', 'weight', 'bias', |
|
'transformer', 'attention', 'embedding', 'tokenization', |
|
'database', 'server', 'client', 'protocol', 'encryption', 'security', |
|
'authentication', 'authorization', 'deployment', 'docker', 'kubernetes', |
|
'microservice', 'architecture', 'scalability', 'performance' |
|
} |
|
self._init_vocab() |
|
def _init_vocab(self): |
|
self.vocab = self.special_tokens.copy() |
|
self.id_to_token = {v: k for k, v in self.special_tokens.items()} |
|
def normalize_text(self, text: str) -> str: |
|
text = re.sub(r'\r\n|\r', '\n', text) |
|
text = re.sub(r'\t', '<|tab|>', text) |
|
text = unicodedata.normalize('NFKC', text) |
|
code_blocks = [] |
|
def replace_code(match): |
|
code_blocks.append(match.group()) |
|
return f'<|code|>CODE_BLOCK_{len(code_blocks)-1}<|/code|>' |
|
text = self.code_pattern.sub(replace_code, text) |
|
text = self.url_pattern.sub('<URL>', text) |
|
text = self.email_pattern.sub('<EMAIL>', text) |
|
for i, code_block in enumerate(code_blocks): |
|
text = text.replace(f'<|code|>CODE_BLOCK_{i}<|/code|>', code_block) |
|
return text |
|
def pre_tokenize(self, text: str) -> List[str]: |
|
text = self.normalize_text(text) |
|
text = re.sub(r'<\|system\|>', ' <system> ', text) |
|
text = re.sub(r'<\|user\|>', ' <user> ', text) |
|
text = re.sub(r'<\|assistant\|>', ' <assistant> ', text) |
|
text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text) |
|
tokens = re.findall(r''' |
|
<[^>]+>| # Special tokens |
|
\b\w+@\w+\.\w+\b| # Email-like patterns |
|
https?://\S+| # URLs |
|
```[\s\S]*?```| # Code blocks |
|
`[^`]+`| # Inline code |
|
\b\d+\.?\d*\b| # Numbers |
|
\b[a-zA-Z]+(?:'[a-z]*)?| # Words with optional apostrophes |
|
[^\w\s] # Punctuation |
|
''', text, re.VERBOSE) |
|
return [token.strip() for token in tokens if token.strip()] |
|
def get_pairs(self, word_freqs: Dict[Tuple[str, ...], int]) -> Counter: |
|
pairs = Counter() |
|
for word, freq in word_freqs.items(): |
|
if len(word) < 2: |
|
continue |
|
for i in range(len(word) - 1): |
|
pair = (word[i], word[i + 1]) |
|
pairs[pair] += freq |
|
return pairs |
|
def merge_symbols(self, pair: Tuple[str, str], word_freqs: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]: |
|
new_word_freqs = {} |
|
bigram = pair |
|
for word, freq in word_freqs.items(): |
|
new_word = [] |
|
i = 0 |
|
while i < len(word): |
|
if i < len(word) - 1 and (word[i], word[i + 1]) == bigram: |
|
new_word.append(word[i] + word[i + 1]) |
|
i += 2 |
|
else: |
|
new_word.append(word[i]) |
|
i += 1 |
|
new_word_freqs[tuple(new_word)] = freq |
|
return new_word_freqs |
|
def train_bpe(self, texts: List[str]) -> None: |
|
print("Training BPE tokenizer...") |
|
word_freqs = Counter() |
|
for i, text in enumerate(texts): |
|
if i % 10000 == 0: |
|
print(f"Processing text {i}/{len(texts)}") |
|
tokens = self.pre_tokenize(text) |
|
for token in tokens: |
|
char_seq = tuple(token) |
|
if len(char_seq) > 0: |
|
word_freqs[char_seq] += 1 |
|
print(f"Found {len(word_freqs)} unique word patterns") |
|
word_freqs = {word: freq for word, freq in word_freqs.items() if freq >= self.min_freq} |
|
for term in self.technical_terms: |
|
if (term,) in word_freqs: |
|
word_freqs[(term,)] *= 10 |
|
all_chars = set() |
|
for word in word_freqs: |
|
all_chars.update(word) |
|
for char in sorted(all_chars): |
|
if char not in self.vocab: |
|
self.vocab[char] = len(self.vocab) |
|
self.id_to_token[len(self.id_to_token)] = char |
|
target_vocab_size = self.vocab_size - len(self.special_tokens) |
|
num_merges = target_vocab_size - len(self.vocab) |
|
for i in range(num_merges): |
|
if i % 1000 == 0: |
|
print(f"BPE merge {i}/{num_merges}") |
|
pairs = self.get_pairs(word_freqs) |
|
if not pairs: |
|
break |
|
best_pair = pairs.most_common(1)[0][0] |
|
word_freqs = self.merge_symbols(best_pair, word_freqs) |
|
merged_token = best_pair[0] + best_pair[1] |
|
if merged_token not in self.vocab: |
|
self.vocab[merged_token] = len(self.vocab) |
|
self.id_to_token[len(self.id_to_token)] = merged_token |
|
self.bpe_merges.append(best_pair) |
|
print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}") |
|
for word, freq in word_freqs.items(): |
|
for token in word: |
|
self.token_frequencies[token] += freq |
|
def apply_bpe(self, word: str) -> List[str]: |
|
if word in self.bpe_cache: |
|
return self.bpe_cache[word] |
|
tokens = list(word) |
|
for merge in self.bpe_merges: |
|
i = 0 |
|
while i < len(tokens) - 1: |
|
if tokens[i] == merge[0] and tokens[i + 1] == merge[1]: |
|
tokens = tokens[:i] + [merge[0] + merge[1]] + tokens[i + 2:] |
|
else: |
|
i += 1 |
|
self.bpe_cache[word] = tokens |
|
return tokens |
|
def tokenize(self, text: str) -> List[str]: |
|
pre_tokens = self.pre_tokenize(text) |
|
final_tokens = [] |
|
for token in pre_tokens: |
|
if token in self.special_tokens or token in self.vocab: |
|
final_tokens.append(token) |
|
else: |
|
bpe_tokens = self.apply_bpe(token) |
|
final_tokens.extend(bpe_tokens) |
|
return final_tokens |
|
def encode_ids(self, text: str, add_special_tokens: bool = True) -> List[int]: |
|
tokens = self.tokenize(text) |
|
if add_special_tokens: |
|
tokens = ['<bos>'] + tokens + ['<eos>'] |
|
ids = [] |
|
for token in tokens: |
|
ids.append(self.vocab.get(token, self.vocab['<unk>'])) |
|
return ids |
|
def decode_ids(self, ids: List[int], skip_special_tokens: bool = True) -> str: |
|
tokens = [] |
|
for id in ids: |
|
token = self.id_to_token.get(id, '<unk>') |
|
if skip_special_tokens and token in self.special_tokens: |
|
continue |
|
tokens.append(token) |
|
text = ''.join(tokens) |
|
text = text.replace('<|tab|>', '\t') |
|
text = text.replace('<|newline|>', '\n') |
|
return text |
|
def save(self, save_dir: str): |
|
os.makedirs(save_dir, exist_ok=True) |
|
with open(os.path.join(save_dir, 'vocab.json'), 'w', encoding='utf-8') as f: |
|
json.dump(self.vocab, f, indent=2, ensure_ascii=False) |
|
with open(os.path.join(save_dir, 'merges.txt'), 'w', encoding='utf-8') as f: |
|
for merge in self.bpe_merges: |
|
f.write(f"{merge[0]} {merge[1]}\n") |
|
config = { |
|
'vocab_size': self.vocab_size, |
|
'min_freq': self.min_freq, |
|
'special_tokens': self.special_tokens, |
|
'technical_terms': list(self.technical_terms) |
|
} |
|
with open(os.path.join(save_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as f: |
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
with open(os.path.join(save_dir, 'token_frequencies.pkl'), 'wb') as f: |
|
pickle.dump(dict(self.token_frequencies), f) |
|
print(f"Tokenizer saved to {save_dir}") |
|
def load(self, save_dir: str): |
|
with open(os.path.join(save_dir, 'vocab.json'), 'r', encoding='utf-8') as f: |
|
self.vocab = json.load(f) |
|
self.id_to_token = {v: k for k, v in self.vocab.items()} |
|
with open(os.path.join(save_dir, 'merges.txt'), 'r', encoding='utf-8') as f: |
|
self.bpe_merges = [tuple(line.strip().split()) for line in f if line.strip()] |
|
config_file = os.path.join(save_dir, 'tokenizer_config.json') |
|
if os.path.exists(config_file): |
|
with open(config_file, 'r', encoding='utf-8') as f: |
|
config = json.load(f) |
|
self.vocab_size = config.get('vocab_size', self.vocab_size) |
|
self.min_freq = config.get('min_freq', self.min_freq) |
|
if 'technical_terms' in config: |
|
self.technical_terms = set(config['technical_terms']) |
|
freq_file = os.path.join(save_dir, 'token_frequencies.pkl') |
|
if os.path.exists(freq_file): |
|
with open(freq_file, 'rb') as f: |
|
self.token_frequencies = Counter(pickle.load(f)) |
|
self.bpe_cache = {} |
|
print(f"Tokenizer loaded from {save_dir}") |
|
print(f"Vocabulary size: {len(self.vocab)}") |
|
print(f"Number of BPE merges: {len(self.bpe_merges)}") |
|
def get_vocab_size(self) -> int: |
|
return len(self.vocab) |
|
def get_token_frequency(self, token: str) -> int: |
|
return self.token_frequencies.get(token, 0) |
|
def analyze_tokenization(self, text: str): |
|
tokens = self.tokenize(text) |
|
ids = self.encode_ids(text, add_special_tokens=False) |
|
print(f"Original text: {text}") |
|
print(f"Tokens: {tokens}") |
|
print(f"Token IDs: {ids}") |
|
print(f"Number of tokens: {len(tokens)}") |
|
print(f"Compression ratio: {len(text.split())/len(tokens):.2f}") |
|
return tokens, ids |
|
class ConversationDataset: |
|
"""Dataset class for handling conversation data with the custom tokenizer""" |
|
def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 512): |
|
self.data_file = data_file |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.conversations = [] |
|
self.load_conversations() |
|
def load_conversations(self): |
|
print(f"Loading conversations from {self.data_file}") |
|
if self.data_file.endswith('.jsonl'): |
|
self.load_jsonl() |
|
else: |
|
self.load_text() |
|
print(f"Loaded {len(self.conversations)} conversations") |
|
def load_jsonl(self): |
|
with open(self.data_file, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
try: |
|
conv = json.loads(line.strip()) |
|
messages = conv.get("messages", []) |
|
if not messages: |
|
continue |
|
text_parts = [] |
|
for msg in messages: |
|
role = msg.get("role", "") |
|
content = msg.get("content", "").strip() |
|
if not content: |
|
continue |
|
if role == "system": |
|
continue |
|
elif role == "user": |
|
text_parts.append(f"<user> {content}") |
|
elif role == "assistant": |
|
text_parts.append(f"<assistant> {content}") |
|
|
|
if len(text_parts) >= 2: |
|
conversation_text = " ".join(text_parts) + " <|endoftext|>" |
|
self.conversations.append(conversation_text) |
|
except json.JSONDecodeError: |
|
continue |
|
def load_text(self): |
|
with open(self.data_file, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
conversations = content.split('<|endoftext|>\n') |
|
for conv in conversations: |
|
conv = conv.strip() |
|
if conv: |
|
self.conversations.append(conv + " <|endoftext|>") |
|
def get_tokenized_conversations(self, include_stats=False): |
|
tokenized = [] |
|
stats = {'total_tokens': 0, 'truncated': 0, 'avg_length': 0} |
|
for conv in self.conversations: |
|
tokens = self.tokenizer.encode_ids(conv) |
|
if len(tokens) > self.max_length: |
|
tokens = tokens[:self.max_length] |
|
stats['truncated'] += 1 |
|
tokenized.append(tokens) |
|
stats['total_tokens'] += len(tokens) |
|
if tokenized: |
|
stats['avg_length'] = stats['total_tokens'] / len(tokenized) |
|
if include_stats: |
|
return tokenized, stats |
|
return tokenized |
|
def create_training_examples(self, stride: int = None): |
|
if stride is None: |
|
stride = self.max_length // 2 |
|
examples = [] |
|
for conv in self.conversations: |
|
tokens = self.tokenizer.encode_ids(conv) |
|
if len(tokens) <= self.max_length: |
|
examples.append(tokens) |
|
else: |
|
for i in range(0, len(tokens), stride): |
|
window = tokens[i:i + self.max_length] |
|
if len(window) >= 32: |
|
examples.append(window) |
|
return examples |
|
def train_tokenizer_from_files(file_paths: List[str], |
|
vocab_size: int = 32000, |
|
min_freq: int = 2, |
|
output_dir: str = "tokenizer", |
|
max_texts: int = None): |
|
print(f"Training tokenizer with vocab_size={vocab_size}") |
|
print(f"Input files: {file_paths}") |
|
all_texts = [] |
|
for file_path in file_paths: |
|
print(f"Loading {file_path}...") |
|
if file_path.endswith('.jsonl'): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
try: |
|
conv = json.loads(line.strip()) |
|
messages = conv.get("messages", []) |
|
text_parts = [] |
|
for msg in messages: |
|
content = msg.get("content", "").strip() |
|
if content: |
|
text_parts.append(content) |
|
if text_parts: |
|
all_texts.append(" ".join(text_parts)) |
|
except json.JSONDecodeError: |
|
continue |
|
else: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
chunks = content.split('\n\n') |
|
for chunk in chunks: |
|
if chunk.strip(): |
|
all_texts.append(chunk.strip()) |
|
print(f"Loaded {len(all_texts)} texts") |
|
if max_texts and len(all_texts) > max_texts: |
|
import random |
|
random.shuffle(all_texts) |
|
all_texts = all_texts[:max_texts] |
|
print(f"Limited to {len(all_texts)} texts") |
|
tokenizer = TechnicalTokenizer(vocab_size=vocab_size, min_freq=min_freq) |
|
tokenizer.train_bpe(all_texts) |
|
tokenizer.save(output_dir) |
|
print("\nTesting tokenization on sample texts:") |
|
test_texts = [ |
|
"Hello, how can I help you with your Python programming question?", |
|
"The neural network has 3 hidden layers with ReLU activation functions.", |
|
"```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```", |
|
"The derivative of x^2 is 2x, and the integral is (x^3)/3 + C." |
|
] |
|
for text in test_texts: |
|
tokenizer.analyze_tokenization(text) |
|
print() |
|
return tokenizer |
|
def main(): |
|
parser = argparse.ArgumentParser(description="Train custom tokenizer for technical content") |
|
parser.add_argument("--input_files", nargs='+', help="Input text/jsonl files") |
|
parser.add_argument("--output_dir", default="tokenizer", help="Output directory for tokenizer") |
|
parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size") |
|
parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency") |
|
parser.add_argument("--max_texts", type=int, help="Maximum number of texts to use for training") |
|
parser.add_argument("--test_file", help="Test file for analyzing tokenization") |
|
parser.add_argument("--load_tokenizer", help="Load existing tokenizer from directory") |
|
args = parser.parse_args() |
|
default_input_file = "/kaggle/input/gpt-based-slm-dataset/slm_training_complete.jsonl" |
|
default_text_file = "/kaggle/working/text_data/training_data_chat.txt" |
|
if not args.input_files and not args.load_tokenizer: |
|
if os.path.exists(default_input_file): |
|
args.input_files = [default_input_file] |
|
print(f"No arguments provided, using default input file: {default_input_file}") |
|
elif os.path.exists(default_text_file): |
|
args.input_files = [default_text_file] |
|
print(f"No arguments provided, using default text file: {default_text_file}") |
|
else: |
|
parser.error("No input files or tokenizer directory provided, and default files not found. " |
|
"Please specify --input_files or --load_tokenizer.") |
|
if args.load_tokenizer: |
|
tokenizer = TechnicalTokenizer() |
|
tokenizer.load(args.load_tokenizer) |
|
if args.test_file: |
|
print(f"\nTesting on {args.test_file}") |
|
dataset = ConversationDataset(args.test_file, tokenizer) |
|
tokenized, stats = dataset.get_tokenized_conversations(include_stats=True) |
|
print(f"Dataset statistics:") |
|
print(f" Total conversations: {len(tokenized)}") |
|
print(f" Total tokens: {stats['total_tokens']:,}") |
|
print(f" Average tokens per conversation: {stats['avg_length']:.1f}") |
|
print(f" Conversations truncated: {stats['truncated']}") |
|
else: |
|
tokenizer = train_tokenizer_from_files( |
|
file_paths=args.input_files, |
|
vocab_size=args.vocab_size, |
|
min_freq=args.min_freq, |
|
output_dir=args.output_dir, |
|
max_texts=args.max_texts |
|
) |
|
if args.test_file: |
|
print(f"\nTesting on {args.test_file}") |
|
dataset = ConversationDataset(args.test_file, tokenizer) |
|
tokenized, stats = dataset.get_tokenized_conversations(include_stats=True) |
|
print(f"Dataset statistics:") |
|
print(f" Total conversations: {len(tokenized)}") |
|
print(f" Total tokens: {stats['total_tokens']:,}") |
|
print(f" Average tokens per conversation: {stats['avg_length']:.1f}") |
|
print(f" Conversations truncated: {stats['truncated']}") |
|
|
|
if __name__ == "__main__": |
|
main() |