|
import argparse |
|
import json |
|
import os |
|
import time |
|
import glob |
|
import logging |
|
import sys |
|
import traceback |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
def get_project_root() -> Path: |
|
"""Get the project root directory.""" |
|
|
|
return Path.cwd() |
|
|
|
def ensure_directory(path: Path) -> None: |
|
"""Ensure directory exists, create if it doesn't.""" |
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
log_dir = Path('test_result') |
|
ensure_directory(log_dir) |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout), |
|
logging.FileHandler(log_dir / 'tokenizer_test.log') |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class Tokenizer: |
|
def __init__(self, tokenizer_path: str): |
|
"""Initialize the EZ-Tokenizer with enhanced error handling and validation.""" |
|
try: |
|
from tokenizers import Tokenizer as HFTokenizer |
|
|
|
logger.info(f"Loading EZ-Tokenizer from {tokenizer_path}") |
|
if not os.path.exists(tokenizer_path): |
|
raise FileNotFoundError(f"EZ-Tokenizer file not found: {tokenizer_path}") |
|
|
|
start_time = time.time() |
|
self.tokenizer = HFTokenizer.from_file(tokenizer_path) |
|
load_time = time.time() - start_time |
|
|
|
self.vocab_size = self.tokenizer.get_vocab_size() |
|
logger.info(f"EZ-Tokenizer loaded in {load_time:.2f} seconds. Vocabulary size: {self.vocab_size:,}") |
|
|
|
|
|
self._run_smoke_tests() |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to initialize EZ-Tokenizer: {e}", exc_info=True) |
|
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True) |
|
raise |
|
|
|
def _run_smoke_tests(self): |
|
"""Run basic smoke tests to verify tokenizer functionality.""" |
|
test_cases = [ |
|
"Hello, world!", |
|
"こんにちは世界", |
|
"안녕하세요", |
|
"Привет, мир!", |
|
"12345 !@#$%^&*()_+{}|:<>?", |
|
"" |
|
] |
|
|
|
logger.info("Running smoke tests...") |
|
for text in test_cases: |
|
try: |
|
tokens = self.encode(text) |
|
decoded = self.decode(tokens) |
|
if text != decoded: |
|
logger.warning(f"Roundtrip mismatch for {text!r} -> {decoded!r}") |
|
except Exception as e: |
|
logger.error(f"Smoke test failed for {text!r}: {e}") |
|
raise |
|
logger.info("Smoke tests completed successfully") |
|
|
|
def encode(self, text: str, chunk_size: int = 10000) -> List[int]: |
|
"""Encode text to token IDs with chunking for large inputs.""" |
|
try: |
|
if not isinstance(text, str): |
|
raise ValueError(f"Expected string, got {type(text).__name__}") |
|
|
|
|
|
if len(text) <= chunk_size: |
|
return self.tokenizer.encode(text).ids |
|
|
|
|
|
tokens = [] |
|
for i in range(0, len(text), chunk_size): |
|
chunk = text[i:i + chunk_size] |
|
tokens.extend(self.tokenizer.encode(chunk).ids) |
|
return tokens |
|
|
|
except Exception as e: |
|
logger.error(f"Encoding failed: {e}") |
|
raise RuntimeError(f"Failed to encode text (length: {len(text)}): {e}") |
|
|
|
def decode(self, token_ids: List[int], chunk_size: int = 10000) -> str: |
|
"""Decode token IDs back to text with memory-efficient chunking.""" |
|
try: |
|
if not token_ids: |
|
return "" |
|
|
|
if not all(isinstance(t, int) for t in token_ids): |
|
raise ValueError("All token IDs must be integers") |
|
|
|
|
|
if len(token_ids) <= chunk_size: |
|
return self.tokenizer.decode(token_ids) |
|
|
|
|
|
chunks = [] |
|
for i in range(0, len(token_ids), chunk_size): |
|
chunk = token_ids[i:i + chunk_size] |
|
chunks.append(self.tokenizer.decode(chunk)) |
|
|
|
|
|
if (i // chunk_size) % 10 == 0: |
|
logger.info(f"Decoded {min(i + chunk_size, len(token_ids)):,}/{len(token_ids):,} tokens") |
|
|
|
return "".join(chunks) |
|
|
|
except Exception as e: |
|
logger.error(f"Decoding failed: {e}") |
|
raise RuntimeError(f"Failed to decode {len(token_ids)} tokens: {e}") |
|
|
|
def get_vocab_size(self) -> int: |
|
"""Return the size of the tokenizer's vocabulary.""" |
|
return self.vocab_size |
|
|
|
def process_file_in_chunks(file_path: str, chunk_size: int = 1024 * 1024) -> str: |
|
"""Read a file in chunks to avoid memory issues.""" |
|
chunks = [] |
|
try: |
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as f: |
|
while True: |
|
chunk = f.read(chunk_size) |
|
if not chunk: |
|
break |
|
chunks.append(chunk) |
|
return "".join(chunks) |
|
except Exception as e: |
|
logger.error(f"Error reading file {file_path}: {e}") |
|
raise |
|
|
|
def normalize_whitespace(text: str) -> str: |
|
"""Normalize whitespace in code for more meaningful comparison.""" |
|
import re |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
return text.strip() |
|
|
|
def calculate_token_metrics(original_tokens, decoded_tokens): |
|
"""Calculate token-level accuracy metrics.""" |
|
min_len = min(len(original_tokens), len(decoded_tokens)) |
|
exact_matches = sum(1 for a, b in zip(original_tokens, decoded_tokens) if a == b) |
|
|
|
return { |
|
'token_accuracy': exact_matches / max(len(original_tokens), 1), |
|
'token_precision': exact_matches / max(len(decoded_tokens), 1), |
|
'token_recall': exact_matches / max(len(original_tokens), 1), |
|
'token_f1': 2 * exact_matches / (len(original_tokens) + len(decoded_tokens)) |
|
if (len(original_tokens) + len(decoded_tokens)) > 0 else 0 |
|
} |
|
|
|
def enhanced_char_metrics(original: str, decoded: str) -> dict: |
|
"""Calculate enhanced character-level metrics.""" |
|
|
|
norm_original = normalize_whitespace(original) |
|
norm_decoded = normalize_whitespace(decoded) |
|
|
|
|
|
min_len = min(len(norm_original), len(norm_decoded)) |
|
max_len = max(len(norm_original), len(norm_decoded)) |
|
|
|
if max_len == 0: |
|
return { |
|
'char_accuracy': 1.0, |
|
'char_similarity': 1.0, |
|
'length_diff_ratio': 0.0 |
|
} |
|
|
|
|
|
matches = sum(1 for a, b in zip(norm_original, norm_decoded) if a == b) |
|
|
|
|
|
try: |
|
from Levenshtein import ratio |
|
similarity = ratio(norm_original, norm_decoded) |
|
except ImportError: |
|
similarity = matches / max_len if max_len > 0 else 1.0 |
|
|
|
return { |
|
'char_accuracy': matches / max_len if max_len > 0 else 1.0, |
|
'char_similarity': similarity, |
|
'length_diff_ratio': abs(len(norm_original) - len(norm_decoded)) / max_len if max_len > 0 else 0.0 |
|
} |
|
|
|
def validate_code_integrity(original: str, decoded: str) -> dict: |
|
"""Validate code-specific integrity metrics.""" |
|
import ast |
|
|
|
def can_parse(code: str) -> bool: |
|
try: |
|
ast.parse(code) |
|
return True |
|
except: |
|
return False |
|
|
|
original_parses = can_parse(original) |
|
decoded_parses = can_parse(decoded) |
|
|
|
return { |
|
'original_parses': original_parses, |
|
'decoded_parses': decoded_parses, |
|
'both_parse': original_parses and decoded_parses |
|
} |
|
|
|
def calculate_metrics(original_text: str, decoded_text: str, tokens, |
|
start_time: float, end_time: float) -> Dict[str, Any]: |
|
"""Enhanced metrics calculation for tokenizer evaluation.""" |
|
|
|
token_count = len(tokens) if tokens else 0 |
|
char_count = len(original_text) if original_text else 0 |
|
process_time = max(end_time - start_time, 0.001) |
|
|
|
metrics = { |
|
'tokens': token_count, |
|
'chars': char_count, |
|
'processing_time': process_time, |
|
'tokens_per_second': token_count / process_time, |
|
'chars_per_token': char_count / (token_count or 1) |
|
} |
|
|
|
|
|
metrics.update({ |
|
'tokens_per_sec': len(tokens) / metrics['processing_time'], |
|
'chars_per_sec': len(original_text) / metrics['processing_time'] |
|
}) |
|
|
|
|
|
metrics.update(enhanced_char_metrics(original_text, decoded_text)) |
|
|
|
|
|
if hasattr(tokens, 'tokens'): |
|
original_tokens = tokens.tokens |
|
decoded_tokens = tokenizer.encode(decoded_text).tokens |
|
metrics.update(calculate_token_metrics(original_tokens, decoded_tokens)) |
|
|
|
|
|
if original_text.strip().endswith('.py') or 'def ' in original_text or 'import ' in original_text: |
|
metrics.update(validate_code_integrity(original_text, decoded_text)) |
|
|
|
return metrics |
|
|
|
def print_metrics_summary(metrics: Dict[str, Any]): |
|
"""Print a clean summary of the metrics.""" |
|
print("\n=== Tokenizer Test Results ===") |
|
print(f"Processing Speed: {metrics.get('tokens_per_second', metrics.get('tokens_per_sec', 0)):,.0f} tokens/sec") |
|
print(f"Characters per Token: {metrics.get('chars_per_token', 0):.2f}") |
|
print(f"\nCharacter-Level Metrics:") |
|
print(f" • Accuracy: {metrics.get('char_accuracy', 0)*100:.2f}%") |
|
print(f" • Similarity: {metrics.get('char_similarity', 0)*100:.2f}%") |
|
print(f" • Levenshtein Ratio: {metrics.get('levenshtein_ratio', 0)*100:.2f}%") |
|
|
|
print(f"\nCode Integrity:") |
|
print(f" • Original parses: {'✓' if metrics.get('original_parses', False) else '✗'}") |
|
print(f" • Decoded parses: {'✓' if metrics.get('decoded_parses', False) else '✗'}") |
|
print(f" • Both parse: {'✓' if metrics.get('both_parse', False) else '✗'}") |
|
|
|
def process_file(file_path: Path, tokenizer: Tokenizer, max_chunk_size: int = 100_000, sample_size: int = 100_000) -> Dict[str, Any]: |
|
"""Process a single file in chunks and return metrics.""" |
|
try: |
|
logger.info(f"\nProcessing file: {file_path}") |
|
file_size = file_path.stat().st_size |
|
logger.info(f"File size: {file_size / (1024*1024):.2f} MB") |
|
|
|
|
|
total_tokens = 0 |
|
total_chars = 0 |
|
total_time = 0 |
|
chunk_metrics = [] |
|
|
|
|
|
total_read = 0 |
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as f: |
|
|
|
max_to_read = sample_size if sample_size > 0 else float('inf') |
|
logger.info(f"Processing up to {max_to_read if max_to_read != float('inf') else 'all'} characters") |
|
|
|
chunk = f.read(min(max_chunk_size, max_to_read - total_read)) |
|
total_read += len(chunk) |
|
|
|
while chunk and total_read <= max_to_read: |
|
if not chunk.strip(): |
|
chunk = f.read(max_chunk_size) |
|
continue |
|
|
|
|
|
start_time = time.time() |
|
try: |
|
|
|
tokens = tokenizer.encode(chunk) |
|
token_ids = tokens.ids if hasattr(tokens, 'ids') else tokens |
|
decoded_text = tokenizer.decode(token_ids) |
|
except Exception as e: |
|
logger.error(f"Error in tokenization: {e}") |
|
|
|
chunk = f.read(max_chunk_size) |
|
continue |
|
|
|
end_time = time.time() |
|
|
|
|
|
if not token_ids: |
|
chunk = f.read(max_chunk_size) |
|
continue |
|
|
|
|
|
metrics = calculate_metrics(chunk, decoded_text, token_ids, start_time, end_time) |
|
chunk_metrics.append(metrics) |
|
|
|
|
|
total_tokens += len(token_ids) |
|
total_chars += len(chunk) |
|
total_time += (end_time - start_time) |
|
|
|
|
|
if total_tokens % 1_000_000 == 0: |
|
logger.info(f" Processed {total_tokens:,} tokens ({total_chars/1024/1024:.2f} MB)") |
|
|
|
|
|
to_read = min(max_chunk_size, max_to_read - total_read) |
|
if to_read <= 0: |
|
|
|
break |
|
|
|
chunk = f.read(to_read) |
|
total_read += len(chunk) |
|
|
|
|
|
if not chunk_metrics: |
|
logger.warning(f"No valid content found in file: {file_path}") |
|
return None |
|
|
|
|
|
total_weight = sum(m.get('tokens', 0) for m in chunk_metrics) or 1 |
|
|
|
avg_metrics = { |
|
'chars_per_token': sum(m.get('chars_per_token', 0) * m.get('tokens', 0) for m in chunk_metrics) / total_weight, |
|
'tokens_per_second': sum(m.get('tokens', 0) for m in chunk_metrics) / (total_time or 1), |
|
'char_accuracy': sum(m.get('char_accuracy', 0) * m.get('tokens', 0) for m in chunk_metrics) / total_weight, |
|
'tokens': total_tokens, |
|
'chars': total_chars, |
|
'processing_time': total_time, |
|
'file_path': str(file_path) |
|
} |
|
|
|
|
|
logger.info(f" Total tokens: {total_tokens:,}") |
|
logger.info(f" Total chars: {total_chars:,}") |
|
logger.info(f" Avg chars/token: {avg_metrics['chars_per_token']:.2f}") |
|
logger.info(f" Avg tokens/sec: {avg_metrics['tokens_per_second']:,.2f}") |
|
|
|
return avg_metrics |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing {file_path}: {e}") |
|
logger.error(traceback.format_exc()) |
|
return None |
|
|
|
def process_single_file(tokenizer: Tokenizer, file_path: str, sample_size: int = 0) -> Dict[str, Any]: |
|
"""Process a single file and return metrics.""" |
|
logger.info(f"\nProcessing file: {file_path}") |
|
|
|
try: |
|
|
|
metrics = process_file(file_path, tokenizer, sample_size=sample_size) |
|
|
|
if not metrics: |
|
logger.warning(f"Empty file or no valid content found: {file_path}") |
|
return {} |
|
|
|
|
|
metrics['file'] = os.path.basename(file_path) |
|
metrics['file_size_mb'] = os.path.getsize(file_path) / (1024 * 1024) |
|
|
|
|
|
logger.info( |
|
f"Processed {metrics['file_size_mb']:.2f}MB: " |
|
f"{metrics['tokens']:,} tokens, " |
|
f"{metrics['chars_per_token']:.2f} chars/token, " |
|
f"{metrics['tokens_per_second']:,.2f} tokens/sec" |
|
) |
|
|
|
|
|
print_metrics_summary(metrics) |
|
|
|
return metrics |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing {file_path}: {e}", exc_info=True) |
|
return {'file': os.path.basename(file_path), 'error': str(e)} |
|
|
|
def main(): |
|
|
|
project_root = get_project_root() |
|
|
|
root_dir = project_root.parent |
|
default_tokenizer = root_dir / 'output' / 'tokenizer.json' |
|
default_input = root_dir / 'Dataset' |
|
default_output = root_dir / 'test_result' |
|
|
|
|
|
ensure_directory(default_output) |
|
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
default_output_file = default_output / f'test_results_{timestamp}.txt' |
|
|
|
parser = argparse.ArgumentParser(description='Test tokenizer on code files') |
|
parser.add_argument('--tokenizer', type=str, default=str(default_tokenizer), |
|
help=f'Path to tokenizer.json file (default: {default_tokenizer})') |
|
parser.add_argument('--input', type=str, default=str(default_input), |
|
help=f'Input directory or file (default: {default_input})') |
|
parser.add_argument('--output', type=str, default=str(default_output_file), |
|
help=f'Output text file for results (default: {default_output_file})') |
|
parser.add_argument('--sample', type=int, default=100000, help='Only process this many characters from each file (0 for full file)') |
|
parser.add_argument('--max-files', type=int, default=10, |
|
help='Maximum number of files to process (default: 10)') |
|
parser.add_argument('--file-types', type=str, default='*', |
|
help='Comma-separated list of file extensions to process (e.g., "py,js,json"). Default: all files') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
output_dir = Path(args.output).parent |
|
ensure_directory(output_dir) |
|
|
|
|
|
logger.info(f"Initializing tokenizer from {args.tokenizer}") |
|
tokenizer = Tokenizer(args.tokenizer) |
|
|
|
|
|
file_extensions = [] |
|
if args.file_types != '*': |
|
file_extensions = [ext.strip().lower() for ext in args.file_types.split(',')] |
|
logger.info(f"Filtering by file extensions: {', '.join(file_extensions)}") |
|
|
|
|
|
input_path = Path(args.input) |
|
file_paths = [] |
|
|
|
if input_path.is_dir(): |
|
|
|
if file_extensions: |
|
|
|
for ext in file_extensions: |
|
pattern = f'*.{ext.lstrip(".")}' |
|
file_paths.extend(input_path.rglob(pattern)) |
|
else: |
|
|
|
file_paths = list(input_path.rglob('*')) |
|
|
|
|
|
file_paths = [ |
|
f for f in file_paths |
|
if f.is_file() and not f.name.startswith(('.', '_')) |
|
] |
|
|
|
|
|
file_paths.sort(key=lambda x: x.stat().st_size) |
|
|
|
logger.info(f"Found {len(file_paths)} files in {input_path}") |
|
if file_paths: |
|
logger.info(f"Sample files: {', '.join(f.name for f in file_paths[:min(5, len(file_paths))])}" + |
|
('...' if len(file_paths) > 5 else '')) |
|
else: |
|
|
|
file_paths = [input_path] if input_path.exists() else [] |
|
logger.info(f"Processing single file: {input_path}") |
|
|
|
if not file_paths: |
|
logger.warning(f"No files found in {input_path}") |
|
return |
|
|
|
|
|
all_metrics = [] |
|
processed_count = 0 |
|
skipped_files = 0 |
|
|
|
|
|
unique_file_paths = [] |
|
seen_paths = set() |
|
|
|
for file_path in file_paths: |
|
abs_path = str(file_path.absolute()) |
|
if abs_path not in seen_paths: |
|
seen_paths.add(abs_path) |
|
unique_file_paths.append(file_path) |
|
|
|
if len(unique_file_paths) < len(file_paths): |
|
logger.info(f"Removed {len(file_paths) - len(unique_file_paths)} duplicate file paths") |
|
|
|
|
|
if args.max_files > 0: |
|
unique_file_paths = unique_file_paths[:args.max_files] |
|
|
|
|
|
for file_path in unique_file_paths: |
|
try: |
|
if not file_path.exists(): |
|
logger.warning(f"File not found: {file_path}") |
|
skipped_files += 1 |
|
continue |
|
|
|
file_size_mb = os.path.getsize(file_path) / (1024 * 1024) |
|
logger.info(f"\nProcessing: {file_path.name} ({file_size_mb:.2f} MB)") |
|
|
|
|
|
metrics = process_single_file(tokenizer, file_path, args.sample) |
|
if metrics: |
|
all_metrics.append(metrics) |
|
processed_count += 1 |
|
logger.info(f"Processed {processed_count}/{len(unique_file_paths)} files") |
|
except Exception as e: |
|
logger.error(f"Error processing {file_path}: {str(e)}") |
|
skipped_files += 1 |
|
|
|
if skipped_files > 0: |
|
logger.warning(f"Skipped {skipped_files} files due to errors") |
|
|
|
|
|
if all_metrics: |
|
avg_metrics = {} |
|
for key in all_metrics[0].keys(): |
|
if isinstance(all_metrics[0][key], (int, float)): |
|
values = [r[key] for r in all_metrics if key in r] |
|
if values: |
|
avg_metrics[f'avg_{key}'] = sum(values) / len(values) |
|
|
|
|
|
with open(args.output, 'w', encoding='utf-8') as f: |
|
f.write("=== Tokenizer Test Results ===\n") |
|
f.write(f"Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
|
f.write(f"Tokenizer: {args.tokenizer}\n") |
|
f.write(f"Input: {args.input}\n") |
|
f.write(f"Sample size: {args.sample if args.sample > 0 else 'Full file'}\n\n") |
|
|
|
f.write("=== Summary ===\n") |
|
if all_metrics: |
|
|
|
for key, value in avg_metrics.items(): |
|
if isinstance(value, float): |
|
f.write(f"{key}: {value:.4f}\n") |
|
else: |
|
f.write(f"{key}: {value}\n") |
|
else: |
|
f.write("No files were successfully processed\n") |
|
|
|
|
|
f.write("\n=== File Details ===\n") |
|
for result in all_metrics: |
|
f.write(f"\nFile: {result.get('file', 'unknown')}\n") |
|
for key, value in result.items(): |
|
if key != 'file': |
|
if isinstance(value, float): |
|
f.write(f" {key}: {value:.4f}\n") |
|
else: |
|
f.write(f" {key}: {value}\n") |
|
|
|
logger.info(f"Results saved to {args.output}") |
|
print(f"\nTest results saved to: {args.output}") |
|
|
|
if all_metrics: |
|
logger.info(f"\n=== Test Complete ===") |
|
logger.info(f"Processed {processed_count} files") |
|
logger.info(f"Average chars/token: {avg_metrics.get('avg_chars_per_token', 0):.2f}") |
|
logger.info(f"Average tokens/sec: {avg_metrics.get('avg_tokens_per_sec', 0):,.0f}") |
|
else: |
|
logger.warning("No files were successfully processed") |
|
|
|
if __name__ == "__main__": |
|
try: |
|
|
|
try: |
|
import Levenshtein |
|
except ImportError: |
|
logger.warning("python-Levenshtein not found. Install with: pip install python-Levenshtein") |
|
logger.warning("Falling back to basic similarity metrics") |
|
|
|
main() |
|
except KeyboardInterrupt: |
|
logger.info("\nProcess interrupted by user") |
|
sys.exit(1) |
|
except Exception as e: |
|
logger.error(f"An error occurred: {e}", exc_info=True) |
|
sys.exit(1) |