codemalt / src /distiller /dataset.py
Sarthak
chore: update dependencies and configuration for improved training
7837959
"""
Custom Dataset Generation for Code-Specialized Model Training.
This module creates optimized training datasets from CodeSearchNet that are specifically
designed to improve performance on code search evaluation tasks.
Features:
- High-quality doc-code pairs optimized for retrieval
- Balanced sampling across programming languages
- Multiple training formats (doc-only, code-only, combined)
- Quality filtering and data cleaning
- Train/test/eval splits with proper stratification
- Efficient parquet format output
"""
import json
import logging
import time
from pathlib import Path
from typing import Annotated, Any
import pandas as pd
import typer
from datasets import load_dataset
from tqdm import tqdm
from .config import languages_config
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Dataset configuration
DATASET_OUTPUT_DIR = Path("code_model2vec/dataset")
DEFAULT_MAX_SAMPLES_PER_LANG = 50000
DEFAULT_MIN_DOC_WORDS = 3
DEFAULT_MAX_DOC_WORDS = 100
DEFAULT_MIN_CODE_CHARS = 50
DEFAULT_MAX_CODE_CHARS = 2000
def create_optimized_dataset(
max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG,
min_doc_words: int = DEFAULT_MIN_DOC_WORDS,
max_doc_words: int = DEFAULT_MAX_DOC_WORDS,
min_code_chars: int = DEFAULT_MIN_CODE_CHARS,
max_code_chars: int = DEFAULT_MAX_CODE_CHARS,
output_dir: Path | None = None,
create_multiple_formats: bool = True,
) -> dict[str, Any]:
"""
Create optimized training dataset from CodeSearchNet for code search tasks.
Args:
max_samples_per_lang: Maximum samples per programming language
min_doc_words: Minimum words in documentation
max_doc_words: Maximum words in documentation
min_code_chars: Minimum characters in code
max_code_chars: Maximum characters in code
output_dir: Output directory for dataset
create_multiple_formats: Create multiple training formats
Returns:
Dictionary with dataset statistics and file paths
"""
output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("πŸš€ Starting optimized CodeSearchNet dataset creation...")
logger.info(f"πŸ“ Output directory: {output_dir}")
logger.info(f"πŸ“Š Target: {max_samples_per_lang} samples per language")
logger.info(f"πŸ” Languages: {', '.join(languages_config.all)}")
start_time = time.time()
all_samples = []
language_stats = {}
# Process each programming language
for language in languages_config.all:
logger.info(f"\nπŸ”„ Processing {language}...")
try:
# Load CodeSearchNet dataset for this language
dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True)
language_samples = []
processed_count = 0
quality_filtered = 0
# Process examples with quality filtering
for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"):
processed_count += 1
# Extract documentation and code
doc_string = example.get("func_documentation_string", "").strip()
code_string = example.get("func_code_string", "").strip()
func_name = example.get("func_name", "").strip()
# Quality filters
if not _passes_quality_filters(
doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars
):
continue
quality_filtered += 1
# Create optimized training samples
samples = _create_training_samples(
doc_string, code_string, func_name, language, create_multiple_formats
)
language_samples.extend(samples)
# Stop if we have enough samples
if len(language_samples) >= max_samples_per_lang:
break
# Truncate to exact target size
language_samples = language_samples[:max_samples_per_lang]
all_samples.extend(language_samples)
# Track statistics
language_stats[language] = {
"processed": processed_count,
"quality_filtered": quality_filtered,
"final_samples": len(language_samples),
"quality_rate": quality_filtered / processed_count if processed_count > 0 else 0,
}
logger.info(f"βœ… {language}: {len(language_samples)} samples from {quality_filtered} quality examples")
except Exception:
logger.exception(f"❌ Failed to process {language}")
language_stats[language] = {
"processed": 0,
"quality_filtered": 0,
"final_samples": 0,
"quality_rate": 0.0,
}
# Create DataFrame
logger.info(f"\nπŸ“Š Creating dataset with {len(all_samples)} total samples...")
df = pd.DataFrame(all_samples)
# Create stratified splits
train_df, test_df = _create_stratified_splits(df)
# Save datasets
dataset_files = _save_datasets(output_dir, train_df, test_df)
# Save metadata
metadata = {
"creation_time": time.strftime("%Y-%m-%d %H:%M:%S"),
"total_samples": len(all_samples),
"train_samples": len(train_df),
"test_samples": len(test_df),
"languages": languages_config.all,
"language_stats": language_stats,
"quality_filters": {
"min_doc_words": min_doc_words,
"max_doc_words": max_doc_words,
"min_code_chars": min_code_chars,
"max_code_chars": max_code_chars,
},
"files": dataset_files,
"processing_time": time.time() - start_time,
}
metadata_file = output_dir / "metadata.json"
with metadata_file.open("w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"\nπŸŽ‰ Dataset creation completed in {metadata['processing_time']:.2f} seconds!")
logger.info("πŸ“Š Final statistics:")
logger.info(f" - Total samples: {metadata['total_samples']}")
logger.info(f" - Train: {metadata['train_samples']}")
logger.info(f" - Test: {metadata['test_samples']}")
logger.info(f"πŸ’Ύ Metadata saved to: {metadata_file}")
return metadata
def _passes_quality_filters(
doc_string: str,
code_string: str,
func_name: str,
min_doc_words: int,
max_doc_words: int,
min_code_chars: int,
max_code_chars: int,
) -> bool:
"""Apply quality filters optimized for code retrieval following RAG best practices."""
# Basic existence checks
if not doc_string or not code_string or not func_name:
return False
# Documentation quality filters for code retrieval
doc_words = len(doc_string.split())
if doc_words < min_doc_words or doc_words > max_doc_words:
return False
# Code quality filters
code_length = len(code_string)
if code_length < min_code_chars or code_length > max_code_chars:
return False
# Content quality filters for code retrieval
doc_lower = doc_string.lower()
code_string.lower()
# Skip low-quality documentation (expanded for code context)
skip_phrases = [
"todo",
"fixme",
"hack",
"temp",
"test",
"placeholder",
"not implemented",
"coming soon",
"tbd",
"xxx",
"broken",
"deprecated",
"legacy",
"old version",
"outdated",
]
if any(phrase in doc_lower for phrase in skip_phrases):
return False
# Ensure meaningful documentation for code retrieval
if func_name.lower() in doc_lower and doc_words < 5:
return False
# Code structure validation (more comprehensive for retrieval)
has_function = any(
pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "]
)
if not has_function:
return False
# Skip trivial or incomplete code
trivial_code_patterns = [
"pass",
"return None",
"return;",
"throw new Error",
"# TODO",
"// TODO",
"print(",
"console.log(",
]
if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100:
return False
# Ensure documentation describes functionality (not just naming)
generic_docs = [
"returns a value",
"does something",
"helper function",
"utility method",
"this function",
"this method",
"returns the result",
"performs operation",
]
if any(generic in doc_lower for generic in generic_docs):
return False
# Ensure documentation has descriptive content for retrieval
descriptive_words = [
"parse",
"convert",
"transform",
"calculate",
"validate",
"format",
"filter",
"sort",
"search",
"find",
"create",
"generate",
"process",
"handle",
"manage",
"update",
"modify",
"remove",
"delete",
"add",
]
if not any(word in doc_lower for word in descriptive_words) and doc_words < 8:
return False
# Code-documentation alignment check (key for retrieval quality)
return _check_code_doc_alignment(doc_string, code_string, func_name)
def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool:
"""Check if documentation and code are well-aligned for retrieval tasks."""
doc_lower = doc_string.lower()
code_lower = code_string.lower()
# Function name should relate to documentation
func_base = func_name.lower().replace("_", " ").replace("-", " ")
# Check for obvious mismatches
doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"])
code_has_return = "return " in code_lower
# If doc mentions returning something, code should have returns
if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3:
return False
# Check for parameter mentions alignment
any(word in doc_lower for word in ["parameter", "param", "argument", "input"])
"(" in func_name and func_name.count("(") == 1
# Basic semantic alignment
action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"]
doc_actions = [word for word in action_words if word in doc_lower]
[word for word in action_words if word in code_lower or word in func_base]
# If documentation mentions specific actions, code or function name should reflect them
return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions))
def _create_training_samples(
doc_string: str,
code_string: str,
func_name: str,
language: str,
create_multiple_formats: bool,
) -> list[dict[str, Any]]:
"""Create optimized training samples for code retrieval with proper training schema."""
samples = []
if create_multiple_formats:
# Format 1: Documentation query β†’ Code (direct evaluation format)
query_1 = doc_string
text_1 = _format_training_text(query_1, code_string, language)
samples.append(
{
"language": language,
"query": query_1,
"code": code_string,
"text": text_1,
}
)
# Format 2: How-to query (realistic developer search)
query_2 = _generate_how_to_query(doc_string, func_name, language)
text_2 = _format_training_text(query_2, code_string, language)
samples.append(
{
"language": language,
"query": query_2,
"code": code_string,
"text": text_2,
}
)
# Format 3: Functional requirement query
query_3 = _generate_functional_query(doc_string, func_name)
text_3 = _format_training_text(query_3, code_string, language)
samples.append(
{
"language": language,
"query": query_3,
"code": code_string,
"text": text_3,
}
)
# Format 4: Implementation-specific query
query_4 = _generate_implementation_query(doc_string, func_name, language)
text_4 = _format_training_text(query_4, code_string, language)
samples.append(
{
"language": language,
"query": query_4,
"code": code_string,
"text": text_4,
}
)
else:
# Simple format - direct documentation to code
query = doc_string
text = _format_training_text(query, code_string, language)
samples.append(
{
"language": language,
"query": query,
"code": code_string,
"text": text,
}
)
return samples
def _format_training_text(query: str, code: str, language: str) -> str:
"""Format query and code into a single training text chunk with markdown-style code blocks."""
# Clean up query but preserve internal code formatting
query_clean = query.strip()
code_clean = code.strip()
# Create training text with proper markdown format and newline separation
# Structure: query + empty line + markdown code block with language
return f"{query_clean}\n\n```{language}\n{code_clean}\n```"
def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str:
"""Generate realistic 'how to' queries that developers might actually search for."""
# Extract key action words from documentation
doc_lower = doc_string.lower()
func_lower = func_name.lower()
# Common developer query patterns
if "sort" in doc_lower or "sort" in func_lower:
return f"How to sort data in {language}"
if "parse" in doc_lower or "parse" in func_lower:
return f"How to parse data in {language}"
if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower:
return f"How to convert data in {language}"
if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower:
return f"How to validate input in {language}"
if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower:
return f"How to calculate values in {language}"
if "format" in doc_lower or "format" in func_lower:
return f"How to format output in {language}"
if "filter" in doc_lower or "filter" in func_lower:
return f"How to filter data in {language}"
if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower:
return f"How to search through data in {language}"
# Use function name for more specific queries
if func_name and len(func_name) > 2:
# Extract meaningful words from function name
func_words = func_name.replace("_", " ").replace("-", " ").strip()
if func_words:
return f"How to {func_words.lower()} in {language}"
# Fallback to more generic query
action = doc_string.split()[0] if doc_string.split() else "implement"
return f"How to {action.lower()} in {language}"
def _generate_functional_query(doc_string: str, func_name: str) -> str:
"""Generate functional requirement queries focusing on what the code accomplishes."""
# Clean up documentation to create natural query
doc_clean = doc_string.strip().rstrip(".")
# Transform to question format
if doc_clean.startswith(("Returns", "Return")):
return f"Function that {doc_clean.lower()}"
if doc_clean.startswith(("Creates", "Create")):
return f"Code to {doc_clean.lower()}"
if doc_clean.startswith(("Checks", "Check")):
return f"Function to {doc_clean.lower()}"
# Use function name to enhance the query if available
if func_name and len(func_name) > 2:
func_words = func_name.replace("_", " ").replace("-", " ").strip()
if func_words and len(doc_clean) < 30: # Only for short docs
return f"Function named '{func_name}' that {doc_clean.lower()}"
return f"Implementation that {doc_clean.lower()}"
def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str:
"""Generate implementation-specific queries with technical details."""
doc_lower = doc_string.lower()
func_lower = func_name.lower() if func_name else ""
# Add language-specific implementation details
if language == "python":
if "list" in doc_lower or "array" in doc_lower or "list" in func_lower:
return f"Python function to {doc_string.lower()} using lists"
if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower:
return f"Python function to {doc_string.lower()} using dictionaries"
# Include function name for context if available
if func_name and len(func_name) > 2:
return f"Python implementation of {func_name}: {doc_string.lower()}"
return f"Python implementation: {doc_string.lower()}"
if language == "java":
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
return f"Java method to {doc_string.lower()}{func_suffix}"
if language == "javascript":
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
return f"JavaScript function to {doc_string.lower()}{func_suffix}"
if language == "php":
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
return f"PHP function to {doc_string.lower()}{func_suffix}"
if language == "ruby":
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
return f"Ruby method to {doc_string.lower()}{func_suffix}"
if language == "go":
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
return f"Go function to {doc_string.lower()}{func_suffix}"
return f"{language} code to {doc_string.lower()}"
def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Create stratified train/test splits preserving language distribution."""
# Define split ratios
train_ratio = 0.9
# test_ratio = 0.1 (remainder)
train_dfs = []
test_dfs = []
# Split by language to ensure balanced representation
for language in df["language"].unique():
lang_df = df[df["language"] == language].copy()
n_samples = len(lang_df)
# Calculate split sizes
n_train = int(n_samples * train_ratio)
# Remainder goes to test
# Shuffle and split
lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True)
train_dfs.append(lang_df[:n_train])
test_dfs.append(lang_df[n_train:])
# Combine and shuffle again
train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
logger.info("πŸ“Š Created stratified splits:")
logger.info(f" - Train: {len(train_df)} samples")
logger.info(f" - Test: {len(test_df)} samples")
return train_df, test_df
def _save_datasets(
output_dir: Path,
train_df: pd.DataFrame,
test_df: pd.DataFrame,
) -> dict[str, str]:
"""Save datasets in parquet format with compression."""
dataset_files = {}
# Save each split
for split_name, df in [("train", train_df), ("test", test_df)]:
filepath = output_dir / f"{split_name}.parquet"
df.to_parquet(
filepath,
compression="snappy",
index=False,
)
dataset_files[split_name] = str(filepath)
logger.info(f"πŸ’Ύ Saved {split_name}: {len(df)} samples β†’ {filepath}")
# Also save a combined dataset for convenience
combined_df = pd.concat([train_df, test_df], ignore_index=True)
combined_filepath = output_dir / "combined.parquet"
combined_df.to_parquet(combined_filepath, compression="snappy", index=False)
dataset_files["combined"] = str(combined_filepath)
logger.info(f"πŸ’Ύ Saved combined: {len(combined_df)} samples β†’ {combined_filepath}")
return dataset_files
def load_optimized_dataset(
output_dir: Path | None = None,
split: str = "train",
) -> pd.DataFrame:
"""
Load a previously created optimized dataset.
Args:
output_dir: Directory containing the dataset files
split: Which split to load ('train', 'test', 'combined')
Returns:
DataFrame with the requested dataset split
"""
if output_dir is None:
output_dir = DATASET_OUTPUT_DIR
filepath = output_dir / f"{split}.parquet"
if not filepath.exists():
available_files = list(output_dir.glob("*.parquet"))
available_splits = [f.stem for f in available_files]
msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}"
raise FileNotFoundError(msg)
logger.info(f"πŸ“‚ Loading {split} dataset from {filepath}")
df = pd.read_parquet(filepath)
logger.info(f"βœ… Loaded {len(df)} samples")
return df
def main(
max_samples_per_lang: Annotated[
int, typer.Option(help="Maximum samples per language")
] = DEFAULT_MAX_SAMPLES_PER_LANG,
min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS,
max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS,
min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS,
max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS,
output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
simple_format: Annotated[
bool, typer.Option(help="Create only simple format (not multiple training formats)")
] = False,
) -> None:
"""Create optimized training dataset from CodeSearchNet for code search tasks."""
logger.info("πŸš€ Starting optimized dataset creation command...")
# Convert output_dir to Path if provided
output_path = Path(output_dir) if output_dir else None
# Create the dataset
try:
metadata = create_optimized_dataset(
max_samples_per_lang=max_samples_per_lang,
min_doc_words=min_doc_words,
max_doc_words=max_doc_words,
min_code_chars=min_code_chars,
max_code_chars=max_code_chars,
output_dir=output_path,
create_multiple_formats=not simple_format,
)
logger.info("βœ… Dataset creation completed successfully!")
logger.info(f"πŸ“ Output directory: {metadata['files']['train']}")
# Print summary statistics
print("\n" + "=" * 60)
print("πŸ“Š DATASET CREATION SUMMARY")
print("=" * 60)
print(f"Total samples created: {metadata['total_samples']:,}")
print(f"Processing time: {metadata['processing_time']:.2f} seconds")
print("\nSplit distribution:")
print(f" β€’ Train: {metadata['train_samples']:,} samples")
print(f" β€’ Test: {metadata['test_samples']:,} samples")
print("\nLanguage distribution:")
for lang, stats in metadata["language_stats"].items():
if "error" not in stats:
print(f" β€’ {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)")
print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}")
print("=" * 60)
except Exception as e:
logger.exception("❌ Dataset creation failed")
raise typer.Exit(1) from e
if __name__ == "__main__":
typer.run(main)