|
""" |
|
Custom Dataset Generation for Code-Specialized Model Training. |
|
|
|
This module creates optimized training datasets from CodeSearchNet that are specifically |
|
designed to improve performance on code search evaluation tasks. |
|
|
|
Features: |
|
- High-quality doc-code pairs optimized for retrieval |
|
- Balanced sampling across programming languages |
|
- Multiple training formats (doc-only, code-only, combined) |
|
- Quality filtering and data cleaning |
|
- Train/test/eval splits with proper stratification |
|
- Efficient parquet format output |
|
""" |
|
|
|
import json |
|
import logging |
|
import time |
|
from pathlib import Path |
|
from typing import Annotated, Any |
|
|
|
import pandas as pd |
|
import typer |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
|
|
from .config import languages_config |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DATASET_OUTPUT_DIR = Path("code_model2vec/dataset") |
|
DEFAULT_MAX_SAMPLES_PER_LANG = 50000 |
|
DEFAULT_MIN_DOC_WORDS = 3 |
|
DEFAULT_MAX_DOC_WORDS = 100 |
|
DEFAULT_MIN_CODE_CHARS = 50 |
|
DEFAULT_MAX_CODE_CHARS = 2000 |
|
|
|
|
|
def create_optimized_dataset( |
|
max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG, |
|
min_doc_words: int = DEFAULT_MIN_DOC_WORDS, |
|
max_doc_words: int = DEFAULT_MAX_DOC_WORDS, |
|
min_code_chars: int = DEFAULT_MIN_CODE_CHARS, |
|
max_code_chars: int = DEFAULT_MAX_CODE_CHARS, |
|
output_dir: Path | None = None, |
|
create_multiple_formats: bool = True, |
|
) -> dict[str, Any]: |
|
""" |
|
Create optimized training dataset from CodeSearchNet for code search tasks. |
|
|
|
Args: |
|
max_samples_per_lang: Maximum samples per programming language |
|
min_doc_words: Minimum words in documentation |
|
max_doc_words: Maximum words in documentation |
|
min_code_chars: Minimum characters in code |
|
max_code_chars: Maximum characters in code |
|
output_dir: Output directory for dataset |
|
create_multiple_formats: Create multiple training formats |
|
|
|
Returns: |
|
Dictionary with dataset statistics and file paths |
|
""" |
|
output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir) |
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info("π Starting optimized CodeSearchNet dataset creation...") |
|
logger.info(f"π Output directory: {output_dir}") |
|
logger.info(f"π Target: {max_samples_per_lang} samples per language") |
|
logger.info(f"π Languages: {', '.join(languages_config.all)}") |
|
|
|
start_time = time.time() |
|
all_samples = [] |
|
language_stats = {} |
|
|
|
|
|
for language in languages_config.all: |
|
logger.info(f"\nπ Processing {language}...") |
|
|
|
try: |
|
|
|
dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True) |
|
|
|
language_samples = [] |
|
processed_count = 0 |
|
quality_filtered = 0 |
|
|
|
|
|
for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"): |
|
processed_count += 1 |
|
|
|
|
|
doc_string = example.get("func_documentation_string", "").strip() |
|
code_string = example.get("func_code_string", "").strip() |
|
func_name = example.get("func_name", "").strip() |
|
|
|
|
|
if not _passes_quality_filters( |
|
doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars |
|
): |
|
continue |
|
|
|
quality_filtered += 1 |
|
|
|
|
|
samples = _create_training_samples( |
|
doc_string, code_string, func_name, language, create_multiple_formats |
|
) |
|
language_samples.extend(samples) |
|
|
|
|
|
if len(language_samples) >= max_samples_per_lang: |
|
break |
|
|
|
|
|
language_samples = language_samples[:max_samples_per_lang] |
|
all_samples.extend(language_samples) |
|
|
|
|
|
language_stats[language] = { |
|
"processed": processed_count, |
|
"quality_filtered": quality_filtered, |
|
"final_samples": len(language_samples), |
|
"quality_rate": quality_filtered / processed_count if processed_count > 0 else 0, |
|
} |
|
|
|
logger.info(f"β
{language}: {len(language_samples)} samples from {quality_filtered} quality examples") |
|
|
|
except Exception: |
|
logger.exception(f"β Failed to process {language}") |
|
language_stats[language] = { |
|
"processed": 0, |
|
"quality_filtered": 0, |
|
"final_samples": 0, |
|
"quality_rate": 0.0, |
|
} |
|
|
|
|
|
logger.info(f"\nπ Creating dataset with {len(all_samples)} total samples...") |
|
df = pd.DataFrame(all_samples) |
|
|
|
|
|
train_df, test_df = _create_stratified_splits(df) |
|
|
|
|
|
dataset_files = _save_datasets(output_dir, train_df, test_df) |
|
|
|
|
|
metadata = { |
|
"creation_time": time.strftime("%Y-%m-%d %H:%M:%S"), |
|
"total_samples": len(all_samples), |
|
"train_samples": len(train_df), |
|
"test_samples": len(test_df), |
|
"languages": languages_config.all, |
|
"language_stats": language_stats, |
|
"quality_filters": { |
|
"min_doc_words": min_doc_words, |
|
"max_doc_words": max_doc_words, |
|
"min_code_chars": min_code_chars, |
|
"max_code_chars": max_code_chars, |
|
}, |
|
"files": dataset_files, |
|
"processing_time": time.time() - start_time, |
|
} |
|
|
|
metadata_file = output_dir / "metadata.json" |
|
with metadata_file.open("w") as f: |
|
json.dump(metadata, f, indent=2) |
|
|
|
logger.info(f"\nπ Dataset creation completed in {metadata['processing_time']:.2f} seconds!") |
|
logger.info("π Final statistics:") |
|
logger.info(f" - Total samples: {metadata['total_samples']}") |
|
logger.info(f" - Train: {metadata['train_samples']}") |
|
logger.info(f" - Test: {metadata['test_samples']}") |
|
logger.info(f"πΎ Metadata saved to: {metadata_file}") |
|
|
|
return metadata |
|
|
|
|
|
def _passes_quality_filters( |
|
doc_string: str, |
|
code_string: str, |
|
func_name: str, |
|
min_doc_words: int, |
|
max_doc_words: int, |
|
min_code_chars: int, |
|
max_code_chars: int, |
|
) -> bool: |
|
"""Apply quality filters optimized for code retrieval following RAG best practices.""" |
|
|
|
if not doc_string or not code_string or not func_name: |
|
return False |
|
|
|
|
|
doc_words = len(doc_string.split()) |
|
if doc_words < min_doc_words or doc_words > max_doc_words: |
|
return False |
|
|
|
|
|
code_length = len(code_string) |
|
if code_length < min_code_chars or code_length > max_code_chars: |
|
return False |
|
|
|
|
|
doc_lower = doc_string.lower() |
|
code_string.lower() |
|
|
|
|
|
skip_phrases = [ |
|
"todo", |
|
"fixme", |
|
"hack", |
|
"temp", |
|
"test", |
|
"placeholder", |
|
"not implemented", |
|
"coming soon", |
|
"tbd", |
|
"xxx", |
|
"broken", |
|
"deprecated", |
|
"legacy", |
|
"old version", |
|
"outdated", |
|
] |
|
if any(phrase in doc_lower for phrase in skip_phrases): |
|
return False |
|
|
|
|
|
if func_name.lower() in doc_lower and doc_words < 5: |
|
return False |
|
|
|
|
|
has_function = any( |
|
pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "] |
|
) |
|
if not has_function: |
|
return False |
|
|
|
|
|
trivial_code_patterns = [ |
|
"pass", |
|
"return None", |
|
"return;", |
|
"throw new Error", |
|
"# TODO", |
|
"// TODO", |
|
"print(", |
|
"console.log(", |
|
] |
|
if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100: |
|
return False |
|
|
|
|
|
generic_docs = [ |
|
"returns a value", |
|
"does something", |
|
"helper function", |
|
"utility method", |
|
"this function", |
|
"this method", |
|
"returns the result", |
|
"performs operation", |
|
] |
|
if any(generic in doc_lower for generic in generic_docs): |
|
return False |
|
|
|
|
|
descriptive_words = [ |
|
"parse", |
|
"convert", |
|
"transform", |
|
"calculate", |
|
"validate", |
|
"format", |
|
"filter", |
|
"sort", |
|
"search", |
|
"find", |
|
"create", |
|
"generate", |
|
"process", |
|
"handle", |
|
"manage", |
|
"update", |
|
"modify", |
|
"remove", |
|
"delete", |
|
"add", |
|
] |
|
if not any(word in doc_lower for word in descriptive_words) and doc_words < 8: |
|
return False |
|
|
|
|
|
return _check_code_doc_alignment(doc_string, code_string, func_name) |
|
|
|
|
|
def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool: |
|
"""Check if documentation and code are well-aligned for retrieval tasks.""" |
|
doc_lower = doc_string.lower() |
|
code_lower = code_string.lower() |
|
|
|
|
|
func_base = func_name.lower().replace("_", " ").replace("-", " ") |
|
|
|
|
|
doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"]) |
|
code_has_return = "return " in code_lower |
|
|
|
|
|
if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3: |
|
return False |
|
|
|
|
|
any(word in doc_lower for word in ["parameter", "param", "argument", "input"]) |
|
"(" in func_name and func_name.count("(") == 1 |
|
|
|
|
|
action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"] |
|
doc_actions = [word for word in action_words if word in doc_lower] |
|
[word for word in action_words if word in code_lower or word in func_base] |
|
|
|
|
|
return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions)) |
|
|
|
|
|
def _create_training_samples( |
|
doc_string: str, |
|
code_string: str, |
|
func_name: str, |
|
language: str, |
|
create_multiple_formats: bool, |
|
) -> list[dict[str, Any]]: |
|
"""Create optimized training samples for code retrieval with proper training schema.""" |
|
samples = [] |
|
|
|
if create_multiple_formats: |
|
|
|
query_1 = doc_string |
|
text_1 = _format_training_text(query_1, code_string, language) |
|
samples.append( |
|
{ |
|
"language": language, |
|
"query": query_1, |
|
"code": code_string, |
|
"text": text_1, |
|
} |
|
) |
|
|
|
|
|
query_2 = _generate_how_to_query(doc_string, func_name, language) |
|
text_2 = _format_training_text(query_2, code_string, language) |
|
samples.append( |
|
{ |
|
"language": language, |
|
"query": query_2, |
|
"code": code_string, |
|
"text": text_2, |
|
} |
|
) |
|
|
|
|
|
query_3 = _generate_functional_query(doc_string, func_name) |
|
text_3 = _format_training_text(query_3, code_string, language) |
|
samples.append( |
|
{ |
|
"language": language, |
|
"query": query_3, |
|
"code": code_string, |
|
"text": text_3, |
|
} |
|
) |
|
|
|
|
|
query_4 = _generate_implementation_query(doc_string, func_name, language) |
|
text_4 = _format_training_text(query_4, code_string, language) |
|
samples.append( |
|
{ |
|
"language": language, |
|
"query": query_4, |
|
"code": code_string, |
|
"text": text_4, |
|
} |
|
) |
|
|
|
else: |
|
|
|
query = doc_string |
|
text = _format_training_text(query, code_string, language) |
|
samples.append( |
|
{ |
|
"language": language, |
|
"query": query, |
|
"code": code_string, |
|
"text": text, |
|
} |
|
) |
|
|
|
return samples |
|
|
|
|
|
def _format_training_text(query: str, code: str, language: str) -> str: |
|
"""Format query and code into a single training text chunk with markdown-style code blocks.""" |
|
|
|
query_clean = query.strip() |
|
code_clean = code.strip() |
|
|
|
|
|
|
|
return f"{query_clean}\n\n```{language}\n{code_clean}\n```" |
|
|
|
|
|
def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str: |
|
"""Generate realistic 'how to' queries that developers might actually search for.""" |
|
|
|
doc_lower = doc_string.lower() |
|
func_lower = func_name.lower() |
|
|
|
|
|
if "sort" in doc_lower or "sort" in func_lower: |
|
return f"How to sort data in {language}" |
|
if "parse" in doc_lower or "parse" in func_lower: |
|
return f"How to parse data in {language}" |
|
if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower: |
|
return f"How to convert data in {language}" |
|
if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower: |
|
return f"How to validate input in {language}" |
|
if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower: |
|
return f"How to calculate values in {language}" |
|
if "format" in doc_lower or "format" in func_lower: |
|
return f"How to format output in {language}" |
|
if "filter" in doc_lower or "filter" in func_lower: |
|
return f"How to filter data in {language}" |
|
if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower: |
|
return f"How to search through data in {language}" |
|
|
|
if func_name and len(func_name) > 2: |
|
|
|
func_words = func_name.replace("_", " ").replace("-", " ").strip() |
|
if func_words: |
|
return f"How to {func_words.lower()} in {language}" |
|
|
|
action = doc_string.split()[0] if doc_string.split() else "implement" |
|
return f"How to {action.lower()} in {language}" |
|
|
|
|
|
def _generate_functional_query(doc_string: str, func_name: str) -> str: |
|
"""Generate functional requirement queries focusing on what the code accomplishes.""" |
|
|
|
doc_clean = doc_string.strip().rstrip(".") |
|
|
|
|
|
if doc_clean.startswith(("Returns", "Return")): |
|
return f"Function that {doc_clean.lower()}" |
|
if doc_clean.startswith(("Creates", "Create")): |
|
return f"Code to {doc_clean.lower()}" |
|
if doc_clean.startswith(("Checks", "Check")): |
|
return f"Function to {doc_clean.lower()}" |
|
|
|
|
|
if func_name and len(func_name) > 2: |
|
func_words = func_name.replace("_", " ").replace("-", " ").strip() |
|
if func_words and len(doc_clean) < 30: |
|
return f"Function named '{func_name}' that {doc_clean.lower()}" |
|
|
|
return f"Implementation that {doc_clean.lower()}" |
|
|
|
|
|
def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str: |
|
"""Generate implementation-specific queries with technical details.""" |
|
doc_lower = doc_string.lower() |
|
func_lower = func_name.lower() if func_name else "" |
|
|
|
|
|
if language == "python": |
|
if "list" in doc_lower or "array" in doc_lower or "list" in func_lower: |
|
return f"Python function to {doc_string.lower()} using lists" |
|
if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower: |
|
return f"Python function to {doc_string.lower()} using dictionaries" |
|
|
|
if func_name and len(func_name) > 2: |
|
return f"Python implementation of {func_name}: {doc_string.lower()}" |
|
return f"Python implementation: {doc_string.lower()}" |
|
if language == "java": |
|
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" |
|
return f"Java method to {doc_string.lower()}{func_suffix}" |
|
if language == "javascript": |
|
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" |
|
return f"JavaScript function to {doc_string.lower()}{func_suffix}" |
|
if language == "php": |
|
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" |
|
return f"PHP function to {doc_string.lower()}{func_suffix}" |
|
if language == "ruby": |
|
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" |
|
return f"Ruby method to {doc_string.lower()}{func_suffix}" |
|
if language == "go": |
|
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" |
|
return f"Go function to {doc_string.lower()}{func_suffix}" |
|
return f"{language} code to {doc_string.lower()}" |
|
|
|
|
|
def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: |
|
"""Create stratified train/test splits preserving language distribution.""" |
|
|
|
train_ratio = 0.9 |
|
|
|
|
|
train_dfs = [] |
|
test_dfs = [] |
|
|
|
|
|
for language in df["language"].unique(): |
|
lang_df = df[df["language"] == language].copy() |
|
n_samples = len(lang_df) |
|
|
|
|
|
n_train = int(n_samples * train_ratio) |
|
|
|
|
|
|
|
lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True) |
|
|
|
train_dfs.append(lang_df[:n_train]) |
|
test_dfs.append(lang_df[n_train:]) |
|
|
|
|
|
train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True) |
|
test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True) |
|
|
|
logger.info("π Created stratified splits:") |
|
logger.info(f" - Train: {len(train_df)} samples") |
|
logger.info(f" - Test: {len(test_df)} samples") |
|
|
|
return train_df, test_df |
|
|
|
|
|
def _save_datasets( |
|
output_dir: Path, |
|
train_df: pd.DataFrame, |
|
test_df: pd.DataFrame, |
|
) -> dict[str, str]: |
|
"""Save datasets in parquet format with compression.""" |
|
dataset_files = {} |
|
|
|
|
|
for split_name, df in [("train", train_df), ("test", test_df)]: |
|
filepath = output_dir / f"{split_name}.parquet" |
|
df.to_parquet( |
|
filepath, |
|
compression="snappy", |
|
index=False, |
|
) |
|
dataset_files[split_name] = str(filepath) |
|
logger.info(f"πΎ Saved {split_name}: {len(df)} samples β {filepath}") |
|
|
|
|
|
combined_df = pd.concat([train_df, test_df], ignore_index=True) |
|
combined_filepath = output_dir / "combined.parquet" |
|
combined_df.to_parquet(combined_filepath, compression="snappy", index=False) |
|
dataset_files["combined"] = str(combined_filepath) |
|
logger.info(f"πΎ Saved combined: {len(combined_df)} samples β {combined_filepath}") |
|
|
|
return dataset_files |
|
|
|
|
|
def load_optimized_dataset( |
|
output_dir: Path | None = None, |
|
split: str = "train", |
|
) -> pd.DataFrame: |
|
""" |
|
Load a previously created optimized dataset. |
|
|
|
Args: |
|
output_dir: Directory containing the dataset files |
|
split: Which split to load ('train', 'test', 'combined') |
|
|
|
Returns: |
|
DataFrame with the requested dataset split |
|
""" |
|
if output_dir is None: |
|
output_dir = DATASET_OUTPUT_DIR |
|
|
|
filepath = output_dir / f"{split}.parquet" |
|
|
|
if not filepath.exists(): |
|
available_files = list(output_dir.glob("*.parquet")) |
|
available_splits = [f.stem for f in available_files] |
|
msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}" |
|
raise FileNotFoundError(msg) |
|
|
|
logger.info(f"π Loading {split} dataset from {filepath}") |
|
df = pd.read_parquet(filepath) |
|
logger.info(f"β
Loaded {len(df)} samples") |
|
|
|
return df |
|
|
|
|
|
def main( |
|
max_samples_per_lang: Annotated[ |
|
int, typer.Option(help="Maximum samples per language") |
|
] = DEFAULT_MAX_SAMPLES_PER_LANG, |
|
min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS, |
|
max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS, |
|
min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS, |
|
max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS, |
|
output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None, |
|
simple_format: Annotated[ |
|
bool, typer.Option(help="Create only simple format (not multiple training formats)") |
|
] = False, |
|
) -> None: |
|
"""Create optimized training dataset from CodeSearchNet for code search tasks.""" |
|
logger.info("π Starting optimized dataset creation command...") |
|
|
|
|
|
output_path = Path(output_dir) if output_dir else None |
|
|
|
|
|
try: |
|
metadata = create_optimized_dataset( |
|
max_samples_per_lang=max_samples_per_lang, |
|
min_doc_words=min_doc_words, |
|
max_doc_words=max_doc_words, |
|
min_code_chars=min_code_chars, |
|
max_code_chars=max_code_chars, |
|
output_dir=output_path, |
|
create_multiple_formats=not simple_format, |
|
) |
|
|
|
logger.info("β
Dataset creation completed successfully!") |
|
logger.info(f"π Output directory: {metadata['files']['train']}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
print("π DATASET CREATION SUMMARY") |
|
print("=" * 60) |
|
print(f"Total samples created: {metadata['total_samples']:,}") |
|
print(f"Processing time: {metadata['processing_time']:.2f} seconds") |
|
print("\nSplit distribution:") |
|
print(f" β’ Train: {metadata['train_samples']:,} samples") |
|
print(f" β’ Test: {metadata['test_samples']:,} samples") |
|
|
|
print("\nLanguage distribution:") |
|
for lang, stats in metadata["language_stats"].items(): |
|
if "error" not in stats: |
|
print(f" β’ {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)") |
|
|
|
print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}") |
|
print("=" * 60) |
|
|
|
except Exception as e: |
|
logger.exception("β Dataset creation failed") |
|
raise typer.Exit(1) from e |
|
|
|
|
|
if __name__ == "__main__": |
|
typer.run(main) |
|
|