""" Unified Code-Specialized Model2Vec Distillation Script. This script provides a unified approach for creating code-specialized embeddings using Model2Vec distillation with optional code-specific training. Features: - Basic distillation (default): Simple Model2Vec distillation - Advanced training (--train flag): Additional CodeSearchNet fine-tuning - Checkpoint support with Beam sync utilities - Multi-teacher model processing - Smart resume capabilities - Hierarchical storage: base → final Directory Structure: - code_model2vec/base: Basic distilled models (first step) - code_model2vec/final: Final models (copied from base or after training) Usage: distiller distill [--use-beam] [--train] # Basic distillation or with training """ import importlib.util import json import logging import os import time from pathlib import Path from typing import Annotated, Any import torch import typer from beam import function from sentence_transformers import SentenceTransformer from distiller.model2vec.distill import distill # Try to import flash_attn to check if it's available from .beam_utils import ( BeamCheckpointManager, create_beam_utilities, download_model_from_beam, sync_checkpoints_from_beam, sync_checkpoints_to_beam, upload_model_to_beam, ) from .config import ( codesearchnet_config, directories, distillation_config, get_distillation_function_kwargs, get_training_function_kwargs, get_volume_config, languages_config, ) # Check if flash_attn is available and compatible FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None # ============================================================================= # CONFIGURATION # ============================================================================= VOLUME_CONFIG = get_volume_config() LOCAL_BASE_DIR = directories.base LOCAL_FINAL_DIR = directories.final logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Teacher models for distillation DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models) # ============================================================================= # FLASH ATTENTION UTILITIES # ============================================================================= def configure_flash_attention() -> dict[str, Any]: """Configure flash attention settings and return model kwargs.""" model_kwargs: dict[str, Any] = {} if not FLASH_ATTN_AVAILABLE: logger.info("⚠️ Flash attention not available - using standard attention") return model_kwargs # Set environment variables for flash attention os.environ["FLASH_ATTENTION_FORCE_USE"] = "1" # Disable torch compile for flash attention compatibility os.environ["TORCH_COMPILE_DISABLE"] = "1" # Enable flash attention in transformers os.environ["TOKENIZERS_PARALLELISM"] = "false" # Check if we're on a compatible GPU try: if torch.cuda.is_available(): device_capability = torch.cuda.get_device_capability() # Flash attention requires compute capability >= 7.5 (Turing, Ampere, Ada, Hopper) if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5): logger.info("✅ Flash attention enabled - compatible GPU detected") model_kwargs.update( { "model_kwargs": { "attn_implementation": "flash_attention_2", "torch_dtype": torch.float16, # Flash attention works best with fp16 "use_flash_attention_2": True, "_attn_implementation": "flash_attention_2", # Alternative key for some models } } ) else: logger.info(f"⚠️ GPU compute capability {device_capability} < 7.5 - flash attention disabled") else: logger.info("⚠️ No CUDA available - flash attention disabled") except Exception as e: logger.warning(f"⚠️ Failed to check GPU compatibility: {e} - flash attention disabled") return model_kwargs def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer: """Load a SentenceTransformer model with flash attention if available.""" flash_kwargs = configure_flash_attention() try: # Try loading with flash attention first if flash_kwargs and "model_kwargs" in flash_kwargs: logger.info(f"🚀 Loading model with flash attention: {Path(model_path).name}") model = SentenceTransformer(model_path, device=device, trust_remote_code=True, **flash_kwargs) logger.info("✅ Model loaded successfully with flash attention") return model except Exception as e: logger.warning(f"⚠️ Failed to load with flash attention: {e}") logger.info("🔄 Falling back to standard attention") # Fallback to standard loading logger.info(f"📂 Loading model with standard attention: {Path(model_path).name}") model = SentenceTransformer(model_path, device=device, trust_remote_code=True) logger.info("✅ Model loaded successfully with standard attention") return model # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= def get_current_config_hash(enable_training: bool) -> str: """Generate a hash of current configuration parameters for checkpoint validation.""" import hashlib config_params = { "pca_dims": distillation_config.optimal_pca_dims, "sif_coefficient": distillation_config.sif_coefficient, "apply_zipf": distillation_config.apply_zipf, "enable_training": enable_training, } if enable_training: # Add a simple hash of tokenlearn parameters for config validation tokenlearn_hash = hash( f"{distillation_config.tokenlearn_dataset}_{distillation_config.tokenlearn_dataset_name}_{distillation_config.tokenlearn_text_key}" ) config_params["tokenlearn_hash"] = float(abs(tokenlearn_hash) % 1000000) # Convert to float for consistency config_str = str(sorted(config_params.items())) return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324 def check_existing_base_model(teacher_name: str) -> str | None: """Check if base distilled model already exists locally.""" base_dir = Path(LOCAL_BASE_DIR) model_dir = base_dir / f"code_model2vec_{teacher_name}" if model_dir.exists(): # Check for essential model files has_config = (model_dir / "config.json").exists() has_model_file = any( [ (model_dir / "model.safetensors").exists(), (model_dir / "model.bin").exists(), (model_dir / "pytorch_model.bin").exists(), ] ) if has_config and has_model_file: logger.info(f"✅ Found existing base model: {teacher_name}") return str(model_dir) return None def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None: """Check if final model already exists locally.""" final_dir = Path(LOCAL_FINAL_DIR) # Add suffix for trained models model_name = f"code_model2vec_{teacher_name}" if enable_training: model_name += "_fine_tuned" final_path = final_dir / model_name if final_path.exists(): # Check for essential model files has_config = (final_path / "config.json").exists() has_model_file = any( [ (final_path / "model.safetensors").exists(), (final_path / "model.bin").exists(), (final_path / "pytorch_model.bin").exists(), ] ) if has_config and has_model_file: logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}") return str(final_path) return None def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool: """Copy base model to final directory.""" import shutil base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" # Add suffix for trained models final_model_name = f"code_model2vec_{teacher_name}" if enable_training: final_model_name += "_fine_tuned" final_path = Path(LOCAL_FINAL_DIR) / final_model_name try: final_path.parent.mkdir(parents=True, exist_ok=True) if final_path.exists(): shutil.rmtree(final_path) shutil.copytree(base_path, final_path) logger.info(f"📁 Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}") return True except Exception: logger.exception(f"❌ Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}") return False def sync_model_from_beam( teacher_name: str, target_dir: str, use_beam_utilities: bool = False, ) -> bool: """Sync model from Beam volume to local directory.""" if not use_beam_utilities: return False try: target_path = Path(target_dir) target_path.mkdir(parents=True, exist_ok=True) beam_model_name = f"{teacher_name}_model" success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path)) if success: logger.info(f"📥 Synced {teacher_name} from Beam to {target_dir}") return True logger.warning(f"⚠️ Failed to sync {teacher_name} from Beam") return False except Exception as e: logger.warning(f"Failed to sync {teacher_name} from Beam: {e}") return False def sync_model_to_beam( teacher_name: str, source_dir: str, use_beam_utilities: bool = False, ) -> bool: """Sync model from local directory to Beam volume.""" if not use_beam_utilities: return False try: beam_model_name = f"{teacher_name}_model" success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir) if success: logger.info(f"📤 Synced {teacher_name} to Beam from {source_dir}") return True logger.warning(f"⚠️ Failed to sync {teacher_name} to Beam") return False except Exception as e: logger.warning(f"Failed to sync {teacher_name} to Beam: {e}") return False # ============================================================================= # DISTILLATION FUNCTIONS # ============================================================================= def simple_distillation( teacher_model: str, output_dir: str, pca_dims: int | None = None, retry_with_cache_clear: bool = False, ) -> Any: """ Perform simple Model2Vec distillation without additional training. Args: teacher_model: Name of teacher model output_dir: Output directory for the distilled model pca_dims: PCA dimensions (uses config default if None) retry_with_cache_clear: Whether this is a retry after clearing cache Returns: Distilled model or None if failed """ if pca_dims is None: pca_dims = int(distillation_config.optimal_pca_dims) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else "" logger.info(f"🔄 Simple distillation{retry_suffix}: {teacher_model} → {output_dir}") logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") start_time = time.time() try: # Perform distillation with optimal parameters model = distill( model_name=teacher_model, pca_dims=int(pca_dims), apply_zipf=bool(distillation_config.apply_zipf), sif_coefficient=float(distillation_config.sif_coefficient), trust_remote_code=True, ) logger.info("✅ Core distillation completed successfully") # Validate model before saving if hasattr(model, "tokenizer") and hasattr(model, "embedding"): vocab_size = len(model.tokenizer.get_vocab()) embedding_size = model.embedding.shape[0] logger.info("📊 Model validation:") logger.info(f" - Vocabulary size: {vocab_size}") logger.info(f" - Embedding matrix size: {embedding_size}") if vocab_size != embedding_size: logger.warning(f"⚠️ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}") logger.warning("⚠️ This may cause issues in downstream usage") else: logger.info("✅ Vocabulary and embedding sizes match") # Save the model model.save_pretrained(str(output_path)) logger.info(f"💾 Model saved to {output_path}") # Log model info logger.info(f"Model type: {type(model)}") if hasattr(model, "embedding"): logger.info(f"Embedding shape: {model.embedding.shape}") logger.info(f"Embedding dtype: {model.embedding.dtype}") total_time = time.time() - start_time logger.info(f"🎉 Simple distillation completed in {total_time:.2f} seconds") return model except ValueError as e: if "Number of tokens" in str(e) and "does not match number of vectors" in str(e): logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue") logger.warning(f"Error details: {e}") logger.warning("💡 This model has incompatible tokenization. Skipping...") return None if "weight is on the meta device" in str(e): logger.warning(f"⚠️ Device placement issue with {teacher_model} - model weights on meta device") logger.warning(f"Error details: {e}") logger.warning("💡 This model has device placement issues. Skipping...") return None raise except AttributeError as e: if "backend_tokenizer" in str(e): logger.warning(f"⚠️ Tokenizer compatibility issue with {teacher_model}") logger.warning(f"Error details: {e}") logger.warning("💡 This model's tokenizer is incompatible with Model2Vec. Skipping...") return None raise except FileNotFoundError as e: if "transformers_modules" in str(e) or "xlm_padding.py" in str(e): logger.warning(f"⚠️ Missing custom model files for {teacher_model}") logger.warning(f"Error details: {e}") # Try clearing cache and retrying once if not retry_with_cache_clear: logger.info("🔧 Attempting to clear cache and retry...") if clear_model_cache(teacher_model): logger.info("🔄 Retrying distillation after cache clear...") return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True) logger.warning("💡 This model has missing dependencies. Manual intervention may be required.") return None raise except Exception: logger.exception(f"❌ Simple distillation failed for {teacher_model}") return None def load_optimized_dataset( max_samples: int | None = None, checkpoint_manager: BeamCheckpointManager | None = None, dataset_path: str | None = None, ) -> list[str]: """Load our pre-created optimized dataset for tokenlearn training.""" from .dataset import DATASET_OUTPUT_DIR from .dataset import load_optimized_dataset as load_dataset_func # Use configuration if not provided as parameter if dataset_path is None: dataset_path = distillation_config.custom_dataset_path dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR # Use configuration default if not specified if max_samples is None: max_samples = distillation_config.tokenlearn_max_samples logger.info(f"🎯 Loading optimized dataset from {dataset_dir}") logger.info(f"📊 Target samples: {max_samples}") try: # Load the training split of our optimized dataset df = load_dataset_func(output_dir=dataset_dir, split="train") # Extract the text column (which contains our formatted query + code) texts = df["text"].tolist() # Shuffle for better training distribution import random random.seed(42) random.shuffle(texts) # Limit to max_samples if len(texts) > max_samples: texts = texts[:max_samples] logger.info(f"✅ Loaded {len(texts)} optimized training samples") # Log language distribution languages = df["language"].value_counts() logger.info("📊 Language distribution:") for lang, count in languages.items(): percentage = (count / len(df)) * 100 logger.info(f" {lang}: {count} samples ({percentage:.1f}%)") return texts except Exception as e: logger.warning(f"⚠️ Failed to load optimized dataset: {e}") logger.info("🔄 Falling back to original CodeSearchNet loading...") return load_codesearchnet_dataset(max_samples, checkpoint_manager) def load_codesearchnet_dataset( max_samples: int | None = None, checkpoint_manager: BeamCheckpointManager | None = None, ) -> list[str]: """Load and format the CodeSearchNet dataset for token frequency computation.""" from datasets import load_dataset # Use configuration default if not specified if max_samples is None: max_samples = distillation_config.tokenlearn_max_samples logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}") logger.info(f"Limiting to {max_samples} samples for training efficiency") logger.info(f"Languages: {', '.join(languages_config.all)}") # Check for existing dataset checkpoint texts = [] start_from = 0 if checkpoint_manager: checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0) if checkpoint_data: cached_texts = checkpoint_data.get("data", {}).get("texts", []) if len(cached_texts) >= max_samples: logger.info(f"✅ Resumed dataset loading: {len(cached_texts)} texts from checkpoint") return cached_texts[:max_samples] logger.info(f"📋 Partial dataset found: {len(cached_texts)} texts, continuing...") texts = cached_texts start_from = len(texts) try: # Calculate samples per language for balanced distribution num_languages = len(languages_config.all) samples_per_language = max_samples // num_languages remaining_samples = max_samples % num_languages logger.info(f"📊 Target distribution: {samples_per_language} samples per language") if remaining_samples > 0: logger.info(f"📊 Extra {remaining_samples} samples will be distributed to first languages") # Load training data from each language separately for balanced distribution language_texts: dict[str, list[str]] = {} total_collected = len(texts) for i, language in enumerate(languages_config.all): if total_collected >= max_samples: break logger.info(f"🔍 Loading {language} training data...") # Determine how many samples to collect for this language target_for_lang = samples_per_language if i < remaining_samples: # Distribute extra samples to first languages target_for_lang += 1 # Skip if we already have enough from this language if language in language_texts and len(language_texts[language]) >= target_for_lang: continue try: # Load training split for the specific language (same format as evaluate.py) from datasets import load_dataset dataset = load_dataset( codesearchnet_config.dataset_name, language, split="train", trust_remote_code=True, ) lang_texts: list[str] = [] processed_count = 0 for processed_count, example in enumerate(dataset, 1): if len(lang_texts) >= target_for_lang: break # Use same field names as evaluate.py doc_string = example.get("func_documentation_string", "").strip() code_string = example.get("func_code_string", "").strip() if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50: # Format as documentation-code pair for training (same as evaluate.py) text = f"Documentation: {doc_string}\nCode:\n{code_string}" # Ensure reasonable length for embedding models if len(text) <= 2048: lang_texts.append(text) if processed_count % 5000 == 0: logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}") language_texts[language] = lang_texts total_collected += len(lang_texts) logger.info(f"✅ {language}: collected {len(lang_texts)} samples") except Exception as e: logger.warning(f"⚠️ Failed to load {language} data: {e}") continue # Combine all language texts in a balanced way combined_texts = [] # Add existing texts first (from checkpoint) if start_from > 0: combined_texts = texts[:start_from] # Interleave texts from different languages for better training distribution max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0 for sample_idx in range(max_lang_samples): for language in languages_config.all: if len(combined_texts) >= max_samples: break if language in language_texts and sample_idx < len(language_texts[language]): combined_texts.append(language_texts[language][sample_idx]) if len(combined_texts) >= max_samples: break # Truncate to exact max_samples combined_texts = combined_texts[:max_samples] # Log final distribution logger.info("📊 Final dataset distribution:") lang_counts: dict[str, int] = {} for text in combined_texts: # Simple heuristic to identify language from code patterns if "def " in text and ":" in text: lang_counts["python"] = lang_counts.get("python", 0) + 1 elif "function " in text and "{" in text: lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1 elif "public " in text and "class " in text: lang_counts["java"] = lang_counts.get("java", 0) + 1 elif " torch.Tensor: """Generate teacher embeddings for code training with checkpoint support.""" logger.info(f"Generating teacher embeddings for {len(texts)} texts...") # Check for existing embeddings checkpoint if checkpoint_manager: volume_path = Path(VOLUME_CONFIG.mount_path) embeddings_path = volume_path / "embeddings_cache.pt" config_path = volume_path / "embeddings_config.json" if embeddings_path.exists() and config_path.exists(): try: # Load config first to validate compatibility with config_path.open("r") as f: config_data = json.load(f) current_hash = get_current_config_hash(enable_training=True) if config_data.get("config_hash") == current_hash: # Load the embeddings tensor final_embeddings = torch.load(embeddings_path, map_location="cpu") num_expected = config_data.get("num_texts", len(texts)) if final_embeddings.shape[0] >= num_expected: logger.info(f"✅ Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)") return final_embeddings[: len(texts)] except Exception as e: logger.warning(f"Failed to load embeddings cache: {e}, regenerating...") # Generate embeddings from scratch logger.info("Generating fresh teacher embeddings...") batch_size = 16 # Fixed batch size for teacher embedding generation embeddings_list = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] try: batch_embeddings = teacher_model.encode( batch_texts, convert_to_tensor=True, batch_size=batch_size, show_progress_bar=False, normalize_embeddings=True, ) embeddings_list.append(batch_embeddings) if i % (batch_size * 10) == 0: logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts") except torch.cuda.OutOfMemoryError: logger.warning(f"GPU OOM with batch size {batch_size}, reducing...") torch.cuda.empty_cache() batch_size = max(1, batch_size // 2) # Retry with smaller batch size batch_embeddings = teacher_model.encode( batch_texts, convert_to_tensor=True, batch_size=batch_size, show_progress_bar=False, normalize_embeddings=True, ) embeddings_list.append(batch_embeddings) # Combine all embeddings teacher_embeddings = torch.cat(embeddings_list, dim=0) # Ensure fp32 precision if teacher_embeddings.dtype != torch.float32: teacher_embeddings = teacher_embeddings.to(torch.float32) logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}") # Save embeddings cache for future runs if checkpoint_manager: try: volume_path = Path(VOLUME_CONFIG.mount_path) embeddings_path = volume_path / "embeddings_cache.pt" config_path = volume_path / "embeddings_config.json" # Save embeddings tensor torch.save(teacher_embeddings, embeddings_path) # Save configuration config_data = { "config_hash": get_current_config_hash(enable_training=True), "num_texts": len(texts), "embedding_shape": list(teacher_embeddings.shape), "timestamp": time.time(), } with config_path.open("w") as f: json.dump(config_data, f, indent=2) logger.info("💾 Saved embeddings cache for future runs") except Exception as e: logger.warning(f"Failed to save embeddings cache: {e}") return teacher_embeddings def tokenlearn_training( student_model: Any, teacher_model: SentenceTransformer, checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001 ) -> Any: """ Perform tokenlearn training following the official POTION approach. This follows the 4-step process: 1. Model2Vec distillation (already done - student_model) 2. Sentence transformer inference (create features) 3. Tokenlearn training """ from pathlib import Path logger.info("🧪 Starting tokenlearn training (POTION approach)...") # Create persistent directories for tokenlearn workflow (for checkpoint preservation) teacher_model_name = getattr(teacher_model, "model_name", None) if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001 # Try to extract from the first module if it's a SentenceTransformer first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001 if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"): teacher_model_name = first_module.auto_model.name_or_path if not teacher_model_name: teacher_model_name = "unknown_teacher" # Use persistent directory for tokenlearn checkpoints teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_") persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug features_dir = persistent_tokenlearn_dir / "features" model_dir = persistent_tokenlearn_dir / "base_model" trained_dir = persistent_tokenlearn_dir / "trained_model" features_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True) trained_dir.mkdir(parents=True, exist_ok=True) logger.info(f"📁 Using persistent tokenlearn directory: {persistent_tokenlearn_dir}") # Save the base distilled model for tokenlearn student_model.save_pretrained(str(model_dir)) logger.info(f"💾 Saved base model to {model_dir}") # Step 2: Create features using sentence transformer logger.info("🔍 Step 2: Creating features using sentence transformer...") # Get teacher model name/path for tokenlearn teacher_model_name = getattr(teacher_model, "model_name", None) if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001 # Try to extract from the first module if it's a SentenceTransformer # _modules is a dict-like container, get the first module by iterating first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001 if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"): teacher_model_name = first_module.auto_model.name_or_path logger.info(f"📊 Using teacher model: {teacher_model_name}") # Prepare dataset for tokenlearn featurization dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir) # Check if featurization already completed (checkpoint detection) featurization_complete_marker = features_dir / ".featurization_complete" if featurization_complete_marker.exists() and verify_featurization_output(features_dir): logger.info("✅ Found existing featurization checkpoint with valid output files") logger.info(f"📂 Using cached features from: {features_dir}") # Verify marker is still valid output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json")) logger.info(f"📁 Found {len(output_files)} cached feature files") else: if featurization_complete_marker.exists(): logger.warning("⚠️ Featurization marker exists but output files are missing - re-running featurization") featurization_complete_marker.unlink() logger.info("🔄 No valid featurization checkpoint found - starting featurization...") if not teacher_model_name: logger.warning("⚠️ Could not determine teacher model name, using fallback") teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model logger.info(f"📊 Using teacher model: {teacher_model_name}") try: # Use direct function call instead of subprocess from datasets import load_dataset from distiller.tokenlearn.featurize import featurize logger.info("🔄 Running tokenlearn featurization...") logger.info(f"📊 Dataset: {dataset_path} (config: {dataset_name})") logger.info(f"📝 Text field: {text_key}") # Load the dataset if dataset_name is None: # For local JSON files, don't pass name parameter dataset = load_dataset( "json", data_files=dataset_path, split="train", streaming=True, ) else: # For remote datasets with specific configurations dataset = load_dataset( dataset_path, name=dataset_name, split="train", streaming=True, ) # Call featurization function directly featurize( dataset=iter(dataset), model=teacher_model, output_dir=str(features_dir), max_means=50000, # IMPROVEMENT: Limit means to prevent overfitting batch_size=512, # IMPROVEMENT: Smaller batch for better gradients text_key=text_key, ) logger.info("✅ Featurization completed successfully") # Create checkpoint marker to indicate featurization is complete featurization_complete_marker.touch() logger.info(f"💾 Created featurization checkpoint: {featurization_complete_marker}") except Exception as e: logger.exception("💥 Tokenlearn featurization failed") logger.exception("💥 Tokenlearn featurization is required for training - cannot proceed") msg = f"Tokenlearn featurization failed: {e}" raise RuntimeError(msg) from e # Step 3: Train using tokenlearn-train logger.info("🎓 Step 3: Training using tokenlearn...") # Check if training already completed (checkpoint detection) training_complete_marker = trained_dir / ".training_complete" training_fallback_marker = trained_dir / ".training_fallback" if training_complete_marker.exists() and verify_training_output(trained_dir): logger.info("✅ Found existing training checkpoint with valid model files") logger.info(f"📂 Using cached trained model from: {trained_dir}") # Show available model files model_files = [] for pattern in ["*.json", "*.safetensors", "*.bin"]: model_files.extend(list(trained_dir.glob(pattern))) for subdir in ["model", "model_weighted"]: subdir_path = trained_dir / subdir if subdir_path.exists(): model_files.extend(list(subdir_path.glob(pattern))) logger.info(f"📁 Found {len(model_files)} cached model files") elif training_fallback_marker.exists(): logger.warning("⚠️ Training fallback marker found - tokenlearn failed previously") logger.info("🔄 Proceeding with fallback to base model (simple distillation)") # Skip training and proceed to model loading (will fallback to base model) else: if training_complete_marker.exists(): logger.warning("⚠️ Training marker exists but model files are missing - re-running training") training_complete_marker.unlink() logger.info("🔄 No valid training checkpoint found - starting training...") try: # Use direct function call instead of subprocess from distiller.tokenlearn.train import train_model from distiller.tokenlearn.utils import collect_means_and_texts # IMPROVED APPROACH: Try optimized parameters first logger.info("🚀 Attempting IMPROVED tokenlearn training with optimized parameters...") logger.info("📊 Using smaller vocabulary and conservative PCA to prevent overfitting") # Collect training data from features directory paths = sorted(features_dir.glob("*.json")) train_txt, train_vec = collect_means_and_texts(paths) logger.info(f"📊 Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training") try: # Try improved parameters first trained_model = train_model( model_name=str(teacher_model_name), train_txt=train_txt, train_vec=train_vec, device="cuda" if torch.cuda.is_available() else "cpu", vocab_size=25000, # IMPROVEMENT: Smaller vocabulary to prevent overfitting pca_dims=256, # IMPROVEMENT: Conservative PCA dimensions ) # Save the trained model trained_model.save_pretrained(str(trained_dir)) logger.info("✅ IMPROVED tokenlearn training completed successfully") training_complete_marker.touch() logger.info(f"💾 Created improved training checkpoint: {training_complete_marker}") except Exception as e: logger.warning(f"⚠️ Improved training failed: {e}") logger.info("🔄 Falling back to CONSERVATIVE tokenlearn training...") # FALLBACK: Ultra-conservative training approach try: trained_model = train_model( model_name=str(teacher_model_name), train_txt=train_txt, train_vec=train_vec, device="cuda" if torch.cuda.is_available() else "cpu", vocab_size=15000, # FALLBACK: Even smaller vocabulary pca_dims=128, # FALLBACK: Smaller PCA dimensions ) # Save the trained model trained_model.save_pretrained(str(trained_dir)) logger.info("✅ Conservative tokenlearn training completed successfully") training_complete_marker.touch() logger.info(f"💾 Created conservative training checkpoint: {training_complete_marker}") except Exception as e2: logger.exception("❌ Conservative tokenlearn training also failed") logger.exception("💥 All training approaches failed - check output above for details") # Create training marker to indicate we tried but failed training_fallback_marker = trained_dir / ".training_fallback" training_fallback_marker.touch() logger.exception("💥 Tokenlearn training failed completely") msg = f"All tokenlearn training approaches failed: {e2}" raise RuntimeError(msg) from e2 except Exception as e: logger.warning("💥 All tokenlearn training approaches failed") logger.exception("💥 All training approaches failed completely - cannot proceed") msg = f"All training approaches failed: {e}" raise RuntimeError(msg) from e # Step 4: Load the trained model and apply post-training re-regularization logger.info("📦 Step 4: Loading trained model and applying post-training re-regularization...") # Check if we need to use fallback due to tokenlearn failure training_fallback_marker = trained_dir / ".training_fallback" if training_fallback_marker.exists(): logger.error("❌ Tokenlearn training failed previously - cannot return trained model") logger.error("💥 Training was requested but failed - this would be misleading to return base model") msg = "Tokenlearn training failed - cannot proceed with training pipeline" raise RuntimeError(msg) try: from distiller.model2vec.model import StaticModel # Load the trained model from tokenlearn trained_model_path = trained_dir / "model" if not trained_model_path.exists(): # Try alternative paths possible_paths = [ trained_dir / "model_weighted", trained_dir, ] for path in possible_paths: if path.exists() and any(path.glob("*.json")): trained_model_path = path break else: logger.error(f"❌ Could not find trained model in {trained_dir}") logger.error("💥 Training was requested but no trained model found - cannot proceed") msg = f"Trained model not found in {trained_dir} - training pipeline failed" raise RuntimeError(msg) # Load the model before re-regularization logger.info("🔄 Loading model from tokenlearn training...") trained_model = StaticModel.from_pretrained(str(trained_model_path)) # Return the trained model directly logger.info("✅ Tokenlearn training pipeline completed successfully") return trained_model except ValueError as e: if "Number of tokens" in str(e) and "does not match number of vectors" in str(e): logger.exception("💥 Token-vector mismatch in tokenlearn training") logger.exception("Error details") logger.exception("🔧 This is a known issue with tokenlearn/Model2Vec integration") logger.exception("💥 Training was requested but failed due to token-vector mismatch") msg = f"Tokenlearn training failed due to token-vector mismatch: {e}" raise RuntimeError(msg) from e logger.exception("💥 Failed to load tokenlearn trained model") msg = f"Failed to load tokenlearn trained model: {e}" raise RuntimeError(msg) from e except Exception as e: logger.exception("💥 Failed to load tokenlearn trained model") logger.exception("💥 Cannot load trained model - training failed") msg = f"Failed to load tokenlearn trained model: {e}" raise RuntimeError(msg) from e def distill_single_teacher( teacher_model: str, enable_training: bool = False, use_beam_utilities: bool = False, pca_dims: int | None = None, ) -> dict[str, Any]: """ Distill a single teacher model with optional training. Args: teacher_model: Name of teacher model enable_training: Whether to enable advanced training use_beam_utilities: Whether to use Beam utilities pca_dims: PCA dimensions Returns: Dictionary with distillation results """ teacher_name = teacher_model.split("/")[-1].replace("-", "_") base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" # Add suffix for trained models final_model_name = f"code_model2vec_{teacher_name}" if enable_training: final_model_name += "_fine_tuned" final_dir = Path(LOCAL_FINAL_DIR) / final_model_name logger.info(f"\n{'=' * 60}") logger.info(f"🔄 Processing teacher model: {teacher_model}") logger.info(f"📁 Teacher name: {teacher_name}") logger.info(f"🎓 Training enabled: {enable_training}") logger.info(f"{'=' * 60}") # Check model compatibility first is_compatible, warning_msg = check_model_compatibility(teacher_model) if not is_compatible: logger.warning(f"⚠️ Known compatibility issue: {warning_msg}") logger.info("🔧 Attempting distillation anyway, but may fail...") # Try model-specific workarounds workaround_type = try_model_workarounds(teacher_model) # Don't skip if we have a workaround - we'll use it later start_time = time.time() # Initialize Beam utilities if requested checkpoint_mgr = None if use_beam_utilities: try: _, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path) except Exception as e: logger.warning(f"Failed to initialize Beam utilities: {e}") try: # Step 1: Check for existing final model existing_final = check_existing_final_model(teacher_name, enable_training) if existing_final: logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}") total_time = time.time() - start_time return { "teacher_model": teacher_model, "teacher_name": teacher_name, "status": "skipped_existing_final", "final_path": existing_final, "distillation_time": total_time, } # Step 1.5: Sync existing checkpoints from Beam if using Beam utilities if use_beam_utilities and checkpoint_mgr: logger.info(f"🔄 Syncing existing checkpoints for {teacher_name}...") sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints) if enable_training: sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints) # Step 2: Check for existing base model or create it existing_base = check_existing_base_model(teacher_name) base_model = None if existing_base: logger.info(f"✅ Found existing base model: {teacher_name}") if enable_training: # Load base model for training from distiller.model2vec.model import StaticModel base_model = StaticModel.from_pretrained(existing_base) elif use_beam_utilities: synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities) if synced: existing_base = str(base_dir) if enable_training: from distiller.model2vec.model import StaticModel base_model = StaticModel.from_pretrained(existing_base) if not existing_base: # Perform simple distillation to create base model logger.info(f"🔄 Creating base model for {teacher_name}") # Check if we need specialized distillation workaround_type = try_model_workarounds(teacher_model) if workaround_type == "salesforce": base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims) elif workaround_type == "baai": base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims) else: base_model = simple_distillation(teacher_model, str(base_dir), pca_dims) if base_model is None: total_time = time.time() - start_time return { "teacher_model": teacher_model, "teacher_name": teacher_name, "status": "failed_base_distillation", "error": "Simple distillation failed", "distillation_time": total_time, } # Sync base model and checkpoints to Beam if use_beam_utilities: sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities) if checkpoint_mgr: sync_checkpoints_to_beam( VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints ) existing_base = str(base_dir) # Step 3: Handle final model creation if enable_training and base_model is not None: # Perform tokenlearn training (POTION approach) logger.info(f"🧪 Starting tokenlearn training for {teacher_name}") try: # Load teacher model for training device = "cuda" if torch.cuda.is_available() else "cpu" teacher_st_model = load_model_with_flash_attention(teacher_model, device) # Perform tokenlearn training (POTION approach) final_model = tokenlearn_training( base_model, teacher_st_model, checkpoint_mgr, ) # Save final model final_dir.mkdir(parents=True, exist_ok=True) final_model.save_pretrained(str(final_dir)) # Sync final model and training checkpoints to Beam if use_beam_utilities: sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities) if checkpoint_mgr: sync_checkpoints_to_beam( VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints ) del teacher_st_model if torch.cuda.is_available(): torch.cuda.empty_cache() except RuntimeError as e: # Training failed - clean up and return failure logger.exception(f"❌ Training failed for {teacher_name}") # Clean up teacher model if it was loaded if "teacher_st_model" in locals(): del teacher_st_model if torch.cuda.is_available(): torch.cuda.empty_cache() total_time = time.time() - start_time return { "teacher_model": teacher_model, "teacher_name": teacher_name, "status": "failed_training", "error": f"Training failed: {e!s}", "base_path": existing_base, # Base model was created successfully "distillation_time": total_time, } else: # Copy base to final (no training) logger.info(f"📁 Copying base to final for {teacher_name}") if not copy_base_to_final(teacher_name, enable_training): total_time = time.time() - start_time return { "teacher_model": teacher_model, "teacher_name": teacher_name, "status": "failed_copy_to_final", "error": "Failed to copy base to final", "distillation_time": total_time, } total_time = time.time() - start_time return { "teacher_model": teacher_model, "teacher_name": teacher_name, "status": "success", "enable_training": enable_training, "base_path": existing_base, "final_path": str(final_dir), "distillation_time": total_time, } except Exception as e: logger.exception(f"❌ Failed to process {teacher_model}") total_time = time.time() - start_time return { "teacher_model": teacher_model, "teacher_name": teacher_name, "status": "failed", "error": str(e), "distillation_time": total_time, } # ============================================================================= # MAIN EXECUTION FUNCTIONS # ============================================================================= def run_local_distillation( teacher_models: list[str] | None = None, enable_training: bool = False, pca_dims: int | None = None, clear_cache: bool = False, ) -> dict[str, Any]: """Run distillation locally.""" logger.info("🖥️ Running distillation locally") if teacher_models is None: teacher_models = DEFAULT_TEACHER_MODELS results = {} successful_models = [] logger.info("🚀 Starting distillation workflow") logger.info(f"📊 Processing {len(teacher_models)} teacher models") logger.info(f"🎓 Training enabled: {enable_training}") # Use default models if none specified models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS logger.info(f"📊 Teacher models to process: {len(models_to_distill)}") for i, model in enumerate(models_to_distill, 1): logger.info(f" {i}. {model}") # Clear cache for problematic models if requested if clear_cache: logger.info("🧹 Clearing cache for known problematic models...") problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] for model in problematic_models: if model in models_to_distill: clear_model_cache(model) # Clear tokenlearn checkpoints if requested (for training mode) # Note: Checkpoint clearing is handled at the main function level # Run distillation workflow for teacher_model in models_to_distill: result = distill_single_teacher( teacher_model=teacher_model, enable_training=enable_training, use_beam_utilities=False, pca_dims=pca_dims, ) teacher_name = result["teacher_name"] results[teacher_name] = result if result["status"] == "success" or result["status"].startswith("skipped"): successful_models.append(teacher_name) elif result["status"] == "failed_training": # Note: Training failed but base model may still be available logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded") # Summary logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!") logger.info(f"📊 Successful models: {len(successful_models)}") logger.info(f"🎓 Training mode: {'Enabled' if enable_training else 'Basic distillation only'}") for model_name in successful_models: result = results[model_name] logger.info(f"✅ {model_name}: {result['teacher_model']}") # Save results summary results_summary = { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "enable_training": enable_training, "successful_models": successful_models, "all_results": results, "total_successful": len(successful_models), "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), } # Save results to file results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json" results_file.parent.mkdir(parents=True, exist_ok=True) with results_file.open("w") as f: json.dump(results_summary, f, indent=2) logger.info(f"📊 Results summary saved to: {results_file}") return results_summary def _beam_distill_internal( teacher_models: list[str] | None = None, enable_training: bool = False, pca_dims: int | None = None, clear_cache: bool = False, ) -> dict[str, Any]: """Shared internal implementation for beam distillation.""" if teacher_models is None: teacher_models = DEFAULT_TEACHER_MODELS # Clear cache for problematic models if requested if clear_cache: logger.info("🧹 Clearing cache for known problematic models...") problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] for model in problematic_models: if model in teacher_models: clear_model_cache(model) results = {} successful_models = [] logger.info("🚀 Starting Beam distillation workflow") logger.info(f"📊 Processing {len(teacher_models)} teacher models") logger.info(f"🎓 Training enabled: {enable_training}") # Use default models if none specified models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS logger.info(f"📊 Teacher models to process: {len(models_to_distill)}") for i, model in enumerate(models_to_distill, 1): logger.info(f" {i}. {model}") for teacher_model in models_to_distill: result = distill_single_teacher( teacher_model=teacher_model, enable_training=enable_training, use_beam_utilities=True, pca_dims=pca_dims, ) teacher_name = result["teacher_name"] results[teacher_name] = result if result["status"] == "success" or result["status"].startswith("skipped"): successful_models.append(teacher_name) elif result["status"] == "failed_training": # Note: Training failed but base model may still be available logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded") # Summary logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!") logger.info(f"📊 Successful models: {len(successful_models)}") # Save results to Beam volume volume_path = Path(VOLUME_CONFIG.mount_path) results_summary = { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "enable_training": enable_training, "successful_models": successful_models, "all_results": results, "total_successful": len(successful_models), "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), } results_file = volume_path / "distillation_results.json" with results_file.open("w") as f: json.dump(results_summary, f, indent=2) logger.info(f"📊 Beam results saved to: {results_file}") return results_summary @function(**get_training_function_kwargs()) def _beam_train_models( teacher_models: list[str] | None = None, enable_training: bool = True, pca_dims: int | None = None, clear_cache: bool = False, ) -> dict[str, Any]: """Beam function for training (distillation + tokenlearn).""" logger.info("☁️ Running training on Beam") return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache) @function(**get_distillation_function_kwargs()) def _beam_distill_models( teacher_models: list[str] | None = None, enable_training: bool = False, pca_dims: int | None = None, clear_cache: bool = False, ) -> dict[str, Any]: """Beam function for basic distillation only.""" logger.info("☁️ Running distillation on Beam") return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache) def run_beam_distillation( teacher_models: list[str] | None = None, enable_training: bool = False, pca_dims: int | None = None, clear_cache: bool = False, ) -> dict[str, Any]: """Run distillation on Beam and sync results.""" logger.info("☁️ Running distillation on Beam with local sync") try: # Choose appropriate beam function based on training flag beam_function = _beam_train_models if enable_training else _beam_distill_models # Run distillation on Beam results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache) # Check if Beam execution was successful if not results: logger.error("❌ Beam execution failed or returned no results") return { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "enable_training": enable_training, "successful_models": [], "all_results": {}, "total_successful": 0, "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), "error": "Beam execution failed", } # Sync models back to local directories if results.get("successful_models"): logger.info("📥 Syncing models from Beam to local directories...") for teacher_name in results["successful_models"]: # Sync base model base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True) # Sync final model if training was enabled if enable_training: final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}" sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True) else: # Copy base to final copy_base_to_final(teacher_name, enable_training) logger.info("✅ All models synced from Beam") return results except Exception as e: logger.exception("❌ Beam distillation failed with exception") return { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "enable_training": enable_training, "successful_models": [], "all_results": {}, "total_successful": 0, "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), "error": str(e), } # ============================================================================= # CLI INTERFACE # ============================================================================= def main( use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False, train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False, teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None, pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None, clear_cache: Annotated[ bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation") ] = False, clear_checkpoints: Annotated[ bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training") ] = False, use_optimized_dataset: Annotated[ bool, typer.Option( "--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset" ), ] = False, dataset_path: Annotated[ str | None, typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"), ] = None, ) -> None: """Unified distillation command with optional training.""" logger.info("🚀 Starting unified Model2Vec distillation workflow") # Set dataset configuration distillation_config.use_optimized_dataset = use_optimized_dataset distillation_config.custom_dataset_path = dataset_path if use_optimized_dataset and train: dataset_source = dataset_path or "code_model2vec/dataset" logger.info(f"🎯 Using optimized dataset from: {dataset_source}") elif train: logger.info("🎯 Using C4 dataset for training (following POTION approach)") logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}") logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}") # Use default models if none specified models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS logger.info(f"📊 Teacher models to process: {len(models_to_distill)}") for i, model in enumerate(models_to_distill, 1): logger.info(f" {i}. {model}") # Clear cache for problematic models if requested if clear_cache: logger.info("🧹 Clearing cache for known problematic models...") problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] for model in problematic_models: if model in models_to_distill: clear_model_cache(model) # Clear tokenlearn checkpoints if requested (for training mode) if clear_checkpoints and train: logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...") for teacher_model in models_to_distill: teacher_model.split("/")[-1].replace("-", "_") # Use the same persistent directory structure as the training function teacher_slug = teacher_model.replace("/", "_").replace("-", "_") persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug features_dir = persistent_tokenlearn_dir / "features" trained_dir = persistent_tokenlearn_dir / "trained_model" # Clear persistent tokenlearn checkpoints if features_dir.exists() or trained_dir.exists(): clear_tokenlearn_checkpoints(features_dir, trained_dir) logger.info(f"🗑️ Cleared persistent tokenlearn checkpoints for {teacher_model}") else: logger.info(f"ℹ️ No tokenlearn checkpoints found for {teacher_model}") elif clear_checkpoints and not train: logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)") # Run distillation workflow if use_beam: results = run_beam_distillation( teacher_models=models_to_distill, enable_training=train, pca_dims=pca_dims, clear_cache=clear_cache, ) else: results = run_local_distillation( teacher_models=models_to_distill, enable_training=train, pca_dims=pca_dims, clear_cache=clear_cache, ) # Handle case where results might be None or invalid if not results or not isinstance(results, dict): logger.error("❌ Distillation workflow failed - no valid results returned") results = { "total_successful": 0, "total_attempted": len(models_to_distill), "error": "Workflow failed", } # Final summary successful_count = results.get("total_successful", 0) total_attempted = results.get("total_attempted", 0) logger.info("\n🎉 UNIFIED DISTILLATION WORKFLOW COMPLETED!") logger.info(f"📊 Successfully processed: {successful_count}/{total_attempted} models") logger.info(f"📁 Base models saved to: {LOCAL_BASE_DIR}") logger.info(f"📁 Final models saved to: {LOCAL_FINAL_DIR}") if train: logger.info("🎓 Advanced training was enabled - models include CodeSearchNet specialization") else: logger.info("📖 Basic distillation only - use --train flag to enable advanced training") def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]: """ Check if a model has known compatibility issues with Model2Vec. Returns: Tuple of (is_compatible, warning_message) """ known_incompatible = { "BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute", "jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies", "Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors", } if teacher_model in known_incompatible: return False, known_incompatible[teacher_model] # Check for model families that might have issues if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower(): return False, "BGE models with Qwen2 tokenizers may have compatibility issues" if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower(): return False, "Jina embeddings v3 models may have missing dependencies" if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower(): return False, "Salesforce SFR embedding models may have device placement issues" return True, None def clear_model_cache(model_name: str) -> bool: """Clear HuggingFace cache for a specific model.""" try: import shutil from pathlib import Path # Get HuggingFace cache directory cache_dir = Path.home() / ".cache" / "huggingface" # Find model-specific cache directories model_slug = model_name.replace("/", "--") # Clear transformers cache transformers_cache = cache_dir / "transformers" / model_slug if transformers_cache.exists(): shutil.rmtree(transformers_cache) logger.info(f"🗑️ Cleared transformers cache for {model_name}") # Clear hub cache hub_cache = cache_dir / "hub" / f"models--{model_slug}" if hub_cache.exists(): shutil.rmtree(hub_cache) logger.info(f"🗑️ Cleared hub cache for {model_name}") # Clear modules cache modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0] if modules_cache.exists(): shutil.rmtree(modules_cache) logger.info(f"🗑️ Cleared modules cache for {model_name}") return True except Exception as e: logger.warning(f"Failed to clear cache for {model_name}: {e}") return False def try_model_workarounds(teacher_model: str) -> str | None: """ Try specific workarounds for problematic models. Returns: The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available """ if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower(): logger.info("🔧 Salesforce SFR model detected - will use specialized distillation") return "salesforce" if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()): logger.info("🔧 BAAI BGE model detected - will use specialized distillation") return "baai" return None def salesforce_model_distillation( teacher_model: str, output_dir: str, pca_dims: int | None = None, ) -> Any: """Special distillation function for Salesforce SFR models that handles device placement issues.""" if pca_dims is None: pca_dims = int(distillation_config.optimal_pca_dims) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) logger.info(f"🔄 Salesforce-specific distillation: {teacher_model} → {output_dir}") logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") start_time = time.time() try: import torch from transformers import AutoModel, AutoTokenizer # Enhanced custom model loading for Salesforce models logger.info("🔧 Loading model with enhanced device settings...") # Method 1: Try with to_empty() for meta tensor handling try: logger.info("🔄 Attempting with to_empty() method...") # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True) # Load model with meta device initially model = AutoModel.from_pretrained( teacher_model, trust_remote_code=True, torch_dtype=torch.float16, device_map="meta", # Load on meta device first ) # Move from meta to actual device using to_empty() if torch.cuda.is_available(): device = torch.device("cuda") # Create empty tensors on target device and copy weights model = model.to_empty(device=device) else: device = torch.device("cpu") model = model.to_empty(device=device) # Ensure model is in the right dtype model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32) logger.info("✅ Successfully loaded with to_empty() method") except Exception as e: logger.warning(f"to_empty() method failed: {e}") # Method 2: Try SentenceTransformer with specific settings logger.info("🔄 Falling back to SentenceTransformer method...") sentence_model = load_model_with_flash_attention( teacher_model, device="cpu", # Force CPU loading first ) # Move to GPU if available if torch.cuda.is_available(): sentence_model = sentence_model.to("cuda") # Extract components model = sentence_model[0].auto_model tokenizer = sentence_model.tokenizer logger.info("✅ Successfully loaded with SentenceTransformer method") # Now use Model2Vec's distill_from_model function directly from distiller.model2vec.distill.distillation import distill_from_model distilled_model = distill_from_model( model=model, tokenizer=tokenizer, pca_dims=int(pca_dims), apply_zipf=bool(distillation_config.apply_zipf), sif_coefficient=float(distillation_config.sif_coefficient), ) logger.info("✅ Core distillation completed successfully") # Save the model distilled_model.save_pretrained(str(output_path)) logger.info(f"💾 Model saved to {output_path}") # Log model info logger.info(f"Model type: {type(distilled_model)}") if hasattr(distilled_model, "embedding"): logger.info(f"Embedding shape: {distilled_model.embedding.shape}") logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}") total_time = time.time() - start_time logger.info(f"🎉 Salesforce distillation completed in {total_time:.2f} seconds") # Clean up if "sentence_model" in locals(): del sentence_model del model if torch.cuda.is_available(): torch.cuda.empty_cache() return distilled_model except Exception: logger.exception(f"❌ Salesforce-specific distillation failed for {teacher_model}") return None def baai_bge_model_distillation( teacher_model: str, output_dir: str, pca_dims: int | None = None, ) -> Any: """Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues.""" if pca_dims is None: pca_dims = int(distillation_config.optimal_pca_dims) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) logger.info(f"🔄 BAAI BGE-specific distillation: {teacher_model} → {output_dir}") logger.info(f"📊 PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") start_time = time.time() try: import torch from transformers import AutoModel, AutoTokenizer logger.info("🔧 Loading BAAI model with tokenizer workaround...") # Try multiple approaches for BAAI models success = False # Method 1: Try SentenceTransformer first (often handles tokenizer issues better) try: logger.info("🔄 Attempting with SentenceTransformer wrapper...") sentence_model = load_model_with_flash_attention(teacher_model) # Extract components model = sentence_model[0].auto_model tokenizer = sentence_model.tokenizer # Test if tokenizer works by encoding a simple text test_encoding = tokenizer.encode("test", return_tensors="pt") logger.info("✅ SentenceTransformer method successful") success = True except Exception as e: logger.warning(f"SentenceTransformer method failed: {e}") # Method 2: Try direct loading with tokenizer replacement try: logger.info("🔄 Attempting with tokenizer replacement...") from transformers import BertTokenizerFast # Load model directly model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True) # Try to use a compatible tokenizer instead try: # First try the original tokenizer tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True) except Exception: # Fall back to BERT tokenizer for BGE models logger.info("🔄 Falling back to BERT tokenizer...") tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") logger.info("✅ Tokenizer replacement method successful") success = True except Exception as e2: logger.warning(f"Tokenizer replacement method failed: {e2}") if not success: logger.error("❌ All BAAI model loading methods failed") return None # Now use Model2Vec's distill_from_model function directly from distiller.model2vec.distill.distillation import distill_from_model distilled_model = distill_from_model( model=model, tokenizer=tokenizer, pca_dims=int(pca_dims), apply_zipf=bool(distillation_config.apply_zipf), sif_coefficient=float(distillation_config.sif_coefficient), ) logger.info("✅ Core distillation completed successfully") # Save the model distilled_model.save_pretrained(str(output_path)) logger.info(f"💾 Model saved to {output_path}") # Log model info logger.info(f"Model type: {type(distilled_model)}") if hasattr(distilled_model, "embedding"): logger.info(f"Embedding shape: {distilled_model.embedding.shape}") logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}") total_time = time.time() - start_time logger.info(f"🎉 BAAI BGE distillation completed in {total_time:.2f} seconds") # Clean up if "sentence_model" in locals(): del sentence_model del model if torch.cuda.is_available(): torch.cuda.empty_cache() return distilled_model except Exception: logger.exception(f"❌ BAAI BGE-specific distillation failed for {teacher_model}") return None def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None: """Clear tokenlearn checkpoint markers to force re-execution of steps.""" featurization_marker = features_dir / ".featurization_complete" training_marker = trained_dir / ".training_complete" if featurization_marker.exists(): featurization_marker.unlink() logger.info(f"🗑️ Cleared featurization checkpoint: {featurization_marker}") if training_marker.exists(): training_marker.unlink() logger.info(f"🗑️ Cleared training checkpoint: {training_marker}") def verify_featurization_output(features_dir: Path) -> bool: """Verify that featurization output files actually exist and are valid.""" if not features_dir.exists(): return False # Check for expected tokenlearn output files # Check if any expected files exist return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"]) def verify_training_output(trained_dir: Path) -> bool: """Verify that training output files actually exist and are valid.""" if not trained_dir.exists(): return False # Check for model files model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"] for model_file in model_files: if (trained_dir / model_file).exists(): return True # Check for alternative model directory structure for subdir in ["model", "model_weighted"]: subdir_path = trained_dir / subdir if subdir_path.exists(): for model_file in model_files: if (subdir_path / model_file).exists(): return True return False def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]: """ Prepare dataset for tokenlearn featurization. Returns: Tuple of (dataset_path, dataset_name, text_key) for tokenlearn """ if distillation_config.use_optimized_dataset: return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir) return _prepare_original_dataset_for_tokenlearn() def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]: """Prepare custom optimized dataset for tokenlearn featurization.""" logger.info("🎯 Preparing custom optimized dataset for tokenlearn...") # Import the dataset module from .dataset import create_optimized_dataset, load_optimized_dataset # Define paths custom_dataset_dir = ( Path(distillation_config.custom_dataset_path) if distillation_config.custom_dataset_path else Path("code_model2vec/dataset") ) tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset" # Check if we need to create the custom dataset if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists(): logger.info("📊 Custom dataset not found - creating optimized dataset...") create_optimized_dataset( max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, # Divide by number of languages output_dir=custom_dataset_dir, create_multiple_formats=False, # Use simple format for tokenlearn ) # Load the custom dataset logger.info(f"📂 Loading custom dataset from {custom_dataset_dir}") train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train") # Prepare dataset for tokenlearn (save as JSON files that load_dataset can read) tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True) # Save as JSON file that tokenlearn can load with load_dataset() train_json_path = tokenlearn_dataset_dir / "train.json" # Create JSON lines format import json with train_json_path.open("w") as f: for text in train_df["text"]: json.dump({"text": text}, f) f.write("\n") logger.info(f"✅ Prepared custom dataset with {len(train_df)} samples for tokenlearn") logger.info(f"💾 Saved JSON dataset to {train_json_path}") # Return the JSON file path directly (not directory) and no config name for JSON loading return str(train_json_path), None, "text" def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]: """Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach).""" logger.info("📊 Using C4 dataset for tokenlearn (following POTION approach)...") return ( str(distillation_config.tokenlearn_dataset), # "allenai/c4" str(distillation_config.tokenlearn_dataset_name), # "en" str(distillation_config.tokenlearn_text_key), # "text" ) if __name__ == "__main__": typer.run(main)