File size: 21,629 Bytes
7837959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
"""
Custom Dataset Generation for Code-Specialized Model Training.

This module creates optimized training datasets from CodeSearchNet that are specifically
designed to improve performance on code search evaluation tasks.

Features:
- High-quality doc-code pairs optimized for retrieval
- Balanced sampling across programming languages
- Multiple training formats (doc-only, code-only, combined)
- Quality filtering and data cleaning
- Train/test/eval splits with proper stratification
- Efficient parquet format output
"""

import json
import logging
import time
from pathlib import Path
from typing import Annotated, Any

import pandas as pd
import typer
from datasets import load_dataset
from tqdm import tqdm

from .config import languages_config

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Dataset configuration
DATASET_OUTPUT_DIR = Path("code_model2vec/dataset")
DEFAULT_MAX_SAMPLES_PER_LANG = 50000
DEFAULT_MIN_DOC_WORDS = 3
DEFAULT_MAX_DOC_WORDS = 100
DEFAULT_MIN_CODE_CHARS = 50
DEFAULT_MAX_CODE_CHARS = 2000


def create_optimized_dataset(
	max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG,
	min_doc_words: int = DEFAULT_MIN_DOC_WORDS,
	max_doc_words: int = DEFAULT_MAX_DOC_WORDS,
	min_code_chars: int = DEFAULT_MIN_CODE_CHARS,
	max_code_chars: int = DEFAULT_MAX_CODE_CHARS,
	output_dir: Path | None = None,
	create_multiple_formats: bool = True,
) -> dict[str, Any]:
	"""
	Create optimized training dataset from CodeSearchNet for code search tasks.

	Args:
	    max_samples_per_lang: Maximum samples per programming language
	    min_doc_words: Minimum words in documentation
	    max_doc_words: Maximum words in documentation
	    min_code_chars: Minimum characters in code
	    max_code_chars: Maximum characters in code
	    output_dir: Output directory for dataset
	    create_multiple_formats: Create multiple training formats

	Returns:
	    Dictionary with dataset statistics and file paths
	"""
	output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir)

	output_dir.mkdir(parents=True, exist_ok=True)

	logger.info("πŸš€ Starting optimized CodeSearchNet dataset creation...")
	logger.info(f"πŸ“ Output directory: {output_dir}")
	logger.info(f"πŸ“Š Target: {max_samples_per_lang} samples per language")
	logger.info(f"πŸ” Languages: {', '.join(languages_config.all)}")

	start_time = time.time()
	all_samples = []
	language_stats = {}

	# Process each programming language
	for language in languages_config.all:
		logger.info(f"\nπŸ”„ Processing {language}...")

		try:
			# Load CodeSearchNet dataset for this language
			dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True)

			language_samples = []
			processed_count = 0
			quality_filtered = 0

			# Process examples with quality filtering
			for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"):
				processed_count += 1

				# Extract documentation and code
				doc_string = example.get("func_documentation_string", "").strip()
				code_string = example.get("func_code_string", "").strip()
				func_name = example.get("func_name", "").strip()

				# Quality filters
				if not _passes_quality_filters(
					doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars
				):
					continue

				quality_filtered += 1

				# Create optimized training samples
				samples = _create_training_samples(
					doc_string, code_string, func_name, language, create_multiple_formats
				)
				language_samples.extend(samples)

				# Stop if we have enough samples
				if len(language_samples) >= max_samples_per_lang:
					break

			# Truncate to exact target size
			language_samples = language_samples[:max_samples_per_lang]
			all_samples.extend(language_samples)

			# Track statistics
			language_stats[language] = {
				"processed": processed_count,
				"quality_filtered": quality_filtered,
				"final_samples": len(language_samples),
				"quality_rate": quality_filtered / processed_count if processed_count > 0 else 0,
			}

			logger.info(f"βœ… {language}: {len(language_samples)} samples from {quality_filtered} quality examples")

		except Exception:
			logger.exception(f"❌ Failed to process {language}")
			language_stats[language] = {
				"processed": 0,
				"quality_filtered": 0,
				"final_samples": 0,
				"quality_rate": 0.0,
			}

	# Create DataFrame
	logger.info(f"\nπŸ“Š Creating dataset with {len(all_samples)} total samples...")
	df = pd.DataFrame(all_samples)

	# Create stratified splits
	train_df, test_df = _create_stratified_splits(df)

	# Save datasets
	dataset_files = _save_datasets(output_dir, train_df, test_df)

	# Save metadata
	metadata = {
		"creation_time": time.strftime("%Y-%m-%d %H:%M:%S"),
		"total_samples": len(all_samples),
		"train_samples": len(train_df),
		"test_samples": len(test_df),
		"languages": languages_config.all,
		"language_stats": language_stats,
		"quality_filters": {
			"min_doc_words": min_doc_words,
			"max_doc_words": max_doc_words,
			"min_code_chars": min_code_chars,
			"max_code_chars": max_code_chars,
		},
		"files": dataset_files,
		"processing_time": time.time() - start_time,
	}

	metadata_file = output_dir / "metadata.json"
	with metadata_file.open("w") as f:
		json.dump(metadata, f, indent=2)

	logger.info(f"\nπŸŽ‰ Dataset creation completed in {metadata['processing_time']:.2f} seconds!")
	logger.info("πŸ“Š Final statistics:")
	logger.info(f"  - Total samples: {metadata['total_samples']}")
	logger.info(f"  - Train: {metadata['train_samples']}")
	logger.info(f"  - Test: {metadata['test_samples']}")
	logger.info(f"πŸ’Ύ Metadata saved to: {metadata_file}")

	return metadata


def _passes_quality_filters(
	doc_string: str,
	code_string: str,
	func_name: str,
	min_doc_words: int,
	max_doc_words: int,
	min_code_chars: int,
	max_code_chars: int,
) -> bool:
	"""Apply quality filters optimized for code retrieval following RAG best practices."""
	# Basic existence checks
	if not doc_string or not code_string or not func_name:
		return False

	# Documentation quality filters for code retrieval
	doc_words = len(doc_string.split())
	if doc_words < min_doc_words or doc_words > max_doc_words:
		return False

	# Code quality filters
	code_length = len(code_string)
	if code_length < min_code_chars or code_length > max_code_chars:
		return False

	# Content quality filters for code retrieval
	doc_lower = doc_string.lower()
	code_string.lower()

	# Skip low-quality documentation (expanded for code context)
	skip_phrases = [
		"todo",
		"fixme",
		"hack",
		"temp",
		"test",
		"placeholder",
		"not implemented",
		"coming soon",
		"tbd",
		"xxx",
		"broken",
		"deprecated",
		"legacy",
		"old version",
		"outdated",
	]
	if any(phrase in doc_lower for phrase in skip_phrases):
		return False

	# Ensure meaningful documentation for code retrieval
	if func_name.lower() in doc_lower and doc_words < 5:
		return False

	# Code structure validation (more comprehensive for retrieval)
	has_function = any(
		pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "]
	)
	if not has_function:
		return False

	# Skip trivial or incomplete code
	trivial_code_patterns = [
		"pass",
		"return None",
		"return;",
		"throw new Error",
		"# TODO",
		"// TODO",
		"print(",
		"console.log(",
	]
	if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100:
		return False

	# Ensure documentation describes functionality (not just naming)
	generic_docs = [
		"returns a value",
		"does something",
		"helper function",
		"utility method",
		"this function",
		"this method",
		"returns the result",
		"performs operation",
	]
	if any(generic in doc_lower for generic in generic_docs):
		return False

	# Ensure documentation has descriptive content for retrieval
	descriptive_words = [
		"parse",
		"convert",
		"transform",
		"calculate",
		"validate",
		"format",
		"filter",
		"sort",
		"search",
		"find",
		"create",
		"generate",
		"process",
		"handle",
		"manage",
		"update",
		"modify",
		"remove",
		"delete",
		"add",
	]
	if not any(word in doc_lower for word in descriptive_words) and doc_words < 8:
		return False

	# Code-documentation alignment check (key for retrieval quality)
	return _check_code_doc_alignment(doc_string, code_string, func_name)


def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool:
	"""Check if documentation and code are well-aligned for retrieval tasks."""
	doc_lower = doc_string.lower()
	code_lower = code_string.lower()

	# Function name should relate to documentation
	func_base = func_name.lower().replace("_", " ").replace("-", " ")

	# Check for obvious mismatches
	doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"])
	code_has_return = "return " in code_lower

	# If doc mentions returning something, code should have returns
	if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3:
		return False

	# Check for parameter mentions alignment
	any(word in doc_lower for word in ["parameter", "param", "argument", "input"])
	"(" in func_name and func_name.count("(") == 1

	# Basic semantic alignment
	action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"]
	doc_actions = [word for word in action_words if word in doc_lower]
	[word for word in action_words if word in code_lower or word in func_base]

	# If documentation mentions specific actions, code or function name should reflect them
	return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions))


def _create_training_samples(
	doc_string: str,
	code_string: str,
	func_name: str,
	language: str,
	create_multiple_formats: bool,
) -> list[dict[str, Any]]:
	"""Create optimized training samples for code retrieval with proper training schema."""
	samples = []

	if create_multiple_formats:
		# Format 1: Documentation query β†’ Code (direct evaluation format)
		query_1 = doc_string
		text_1 = _format_training_text(query_1, code_string, language)
		samples.append(
			{
				"language": language,
				"query": query_1,
				"code": code_string,
				"text": text_1,
			}
		)

		# Format 2: How-to query (realistic developer search)
		query_2 = _generate_how_to_query(doc_string, func_name, language)
		text_2 = _format_training_text(query_2, code_string, language)
		samples.append(
			{
				"language": language,
				"query": query_2,
				"code": code_string,
				"text": text_2,
			}
		)

		# Format 3: Functional requirement query
		query_3 = _generate_functional_query(doc_string, func_name)
		text_3 = _format_training_text(query_3, code_string, language)
		samples.append(
			{
				"language": language,
				"query": query_3,
				"code": code_string,
				"text": text_3,
			}
		)

		# Format 4: Implementation-specific query
		query_4 = _generate_implementation_query(doc_string, func_name, language)
		text_4 = _format_training_text(query_4, code_string, language)
		samples.append(
			{
				"language": language,
				"query": query_4,
				"code": code_string,
				"text": text_4,
			}
		)

	else:
		# Simple format - direct documentation to code
		query = doc_string
		text = _format_training_text(query, code_string, language)
		samples.append(
			{
				"language": language,
				"query": query,
				"code": code_string,
				"text": text,
			}
		)

	return samples


def _format_training_text(query: str, code: str, language: str) -> str:
	"""Format query and code into a single training text chunk with markdown-style code blocks."""
	# Clean up query but preserve internal code formatting
	query_clean = query.strip()
	code_clean = code.strip()

	# Create training text with proper markdown format and newline separation
	# Structure: query + empty line + markdown code block with language
	return f"{query_clean}\n\n```{language}\n{code_clean}\n```"


def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str:
	"""Generate realistic 'how to' queries that developers might actually search for."""
	# Extract key action words from documentation
	doc_lower = doc_string.lower()
	func_lower = func_name.lower()

	# Common developer query patterns
	if "sort" in doc_lower or "sort" in func_lower:
		return f"How to sort data in {language}"
	if "parse" in doc_lower or "parse" in func_lower:
		return f"How to parse data in {language}"
	if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower:
		return f"How to convert data in {language}"
	if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower:
		return f"How to validate input in {language}"
	if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower:
		return f"How to calculate values in {language}"
	if "format" in doc_lower or "format" in func_lower:
		return f"How to format output in {language}"
	if "filter" in doc_lower or "filter" in func_lower:
		return f"How to filter data in {language}"
	if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower:
		return f"How to search through data in {language}"
	# Use function name for more specific queries
	if func_name and len(func_name) > 2:
		# Extract meaningful words from function name
		func_words = func_name.replace("_", " ").replace("-", " ").strip()
		if func_words:
			return f"How to {func_words.lower()} in {language}"
	# Fallback to more generic query
	action = doc_string.split()[0] if doc_string.split() else "implement"
	return f"How to {action.lower()} in {language}"


def _generate_functional_query(doc_string: str, func_name: str) -> str:
	"""Generate functional requirement queries focusing on what the code accomplishes."""
	# Clean up documentation to create natural query
	doc_clean = doc_string.strip().rstrip(".")

	# Transform to question format
	if doc_clean.startswith(("Returns", "Return")):
		return f"Function that {doc_clean.lower()}"
	if doc_clean.startswith(("Creates", "Create")):
		return f"Code to {doc_clean.lower()}"
	if doc_clean.startswith(("Checks", "Check")):
		return f"Function to {doc_clean.lower()}"

	# Use function name to enhance the query if available
	if func_name and len(func_name) > 2:
		func_words = func_name.replace("_", " ").replace("-", " ").strip()
		if func_words and len(doc_clean) < 30:  # Only for short docs
			return f"Function named '{func_name}' that {doc_clean.lower()}"

	return f"Implementation that {doc_clean.lower()}"


def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str:
	"""Generate implementation-specific queries with technical details."""
	doc_lower = doc_string.lower()
	func_lower = func_name.lower() if func_name else ""

	# Add language-specific implementation details
	if language == "python":
		if "list" in doc_lower or "array" in doc_lower or "list" in func_lower:
			return f"Python function to {doc_string.lower()} using lists"
		if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower:
			return f"Python function to {doc_string.lower()} using dictionaries"
		# Include function name for context if available
		if func_name and len(func_name) > 2:
			return f"Python implementation of {func_name}: {doc_string.lower()}"
		return f"Python implementation: {doc_string.lower()}"
	if language == "java":
		func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
		return f"Java method to {doc_string.lower()}{func_suffix}"
	if language == "javascript":
		func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
		return f"JavaScript function to {doc_string.lower()}{func_suffix}"
	if language == "php":
		func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
		return f"PHP function to {doc_string.lower()}{func_suffix}"
	if language == "ruby":
		func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
		return f"Ruby method to {doc_string.lower()}{func_suffix}"
	if language == "go":
		func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
		return f"Go function to {doc_string.lower()}{func_suffix}"
	return f"{language} code to {doc_string.lower()}"


def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
	"""Create stratified train/test splits preserving language distribution."""
	# Define split ratios
	train_ratio = 0.9
	# test_ratio = 0.1 (remainder)

	train_dfs = []
	test_dfs = []

	# Split by language to ensure balanced representation
	for language in df["language"].unique():
		lang_df = df[df["language"] == language].copy()
		n_samples = len(lang_df)

		# Calculate split sizes
		n_train = int(n_samples * train_ratio)
		# Remainder goes to test

		# Shuffle and split
		lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True)

		train_dfs.append(lang_df[:n_train])
		test_dfs.append(lang_df[n_train:])

	# Combine and shuffle again
	train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
	test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)

	logger.info("πŸ“Š Created stratified splits:")
	logger.info(f"  - Train: {len(train_df)} samples")
	logger.info(f"  - Test: {len(test_df)} samples")

	return train_df, test_df


def _save_datasets(
	output_dir: Path,
	train_df: pd.DataFrame,
	test_df: pd.DataFrame,
) -> dict[str, str]:
	"""Save datasets in parquet format with compression."""
	dataset_files = {}

	# Save each split
	for split_name, df in [("train", train_df), ("test", test_df)]:
		filepath = output_dir / f"{split_name}.parquet"
		df.to_parquet(
			filepath,
			compression="snappy",
			index=False,
		)
		dataset_files[split_name] = str(filepath)
		logger.info(f"πŸ’Ύ Saved {split_name}: {len(df)} samples β†’ {filepath}")

	# Also save a combined dataset for convenience
	combined_df = pd.concat([train_df, test_df], ignore_index=True)
	combined_filepath = output_dir / "combined.parquet"
	combined_df.to_parquet(combined_filepath, compression="snappy", index=False)
	dataset_files["combined"] = str(combined_filepath)
	logger.info(f"πŸ’Ύ Saved combined: {len(combined_df)} samples β†’ {combined_filepath}")

	return dataset_files


def load_optimized_dataset(
	output_dir: Path | None = None,
	split: str = "train",
) -> pd.DataFrame:
	"""
	Load a previously created optimized dataset.

	Args:
	    output_dir: Directory containing the dataset files
	    split: Which split to load ('train', 'test', 'combined')

	Returns:
	    DataFrame with the requested dataset split
	"""
	if output_dir is None:
		output_dir = DATASET_OUTPUT_DIR

	filepath = output_dir / f"{split}.parquet"

	if not filepath.exists():
		available_files = list(output_dir.glob("*.parquet"))
		available_splits = [f.stem for f in available_files]
		msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}"
		raise FileNotFoundError(msg)

	logger.info(f"πŸ“‚ Loading {split} dataset from {filepath}")
	df = pd.read_parquet(filepath)
	logger.info(f"βœ… Loaded {len(df)} samples")

	return df


def main(
	max_samples_per_lang: Annotated[
		int, typer.Option(help="Maximum samples per language")
	] = DEFAULT_MAX_SAMPLES_PER_LANG,
	min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS,
	max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS,
	min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS,
	max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS,
	output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
	simple_format: Annotated[
		bool, typer.Option(help="Create only simple format (not multiple training formats)")
	] = False,
) -> None:
	"""Create optimized training dataset from CodeSearchNet for code search tasks."""
	logger.info("πŸš€ Starting optimized dataset creation command...")

	# Convert output_dir to Path if provided
	output_path = Path(output_dir) if output_dir else None

	# Create the dataset
	try:
		metadata = create_optimized_dataset(
			max_samples_per_lang=max_samples_per_lang,
			min_doc_words=min_doc_words,
			max_doc_words=max_doc_words,
			min_code_chars=min_code_chars,
			max_code_chars=max_code_chars,
			output_dir=output_path,
			create_multiple_formats=not simple_format,
		)

		logger.info("βœ… Dataset creation completed successfully!")
		logger.info(f"πŸ“ Output directory: {metadata['files']['train']}")

		# Print summary statistics
		print("\n" + "=" * 60)
		print("πŸ“Š DATASET CREATION SUMMARY")
		print("=" * 60)
		print(f"Total samples created: {metadata['total_samples']:,}")
		print(f"Processing time: {metadata['processing_time']:.2f} seconds")
		print("\nSplit distribution:")
		print(f"  β€’ Train: {metadata['train_samples']:,} samples")
		print(f"  β€’ Test:  {metadata['test_samples']:,} samples")

		print("\nLanguage distribution:")
		for lang, stats in metadata["language_stats"].items():
			if "error" not in stats:
				print(f"  β€’ {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)")

		print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}")
		print("=" * 60)

	except Exception as e:
		logger.exception("❌ Dataset creation failed")
		raise typer.Exit(1) from e


if __name__ == "__main__":
	typer.run(main)