codemalt / src /distiller /distill.py
Sarthak
chore: update README and REPORT with performance insights and dataset changes
0dbb356
"""
Unified Code-Specialized Model2Vec Distillation Script.
This script provides a unified approach for creating code-specialized embeddings
using Model2Vec distillation with optional code-specific training.
Features:
- Basic distillation (default): Simple Model2Vec distillation
- Advanced training (--train flag): Additional CodeSearchNet fine-tuning
- Checkpoint support with Beam sync utilities
- Multi-teacher model processing
- Smart resume capabilities
- Hierarchical storage: base β†’ final
Directory Structure:
- code_model2vec/base: Basic distilled models (first step)
- code_model2vec/final: Final models (copied from base or after training)
Usage:
distiller distill [--use-beam] [--train] # Basic distillation or with training
"""
import importlib.util
import json
import logging
import os
import time
from pathlib import Path
from typing import Annotated, Any
import torch
import typer
from beam import function
from sentence_transformers import SentenceTransformer
from distiller.model2vec.distill import distill
# Try to import flash_attn to check if it's available
from .beam_utils import (
BeamCheckpointManager,
create_beam_utilities,
download_model_from_beam,
sync_checkpoints_from_beam,
sync_checkpoints_to_beam,
upload_model_to_beam,
)
from .config import (
codesearchnet_config,
directories,
distillation_config,
get_distillation_function_kwargs,
get_training_function_kwargs,
get_volume_config,
languages_config,
)
# Check if flash_attn is available and compatible
FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None
# =============================================================================
# CONFIGURATION
# =============================================================================
VOLUME_CONFIG = get_volume_config()
LOCAL_BASE_DIR = directories.base
LOCAL_FINAL_DIR = directories.final
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Teacher models for distillation
DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models)
# =============================================================================
# FLASH ATTENTION UTILITIES
# =============================================================================
def configure_flash_attention() -> dict[str, Any]:
"""Configure flash attention settings and return model kwargs."""
model_kwargs: dict[str, Any] = {}
if not FLASH_ATTN_AVAILABLE:
logger.info("⚠️ Flash attention not available - using standard attention")
return model_kwargs
# Set environment variables for flash attention
os.environ["FLASH_ATTENTION_FORCE_USE"] = "1"
# Disable torch compile for flash attention compatibility
os.environ["TORCH_COMPILE_DISABLE"] = "1"
# Enable flash attention in transformers
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Check if we're on a compatible GPU
try:
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability()
# Flash attention requires compute capability >= 7.5 (Turing, Ampere, Ada, Hopper)
if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5):
logger.info("βœ… Flash attention enabled - compatible GPU detected")
model_kwargs.update(
{
"model_kwargs": {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.float16, # Flash attention works best with fp16
"use_flash_attention_2": True,
"_attn_implementation": "flash_attention_2", # Alternative key for some models
}
}
)
else:
logger.info(f"⚠️ GPU compute capability {device_capability} < 7.5 - flash attention disabled")
else:
logger.info("⚠️ No CUDA available - flash attention disabled")
except Exception as e:
logger.warning(f"⚠️ Failed to check GPU compatibility: {e} - flash attention disabled")
return model_kwargs
def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer:
"""Load a SentenceTransformer model with flash attention if available."""
flash_kwargs = configure_flash_attention()
try:
# Try loading with flash attention first
if flash_kwargs and "model_kwargs" in flash_kwargs:
logger.info(f"πŸš€ Loading model with flash attention: {Path(model_path).name}")
model = SentenceTransformer(model_path, device=device, trust_remote_code=True, **flash_kwargs)
logger.info("βœ… Model loaded successfully with flash attention")
return model
except Exception as e:
logger.warning(f"⚠️ Failed to load with flash attention: {e}")
logger.info("πŸ”„ Falling back to standard attention")
# Fallback to standard loading
logger.info(f"πŸ“‚ Loading model with standard attention: {Path(model_path).name}")
model = SentenceTransformer(model_path, device=device, trust_remote_code=True)
logger.info("βœ… Model loaded successfully with standard attention")
return model
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def get_current_config_hash(enable_training: bool) -> str:
"""Generate a hash of current configuration parameters for checkpoint validation."""
import hashlib
config_params = {
"pca_dims": distillation_config.optimal_pca_dims,
"sif_coefficient": distillation_config.sif_coefficient,
"apply_zipf": distillation_config.apply_zipf,
"enable_training": enable_training,
}
if enable_training:
# Add a simple hash of tokenlearn parameters for config validation
tokenlearn_hash = hash(
f"{distillation_config.tokenlearn_dataset}_{distillation_config.tokenlearn_dataset_name}_{distillation_config.tokenlearn_text_key}"
)
config_params["tokenlearn_hash"] = float(abs(tokenlearn_hash) % 1000000) # Convert to float for consistency
config_str = str(sorted(config_params.items()))
return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
def check_existing_base_model(teacher_name: str) -> str | None:
"""Check if base distilled model already exists locally."""
base_dir = Path(LOCAL_BASE_DIR)
model_dir = base_dir / f"code_model2vec_{teacher_name}"
if model_dir.exists():
# Check for essential model files
has_config = (model_dir / "config.json").exists()
has_model_file = any(
[
(model_dir / "model.safetensors").exists(),
(model_dir / "model.bin").exists(),
(model_dir / "pytorch_model.bin").exists(),
]
)
if has_config and has_model_file:
logger.info(f"βœ… Found existing base model: {teacher_name}")
return str(model_dir)
return None
def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None:
"""Check if final model already exists locally."""
final_dir = Path(LOCAL_FINAL_DIR)
# Add suffix for trained models
model_name = f"code_model2vec_{teacher_name}"
if enable_training:
model_name += "_fine_tuned"
final_path = final_dir / model_name
if final_path.exists():
# Check for essential model files
has_config = (final_path / "config.json").exists()
has_model_file = any(
[
(final_path / "model.safetensors").exists(),
(final_path / "model.bin").exists(),
(final_path / "pytorch_model.bin").exists(),
]
)
if has_config and has_model_file:
logger.info(f"βœ… Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
return str(final_path)
return None
def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool:
"""Copy base model to final directory."""
import shutil
base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
# Add suffix for trained models
final_model_name = f"code_model2vec_{teacher_name}"
if enable_training:
final_model_name += "_fine_tuned"
final_path = Path(LOCAL_FINAL_DIR) / final_model_name
try:
final_path.parent.mkdir(parents=True, exist_ok=True)
if final_path.exists():
shutil.rmtree(final_path)
shutil.copytree(base_path, final_path)
logger.info(f"πŸ“ Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}")
return True
except Exception:
logger.exception(f"❌ Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}")
return False
def sync_model_from_beam(
teacher_name: str,
target_dir: str,
use_beam_utilities: bool = False,
) -> bool:
"""Sync model from Beam volume to local directory."""
if not use_beam_utilities:
return False
try:
target_path = Path(target_dir)
target_path.mkdir(parents=True, exist_ok=True)
beam_model_name = f"{teacher_name}_model"
success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path))
if success:
logger.info(f"πŸ“₯ Synced {teacher_name} from Beam to {target_dir}")
return True
logger.warning(f"⚠️ Failed to sync {teacher_name} from Beam")
return False
except Exception as e:
logger.warning(f"Failed to sync {teacher_name} from Beam: {e}")
return False
def sync_model_to_beam(
teacher_name: str,
source_dir: str,
use_beam_utilities: bool = False,
) -> bool:
"""Sync model from local directory to Beam volume."""
if not use_beam_utilities:
return False
try:
beam_model_name = f"{teacher_name}_model"
success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir)
if success:
logger.info(f"πŸ“€ Synced {teacher_name} to Beam from {source_dir}")
return True
logger.warning(f"⚠️ Failed to sync {teacher_name} to Beam")
return False
except Exception as e:
logger.warning(f"Failed to sync {teacher_name} to Beam: {e}")
return False
# =============================================================================
# DISTILLATION FUNCTIONS
# =============================================================================
def simple_distillation(
teacher_model: str,
output_dir: str,
pca_dims: int | None = None,
retry_with_cache_clear: bool = False,
) -> Any:
"""
Perform simple Model2Vec distillation without additional training.
Args:
teacher_model: Name of teacher model
output_dir: Output directory for the distilled model
pca_dims: PCA dimensions (uses config default if None)
retry_with_cache_clear: Whether this is a retry after clearing cache
Returns:
Distilled model or None if failed
"""
if pca_dims is None:
pca_dims = int(distillation_config.optimal_pca_dims)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else ""
logger.info(f"πŸ”„ Simple distillation{retry_suffix}: {teacher_model} β†’ {output_dir}")
logger.info(f"πŸ“Š PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
start_time = time.time()
try:
# Perform distillation with optimal parameters
model = distill(
model_name=teacher_model,
pca_dims=int(pca_dims),
apply_zipf=bool(distillation_config.apply_zipf),
sif_coefficient=float(distillation_config.sif_coefficient),
trust_remote_code=True,
)
logger.info("βœ… Core distillation completed successfully")
# Validate model before saving
if hasattr(model, "tokenizer") and hasattr(model, "embedding"):
vocab_size = len(model.tokenizer.get_vocab())
embedding_size = model.embedding.shape[0]
logger.info("πŸ“Š Model validation:")
logger.info(f" - Vocabulary size: {vocab_size}")
logger.info(f" - Embedding matrix size: {embedding_size}")
if vocab_size != embedding_size:
logger.warning(f"⚠️ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}")
logger.warning("⚠️ This may cause issues in downstream usage")
else:
logger.info("βœ… Vocabulary and embedding sizes match")
# Save the model
model.save_pretrained(str(output_path))
logger.info(f"πŸ’Ύ Model saved to {output_path}")
# Log model info
logger.info(f"Model type: {type(model)}")
if hasattr(model, "embedding"):
logger.info(f"Embedding shape: {model.embedding.shape}")
logger.info(f"Embedding dtype: {model.embedding.dtype}")
total_time = time.time() - start_time
logger.info(f"πŸŽ‰ Simple distillation completed in {total_time:.2f} seconds")
return model
except ValueError as e:
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
logger.warning(f"Error details: {e}")
logger.warning("πŸ’‘ This model has incompatible tokenization. Skipping...")
return None
if "weight is on the meta device" in str(e):
logger.warning(f"⚠️ Device placement issue with {teacher_model} - model weights on meta device")
logger.warning(f"Error details: {e}")
logger.warning("πŸ’‘ This model has device placement issues. Skipping...")
return None
raise
except AttributeError as e:
if "backend_tokenizer" in str(e):
logger.warning(f"⚠️ Tokenizer compatibility issue with {teacher_model}")
logger.warning(f"Error details: {e}")
logger.warning("πŸ’‘ This model's tokenizer is incompatible with Model2Vec. Skipping...")
return None
raise
except FileNotFoundError as e:
if "transformers_modules" in str(e) or "xlm_padding.py" in str(e):
logger.warning(f"⚠️ Missing custom model files for {teacher_model}")
logger.warning(f"Error details: {e}")
# Try clearing cache and retrying once
if not retry_with_cache_clear:
logger.info("πŸ”§ Attempting to clear cache and retry...")
if clear_model_cache(teacher_model):
logger.info("πŸ”„ Retrying distillation after cache clear...")
return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True)
logger.warning("πŸ’‘ This model has missing dependencies. Manual intervention may be required.")
return None
raise
except Exception:
logger.exception(f"❌ Simple distillation failed for {teacher_model}")
return None
def load_optimized_dataset(
max_samples: int | None = None,
checkpoint_manager: BeamCheckpointManager | None = None,
dataset_path: str | None = None,
) -> list[str]:
"""Load our pre-created optimized dataset for tokenlearn training."""
from .dataset import DATASET_OUTPUT_DIR
from .dataset import load_optimized_dataset as load_dataset_func
# Use configuration if not provided as parameter
if dataset_path is None:
dataset_path = distillation_config.custom_dataset_path
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
# Use configuration default if not specified
if max_samples is None:
max_samples = distillation_config.tokenlearn_max_samples
logger.info(f"🎯 Loading optimized dataset from {dataset_dir}")
logger.info(f"πŸ“Š Target samples: {max_samples}")
try:
# Load the training split of our optimized dataset
df = load_dataset_func(output_dir=dataset_dir, split="train")
# Extract the text column (which contains our formatted query + code)
texts = df["text"].tolist()
# Shuffle for better training distribution
import random
random.seed(42)
random.shuffle(texts)
# Limit to max_samples
if len(texts) > max_samples:
texts = texts[:max_samples]
logger.info(f"βœ… Loaded {len(texts)} optimized training samples")
# Log language distribution
languages = df["language"].value_counts()
logger.info("πŸ“Š Language distribution:")
for lang, count in languages.items():
percentage = (count / len(df)) * 100
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
return texts
except Exception as e:
logger.warning(f"⚠️ Failed to load optimized dataset: {e}")
logger.info("πŸ”„ Falling back to original CodeSearchNet loading...")
return load_codesearchnet_dataset(max_samples, checkpoint_manager)
def load_codesearchnet_dataset(
max_samples: int | None = None,
checkpoint_manager: BeamCheckpointManager | None = None,
) -> list[str]:
"""Load and format the CodeSearchNet dataset for token frequency computation."""
from datasets import load_dataset
# Use configuration default if not specified
if max_samples is None:
max_samples = distillation_config.tokenlearn_max_samples
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
logger.info(f"Limiting to {max_samples} samples for training efficiency")
logger.info(f"Languages: {', '.join(languages_config.all)}")
# Check for existing dataset checkpoint
texts = []
start_from = 0
if checkpoint_manager:
checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
if checkpoint_data:
cached_texts = checkpoint_data.get("data", {}).get("texts", [])
if len(cached_texts) >= max_samples:
logger.info(f"βœ… Resumed dataset loading: {len(cached_texts)} texts from checkpoint")
return cached_texts[:max_samples]
logger.info(f"πŸ“‹ Partial dataset found: {len(cached_texts)} texts, continuing...")
texts = cached_texts
start_from = len(texts)
try:
# Calculate samples per language for balanced distribution
num_languages = len(languages_config.all)
samples_per_language = max_samples // num_languages
remaining_samples = max_samples % num_languages
logger.info(f"πŸ“Š Target distribution: {samples_per_language} samples per language")
if remaining_samples > 0:
logger.info(f"πŸ“Š Extra {remaining_samples} samples will be distributed to first languages")
# Load training data from each language separately for balanced distribution
language_texts: dict[str, list[str]] = {}
total_collected = len(texts)
for i, language in enumerate(languages_config.all):
if total_collected >= max_samples:
break
logger.info(f"πŸ” Loading {language} training data...")
# Determine how many samples to collect for this language
target_for_lang = samples_per_language
if i < remaining_samples: # Distribute extra samples to first languages
target_for_lang += 1
# Skip if we already have enough from this language
if language in language_texts and len(language_texts[language]) >= target_for_lang:
continue
try:
# Load training split for the specific language (same format as evaluate.py)
from datasets import load_dataset
dataset = load_dataset(
codesearchnet_config.dataset_name,
language,
split="train",
trust_remote_code=True,
)
lang_texts: list[str] = []
processed_count = 0
for processed_count, example in enumerate(dataset, 1):
if len(lang_texts) >= target_for_lang:
break
# Use same field names as evaluate.py
doc_string = example.get("func_documentation_string", "").strip()
code_string = example.get("func_code_string", "").strip()
if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50:
# Format as documentation-code pair for training (same as evaluate.py)
text = f"Documentation: {doc_string}\nCode:\n{code_string}"
# Ensure reasonable length for embedding models
if len(text) <= 2048:
lang_texts.append(text)
if processed_count % 5000 == 0:
logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}")
language_texts[language] = lang_texts
total_collected += len(lang_texts)
logger.info(f"βœ… {language}: collected {len(lang_texts)} samples")
except Exception as e:
logger.warning(f"⚠️ Failed to load {language} data: {e}")
continue
# Combine all language texts in a balanced way
combined_texts = []
# Add existing texts first (from checkpoint)
if start_from > 0:
combined_texts = texts[:start_from]
# Interleave texts from different languages for better training distribution
max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0
for sample_idx in range(max_lang_samples):
for language in languages_config.all:
if len(combined_texts) >= max_samples:
break
if language in language_texts and sample_idx < len(language_texts[language]):
combined_texts.append(language_texts[language][sample_idx])
if len(combined_texts) >= max_samples:
break
# Truncate to exact max_samples
combined_texts = combined_texts[:max_samples]
# Log final distribution
logger.info("πŸ“Š Final dataset distribution:")
lang_counts: dict[str, int] = {}
for text in combined_texts:
# Simple heuristic to identify language from code patterns
if "def " in text and ":" in text:
lang_counts["python"] = lang_counts.get("python", 0) + 1
elif "function " in text and "{" in text:
lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1
elif "public " in text and "class " in text:
lang_counts["java"] = lang_counts.get("java", 0) + 1
elif "<?php" in text or "$" in text:
lang_counts["php"] = lang_counts.get("php", 0) + 1
elif "func " in text and "end" in text:
lang_counts["ruby"] = lang_counts.get("ruby", 0) + 1
elif "func " in text and "}" in text:
lang_counts["go"] = lang_counts.get("go", 0) + 1
else:
lang_counts["other"] = lang_counts.get("other", 0) + 1
for lang, count in lang_counts.items():
percentage = (count / len(combined_texts)) * 100
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
# Final checkpoint save
if checkpoint_manager:
checkpoint_data = {
"config_hash": get_current_config_hash(enable_training=True),
"stage": "dataset",
"step": 0,
"timestamp": time.time(),
"data": {"texts": combined_texts},
}
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
logger.info(f"Successfully loaded {len(combined_texts)} balanced code-documentation pairs from CodeSearchNet")
return combined_texts
except Exception:
logger.exception("Error loading CodeSearchNet dataset")
return texts # Return what we have so far
def generate_teacher_embeddings(
teacher_model: SentenceTransformer,
texts: list[str],
checkpoint_manager: BeamCheckpointManager | None = None,
) -> torch.Tensor:
"""Generate teacher embeddings for code training with checkpoint support."""
logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
# Check for existing embeddings checkpoint
if checkpoint_manager:
volume_path = Path(VOLUME_CONFIG.mount_path)
embeddings_path = volume_path / "embeddings_cache.pt"
config_path = volume_path / "embeddings_config.json"
if embeddings_path.exists() and config_path.exists():
try:
# Load config first to validate compatibility
with config_path.open("r") as f:
config_data = json.load(f)
current_hash = get_current_config_hash(enable_training=True)
if config_data.get("config_hash") == current_hash:
# Load the embeddings tensor
final_embeddings = torch.load(embeddings_path, map_location="cpu")
num_expected = config_data.get("num_texts", len(texts))
if final_embeddings.shape[0] >= num_expected:
logger.info(f"βœ… Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)")
return final_embeddings[: len(texts)]
except Exception as e:
logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
# Generate embeddings from scratch
logger.info("Generating fresh teacher embeddings...")
batch_size = 16 # Fixed batch size for teacher embedding generation
embeddings_list = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
try:
batch_embeddings = teacher_model.encode(
batch_texts,
convert_to_tensor=True,
batch_size=batch_size,
show_progress_bar=False,
normalize_embeddings=True,
)
embeddings_list.append(batch_embeddings)
if i % (batch_size * 10) == 0:
logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
except torch.cuda.OutOfMemoryError:
logger.warning(f"GPU OOM with batch size {batch_size}, reducing...")
torch.cuda.empty_cache()
batch_size = max(1, batch_size // 2)
# Retry with smaller batch size
batch_embeddings = teacher_model.encode(
batch_texts,
convert_to_tensor=True,
batch_size=batch_size,
show_progress_bar=False,
normalize_embeddings=True,
)
embeddings_list.append(batch_embeddings)
# Combine all embeddings
teacher_embeddings = torch.cat(embeddings_list, dim=0)
# Ensure fp32 precision
if teacher_embeddings.dtype != torch.float32:
teacher_embeddings = teacher_embeddings.to(torch.float32)
logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
# Save embeddings cache for future runs
if checkpoint_manager:
try:
volume_path = Path(VOLUME_CONFIG.mount_path)
embeddings_path = volume_path / "embeddings_cache.pt"
config_path = volume_path / "embeddings_config.json"
# Save embeddings tensor
torch.save(teacher_embeddings, embeddings_path)
# Save configuration
config_data = {
"config_hash": get_current_config_hash(enable_training=True),
"num_texts": len(texts),
"embedding_shape": list(teacher_embeddings.shape),
"timestamp": time.time(),
}
with config_path.open("w") as f:
json.dump(config_data, f, indent=2)
logger.info("πŸ’Ύ Saved embeddings cache for future runs")
except Exception as e:
logger.warning(f"Failed to save embeddings cache: {e}")
return teacher_embeddings
def tokenlearn_training(
student_model: Any,
teacher_model: SentenceTransformer,
checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
) -> Any:
"""
Perform tokenlearn training following the official POTION approach.
This follows the 4-step process:
1. Model2Vec distillation (already done - student_model)
2. Sentence transformer inference (create features)
3. Tokenlearn training
"""
from pathlib import Path
logger.info("πŸ§ͺ Starting tokenlearn training (POTION approach)...")
# Create persistent directories for tokenlearn workflow (for checkpoint preservation)
teacher_model_name = getattr(teacher_model, "model_name", None)
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
# Try to extract from the first module if it's a SentenceTransformer
first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
teacher_model_name = first_module.auto_model.name_or_path
if not teacher_model_name:
teacher_model_name = "unknown_teacher"
# Use persistent directory for tokenlearn checkpoints
teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_")
persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug
features_dir = persistent_tokenlearn_dir / "features"
model_dir = persistent_tokenlearn_dir / "base_model"
trained_dir = persistent_tokenlearn_dir / "trained_model"
features_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
trained_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"πŸ“ Using persistent tokenlearn directory: {persistent_tokenlearn_dir}")
# Save the base distilled model for tokenlearn
student_model.save_pretrained(str(model_dir))
logger.info(f"πŸ’Ύ Saved base model to {model_dir}")
# Step 2: Create features using sentence transformer
logger.info("πŸ” Step 2: Creating features using sentence transformer...")
# Get teacher model name/path for tokenlearn
teacher_model_name = getattr(teacher_model, "model_name", None)
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
# Try to extract from the first module if it's a SentenceTransformer
# _modules is a dict-like container, get the first module by iterating
first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
teacher_model_name = first_module.auto_model.name_or_path
logger.info(f"πŸ“Š Using teacher model: {teacher_model_name}")
# Prepare dataset for tokenlearn featurization
dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir)
# Check if featurization already completed (checkpoint detection)
featurization_complete_marker = features_dir / ".featurization_complete"
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
logger.info("βœ… Found existing featurization checkpoint with valid output files")
logger.info(f"πŸ“‚ Using cached features from: {features_dir}")
# Verify marker is still valid
output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
logger.info(f"πŸ“ Found {len(output_files)} cached feature files")
else:
if featurization_complete_marker.exists():
logger.warning("⚠️ Featurization marker exists but output files are missing - re-running featurization")
featurization_complete_marker.unlink()
logger.info("πŸ”„ No valid featurization checkpoint found - starting featurization...")
if not teacher_model_name:
logger.warning("⚠️ Could not determine teacher model name, using fallback")
teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
logger.info(f"πŸ“Š Using teacher model: {teacher_model_name}")
try:
# Use direct function call instead of subprocess
from datasets import load_dataset
from distiller.tokenlearn.featurize import featurize
logger.info("πŸ”„ Running tokenlearn featurization...")
logger.info(f"πŸ“Š Dataset: {dataset_path} (config: {dataset_name})")
logger.info(f"πŸ“ Text field: {text_key}")
# Load the dataset
if dataset_name is None:
# For local JSON files, don't pass name parameter
dataset = load_dataset(
"json",
data_files=dataset_path,
split="train",
streaming=True,
)
else:
# For remote datasets with specific configurations
dataset = load_dataset(
dataset_path,
name=dataset_name,
split="train",
streaming=True,
)
# Call featurization function directly
featurize(
dataset=iter(dataset),
model=teacher_model,
output_dir=str(features_dir),
max_means=50000, # IMPROVEMENT: Limit means to prevent overfitting
batch_size=512, # IMPROVEMENT: Smaller batch for better gradients
text_key=text_key,
)
logger.info("βœ… Featurization completed successfully")
# Create checkpoint marker to indicate featurization is complete
featurization_complete_marker.touch()
logger.info(f"πŸ’Ύ Created featurization checkpoint: {featurization_complete_marker}")
except Exception as e:
logger.exception("πŸ’₯ Tokenlearn featurization failed")
logger.exception("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
msg = f"Tokenlearn featurization failed: {e}"
raise RuntimeError(msg) from e
# Step 3: Train using tokenlearn-train
logger.info("πŸŽ“ Step 3: Training using tokenlearn...")
# Check if training already completed (checkpoint detection)
training_complete_marker = trained_dir / ".training_complete"
training_fallback_marker = trained_dir / ".training_fallback"
if training_complete_marker.exists() and verify_training_output(trained_dir):
logger.info("βœ… Found existing training checkpoint with valid model files")
logger.info(f"πŸ“‚ Using cached trained model from: {trained_dir}")
# Show available model files
model_files = []
for pattern in ["*.json", "*.safetensors", "*.bin"]:
model_files.extend(list(trained_dir.glob(pattern)))
for subdir in ["model", "model_weighted"]:
subdir_path = trained_dir / subdir
if subdir_path.exists():
model_files.extend(list(subdir_path.glob(pattern)))
logger.info(f"πŸ“ Found {len(model_files)} cached model files")
elif training_fallback_marker.exists():
logger.warning("⚠️ Training fallback marker found - tokenlearn failed previously")
logger.info("πŸ”„ Proceeding with fallback to base model (simple distillation)")
# Skip training and proceed to model loading (will fallback to base model)
else:
if training_complete_marker.exists():
logger.warning("⚠️ Training marker exists but model files are missing - re-running training")
training_complete_marker.unlink()
logger.info("πŸ”„ No valid training checkpoint found - starting training...")
try:
# Use direct function call instead of subprocess
from distiller.tokenlearn.train import train_model
from distiller.tokenlearn.utils import collect_means_and_texts
# IMPROVED APPROACH: Try optimized parameters first
logger.info("πŸš€ Attempting IMPROVED tokenlearn training with optimized parameters...")
logger.info("πŸ“Š Using smaller vocabulary and conservative PCA to prevent overfitting")
# Collect training data from features directory
paths = sorted(features_dir.glob("*.json"))
train_txt, train_vec = collect_means_and_texts(paths)
logger.info(f"πŸ“Š Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training")
try:
# Try improved parameters first
trained_model = train_model(
model_name=str(teacher_model_name),
train_txt=train_txt,
train_vec=train_vec,
device="cuda" if torch.cuda.is_available() else "cpu",
vocab_size=25000, # IMPROVEMENT: Smaller vocabulary to prevent overfitting
pca_dims=256, # IMPROVEMENT: Conservative PCA dimensions
)
# Save the trained model
trained_model.save_pretrained(str(trained_dir))
logger.info("βœ… IMPROVED tokenlearn training completed successfully")
training_complete_marker.touch()
logger.info(f"πŸ’Ύ Created improved training checkpoint: {training_complete_marker}")
except Exception as e:
logger.warning(f"⚠️ Improved training failed: {e}")
logger.info("πŸ”„ Falling back to CONSERVATIVE tokenlearn training...")
# FALLBACK: Ultra-conservative training approach
try:
trained_model = train_model(
model_name=str(teacher_model_name),
train_txt=train_txt,
train_vec=train_vec,
device="cuda" if torch.cuda.is_available() else "cpu",
vocab_size=15000, # FALLBACK: Even smaller vocabulary
pca_dims=128, # FALLBACK: Smaller PCA dimensions
)
# Save the trained model
trained_model.save_pretrained(str(trained_dir))
logger.info("βœ… Conservative tokenlearn training completed successfully")
training_complete_marker.touch()
logger.info(f"πŸ’Ύ Created conservative training checkpoint: {training_complete_marker}")
except Exception as e2:
logger.exception("❌ Conservative tokenlearn training also failed")
logger.exception("πŸ’₯ All training approaches failed - check output above for details")
# Create training marker to indicate we tried but failed
training_fallback_marker = trained_dir / ".training_fallback"
training_fallback_marker.touch()
logger.exception("πŸ’₯ Tokenlearn training failed completely")
msg = f"All tokenlearn training approaches failed: {e2}"
raise RuntimeError(msg) from e2
except Exception as e:
logger.warning("πŸ’₯ All tokenlearn training approaches failed")
logger.exception("πŸ’₯ All training approaches failed completely - cannot proceed")
msg = f"All training approaches failed: {e}"
raise RuntimeError(msg) from e
# Step 4: Load the trained model and apply post-training re-regularization
logger.info("πŸ“¦ Step 4: Loading trained model and applying post-training re-regularization...")
# Check if we need to use fallback due to tokenlearn failure
training_fallback_marker = trained_dir / ".training_fallback"
if training_fallback_marker.exists():
logger.error("❌ Tokenlearn training failed previously - cannot return trained model")
logger.error("πŸ’₯ Training was requested but failed - this would be misleading to return base model")
msg = "Tokenlearn training failed - cannot proceed with training pipeline"
raise RuntimeError(msg)
try:
from distiller.model2vec.model import StaticModel
# Load the trained model from tokenlearn
trained_model_path = trained_dir / "model"
if not trained_model_path.exists():
# Try alternative paths
possible_paths = [
trained_dir / "model_weighted",
trained_dir,
]
for path in possible_paths:
if path.exists() and any(path.glob("*.json")):
trained_model_path = path
break
else:
logger.error(f"❌ Could not find trained model in {trained_dir}")
logger.error("πŸ’₯ Training was requested but no trained model found - cannot proceed")
msg = f"Trained model not found in {trained_dir} - training pipeline failed"
raise RuntimeError(msg)
# Load the model before re-regularization
logger.info("πŸ”„ Loading model from tokenlearn training...")
trained_model = StaticModel.from_pretrained(str(trained_model_path))
# Return the trained model directly
logger.info("βœ… Tokenlearn training pipeline completed successfully")
return trained_model
except ValueError as e:
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
logger.exception("πŸ’₯ Token-vector mismatch in tokenlearn training")
logger.exception("Error details")
logger.exception("πŸ”§ This is a known issue with tokenlearn/Model2Vec integration")
logger.exception("πŸ’₯ Training was requested but failed due to token-vector mismatch")
msg = f"Tokenlearn training failed due to token-vector mismatch: {e}"
raise RuntimeError(msg) from e
logger.exception("πŸ’₯ Failed to load tokenlearn trained model")
msg = f"Failed to load tokenlearn trained model: {e}"
raise RuntimeError(msg) from e
except Exception as e:
logger.exception("πŸ’₯ Failed to load tokenlearn trained model")
logger.exception("πŸ’₯ Cannot load trained model - training failed")
msg = f"Failed to load tokenlearn trained model: {e}"
raise RuntimeError(msg) from e
def distill_single_teacher(
teacher_model: str,
enable_training: bool = False,
use_beam_utilities: bool = False,
pca_dims: int | None = None,
) -> dict[str, Any]:
"""
Distill a single teacher model with optional training.
Args:
teacher_model: Name of teacher model
enable_training: Whether to enable advanced training
use_beam_utilities: Whether to use Beam utilities
pca_dims: PCA dimensions
Returns:
Dictionary with distillation results
"""
teacher_name = teacher_model.split("/")[-1].replace("-", "_")
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
# Add suffix for trained models
final_model_name = f"code_model2vec_{teacher_name}"
if enable_training:
final_model_name += "_fine_tuned"
final_dir = Path(LOCAL_FINAL_DIR) / final_model_name
logger.info(f"\n{'=' * 60}")
logger.info(f"πŸ”„ Processing teacher model: {teacher_model}")
logger.info(f"πŸ“ Teacher name: {teacher_name}")
logger.info(f"πŸŽ“ Training enabled: {enable_training}")
logger.info(f"{'=' * 60}")
# Check model compatibility first
is_compatible, warning_msg = check_model_compatibility(teacher_model)
if not is_compatible:
logger.warning(f"⚠️ Known compatibility issue: {warning_msg}")
logger.info("πŸ”§ Attempting distillation anyway, but may fail...")
# Try model-specific workarounds
workaround_type = try_model_workarounds(teacher_model)
# Don't skip if we have a workaround - we'll use it later
start_time = time.time()
# Initialize Beam utilities if requested
checkpoint_mgr = None
if use_beam_utilities:
try:
_, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
except Exception as e:
logger.warning(f"Failed to initialize Beam utilities: {e}")
try:
# Step 1: Check for existing final model
existing_final = check_existing_final_model(teacher_name, enable_training)
if existing_final:
logger.info(f"βœ… Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "skipped_existing_final",
"final_path": existing_final,
"distillation_time": total_time,
}
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
if use_beam_utilities and checkpoint_mgr:
logger.info(f"πŸ”„ Syncing existing checkpoints for {teacher_name}...")
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints)
if enable_training:
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
# Step 2: Check for existing base model or create it
existing_base = check_existing_base_model(teacher_name)
base_model = None
if existing_base:
logger.info(f"βœ… Found existing base model: {teacher_name}")
if enable_training:
# Load base model for training
from distiller.model2vec.model import StaticModel
base_model = StaticModel.from_pretrained(existing_base)
elif use_beam_utilities:
synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities)
if synced:
existing_base = str(base_dir)
if enable_training:
from distiller.model2vec.model import StaticModel
base_model = StaticModel.from_pretrained(existing_base)
if not existing_base:
# Perform simple distillation to create base model
logger.info(f"πŸ”„ Creating base model for {teacher_name}")
# Check if we need specialized distillation
workaround_type = try_model_workarounds(teacher_model)
if workaround_type == "salesforce":
base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims)
elif workaround_type == "baai":
base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims)
else:
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
if base_model is None:
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed_base_distillation",
"error": "Simple distillation failed",
"distillation_time": total_time,
}
# Sync base model and checkpoints to Beam
if use_beam_utilities:
sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities)
if checkpoint_mgr:
sync_checkpoints_to_beam(
VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints
)
existing_base = str(base_dir)
# Step 3: Handle final model creation
if enable_training and base_model is not None:
# Perform tokenlearn training (POTION approach)
logger.info(f"πŸ§ͺ Starting tokenlearn training for {teacher_name}")
try:
# Load teacher model for training
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
# Perform tokenlearn training (POTION approach)
final_model = tokenlearn_training(
base_model,
teacher_st_model,
checkpoint_mgr,
)
# Save final model
final_dir.mkdir(parents=True, exist_ok=True)
final_model.save_pretrained(str(final_dir))
# Sync final model and training checkpoints to Beam
if use_beam_utilities:
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
if checkpoint_mgr:
sync_checkpoints_to_beam(
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
)
del teacher_st_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
except RuntimeError as e:
# Training failed - clean up and return failure
logger.exception(f"❌ Training failed for {teacher_name}")
# Clean up teacher model if it was loaded
if "teacher_st_model" in locals():
del teacher_st_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed_training",
"error": f"Training failed: {e!s}",
"base_path": existing_base, # Base model was created successfully
"distillation_time": total_time,
}
else:
# Copy base to final (no training)
logger.info(f"πŸ“ Copying base to final for {teacher_name}")
if not copy_base_to_final(teacher_name, enable_training):
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed_copy_to_final",
"error": "Failed to copy base to final",
"distillation_time": total_time,
}
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "success",
"enable_training": enable_training,
"base_path": existing_base,
"final_path": str(final_dir),
"distillation_time": total_time,
}
except Exception as e:
logger.exception(f"❌ Failed to process {teacher_model}")
total_time = time.time() - start_time
return {
"teacher_model": teacher_model,
"teacher_name": teacher_name,
"status": "failed",
"error": str(e),
"distillation_time": total_time,
}
# =============================================================================
# MAIN EXECUTION FUNCTIONS
# =============================================================================
def run_local_distillation(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Run distillation locally."""
logger.info("πŸ–₯️ Running distillation locally")
if teacher_models is None:
teacher_models = DEFAULT_TEACHER_MODELS
results = {}
successful_models = []
logger.info("πŸš€ Starting distillation workflow")
logger.info(f"πŸ“Š Processing {len(teacher_models)} teacher models")
logger.info(f"πŸŽ“ Training enabled: {enable_training}")
# Use default models if none specified
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
logger.info(f"πŸ“Š Teacher models to process: {len(models_to_distill)}")
for i, model in enumerate(models_to_distill, 1):
logger.info(f" {i}. {model}")
# Clear cache for problematic models if requested
if clear_cache:
logger.info("🧹 Clearing cache for known problematic models...")
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
for model in problematic_models:
if model in models_to_distill:
clear_model_cache(model)
# Clear tokenlearn checkpoints if requested (for training mode)
# Note: Checkpoint clearing is handled at the main function level
# Run distillation workflow
for teacher_model in models_to_distill:
result = distill_single_teacher(
teacher_model=teacher_model,
enable_training=enable_training,
use_beam_utilities=False,
pca_dims=pca_dims,
)
teacher_name = result["teacher_name"]
results[teacher_name] = result
if result["status"] == "success" or result["status"].startswith("skipped"):
successful_models.append(teacher_name)
elif result["status"] == "failed_training":
# Note: Training failed but base model may still be available
logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded")
# Summary
logger.info("\nπŸ† DISTILLATION WORKFLOW COMPLETE!")
logger.info(f"πŸ“Š Successful models: {len(successful_models)}")
logger.info(f"πŸŽ“ Training mode: {'Enabled' if enable_training else 'Basic distillation only'}")
for model_name in successful_models:
result = results[model_name]
logger.info(f"βœ… {model_name}: {result['teacher_model']}")
# Save results summary
results_summary = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": successful_models,
"all_results": results,
"total_successful": len(successful_models),
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
}
# Save results to file
results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json"
results_file.parent.mkdir(parents=True, exist_ok=True)
with results_file.open("w") as f:
json.dump(results_summary, f, indent=2)
logger.info(f"πŸ“Š Results summary saved to: {results_file}")
return results_summary
def _beam_distill_internal(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Shared internal implementation for beam distillation."""
if teacher_models is None:
teacher_models = DEFAULT_TEACHER_MODELS
# Clear cache for problematic models if requested
if clear_cache:
logger.info("🧹 Clearing cache for known problematic models...")
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
for model in problematic_models:
if model in teacher_models:
clear_model_cache(model)
results = {}
successful_models = []
logger.info("πŸš€ Starting Beam distillation workflow")
logger.info(f"πŸ“Š Processing {len(teacher_models)} teacher models")
logger.info(f"πŸŽ“ Training enabled: {enable_training}")
# Use default models if none specified
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
logger.info(f"πŸ“Š Teacher models to process: {len(models_to_distill)}")
for i, model in enumerate(models_to_distill, 1):
logger.info(f" {i}. {model}")
for teacher_model in models_to_distill:
result = distill_single_teacher(
teacher_model=teacher_model,
enable_training=enable_training,
use_beam_utilities=True,
pca_dims=pca_dims,
)
teacher_name = result["teacher_name"]
results[teacher_name] = result
if result["status"] == "success" or result["status"].startswith("skipped"):
successful_models.append(teacher_name)
elif result["status"] == "failed_training":
# Note: Training failed but base model may still be available
logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded")
# Summary
logger.info("\nπŸ† BEAM DISTILLATION WORKFLOW COMPLETE!")
logger.info(f"πŸ“Š Successful models: {len(successful_models)}")
# Save results to Beam volume
volume_path = Path(VOLUME_CONFIG.mount_path)
results_summary = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": successful_models,
"all_results": results,
"total_successful": len(successful_models),
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
}
results_file = volume_path / "distillation_results.json"
with results_file.open("w") as f:
json.dump(results_summary, f, indent=2)
logger.info(f"πŸ“Š Beam results saved to: {results_file}")
return results_summary
@function(**get_training_function_kwargs())
def _beam_train_models(
teacher_models: list[str] | None = None,
enable_training: bool = True,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Beam function for training (distillation + tokenlearn)."""
logger.info("☁️ Running training on Beam")
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
@function(**get_distillation_function_kwargs())
def _beam_distill_models(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Beam function for basic distillation only."""
logger.info("☁️ Running distillation on Beam")
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
def run_beam_distillation(
teacher_models: list[str] | None = None,
enable_training: bool = False,
pca_dims: int | None = None,
clear_cache: bool = False,
) -> dict[str, Any]:
"""Run distillation on Beam and sync results."""
logger.info("☁️ Running distillation on Beam with local sync")
try:
# Choose appropriate beam function based on training flag
beam_function = _beam_train_models if enable_training else _beam_distill_models
# Run distillation on Beam
results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache)
# Check if Beam execution was successful
if not results:
logger.error("❌ Beam execution failed or returned no results")
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": [],
"all_results": {},
"total_successful": 0,
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
"error": "Beam execution failed",
}
# Sync models back to local directories
if results.get("successful_models"):
logger.info("πŸ“₯ Syncing models from Beam to local directories...")
for teacher_name in results["successful_models"]:
# Sync base model
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}"
sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True)
# Sync final model if training was enabled
if enable_training:
final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}"
sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True)
else:
# Copy base to final
copy_base_to_final(teacher_name, enable_training)
logger.info("βœ… All models synced from Beam")
return results
except Exception as e:
logger.exception("❌ Beam distillation failed with exception")
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"enable_training": enable_training,
"successful_models": [],
"all_results": {},
"total_successful": 0,
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS),
"error": str(e),
}
# =============================================================================
# CLI INTERFACE
# =============================================================================
def main(
use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False,
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
clear_cache: Annotated[
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
] = False,
clear_checkpoints: Annotated[
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
] = False,
use_optimized_dataset: Annotated[
bool,
typer.Option(
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
),
] = False,
dataset_path: Annotated[
str | None,
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
] = None,
) -> None:
"""Unified distillation command with optional training."""
logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
# Set dataset configuration
distillation_config.use_optimized_dataset = use_optimized_dataset
distillation_config.custom_dataset_path = dataset_path
if use_optimized_dataset and train:
dataset_source = dataset_path or "code_model2vec/dataset"
logger.info(f"🎯 Using optimized dataset from: {dataset_source}")
elif train:
logger.info("🎯 Using C4 dataset for training (following POTION approach)")
logger.info(f"πŸŽ“ Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
# Use default models if none specified
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS
logger.info(f"πŸ“Š Teacher models to process: {len(models_to_distill)}")
for i, model in enumerate(models_to_distill, 1):
logger.info(f" {i}. {model}")
# Clear cache for problematic models if requested
if clear_cache:
logger.info("🧹 Clearing cache for known problematic models...")
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"]
for model in problematic_models:
if model in models_to_distill:
clear_model_cache(model)
# Clear tokenlearn checkpoints if requested (for training mode)
if clear_checkpoints and train:
logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...")
for teacher_model in models_to_distill:
teacher_model.split("/")[-1].replace("-", "_")
# Use the same persistent directory structure as the training function
teacher_slug = teacher_model.replace("/", "_").replace("-", "_")
persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug
features_dir = persistent_tokenlearn_dir / "features"
trained_dir = persistent_tokenlearn_dir / "trained_model"
# Clear persistent tokenlearn checkpoints
if features_dir.exists() or trained_dir.exists():
clear_tokenlearn_checkpoints(features_dir, trained_dir)
logger.info(f"πŸ—‘οΈ Cleared persistent tokenlearn checkpoints for {teacher_model}")
else:
logger.info(f"ℹ️ No tokenlearn checkpoints found for {teacher_model}")
elif clear_checkpoints and not train:
logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)")
# Run distillation workflow
if use_beam:
results = run_beam_distillation(
teacher_models=models_to_distill,
enable_training=train,
pca_dims=pca_dims,
clear_cache=clear_cache,
)
else:
results = run_local_distillation(
teacher_models=models_to_distill,
enable_training=train,
pca_dims=pca_dims,
clear_cache=clear_cache,
)
# Handle case where results might be None or invalid
if not results or not isinstance(results, dict):
logger.error("❌ Distillation workflow failed - no valid results returned")
results = {
"total_successful": 0,
"total_attempted": len(models_to_distill),
"error": "Workflow failed",
}
# Final summary
successful_count = results.get("total_successful", 0)
total_attempted = results.get("total_attempted", 0)
logger.info("\nπŸŽ‰ UNIFIED DISTILLATION WORKFLOW COMPLETED!")
logger.info(f"πŸ“Š Successfully processed: {successful_count}/{total_attempted} models")
logger.info(f"πŸ“ Base models saved to: {LOCAL_BASE_DIR}")
logger.info(f"πŸ“ Final models saved to: {LOCAL_FINAL_DIR}")
if train:
logger.info("πŸŽ“ Advanced training was enabled - models include CodeSearchNet specialization")
else:
logger.info("πŸ“– Basic distillation only - use --train flag to enable advanced training")
def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]:
"""
Check if a model has known compatibility issues with Model2Vec.
Returns:
Tuple of (is_compatible, warning_message)
"""
known_incompatible = {
"BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute",
"jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies",
"Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors",
}
if teacher_model in known_incompatible:
return False, known_incompatible[teacher_model]
# Check for model families that might have issues
if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower():
return False, "BGE models with Qwen2 tokenizers may have compatibility issues"
if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower():
return False, "Jina embeddings v3 models may have missing dependencies"
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
return False, "Salesforce SFR embedding models may have device placement issues"
return True, None
def clear_model_cache(model_name: str) -> bool:
"""Clear HuggingFace cache for a specific model."""
try:
import shutil
from pathlib import Path
# Get HuggingFace cache directory
cache_dir = Path.home() / ".cache" / "huggingface"
# Find model-specific cache directories
model_slug = model_name.replace("/", "--")
# Clear transformers cache
transformers_cache = cache_dir / "transformers" / model_slug
if transformers_cache.exists():
shutil.rmtree(transformers_cache)
logger.info(f"πŸ—‘οΈ Cleared transformers cache for {model_name}")
# Clear hub cache
hub_cache = cache_dir / "hub" / f"models--{model_slug}"
if hub_cache.exists():
shutil.rmtree(hub_cache)
logger.info(f"πŸ—‘οΈ Cleared hub cache for {model_name}")
# Clear modules cache
modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0]
if modules_cache.exists():
shutil.rmtree(modules_cache)
logger.info(f"πŸ—‘οΈ Cleared modules cache for {model_name}")
return True
except Exception as e:
logger.warning(f"Failed to clear cache for {model_name}: {e}")
return False
def try_model_workarounds(teacher_model: str) -> str | None:
"""
Try specific workarounds for problematic models.
Returns:
The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available
"""
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower():
logger.info("πŸ”§ Salesforce SFR model detected - will use specialized distillation")
return "salesforce"
if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()):
logger.info("πŸ”§ BAAI BGE model detected - will use specialized distillation")
return "baai"
return None
def salesforce_model_distillation(
teacher_model: str,
output_dir: str,
pca_dims: int | None = None,
) -> Any:
"""Special distillation function for Salesforce SFR models that handles device placement issues."""
if pca_dims is None:
pca_dims = int(distillation_config.optimal_pca_dims)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"πŸ”„ Salesforce-specific distillation: {teacher_model} β†’ {output_dir}")
logger.info(f"πŸ“Š PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
start_time = time.time()
try:
import torch
from transformers import AutoModel, AutoTokenizer
# Enhanced custom model loading for Salesforce models
logger.info("πŸ”§ Loading model with enhanced device settings...")
# Method 1: Try with to_empty() for meta tensor handling
try:
logger.info("πŸ”„ Attempting with to_empty() method...")
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
# Load model with meta device initially
model = AutoModel.from_pretrained(
teacher_model,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="meta", # Load on meta device first
)
# Move from meta to actual device using to_empty()
if torch.cuda.is_available():
device = torch.device("cuda")
# Create empty tensors on target device and copy weights
model = model.to_empty(device=device)
else:
device = torch.device("cpu")
model = model.to_empty(device=device)
# Ensure model is in the right dtype
model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32)
logger.info("βœ… Successfully loaded with to_empty() method")
except Exception as e:
logger.warning(f"to_empty() method failed: {e}")
# Method 2: Try SentenceTransformer with specific settings
logger.info("πŸ”„ Falling back to SentenceTransformer method...")
sentence_model = load_model_with_flash_attention(
teacher_model,
device="cpu", # Force CPU loading first
)
# Move to GPU if available
if torch.cuda.is_available():
sentence_model = sentence_model.to("cuda")
# Extract components
model = sentence_model[0].auto_model
tokenizer = sentence_model.tokenizer
logger.info("βœ… Successfully loaded with SentenceTransformer method")
# Now use Model2Vec's distill_from_model function directly
from distiller.model2vec.distill.distillation import distill_from_model
distilled_model = distill_from_model(
model=model,
tokenizer=tokenizer,
pca_dims=int(pca_dims),
apply_zipf=bool(distillation_config.apply_zipf),
sif_coefficient=float(distillation_config.sif_coefficient),
)
logger.info("βœ… Core distillation completed successfully")
# Save the model
distilled_model.save_pretrained(str(output_path))
logger.info(f"πŸ’Ύ Model saved to {output_path}")
# Log model info
logger.info(f"Model type: {type(distilled_model)}")
if hasattr(distilled_model, "embedding"):
logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
total_time = time.time() - start_time
logger.info(f"πŸŽ‰ Salesforce distillation completed in {total_time:.2f} seconds")
# Clean up
if "sentence_model" in locals():
del sentence_model
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return distilled_model
except Exception:
logger.exception(f"❌ Salesforce-specific distillation failed for {teacher_model}")
return None
def baai_bge_model_distillation(
teacher_model: str,
output_dir: str,
pca_dims: int | None = None,
) -> Any:
"""Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues."""
if pca_dims is None:
pca_dims = int(distillation_config.optimal_pca_dims)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"πŸ”„ BAAI BGE-specific distillation: {teacher_model} β†’ {output_dir}")
logger.info(f"πŸ“Š PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}")
start_time = time.time()
try:
import torch
from transformers import AutoModel, AutoTokenizer
logger.info("πŸ”§ Loading BAAI model with tokenizer workaround...")
# Try multiple approaches for BAAI models
success = False
# Method 1: Try SentenceTransformer first (often handles tokenizer issues better)
try:
logger.info("πŸ”„ Attempting with SentenceTransformer wrapper...")
sentence_model = load_model_with_flash_attention(teacher_model)
# Extract components
model = sentence_model[0].auto_model
tokenizer = sentence_model.tokenizer
# Test if tokenizer works by encoding a simple text
test_encoding = tokenizer.encode("test", return_tensors="pt")
logger.info("βœ… SentenceTransformer method successful")
success = True
except Exception as e:
logger.warning(f"SentenceTransformer method failed: {e}")
# Method 2: Try direct loading with tokenizer replacement
try:
logger.info("πŸ”„ Attempting with tokenizer replacement...")
from transformers import BertTokenizerFast
# Load model directly
model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True)
# Try to use a compatible tokenizer instead
try:
# First try the original tokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True)
except Exception:
# Fall back to BERT tokenizer for BGE models
logger.info("πŸ”„ Falling back to BERT tokenizer...")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
logger.info("βœ… Tokenizer replacement method successful")
success = True
except Exception as e2:
logger.warning(f"Tokenizer replacement method failed: {e2}")
if not success:
logger.error("❌ All BAAI model loading methods failed")
return None
# Now use Model2Vec's distill_from_model function directly
from distiller.model2vec.distill.distillation import distill_from_model
distilled_model = distill_from_model(
model=model,
tokenizer=tokenizer,
pca_dims=int(pca_dims),
apply_zipf=bool(distillation_config.apply_zipf),
sif_coefficient=float(distillation_config.sif_coefficient),
)
logger.info("βœ… Core distillation completed successfully")
# Save the model
distilled_model.save_pretrained(str(output_path))
logger.info(f"πŸ’Ύ Model saved to {output_path}")
# Log model info
logger.info(f"Model type: {type(distilled_model)}")
if hasattr(distilled_model, "embedding"):
logger.info(f"Embedding shape: {distilled_model.embedding.shape}")
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}")
total_time = time.time() - start_time
logger.info(f"πŸŽ‰ BAAI BGE distillation completed in {total_time:.2f} seconds")
# Clean up
if "sentence_model" in locals():
del sentence_model
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return distilled_model
except Exception:
logger.exception(f"❌ BAAI BGE-specific distillation failed for {teacher_model}")
return None
def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None:
"""Clear tokenlearn checkpoint markers to force re-execution of steps."""
featurization_marker = features_dir / ".featurization_complete"
training_marker = trained_dir / ".training_complete"
if featurization_marker.exists():
featurization_marker.unlink()
logger.info(f"πŸ—‘οΈ Cleared featurization checkpoint: {featurization_marker}")
if training_marker.exists():
training_marker.unlink()
logger.info(f"πŸ—‘οΈ Cleared training checkpoint: {training_marker}")
def verify_featurization_output(features_dir: Path) -> bool:
"""Verify that featurization output files actually exist and are valid."""
if not features_dir.exists():
return False
# Check for expected tokenlearn output files
# Check if any expected files exist
return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"])
def verify_training_output(trained_dir: Path) -> bool:
"""Verify that training output files actually exist and are valid."""
if not trained_dir.exists():
return False
# Check for model files
model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"]
for model_file in model_files:
if (trained_dir / model_file).exists():
return True
# Check for alternative model directory structure
for subdir in ["model", "model_weighted"]:
subdir_path = trained_dir / subdir
if subdir_path.exists():
for model_file in model_files:
if (subdir_path / model_file).exists():
return True
return False
def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
"""
Prepare dataset for tokenlearn featurization.
Returns:
Tuple of (dataset_path, dataset_name, text_key) for tokenlearn
"""
if distillation_config.use_optimized_dataset:
return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir)
return _prepare_original_dataset_for_tokenlearn()
def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
"""Prepare custom optimized dataset for tokenlearn featurization."""
logger.info("🎯 Preparing custom optimized dataset for tokenlearn...")
# Import the dataset module
from .dataset import create_optimized_dataset, load_optimized_dataset
# Define paths
custom_dataset_dir = (
Path(distillation_config.custom_dataset_path)
if distillation_config.custom_dataset_path
else Path("code_model2vec/dataset")
)
tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset"
# Check if we need to create the custom dataset
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
logger.info("πŸ“Š Custom dataset not found - creating optimized dataset...")
create_optimized_dataset(
max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, # Divide by number of languages
output_dir=custom_dataset_dir,
create_multiple_formats=False, # Use simple format for tokenlearn
)
# Load the custom dataset
logger.info(f"πŸ“‚ Loading custom dataset from {custom_dataset_dir}")
train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train")
# Prepare dataset for tokenlearn (save as JSON files that load_dataset can read)
tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True)
# Save as JSON file that tokenlearn can load with load_dataset()
train_json_path = tokenlearn_dataset_dir / "train.json"
# Create JSON lines format
import json
with train_json_path.open("w") as f:
for text in train_df["text"]:
json.dump({"text": text}, f)
f.write("\n")
logger.info(f"βœ… Prepared custom dataset with {len(train_df)} samples for tokenlearn")
logger.info(f"πŸ’Ύ Saved JSON dataset to {train_json_path}")
# Return the JSON file path directly (not directory) and no config name for JSON loading
return str(train_json_path), None, "text"
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]:
"""Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach)."""
logger.info("πŸ“Š Using C4 dataset for tokenlearn (following POTION approach)...")
return (
str(distillation_config.tokenlearn_dataset), # "allenai/c4"
str(distillation_config.tokenlearn_dataset_name), # "en"
str(distillation_config.tokenlearn_text_key), # "text"
)
if __name__ == "__main__":
typer.run(main)