""" Custom Dataset Generation for Code-Specialized Model Training. This module creates optimized training datasets from CodeSearchNet that are specifically designed to improve performance on code search evaluation tasks. Features: - High-quality doc-code pairs optimized for retrieval - Balanced sampling across programming languages - Multiple training formats (doc-only, code-only, combined) - Quality filtering and data cleaning - Train/test/eval splits with proper stratification - Efficient parquet format output """ import json import logging import time from pathlib import Path from typing import Annotated, Any import pandas as pd import typer from datasets import load_dataset from tqdm import tqdm from .config import languages_config # Set up logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Dataset configuration DATASET_OUTPUT_DIR = Path("code_model2vec/dataset") DEFAULT_MAX_SAMPLES_PER_LANG = 50000 DEFAULT_MIN_DOC_WORDS = 3 DEFAULT_MAX_DOC_WORDS = 100 DEFAULT_MIN_CODE_CHARS = 50 DEFAULT_MAX_CODE_CHARS = 2000 def create_optimized_dataset( max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG, min_doc_words: int = DEFAULT_MIN_DOC_WORDS, max_doc_words: int = DEFAULT_MAX_DOC_WORDS, min_code_chars: int = DEFAULT_MIN_CODE_CHARS, max_code_chars: int = DEFAULT_MAX_CODE_CHARS, output_dir: Path | None = None, create_multiple_formats: bool = True, ) -> dict[str, Any]: """ Create optimized training dataset from CodeSearchNet for code search tasks. Args: max_samples_per_lang: Maximum samples per programming language min_doc_words: Minimum words in documentation max_doc_words: Maximum words in documentation min_code_chars: Minimum characters in code max_code_chars: Maximum characters in code output_dir: Output directory for dataset create_multiple_formats: Create multiple training formats Returns: Dictionary with dataset statistics and file paths """ output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info("šŸš€ Starting optimized CodeSearchNet dataset creation...") logger.info(f"šŸ“ Output directory: {output_dir}") logger.info(f"šŸ“Š Target: {max_samples_per_lang} samples per language") logger.info(f"šŸ” Languages: {', '.join(languages_config.all)}") start_time = time.time() all_samples = [] language_stats = {} # Process each programming language for language in languages_config.all: logger.info(f"\nšŸ”„ Processing {language}...") try: # Load CodeSearchNet dataset for this language dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True) language_samples = [] processed_count = 0 quality_filtered = 0 # Process examples with quality filtering for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"): processed_count += 1 # Extract documentation and code doc_string = example.get("func_documentation_string", "").strip() code_string = example.get("func_code_string", "").strip() func_name = example.get("func_name", "").strip() # Quality filters if not _passes_quality_filters( doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars ): continue quality_filtered += 1 # Create optimized training samples samples = _create_training_samples( doc_string, code_string, func_name, language, create_multiple_formats ) language_samples.extend(samples) # Stop if we have enough samples if len(language_samples) >= max_samples_per_lang: break # Truncate to exact target size language_samples = language_samples[:max_samples_per_lang] all_samples.extend(language_samples) # Track statistics language_stats[language] = { "processed": processed_count, "quality_filtered": quality_filtered, "final_samples": len(language_samples), "quality_rate": quality_filtered / processed_count if processed_count > 0 else 0, } logger.info(f"āœ… {language}: {len(language_samples)} samples from {quality_filtered} quality examples") except Exception: logger.exception(f"āŒ Failed to process {language}") language_stats[language] = { "processed": 0, "quality_filtered": 0, "final_samples": 0, "quality_rate": 0.0, } # Create DataFrame logger.info(f"\nšŸ“Š Creating dataset with {len(all_samples)} total samples...") df = pd.DataFrame(all_samples) # Create stratified splits train_df, test_df = _create_stratified_splits(df) # Save datasets dataset_files = _save_datasets(output_dir, train_df, test_df) # Save metadata metadata = { "creation_time": time.strftime("%Y-%m-%d %H:%M:%S"), "total_samples": len(all_samples), "train_samples": len(train_df), "test_samples": len(test_df), "languages": languages_config.all, "language_stats": language_stats, "quality_filters": { "min_doc_words": min_doc_words, "max_doc_words": max_doc_words, "min_code_chars": min_code_chars, "max_code_chars": max_code_chars, }, "files": dataset_files, "processing_time": time.time() - start_time, } metadata_file = output_dir / "metadata.json" with metadata_file.open("w") as f: json.dump(metadata, f, indent=2) logger.info(f"\nšŸŽ‰ Dataset creation completed in {metadata['processing_time']:.2f} seconds!") logger.info("šŸ“Š Final statistics:") logger.info(f" - Total samples: {metadata['total_samples']}") logger.info(f" - Train: {metadata['train_samples']}") logger.info(f" - Test: {metadata['test_samples']}") logger.info(f"šŸ’¾ Metadata saved to: {metadata_file}") return metadata def _passes_quality_filters( doc_string: str, code_string: str, func_name: str, min_doc_words: int, max_doc_words: int, min_code_chars: int, max_code_chars: int, ) -> bool: """Apply quality filters optimized for code retrieval following RAG best practices.""" # Basic existence checks if not doc_string or not code_string or not func_name: return False # Documentation quality filters for code retrieval doc_words = len(doc_string.split()) if doc_words < min_doc_words or doc_words > max_doc_words: return False # Code quality filters code_length = len(code_string) if code_length < min_code_chars or code_length > max_code_chars: return False # Content quality filters for code retrieval doc_lower = doc_string.lower() code_string.lower() # Skip low-quality documentation (expanded for code context) skip_phrases = [ "todo", "fixme", "hack", "temp", "test", "placeholder", "not implemented", "coming soon", "tbd", "xxx", "broken", "deprecated", "legacy", "old version", "outdated", ] if any(phrase in doc_lower for phrase in skip_phrases): return False # Ensure meaningful documentation for code retrieval if func_name.lower() in doc_lower and doc_words < 5: return False # Code structure validation (more comprehensive for retrieval) has_function = any( pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "] ) if not has_function: return False # Skip trivial or incomplete code trivial_code_patterns = [ "pass", "return None", "return;", "throw new Error", "# TODO", "// TODO", "print(", "console.log(", ] if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100: return False # Ensure documentation describes functionality (not just naming) generic_docs = [ "returns a value", "does something", "helper function", "utility method", "this function", "this method", "returns the result", "performs operation", ] if any(generic in doc_lower for generic in generic_docs): return False # Ensure documentation has descriptive content for retrieval descriptive_words = [ "parse", "convert", "transform", "calculate", "validate", "format", "filter", "sort", "search", "find", "create", "generate", "process", "handle", "manage", "update", "modify", "remove", "delete", "add", ] if not any(word in doc_lower for word in descriptive_words) and doc_words < 8: return False # Code-documentation alignment check (key for retrieval quality) return _check_code_doc_alignment(doc_string, code_string, func_name) def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool: """Check if documentation and code are well-aligned for retrieval tasks.""" doc_lower = doc_string.lower() code_lower = code_string.lower() # Function name should relate to documentation func_base = func_name.lower().replace("_", " ").replace("-", " ") # Check for obvious mismatches doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"]) code_has_return = "return " in code_lower # If doc mentions returning something, code should have returns if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3: return False # Check for parameter mentions alignment any(word in doc_lower for word in ["parameter", "param", "argument", "input"]) "(" in func_name and func_name.count("(") == 1 # Basic semantic alignment action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"] doc_actions = [word for word in action_words if word in doc_lower] [word for word in action_words if word in code_lower or word in func_base] # If documentation mentions specific actions, code or function name should reflect them return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions)) def _create_training_samples( doc_string: str, code_string: str, func_name: str, language: str, create_multiple_formats: bool, ) -> list[dict[str, Any]]: """Create optimized training samples for code retrieval with proper training schema.""" samples = [] if create_multiple_formats: # Format 1: Documentation query → Code (direct evaluation format) query_1 = doc_string text_1 = _format_training_text(query_1, code_string, language) samples.append( { "language": language, "query": query_1, "code": code_string, "text": text_1, } ) # Format 2: How-to query (realistic developer search) query_2 = _generate_how_to_query(doc_string, func_name, language) text_2 = _format_training_text(query_2, code_string, language) samples.append( { "language": language, "query": query_2, "code": code_string, "text": text_2, } ) # Format 3: Functional requirement query query_3 = _generate_functional_query(doc_string, func_name) text_3 = _format_training_text(query_3, code_string, language) samples.append( { "language": language, "query": query_3, "code": code_string, "text": text_3, } ) # Format 4: Implementation-specific query query_4 = _generate_implementation_query(doc_string, func_name, language) text_4 = _format_training_text(query_4, code_string, language) samples.append( { "language": language, "query": query_4, "code": code_string, "text": text_4, } ) else: # Simple format - direct documentation to code query = doc_string text = _format_training_text(query, code_string, language) samples.append( { "language": language, "query": query, "code": code_string, "text": text, } ) return samples def _format_training_text(query: str, code: str, language: str) -> str: """Format query and code into a single training text chunk with markdown-style code blocks.""" # Clean up query but preserve internal code formatting query_clean = query.strip() code_clean = code.strip() # Create training text with proper markdown format and newline separation # Structure: query + empty line + markdown code block with language return f"{query_clean}\n\n```{language}\n{code_clean}\n```" def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str: """Generate realistic 'how to' queries that developers might actually search for.""" # Extract key action words from documentation doc_lower = doc_string.lower() func_lower = func_name.lower() # Common developer query patterns if "sort" in doc_lower or "sort" in func_lower: return f"How to sort data in {language}" if "parse" in doc_lower or "parse" in func_lower: return f"How to parse data in {language}" if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower: return f"How to convert data in {language}" if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower: return f"How to validate input in {language}" if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower: return f"How to calculate values in {language}" if "format" in doc_lower or "format" in func_lower: return f"How to format output in {language}" if "filter" in doc_lower or "filter" in func_lower: return f"How to filter data in {language}" if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower: return f"How to search through data in {language}" # Use function name for more specific queries if func_name and len(func_name) > 2: # Extract meaningful words from function name func_words = func_name.replace("_", " ").replace("-", " ").strip() if func_words: return f"How to {func_words.lower()} in {language}" # Fallback to more generic query action = doc_string.split()[0] if doc_string.split() else "implement" return f"How to {action.lower()} in {language}" def _generate_functional_query(doc_string: str, func_name: str) -> str: """Generate functional requirement queries focusing on what the code accomplishes.""" # Clean up documentation to create natural query doc_clean = doc_string.strip().rstrip(".") # Transform to question format if doc_clean.startswith(("Returns", "Return")): return f"Function that {doc_clean.lower()}" if doc_clean.startswith(("Creates", "Create")): return f"Code to {doc_clean.lower()}" if doc_clean.startswith(("Checks", "Check")): return f"Function to {doc_clean.lower()}" # Use function name to enhance the query if available if func_name and len(func_name) > 2: func_words = func_name.replace("_", " ").replace("-", " ").strip() if func_words and len(doc_clean) < 30: # Only for short docs return f"Function named '{func_name}' that {doc_clean.lower()}" return f"Implementation that {doc_clean.lower()}" def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str: """Generate implementation-specific queries with technical details.""" doc_lower = doc_string.lower() func_lower = func_name.lower() if func_name else "" # Add language-specific implementation details if language == "python": if "list" in doc_lower or "array" in doc_lower or "list" in func_lower: return f"Python function to {doc_string.lower()} using lists" if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower: return f"Python function to {doc_string.lower()} using dictionaries" # Include function name for context if available if func_name and len(func_name) > 2: return f"Python implementation of {func_name}: {doc_string.lower()}" return f"Python implementation: {doc_string.lower()}" if language == "java": func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" return f"Java method to {doc_string.lower()}{func_suffix}" if language == "javascript": func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" return f"JavaScript function to {doc_string.lower()}{func_suffix}" if language == "php": func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" return f"PHP function to {doc_string.lower()}{func_suffix}" if language == "ruby": func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" return f"Ruby method to {doc_string.lower()}{func_suffix}" if language == "go": func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else "" return f"Go function to {doc_string.lower()}{func_suffix}" return f"{language} code to {doc_string.lower()}" def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: """Create stratified train/test splits preserving language distribution.""" # Define split ratios train_ratio = 0.9 # test_ratio = 0.1 (remainder) train_dfs = [] test_dfs = [] # Split by language to ensure balanced representation for language in df["language"].unique(): lang_df = df[df["language"] == language].copy() n_samples = len(lang_df) # Calculate split sizes n_train = int(n_samples * train_ratio) # Remainder goes to test # Shuffle and split lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True) train_dfs.append(lang_df[:n_train]) test_dfs.append(lang_df[n_train:]) # Combine and shuffle again train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True) test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True) logger.info("šŸ“Š Created stratified splits:") logger.info(f" - Train: {len(train_df)} samples") logger.info(f" - Test: {len(test_df)} samples") return train_df, test_df def _save_datasets( output_dir: Path, train_df: pd.DataFrame, test_df: pd.DataFrame, ) -> dict[str, str]: """Save datasets in parquet format with compression.""" dataset_files = {} # Save each split for split_name, df in [("train", train_df), ("test", test_df)]: filepath = output_dir / f"{split_name}.parquet" df.to_parquet( filepath, compression="snappy", index=False, ) dataset_files[split_name] = str(filepath) logger.info(f"šŸ’¾ Saved {split_name}: {len(df)} samples → {filepath}") # Also save a combined dataset for convenience combined_df = pd.concat([train_df, test_df], ignore_index=True) combined_filepath = output_dir / "combined.parquet" combined_df.to_parquet(combined_filepath, compression="snappy", index=False) dataset_files["combined"] = str(combined_filepath) logger.info(f"šŸ’¾ Saved combined: {len(combined_df)} samples → {combined_filepath}") return dataset_files def load_optimized_dataset( output_dir: Path | None = None, split: str = "train", ) -> pd.DataFrame: """ Load a previously created optimized dataset. Args: output_dir: Directory containing the dataset files split: Which split to load ('train', 'test', 'combined') Returns: DataFrame with the requested dataset split """ if output_dir is None: output_dir = DATASET_OUTPUT_DIR filepath = output_dir / f"{split}.parquet" if not filepath.exists(): available_files = list(output_dir.glob("*.parquet")) available_splits = [f.stem for f in available_files] msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}" raise FileNotFoundError(msg) logger.info(f"šŸ“‚ Loading {split} dataset from {filepath}") df = pd.read_parquet(filepath) logger.info(f"āœ… Loaded {len(df)} samples") return df def main( max_samples_per_lang: Annotated[ int, typer.Option(help="Maximum samples per language") ] = DEFAULT_MAX_SAMPLES_PER_LANG, min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS, max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS, min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS, max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS, output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None, simple_format: Annotated[ bool, typer.Option(help="Create only simple format (not multiple training formats)") ] = False, ) -> None: """Create optimized training dataset from CodeSearchNet for code search tasks.""" logger.info("šŸš€ Starting optimized dataset creation command...") # Convert output_dir to Path if provided output_path = Path(output_dir) if output_dir else None # Create the dataset try: metadata = create_optimized_dataset( max_samples_per_lang=max_samples_per_lang, min_doc_words=min_doc_words, max_doc_words=max_doc_words, min_code_chars=min_code_chars, max_code_chars=max_code_chars, output_dir=output_path, create_multiple_formats=not simple_format, ) logger.info("āœ… Dataset creation completed successfully!") logger.info(f"šŸ“ Output directory: {metadata['files']['train']}") # Print summary statistics print("\n" + "=" * 60) print("šŸ“Š DATASET CREATION SUMMARY") print("=" * 60) print(f"Total samples created: {metadata['total_samples']:,}") print(f"Processing time: {metadata['processing_time']:.2f} seconds") print("\nSplit distribution:") print(f" • Train: {metadata['train_samples']:,} samples") print(f" • Test: {metadata['test_samples']:,} samples") print("\nLanguage distribution:") for lang, stats in metadata["language_stats"].items(): if "error" not in stats: print(f" • {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)") print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}") print("=" * 60) except Exception as e: logger.exception("āŒ Dataset creation failed") raise typer.Exit(1) from e if __name__ == "__main__": typer.run(main)