|
""" |
|
Unified Code-Specialized Model2Vec Distillation Script. |
|
|
|
This script provides a unified approach for creating code-specialized embeddings |
|
using Model2Vec distillation with optional code-specific training. |
|
|
|
Features: |
|
- Basic distillation (default): Simple Model2Vec distillation |
|
- Advanced training (--train flag): Additional CodeSearchNet fine-tuning |
|
- Checkpoint support with Beam sync utilities |
|
- Multi-teacher model processing |
|
- Smart resume capabilities |
|
- Hierarchical storage: base β final |
|
|
|
Directory Structure: |
|
- code_model2vec/base: Basic distilled models (first step) |
|
- code_model2vec/final: Final models (copied from base or after training) |
|
|
|
Usage: |
|
distiller distill [--use-beam] [--train] # Basic distillation or with training |
|
""" |
|
|
|
import importlib.util |
|
import json |
|
import logging |
|
import os |
|
import time |
|
from pathlib import Path |
|
from typing import Annotated, Any |
|
|
|
import torch |
|
import typer |
|
from beam import function |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from distiller.model2vec.distill import distill |
|
|
|
|
|
from .beam_utils import ( |
|
BeamCheckpointManager, |
|
create_beam_utilities, |
|
download_model_from_beam, |
|
sync_checkpoints_from_beam, |
|
sync_checkpoints_to_beam, |
|
upload_model_to_beam, |
|
) |
|
from .config import ( |
|
codesearchnet_config, |
|
directories, |
|
distillation_config, |
|
get_distillation_function_kwargs, |
|
get_training_function_kwargs, |
|
get_volume_config, |
|
languages_config, |
|
) |
|
|
|
|
|
FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None |
|
|
|
|
|
|
|
|
|
|
|
VOLUME_CONFIG = get_volume_config() |
|
LOCAL_BASE_DIR = directories.base |
|
LOCAL_FINAL_DIR = directories.final |
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_flash_attention() -> dict[str, Any]: |
|
"""Configure flash attention settings and return model kwargs.""" |
|
model_kwargs: dict[str, Any] = {} |
|
|
|
if not FLASH_ATTN_AVAILABLE: |
|
logger.info("β οΈ Flash attention not available - using standard attention") |
|
return model_kwargs |
|
|
|
|
|
os.environ["FLASH_ATTENTION_FORCE_USE"] = "1" |
|
|
|
os.environ["TORCH_COMPILE_DISABLE"] = "1" |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
try: |
|
if torch.cuda.is_available(): |
|
device_capability = torch.cuda.get_device_capability() |
|
|
|
if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5): |
|
logger.info("β
Flash attention enabled - compatible GPU detected") |
|
model_kwargs.update( |
|
{ |
|
"model_kwargs": { |
|
"attn_implementation": "flash_attention_2", |
|
"torch_dtype": torch.float16, |
|
"use_flash_attention_2": True, |
|
"_attn_implementation": "flash_attention_2", |
|
} |
|
} |
|
) |
|
else: |
|
logger.info(f"β οΈ GPU compute capability {device_capability} < 7.5 - flash attention disabled") |
|
else: |
|
logger.info("β οΈ No CUDA available - flash attention disabled") |
|
except Exception as e: |
|
logger.warning(f"β οΈ Failed to check GPU compatibility: {e} - flash attention disabled") |
|
|
|
return model_kwargs |
|
|
|
|
|
def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer: |
|
"""Load a SentenceTransformer model with flash attention if available.""" |
|
flash_kwargs = configure_flash_attention() |
|
|
|
try: |
|
|
|
if flash_kwargs and "model_kwargs" in flash_kwargs: |
|
logger.info(f"π Loading model with flash attention: {Path(model_path).name}") |
|
model = SentenceTransformer(model_path, device=device, trust_remote_code=True, **flash_kwargs) |
|
logger.info("β
Model loaded successfully with flash attention") |
|
return model |
|
except Exception as e: |
|
logger.warning(f"β οΈ Failed to load with flash attention: {e}") |
|
logger.info("π Falling back to standard attention") |
|
|
|
|
|
logger.info(f"π Loading model with standard attention: {Path(model_path).name}") |
|
model = SentenceTransformer(model_path, device=device, trust_remote_code=True) |
|
logger.info("β
Model loaded successfully with standard attention") |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_config_hash(enable_training: bool) -> str: |
|
"""Generate a hash of current configuration parameters for checkpoint validation.""" |
|
import hashlib |
|
|
|
config_params = { |
|
"pca_dims": distillation_config.optimal_pca_dims, |
|
"sif_coefficient": distillation_config.sif_coefficient, |
|
"apply_zipf": distillation_config.apply_zipf, |
|
"enable_training": enable_training, |
|
} |
|
|
|
if enable_training: |
|
|
|
tokenlearn_hash = hash( |
|
f"{distillation_config.tokenlearn_dataset}_{distillation_config.tokenlearn_dataset_name}_{distillation_config.tokenlearn_text_key}" |
|
) |
|
config_params["tokenlearn_hash"] = float(abs(tokenlearn_hash) % 1000000) |
|
|
|
config_str = str(sorted(config_params.items())) |
|
return hashlib.md5(config_str.encode()).hexdigest()[:12] |
|
|
|
|
|
def check_existing_base_model(teacher_name: str) -> str | None: |
|
"""Check if base distilled model already exists locally.""" |
|
base_dir = Path(LOCAL_BASE_DIR) |
|
model_dir = base_dir / f"code_model2vec_{teacher_name}" |
|
|
|
if model_dir.exists(): |
|
|
|
has_config = (model_dir / "config.json").exists() |
|
has_model_file = any( |
|
[ |
|
(model_dir / "model.safetensors").exists(), |
|
(model_dir / "model.bin").exists(), |
|
(model_dir / "pytorch_model.bin").exists(), |
|
] |
|
) |
|
|
|
if has_config and has_model_file: |
|
logger.info(f"β
Found existing base model: {teacher_name}") |
|
return str(model_dir) |
|
|
|
return None |
|
|
|
|
|
def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None: |
|
"""Check if final model already exists locally.""" |
|
final_dir = Path(LOCAL_FINAL_DIR) |
|
|
|
|
|
model_name = f"code_model2vec_{teacher_name}" |
|
if enable_training: |
|
model_name += "_fine_tuned" |
|
final_path = final_dir / model_name |
|
|
|
if final_path.exists(): |
|
|
|
has_config = (final_path / "config.json").exists() |
|
has_model_file = any( |
|
[ |
|
(final_path / "model.safetensors").exists(), |
|
(final_path / "model.bin").exists(), |
|
(final_path / "pytorch_model.bin").exists(), |
|
] |
|
) |
|
|
|
if has_config and has_model_file: |
|
logger.info(f"β
Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}") |
|
return str(final_path) |
|
|
|
return None |
|
|
|
|
|
def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool: |
|
"""Copy base model to final directory.""" |
|
import shutil |
|
|
|
base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" |
|
|
|
|
|
final_model_name = f"code_model2vec_{teacher_name}" |
|
if enable_training: |
|
final_model_name += "_fine_tuned" |
|
final_path = Path(LOCAL_FINAL_DIR) / final_model_name |
|
|
|
try: |
|
final_path.parent.mkdir(parents=True, exist_ok=True) |
|
if final_path.exists(): |
|
shutil.rmtree(final_path) |
|
shutil.copytree(base_path, final_path) |
|
logger.info(f"π Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}") |
|
return True |
|
except Exception: |
|
logger.exception(f"β Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}") |
|
return False |
|
|
|
|
|
def sync_model_from_beam( |
|
teacher_name: str, |
|
target_dir: str, |
|
use_beam_utilities: bool = False, |
|
) -> bool: |
|
"""Sync model from Beam volume to local directory.""" |
|
if not use_beam_utilities: |
|
return False |
|
|
|
try: |
|
target_path = Path(target_dir) |
|
target_path.mkdir(parents=True, exist_ok=True) |
|
|
|
beam_model_name = f"{teacher_name}_model" |
|
success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path)) |
|
|
|
if success: |
|
logger.info(f"π₯ Synced {teacher_name} from Beam to {target_dir}") |
|
return True |
|
logger.warning(f"β οΈ Failed to sync {teacher_name} from Beam") |
|
return False |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to sync {teacher_name} from Beam: {e}") |
|
return False |
|
|
|
|
|
def sync_model_to_beam( |
|
teacher_name: str, |
|
source_dir: str, |
|
use_beam_utilities: bool = False, |
|
) -> bool: |
|
"""Sync model from local directory to Beam volume.""" |
|
if not use_beam_utilities: |
|
return False |
|
|
|
try: |
|
beam_model_name = f"{teacher_name}_model" |
|
success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir) |
|
|
|
if success: |
|
logger.info(f"π€ Synced {teacher_name} to Beam from {source_dir}") |
|
return True |
|
logger.warning(f"β οΈ Failed to sync {teacher_name} to Beam") |
|
return False |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to sync {teacher_name} to Beam: {e}") |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def simple_distillation( |
|
teacher_model: str, |
|
output_dir: str, |
|
pca_dims: int | None = None, |
|
retry_with_cache_clear: bool = False, |
|
) -> Any: |
|
""" |
|
Perform simple Model2Vec distillation without additional training. |
|
|
|
Args: |
|
teacher_model: Name of teacher model |
|
output_dir: Output directory for the distilled model |
|
pca_dims: PCA dimensions (uses config default if None) |
|
retry_with_cache_clear: Whether this is a retry after clearing cache |
|
|
|
Returns: |
|
Distilled model or None if failed |
|
""" |
|
if pca_dims is None: |
|
pca_dims = int(distillation_config.optimal_pca_dims) |
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else "" |
|
logger.info(f"π Simple distillation{retry_suffix}: {teacher_model} β {output_dir}") |
|
logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
|
|
model = distill( |
|
model_name=teacher_model, |
|
pca_dims=int(pca_dims), |
|
apply_zipf=bool(distillation_config.apply_zipf), |
|
sif_coefficient=float(distillation_config.sif_coefficient), |
|
trust_remote_code=True, |
|
) |
|
|
|
logger.info("β
Core distillation completed successfully") |
|
|
|
|
|
if hasattr(model, "tokenizer") and hasattr(model, "embedding"): |
|
vocab_size = len(model.tokenizer.get_vocab()) |
|
embedding_size = model.embedding.shape[0] |
|
|
|
logger.info("π Model validation:") |
|
logger.info(f" - Vocabulary size: {vocab_size}") |
|
logger.info(f" - Embedding matrix size: {embedding_size}") |
|
|
|
if vocab_size != embedding_size: |
|
logger.warning(f"β οΈ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}") |
|
logger.warning("β οΈ This may cause issues in downstream usage") |
|
else: |
|
logger.info("β
Vocabulary and embedding sizes match") |
|
|
|
|
|
model.save_pretrained(str(output_path)) |
|
logger.info(f"πΎ Model saved to {output_path}") |
|
|
|
|
|
logger.info(f"Model type: {type(model)}") |
|
if hasattr(model, "embedding"): |
|
logger.info(f"Embedding shape: {model.embedding.shape}") |
|
logger.info(f"Embedding dtype: {model.embedding.dtype}") |
|
|
|
total_time = time.time() - start_time |
|
logger.info(f"π Simple distillation completed in {total_time:.2f} seconds") |
|
return model |
|
|
|
except ValueError as e: |
|
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e): |
|
logger.warning(f"β οΈ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue") |
|
logger.warning(f"Error details: {e}") |
|
logger.warning("π‘ This model has incompatible tokenization. Skipping...") |
|
return None |
|
if "weight is on the meta device" in str(e): |
|
logger.warning(f"β οΈ Device placement issue with {teacher_model} - model weights on meta device") |
|
logger.warning(f"Error details: {e}") |
|
logger.warning("π‘ This model has device placement issues. Skipping...") |
|
return None |
|
raise |
|
except AttributeError as e: |
|
if "backend_tokenizer" in str(e): |
|
logger.warning(f"β οΈ Tokenizer compatibility issue with {teacher_model}") |
|
logger.warning(f"Error details: {e}") |
|
logger.warning("π‘ This model's tokenizer is incompatible with Model2Vec. Skipping...") |
|
return None |
|
raise |
|
except FileNotFoundError as e: |
|
if "transformers_modules" in str(e) or "xlm_padding.py" in str(e): |
|
logger.warning(f"β οΈ Missing custom model files for {teacher_model}") |
|
logger.warning(f"Error details: {e}") |
|
|
|
|
|
if not retry_with_cache_clear: |
|
logger.info("π§ Attempting to clear cache and retry...") |
|
if clear_model_cache(teacher_model): |
|
logger.info("π Retrying distillation after cache clear...") |
|
return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True) |
|
|
|
logger.warning("π‘ This model has missing dependencies. Manual intervention may be required.") |
|
return None |
|
raise |
|
except Exception: |
|
logger.exception(f"β Simple distillation failed for {teacher_model}") |
|
return None |
|
|
|
|
|
def load_optimized_dataset( |
|
max_samples: int | None = None, |
|
checkpoint_manager: BeamCheckpointManager | None = None, |
|
dataset_path: str | None = None, |
|
) -> list[str]: |
|
"""Load our pre-created optimized dataset for tokenlearn training.""" |
|
from .dataset import DATASET_OUTPUT_DIR |
|
from .dataset import load_optimized_dataset as load_dataset_func |
|
|
|
|
|
if dataset_path is None: |
|
dataset_path = distillation_config.custom_dataset_path |
|
|
|
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR |
|
|
|
|
|
if max_samples is None: |
|
max_samples = distillation_config.tokenlearn_max_samples |
|
|
|
logger.info(f"π― Loading optimized dataset from {dataset_dir}") |
|
logger.info(f"π Target samples: {max_samples}") |
|
|
|
try: |
|
|
|
df = load_dataset_func(output_dir=dataset_dir, split="train") |
|
|
|
|
|
texts = df["text"].tolist() |
|
|
|
|
|
import random |
|
|
|
random.seed(42) |
|
random.shuffle(texts) |
|
|
|
|
|
if len(texts) > max_samples: |
|
texts = texts[:max_samples] |
|
|
|
logger.info(f"β
Loaded {len(texts)} optimized training samples") |
|
|
|
|
|
languages = df["language"].value_counts() |
|
logger.info("π Language distribution:") |
|
for lang, count in languages.items(): |
|
percentage = (count / len(df)) * 100 |
|
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)") |
|
|
|
return texts |
|
|
|
except Exception as e: |
|
logger.warning(f"β οΈ Failed to load optimized dataset: {e}") |
|
logger.info("π Falling back to original CodeSearchNet loading...") |
|
return load_codesearchnet_dataset(max_samples, checkpoint_manager) |
|
|
|
|
|
def load_codesearchnet_dataset( |
|
max_samples: int | None = None, |
|
checkpoint_manager: BeamCheckpointManager | None = None, |
|
) -> list[str]: |
|
"""Load and format the CodeSearchNet dataset for token frequency computation.""" |
|
from datasets import load_dataset |
|
|
|
|
|
if max_samples is None: |
|
max_samples = distillation_config.tokenlearn_max_samples |
|
|
|
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}") |
|
logger.info(f"Limiting to {max_samples} samples for training efficiency") |
|
logger.info(f"Languages: {', '.join(languages_config.all)}") |
|
|
|
|
|
texts = [] |
|
start_from = 0 |
|
|
|
if checkpoint_manager: |
|
checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0) |
|
if checkpoint_data: |
|
cached_texts = checkpoint_data.get("data", {}).get("texts", []) |
|
if len(cached_texts) >= max_samples: |
|
logger.info(f"β
Resumed dataset loading: {len(cached_texts)} texts from checkpoint") |
|
return cached_texts[:max_samples] |
|
logger.info(f"π Partial dataset found: {len(cached_texts)} texts, continuing...") |
|
texts = cached_texts |
|
start_from = len(texts) |
|
|
|
try: |
|
|
|
num_languages = len(languages_config.all) |
|
samples_per_language = max_samples // num_languages |
|
remaining_samples = max_samples % num_languages |
|
|
|
logger.info(f"π Target distribution: {samples_per_language} samples per language") |
|
if remaining_samples > 0: |
|
logger.info(f"π Extra {remaining_samples} samples will be distributed to first languages") |
|
|
|
|
|
language_texts: dict[str, list[str]] = {} |
|
total_collected = len(texts) |
|
|
|
for i, language in enumerate(languages_config.all): |
|
if total_collected >= max_samples: |
|
break |
|
|
|
logger.info(f"π Loading {language} training data...") |
|
|
|
|
|
target_for_lang = samples_per_language |
|
if i < remaining_samples: |
|
target_for_lang += 1 |
|
|
|
|
|
if language in language_texts and len(language_texts[language]) >= target_for_lang: |
|
continue |
|
|
|
try: |
|
|
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset( |
|
codesearchnet_config.dataset_name, |
|
language, |
|
split="train", |
|
trust_remote_code=True, |
|
) |
|
|
|
lang_texts: list[str] = [] |
|
processed_count = 0 |
|
|
|
for processed_count, example in enumerate(dataset, 1): |
|
if len(lang_texts) >= target_for_lang: |
|
break |
|
|
|
|
|
doc_string = example.get("func_documentation_string", "").strip() |
|
code_string = example.get("func_code_string", "").strip() |
|
|
|
if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50: |
|
|
|
text = f"Documentation: {doc_string}\nCode:\n{code_string}" |
|
|
|
|
|
if len(text) <= 2048: |
|
lang_texts.append(text) |
|
|
|
if processed_count % 5000 == 0: |
|
logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}") |
|
|
|
language_texts[language] = lang_texts |
|
total_collected += len(lang_texts) |
|
logger.info(f"β
{language}: collected {len(lang_texts)} samples") |
|
|
|
except Exception as e: |
|
logger.warning(f"β οΈ Failed to load {language} data: {e}") |
|
continue |
|
|
|
|
|
combined_texts = [] |
|
|
|
|
|
if start_from > 0: |
|
combined_texts = texts[:start_from] |
|
|
|
|
|
max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0 |
|
|
|
for sample_idx in range(max_lang_samples): |
|
for language in languages_config.all: |
|
if len(combined_texts) >= max_samples: |
|
break |
|
|
|
if language in language_texts and sample_idx < len(language_texts[language]): |
|
combined_texts.append(language_texts[language][sample_idx]) |
|
|
|
if len(combined_texts) >= max_samples: |
|
break |
|
|
|
|
|
combined_texts = combined_texts[:max_samples] |
|
|
|
|
|
logger.info("π Final dataset distribution:") |
|
lang_counts: dict[str, int] = {} |
|
for text in combined_texts: |
|
|
|
if "def " in text and ":" in text: |
|
lang_counts["python"] = lang_counts.get("python", 0) + 1 |
|
elif "function " in text and "{" in text: |
|
lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1 |
|
elif "public " in text and "class " in text: |
|
lang_counts["java"] = lang_counts.get("java", 0) + 1 |
|
elif "<?php" in text or "$" in text: |
|
lang_counts["php"] = lang_counts.get("php", 0) + 1 |
|
elif "func " in text and "end" in text: |
|
lang_counts["ruby"] = lang_counts.get("ruby", 0) + 1 |
|
elif "func " in text and "}" in text: |
|
lang_counts["go"] = lang_counts.get("go", 0) + 1 |
|
else: |
|
lang_counts["other"] = lang_counts.get("other", 0) + 1 |
|
|
|
for lang, count in lang_counts.items(): |
|
percentage = (count / len(combined_texts)) * 100 |
|
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)") |
|
|
|
|
|
if checkpoint_manager: |
|
checkpoint_data = { |
|
"config_hash": get_current_config_hash(enable_training=True), |
|
"stage": "dataset", |
|
"step": 0, |
|
"timestamp": time.time(), |
|
"data": {"texts": combined_texts}, |
|
} |
|
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0) |
|
|
|
logger.info(f"Successfully loaded {len(combined_texts)} balanced code-documentation pairs from CodeSearchNet") |
|
return combined_texts |
|
|
|
except Exception: |
|
logger.exception("Error loading CodeSearchNet dataset") |
|
return texts |
|
|
|
|
|
def generate_teacher_embeddings( |
|
teacher_model: SentenceTransformer, |
|
texts: list[str], |
|
checkpoint_manager: BeamCheckpointManager | None = None, |
|
) -> torch.Tensor: |
|
"""Generate teacher embeddings for code training with checkpoint support.""" |
|
logger.info(f"Generating teacher embeddings for {len(texts)} texts...") |
|
|
|
|
|
if checkpoint_manager: |
|
volume_path = Path(VOLUME_CONFIG.mount_path) |
|
embeddings_path = volume_path / "embeddings_cache.pt" |
|
config_path = volume_path / "embeddings_config.json" |
|
|
|
if embeddings_path.exists() and config_path.exists(): |
|
try: |
|
|
|
with config_path.open("r") as f: |
|
config_data = json.load(f) |
|
|
|
current_hash = get_current_config_hash(enable_training=True) |
|
if config_data.get("config_hash") == current_hash: |
|
|
|
final_embeddings = torch.load(embeddings_path, map_location="cpu") |
|
num_expected = config_data.get("num_texts", len(texts)) |
|
|
|
if final_embeddings.shape[0] >= num_expected: |
|
logger.info(f"β
Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)") |
|
return final_embeddings[: len(texts)] |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to load embeddings cache: {e}, regenerating...") |
|
|
|
|
|
logger.info("Generating fresh teacher embeddings...") |
|
|
|
batch_size = 16 |
|
embeddings_list = [] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i : i + batch_size] |
|
|
|
try: |
|
batch_embeddings = teacher_model.encode( |
|
batch_texts, |
|
convert_to_tensor=True, |
|
batch_size=batch_size, |
|
show_progress_bar=False, |
|
normalize_embeddings=True, |
|
) |
|
embeddings_list.append(batch_embeddings) |
|
|
|
if i % (batch_size * 10) == 0: |
|
logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts") |
|
|
|
except torch.cuda.OutOfMemoryError: |
|
logger.warning(f"GPU OOM with batch size {batch_size}, reducing...") |
|
torch.cuda.empty_cache() |
|
batch_size = max(1, batch_size // 2) |
|
|
|
|
|
batch_embeddings = teacher_model.encode( |
|
batch_texts, |
|
convert_to_tensor=True, |
|
batch_size=batch_size, |
|
show_progress_bar=False, |
|
normalize_embeddings=True, |
|
) |
|
embeddings_list.append(batch_embeddings) |
|
|
|
|
|
teacher_embeddings = torch.cat(embeddings_list, dim=0) |
|
|
|
|
|
if teacher_embeddings.dtype != torch.float32: |
|
teacher_embeddings = teacher_embeddings.to(torch.float32) |
|
|
|
logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}") |
|
|
|
|
|
if checkpoint_manager: |
|
try: |
|
volume_path = Path(VOLUME_CONFIG.mount_path) |
|
embeddings_path = volume_path / "embeddings_cache.pt" |
|
config_path = volume_path / "embeddings_config.json" |
|
|
|
|
|
torch.save(teacher_embeddings, embeddings_path) |
|
|
|
|
|
config_data = { |
|
"config_hash": get_current_config_hash(enable_training=True), |
|
"num_texts": len(texts), |
|
"embedding_shape": list(teacher_embeddings.shape), |
|
"timestamp": time.time(), |
|
} |
|
|
|
with config_path.open("w") as f: |
|
json.dump(config_data, f, indent=2) |
|
|
|
logger.info("πΎ Saved embeddings cache for future runs") |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to save embeddings cache: {e}") |
|
|
|
return teacher_embeddings |
|
|
|
|
|
def tokenlearn_training( |
|
student_model: Any, |
|
teacher_model: SentenceTransformer, |
|
checkpoint_manager: BeamCheckpointManager | None = None, |
|
) -> Any: |
|
""" |
|
Perform tokenlearn training following the official POTION approach. |
|
|
|
This follows the 4-step process: |
|
1. Model2Vec distillation (already done - student_model) |
|
2. Sentence transformer inference (create features) |
|
3. Tokenlearn training |
|
""" |
|
from pathlib import Path |
|
|
|
logger.info("π§ͺ Starting tokenlearn training (POTION approach)...") |
|
|
|
|
|
teacher_model_name = getattr(teacher_model, "model_name", None) |
|
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: |
|
|
|
first_module = next(iter(teacher_model._modules.values())) |
|
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"): |
|
teacher_model_name = first_module.auto_model.name_or_path |
|
|
|
if not teacher_model_name: |
|
teacher_model_name = "unknown_teacher" |
|
|
|
|
|
teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_") |
|
persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug |
|
|
|
features_dir = persistent_tokenlearn_dir / "features" |
|
model_dir = persistent_tokenlearn_dir / "base_model" |
|
trained_dir = persistent_tokenlearn_dir / "trained_model" |
|
|
|
features_dir.mkdir(parents=True, exist_ok=True) |
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
trained_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"π Using persistent tokenlearn directory: {persistent_tokenlearn_dir}") |
|
|
|
|
|
student_model.save_pretrained(str(model_dir)) |
|
logger.info(f"πΎ Saved base model to {model_dir}") |
|
|
|
|
|
logger.info("π Step 2: Creating features using sentence transformer...") |
|
|
|
|
|
teacher_model_name = getattr(teacher_model, "model_name", None) |
|
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: |
|
|
|
|
|
first_module = next(iter(teacher_model._modules.values())) |
|
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"): |
|
teacher_model_name = first_module.auto_model.name_or_path |
|
|
|
logger.info(f"π Using teacher model: {teacher_model_name}") |
|
|
|
|
|
dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir) |
|
|
|
|
|
featurization_complete_marker = features_dir / ".featurization_complete" |
|
if featurization_complete_marker.exists() and verify_featurization_output(features_dir): |
|
logger.info("β
Found existing featurization checkpoint with valid output files") |
|
logger.info(f"π Using cached features from: {features_dir}") |
|
|
|
|
|
output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json")) |
|
logger.info(f"π Found {len(output_files)} cached feature files") |
|
else: |
|
if featurization_complete_marker.exists(): |
|
logger.warning("β οΈ Featurization marker exists but output files are missing - re-running featurization") |
|
featurization_complete_marker.unlink() |
|
logger.info("π No valid featurization checkpoint found - starting featurization...") |
|
|
|
if not teacher_model_name: |
|
logger.warning("β οΈ Could not determine teacher model name, using fallback") |
|
teacher_model_name = "BAAI/bge-base-en-v1.5" |
|
|
|
logger.info(f"π Using teacher model: {teacher_model_name}") |
|
|
|
try: |
|
|
|
from datasets import load_dataset |
|
|
|
from distiller.tokenlearn.featurize import featurize |
|
|
|
logger.info("π Running tokenlearn featurization...") |
|
logger.info(f"π Dataset: {dataset_path} (config: {dataset_name})") |
|
logger.info(f"π Text field: {text_key}") |
|
|
|
|
|
if dataset_name is None: |
|
|
|
dataset = load_dataset( |
|
"json", |
|
data_files=dataset_path, |
|
split="train", |
|
streaming=True, |
|
) |
|
else: |
|
|
|
dataset = load_dataset( |
|
dataset_path, |
|
name=dataset_name, |
|
split="train", |
|
streaming=True, |
|
) |
|
|
|
|
|
featurize( |
|
dataset=iter(dataset), |
|
model=teacher_model, |
|
output_dir=str(features_dir), |
|
max_means=50000, |
|
batch_size=512, |
|
text_key=text_key, |
|
) |
|
|
|
logger.info("β
Featurization completed successfully") |
|
|
|
|
|
featurization_complete_marker.touch() |
|
logger.info(f"πΎ Created featurization checkpoint: {featurization_complete_marker}") |
|
|
|
except Exception as e: |
|
logger.exception("π₯ Tokenlearn featurization failed") |
|
logger.exception("π₯ Tokenlearn featurization is required for training - cannot proceed") |
|
msg = f"Tokenlearn featurization failed: {e}" |
|
raise RuntimeError(msg) from e |
|
|
|
|
|
logger.info("π Step 3: Training using tokenlearn...") |
|
|
|
|
|
training_complete_marker = trained_dir / ".training_complete" |
|
training_fallback_marker = trained_dir / ".training_fallback" |
|
|
|
if training_complete_marker.exists() and verify_training_output(trained_dir): |
|
logger.info("β
Found existing training checkpoint with valid model files") |
|
logger.info(f"π Using cached trained model from: {trained_dir}") |
|
|
|
|
|
model_files = [] |
|
for pattern in ["*.json", "*.safetensors", "*.bin"]: |
|
model_files.extend(list(trained_dir.glob(pattern))) |
|
for subdir in ["model", "model_weighted"]: |
|
subdir_path = trained_dir / subdir |
|
if subdir_path.exists(): |
|
model_files.extend(list(subdir_path.glob(pattern))) |
|
logger.info(f"π Found {len(model_files)} cached model files") |
|
elif training_fallback_marker.exists(): |
|
logger.warning("β οΈ Training fallback marker found - tokenlearn failed previously") |
|
logger.info("π Proceeding with fallback to base model (simple distillation)") |
|
|
|
else: |
|
if training_complete_marker.exists(): |
|
logger.warning("β οΈ Training marker exists but model files are missing - re-running training") |
|
training_complete_marker.unlink() |
|
logger.info("π No valid training checkpoint found - starting training...") |
|
|
|
try: |
|
|
|
from distiller.tokenlearn.train import train_model |
|
from distiller.tokenlearn.utils import collect_means_and_texts |
|
|
|
|
|
logger.info("π Attempting IMPROVED tokenlearn training with optimized parameters...") |
|
logger.info("π Using smaller vocabulary and conservative PCA to prevent overfitting") |
|
|
|
|
|
paths = sorted(features_dir.glob("*.json")) |
|
train_txt, train_vec = collect_means_and_texts(paths) |
|
|
|
logger.info(f"π Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training") |
|
|
|
try: |
|
|
|
trained_model = train_model( |
|
model_name=str(teacher_model_name), |
|
train_txt=train_txt, |
|
train_vec=train_vec, |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
vocab_size=25000, |
|
pca_dims=256, |
|
) |
|
|
|
|
|
trained_model.save_pretrained(str(trained_dir)) |
|
logger.info("β
IMPROVED tokenlearn training completed successfully") |
|
training_complete_marker.touch() |
|
logger.info(f"πΎ Created improved training checkpoint: {training_complete_marker}") |
|
|
|
except Exception as e: |
|
logger.warning(f"β οΈ Improved training failed: {e}") |
|
logger.info("π Falling back to CONSERVATIVE tokenlearn training...") |
|
|
|
|
|
try: |
|
trained_model = train_model( |
|
model_name=str(teacher_model_name), |
|
train_txt=train_txt, |
|
train_vec=train_vec, |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
vocab_size=15000, |
|
pca_dims=128, |
|
) |
|
|
|
|
|
trained_model.save_pretrained(str(trained_dir)) |
|
logger.info("β
Conservative tokenlearn training completed successfully") |
|
training_complete_marker.touch() |
|
logger.info(f"πΎ Created conservative training checkpoint: {training_complete_marker}") |
|
|
|
except Exception as e2: |
|
logger.exception("β Conservative tokenlearn training also failed") |
|
logger.exception("π₯ All training approaches failed - check output above for details") |
|
|
|
|
|
training_fallback_marker = trained_dir / ".training_fallback" |
|
training_fallback_marker.touch() |
|
|
|
logger.exception("π₯ Tokenlearn training failed completely") |
|
msg = f"All tokenlearn training approaches failed: {e2}" |
|
raise RuntimeError(msg) from e2 |
|
|
|
except Exception as e: |
|
logger.warning("π₯ All tokenlearn training approaches failed") |
|
logger.exception("π₯ All training approaches failed completely - cannot proceed") |
|
msg = f"All training approaches failed: {e}" |
|
raise RuntimeError(msg) from e |
|
|
|
|
|
logger.info("π¦ Step 4: Loading trained model and applying post-training re-regularization...") |
|
|
|
|
|
training_fallback_marker = trained_dir / ".training_fallback" |
|
if training_fallback_marker.exists(): |
|
logger.error("β Tokenlearn training failed previously - cannot return trained model") |
|
logger.error("π₯ Training was requested but failed - this would be misleading to return base model") |
|
msg = "Tokenlearn training failed - cannot proceed with training pipeline" |
|
raise RuntimeError(msg) |
|
|
|
try: |
|
from distiller.model2vec.model import StaticModel |
|
|
|
|
|
trained_model_path = trained_dir / "model" |
|
if not trained_model_path.exists(): |
|
|
|
possible_paths = [ |
|
trained_dir / "model_weighted", |
|
trained_dir, |
|
] |
|
|
|
for path in possible_paths: |
|
if path.exists() and any(path.glob("*.json")): |
|
trained_model_path = path |
|
break |
|
else: |
|
logger.error(f"β Could not find trained model in {trained_dir}") |
|
logger.error("π₯ Training was requested but no trained model found - cannot proceed") |
|
msg = f"Trained model not found in {trained_dir} - training pipeline failed" |
|
raise RuntimeError(msg) |
|
|
|
|
|
logger.info("π Loading model from tokenlearn training...") |
|
trained_model = StaticModel.from_pretrained(str(trained_model_path)) |
|
|
|
|
|
logger.info("β
Tokenlearn training pipeline completed successfully") |
|
return trained_model |
|
|
|
except ValueError as e: |
|
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e): |
|
logger.exception("π₯ Token-vector mismatch in tokenlearn training") |
|
logger.exception("Error details") |
|
logger.exception("π§ This is a known issue with tokenlearn/Model2Vec integration") |
|
logger.exception("π₯ Training was requested but failed due to token-vector mismatch") |
|
msg = f"Tokenlearn training failed due to token-vector mismatch: {e}" |
|
raise RuntimeError(msg) from e |
|
logger.exception("π₯ Failed to load tokenlearn trained model") |
|
msg = f"Failed to load tokenlearn trained model: {e}" |
|
raise RuntimeError(msg) from e |
|
except Exception as e: |
|
logger.exception("π₯ Failed to load tokenlearn trained model") |
|
logger.exception("π₯ Cannot load trained model - training failed") |
|
msg = f"Failed to load tokenlearn trained model: {e}" |
|
raise RuntimeError(msg) from e |
|
|
|
|
|
def distill_single_teacher( |
|
teacher_model: str, |
|
enable_training: bool = False, |
|
use_beam_utilities: bool = False, |
|
pca_dims: int | None = None, |
|
) -> dict[str, Any]: |
|
""" |
|
Distill a single teacher model with optional training. |
|
|
|
Args: |
|
teacher_model: Name of teacher model |
|
enable_training: Whether to enable advanced training |
|
use_beam_utilities: Whether to use Beam utilities |
|
pca_dims: PCA dimensions |
|
|
|
Returns: |
|
Dictionary with distillation results |
|
""" |
|
teacher_name = teacher_model.split("/")[-1].replace("-", "_") |
|
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" |
|
|
|
|
|
final_model_name = f"code_model2vec_{teacher_name}" |
|
if enable_training: |
|
final_model_name += "_fine_tuned" |
|
final_dir = Path(LOCAL_FINAL_DIR) / final_model_name |
|
|
|
logger.info(f"\n{'=' * 60}") |
|
logger.info(f"π Processing teacher model: {teacher_model}") |
|
logger.info(f"π Teacher name: {teacher_name}") |
|
logger.info(f"π Training enabled: {enable_training}") |
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
is_compatible, warning_msg = check_model_compatibility(teacher_model) |
|
if not is_compatible: |
|
logger.warning(f"β οΈ Known compatibility issue: {warning_msg}") |
|
logger.info("π§ Attempting distillation anyway, but may fail...") |
|
|
|
|
|
workaround_type = try_model_workarounds(teacher_model) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
checkpoint_mgr = None |
|
if use_beam_utilities: |
|
try: |
|
_, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path) |
|
except Exception as e: |
|
logger.warning(f"Failed to initialize Beam utilities: {e}") |
|
|
|
try: |
|
|
|
existing_final = check_existing_final_model(teacher_name, enable_training) |
|
if existing_final: |
|
logger.info(f"β
Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}") |
|
total_time = time.time() - start_time |
|
return { |
|
"teacher_model": teacher_model, |
|
"teacher_name": teacher_name, |
|
"status": "skipped_existing_final", |
|
"final_path": existing_final, |
|
"distillation_time": total_time, |
|
} |
|
|
|
|
|
if use_beam_utilities and checkpoint_mgr: |
|
logger.info(f"π Syncing existing checkpoints for {teacher_name}...") |
|
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints) |
|
if enable_training: |
|
sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints) |
|
|
|
|
|
existing_base = check_existing_base_model(teacher_name) |
|
base_model = None |
|
|
|
if existing_base: |
|
logger.info(f"β
Found existing base model: {teacher_name}") |
|
if enable_training: |
|
|
|
from distiller.model2vec.model import StaticModel |
|
|
|
base_model = StaticModel.from_pretrained(existing_base) |
|
elif use_beam_utilities: |
|
synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities) |
|
if synced: |
|
existing_base = str(base_dir) |
|
if enable_training: |
|
from distiller.model2vec.model import StaticModel |
|
|
|
base_model = StaticModel.from_pretrained(existing_base) |
|
|
|
if not existing_base: |
|
|
|
logger.info(f"π Creating base model for {teacher_name}") |
|
|
|
|
|
workaround_type = try_model_workarounds(teacher_model) |
|
|
|
if workaround_type == "salesforce": |
|
base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims) |
|
elif workaround_type == "baai": |
|
base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims) |
|
else: |
|
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims) |
|
|
|
if base_model is None: |
|
total_time = time.time() - start_time |
|
return { |
|
"teacher_model": teacher_model, |
|
"teacher_name": teacher_name, |
|
"status": "failed_base_distillation", |
|
"error": "Simple distillation failed", |
|
"distillation_time": total_time, |
|
} |
|
|
|
|
|
if use_beam_utilities: |
|
sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities) |
|
if checkpoint_mgr: |
|
sync_checkpoints_to_beam( |
|
VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints |
|
) |
|
|
|
existing_base = str(base_dir) |
|
|
|
|
|
if enable_training and base_model is not None: |
|
|
|
logger.info(f"π§ͺ Starting tokenlearn training for {teacher_name}") |
|
|
|
try: |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
teacher_st_model = load_model_with_flash_attention(teacher_model, device) |
|
|
|
|
|
final_model = tokenlearn_training( |
|
base_model, |
|
teacher_st_model, |
|
checkpoint_mgr, |
|
) |
|
|
|
|
|
final_dir.mkdir(parents=True, exist_ok=True) |
|
final_model.save_pretrained(str(final_dir)) |
|
|
|
|
|
if use_beam_utilities: |
|
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities) |
|
if checkpoint_mgr: |
|
sync_checkpoints_to_beam( |
|
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints |
|
) |
|
|
|
del teacher_st_model |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
except RuntimeError as e: |
|
|
|
logger.exception(f"β Training failed for {teacher_name}") |
|
|
|
|
|
if "teacher_st_model" in locals(): |
|
del teacher_st_model |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
total_time = time.time() - start_time |
|
return { |
|
"teacher_model": teacher_model, |
|
"teacher_name": teacher_name, |
|
"status": "failed_training", |
|
"error": f"Training failed: {e!s}", |
|
"base_path": existing_base, |
|
"distillation_time": total_time, |
|
} |
|
|
|
else: |
|
|
|
logger.info(f"π Copying base to final for {teacher_name}") |
|
if not copy_base_to_final(teacher_name, enable_training): |
|
total_time = time.time() - start_time |
|
return { |
|
"teacher_model": teacher_model, |
|
"teacher_name": teacher_name, |
|
"status": "failed_copy_to_final", |
|
"error": "Failed to copy base to final", |
|
"distillation_time": total_time, |
|
} |
|
|
|
total_time = time.time() - start_time |
|
return { |
|
"teacher_model": teacher_model, |
|
"teacher_name": teacher_name, |
|
"status": "success", |
|
"enable_training": enable_training, |
|
"base_path": existing_base, |
|
"final_path": str(final_dir), |
|
"distillation_time": total_time, |
|
} |
|
|
|
except Exception as e: |
|
logger.exception(f"β Failed to process {teacher_model}") |
|
total_time = time.time() - start_time |
|
return { |
|
"teacher_model": teacher_model, |
|
"teacher_name": teacher_name, |
|
"status": "failed", |
|
"error": str(e), |
|
"distillation_time": total_time, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_local_distillation( |
|
teacher_models: list[str] | None = None, |
|
enable_training: bool = False, |
|
pca_dims: int | None = None, |
|
clear_cache: bool = False, |
|
) -> dict[str, Any]: |
|
"""Run distillation locally.""" |
|
logger.info("π₯οΈ Running distillation locally") |
|
|
|
if teacher_models is None: |
|
teacher_models = DEFAULT_TEACHER_MODELS |
|
|
|
results = {} |
|
successful_models = [] |
|
|
|
logger.info("π Starting distillation workflow") |
|
logger.info(f"π Processing {len(teacher_models)} teacher models") |
|
logger.info(f"π Training enabled: {enable_training}") |
|
|
|
|
|
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS |
|
|
|
logger.info(f"π Teacher models to process: {len(models_to_distill)}") |
|
for i, model in enumerate(models_to_distill, 1): |
|
logger.info(f" {i}. {model}") |
|
|
|
|
|
if clear_cache: |
|
logger.info("π§Ή Clearing cache for known problematic models...") |
|
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] |
|
for model in problematic_models: |
|
if model in models_to_distill: |
|
clear_model_cache(model) |
|
|
|
|
|
|
|
|
|
for teacher_model in models_to_distill: |
|
result = distill_single_teacher( |
|
teacher_model=teacher_model, |
|
enable_training=enable_training, |
|
use_beam_utilities=False, |
|
pca_dims=pca_dims, |
|
) |
|
|
|
teacher_name = result["teacher_name"] |
|
results[teacher_name] = result |
|
|
|
if result["status"] == "success" or result["status"].startswith("skipped"): |
|
successful_models.append(teacher_name) |
|
elif result["status"] == "failed_training": |
|
|
|
logger.warning(f"β οΈ Training failed for {teacher_name}, but base distillation may have succeeded") |
|
|
|
|
|
logger.info("\nπ DISTILLATION WORKFLOW COMPLETE!") |
|
logger.info(f"π Successful models: {len(successful_models)}") |
|
logger.info(f"π Training mode: {'Enabled' if enable_training else 'Basic distillation only'}") |
|
|
|
for model_name in successful_models: |
|
result = results[model_name] |
|
logger.info(f"β
{model_name}: {result['teacher_model']}") |
|
|
|
|
|
results_summary = { |
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
|
"enable_training": enable_training, |
|
"successful_models": successful_models, |
|
"all_results": results, |
|
"total_successful": len(successful_models), |
|
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
|
} |
|
|
|
|
|
results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json" |
|
results_file.parent.mkdir(parents=True, exist_ok=True) |
|
with results_file.open("w") as f: |
|
json.dump(results_summary, f, indent=2) |
|
|
|
logger.info(f"π Results summary saved to: {results_file}") |
|
|
|
return results_summary |
|
|
|
|
|
def _beam_distill_internal( |
|
teacher_models: list[str] | None = None, |
|
enable_training: bool = False, |
|
pca_dims: int | None = None, |
|
clear_cache: bool = False, |
|
) -> dict[str, Any]: |
|
"""Shared internal implementation for beam distillation.""" |
|
if teacher_models is None: |
|
teacher_models = DEFAULT_TEACHER_MODELS |
|
|
|
|
|
if clear_cache: |
|
logger.info("π§Ή Clearing cache for known problematic models...") |
|
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] |
|
for model in problematic_models: |
|
if model in teacher_models: |
|
clear_model_cache(model) |
|
|
|
results = {} |
|
successful_models = [] |
|
|
|
logger.info("π Starting Beam distillation workflow") |
|
logger.info(f"π Processing {len(teacher_models)} teacher models") |
|
logger.info(f"π Training enabled: {enable_training}") |
|
|
|
|
|
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS |
|
|
|
logger.info(f"π Teacher models to process: {len(models_to_distill)}") |
|
for i, model in enumerate(models_to_distill, 1): |
|
logger.info(f" {i}. {model}") |
|
|
|
for teacher_model in models_to_distill: |
|
result = distill_single_teacher( |
|
teacher_model=teacher_model, |
|
enable_training=enable_training, |
|
use_beam_utilities=True, |
|
pca_dims=pca_dims, |
|
) |
|
|
|
teacher_name = result["teacher_name"] |
|
results[teacher_name] = result |
|
|
|
if result["status"] == "success" or result["status"].startswith("skipped"): |
|
successful_models.append(teacher_name) |
|
elif result["status"] == "failed_training": |
|
|
|
logger.warning(f"β οΈ Training failed for {teacher_name}, but base distillation may have succeeded") |
|
|
|
|
|
logger.info("\nπ BEAM DISTILLATION WORKFLOW COMPLETE!") |
|
logger.info(f"π Successful models: {len(successful_models)}") |
|
|
|
|
|
volume_path = Path(VOLUME_CONFIG.mount_path) |
|
results_summary = { |
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
|
"enable_training": enable_training, |
|
"successful_models": successful_models, |
|
"all_results": results, |
|
"total_successful": len(successful_models), |
|
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
|
} |
|
|
|
results_file = volume_path / "distillation_results.json" |
|
with results_file.open("w") as f: |
|
json.dump(results_summary, f, indent=2) |
|
|
|
logger.info(f"π Beam results saved to: {results_file}") |
|
|
|
return results_summary |
|
|
|
|
|
@function(**get_training_function_kwargs()) |
|
def _beam_train_models( |
|
teacher_models: list[str] | None = None, |
|
enable_training: bool = True, |
|
pca_dims: int | None = None, |
|
clear_cache: bool = False, |
|
) -> dict[str, Any]: |
|
"""Beam function for training (distillation + tokenlearn).""" |
|
logger.info("βοΈ Running training on Beam") |
|
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache) |
|
|
|
|
|
@function(**get_distillation_function_kwargs()) |
|
def _beam_distill_models( |
|
teacher_models: list[str] | None = None, |
|
enable_training: bool = False, |
|
pca_dims: int | None = None, |
|
clear_cache: bool = False, |
|
) -> dict[str, Any]: |
|
"""Beam function for basic distillation only.""" |
|
logger.info("βοΈ Running distillation on Beam") |
|
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache) |
|
|
|
|
|
def run_beam_distillation( |
|
teacher_models: list[str] | None = None, |
|
enable_training: bool = False, |
|
pca_dims: int | None = None, |
|
clear_cache: bool = False, |
|
) -> dict[str, Any]: |
|
"""Run distillation on Beam and sync results.""" |
|
logger.info("βοΈ Running distillation on Beam with local sync") |
|
|
|
try: |
|
|
|
beam_function = _beam_train_models if enable_training else _beam_distill_models |
|
|
|
|
|
results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache) |
|
|
|
|
|
if not results: |
|
logger.error("β Beam execution failed or returned no results") |
|
return { |
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
|
"enable_training": enable_training, |
|
"successful_models": [], |
|
"all_results": {}, |
|
"total_successful": 0, |
|
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
|
"error": "Beam execution failed", |
|
} |
|
|
|
|
|
if results.get("successful_models"): |
|
logger.info("π₯ Syncing models from Beam to local directories...") |
|
|
|
for teacher_name in results["successful_models"]: |
|
|
|
base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" |
|
sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True) |
|
|
|
|
|
if enable_training: |
|
final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}" |
|
sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True) |
|
else: |
|
|
|
copy_base_to_final(teacher_name, enable_training) |
|
|
|
logger.info("β
All models synced from Beam") |
|
|
|
return results |
|
|
|
except Exception as e: |
|
logger.exception("β Beam distillation failed with exception") |
|
return { |
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
|
"enable_training": enable_training, |
|
"successful_models": [], |
|
"all_results": {}, |
|
"total_successful": 0, |
|
"total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
|
"error": str(e), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main( |
|
use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False, |
|
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False, |
|
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None, |
|
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None, |
|
clear_cache: Annotated[ |
|
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation") |
|
] = False, |
|
clear_checkpoints: Annotated[ |
|
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training") |
|
] = False, |
|
use_optimized_dataset: Annotated[ |
|
bool, |
|
typer.Option( |
|
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset" |
|
), |
|
] = False, |
|
dataset_path: Annotated[ |
|
str | None, |
|
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"), |
|
] = None, |
|
) -> None: |
|
"""Unified distillation command with optional training.""" |
|
logger.info("π Starting unified Model2Vec distillation workflow") |
|
|
|
|
|
distillation_config.use_optimized_dataset = use_optimized_dataset |
|
distillation_config.custom_dataset_path = dataset_path |
|
|
|
if use_optimized_dataset and train: |
|
dataset_source = dataset_path or "code_model2vec/dataset" |
|
logger.info(f"π― Using optimized dataset from: {dataset_source}") |
|
elif train: |
|
logger.info("π― Using C4 dataset for training (following POTION approach)") |
|
|
|
logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}") |
|
logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}") |
|
|
|
|
|
models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS |
|
|
|
logger.info(f"π Teacher models to process: {len(models_to_distill)}") |
|
for i, model in enumerate(models_to_distill, 1): |
|
logger.info(f" {i}. {model}") |
|
|
|
|
|
if clear_cache: |
|
logger.info("π§Ή Clearing cache for known problematic models...") |
|
problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] |
|
for model in problematic_models: |
|
if model in models_to_distill: |
|
clear_model_cache(model) |
|
|
|
|
|
if clear_checkpoints and train: |
|
logger.info("π§Ή Clearing tokenlearn checkpoints to force fresh featurization and training...") |
|
for teacher_model in models_to_distill: |
|
teacher_model.split("/")[-1].replace("-", "_") |
|
|
|
|
|
teacher_slug = teacher_model.replace("/", "_").replace("-", "_") |
|
persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug |
|
|
|
features_dir = persistent_tokenlearn_dir / "features" |
|
trained_dir = persistent_tokenlearn_dir / "trained_model" |
|
|
|
|
|
if features_dir.exists() or trained_dir.exists(): |
|
clear_tokenlearn_checkpoints(features_dir, trained_dir) |
|
logger.info(f"ποΈ Cleared persistent tokenlearn checkpoints for {teacher_model}") |
|
else: |
|
logger.info(f"βΉοΈ No tokenlearn checkpoints found for {teacher_model}") |
|
elif clear_checkpoints and not train: |
|
logger.warning("β οΈ --clear-checkpoints flag is only relevant when training is enabled (--train)") |
|
|
|
|
|
if use_beam: |
|
results = run_beam_distillation( |
|
teacher_models=models_to_distill, |
|
enable_training=train, |
|
pca_dims=pca_dims, |
|
clear_cache=clear_cache, |
|
) |
|
else: |
|
results = run_local_distillation( |
|
teacher_models=models_to_distill, |
|
enable_training=train, |
|
pca_dims=pca_dims, |
|
clear_cache=clear_cache, |
|
) |
|
|
|
|
|
if not results or not isinstance(results, dict): |
|
logger.error("β Distillation workflow failed - no valid results returned") |
|
results = { |
|
"total_successful": 0, |
|
"total_attempted": len(models_to_distill), |
|
"error": "Workflow failed", |
|
} |
|
|
|
|
|
successful_count = results.get("total_successful", 0) |
|
total_attempted = results.get("total_attempted", 0) |
|
|
|
logger.info("\nπ UNIFIED DISTILLATION WORKFLOW COMPLETED!") |
|
logger.info(f"π Successfully processed: {successful_count}/{total_attempted} models") |
|
logger.info(f"π Base models saved to: {LOCAL_BASE_DIR}") |
|
logger.info(f"π Final models saved to: {LOCAL_FINAL_DIR}") |
|
|
|
if train: |
|
logger.info("π Advanced training was enabled - models include CodeSearchNet specialization") |
|
else: |
|
logger.info("π Basic distillation only - use --train flag to enable advanced training") |
|
|
|
|
|
def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]: |
|
""" |
|
Check if a model has known compatibility issues with Model2Vec. |
|
|
|
Returns: |
|
Tuple of (is_compatible, warning_message) |
|
""" |
|
known_incompatible = { |
|
"BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute", |
|
"jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies", |
|
"Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors", |
|
} |
|
|
|
if teacher_model in known_incompatible: |
|
return False, known_incompatible[teacher_model] |
|
|
|
|
|
if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower(): |
|
return False, "BGE models with Qwen2 tokenizers may have compatibility issues" |
|
|
|
if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower(): |
|
return False, "Jina embeddings v3 models may have missing dependencies" |
|
|
|
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower(): |
|
return False, "Salesforce SFR embedding models may have device placement issues" |
|
|
|
return True, None |
|
|
|
|
|
def clear_model_cache(model_name: str) -> bool: |
|
"""Clear HuggingFace cache for a specific model.""" |
|
try: |
|
import shutil |
|
from pathlib import Path |
|
|
|
|
|
cache_dir = Path.home() / ".cache" / "huggingface" |
|
|
|
|
|
model_slug = model_name.replace("/", "--") |
|
|
|
|
|
transformers_cache = cache_dir / "transformers" / model_slug |
|
if transformers_cache.exists(): |
|
shutil.rmtree(transformers_cache) |
|
logger.info(f"ποΈ Cleared transformers cache for {model_name}") |
|
|
|
|
|
hub_cache = cache_dir / "hub" / f"models--{model_slug}" |
|
if hub_cache.exists(): |
|
shutil.rmtree(hub_cache) |
|
logger.info(f"ποΈ Cleared hub cache for {model_name}") |
|
|
|
|
|
modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0] |
|
if modules_cache.exists(): |
|
shutil.rmtree(modules_cache) |
|
logger.info(f"ποΈ Cleared modules cache for {model_name}") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to clear cache for {model_name}: {e}") |
|
return False |
|
|
|
|
|
def try_model_workarounds(teacher_model: str) -> str | None: |
|
""" |
|
Try specific workarounds for problematic models. |
|
|
|
Returns: |
|
The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available |
|
""" |
|
if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower(): |
|
logger.info("π§ Salesforce SFR model detected - will use specialized distillation") |
|
return "salesforce" |
|
|
|
if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()): |
|
logger.info("π§ BAAI BGE model detected - will use specialized distillation") |
|
return "baai" |
|
|
|
return None |
|
|
|
|
|
def salesforce_model_distillation( |
|
teacher_model: str, |
|
output_dir: str, |
|
pca_dims: int | None = None, |
|
) -> Any: |
|
"""Special distillation function for Salesforce SFR models that handles device placement issues.""" |
|
if pca_dims is None: |
|
pca_dims = int(distillation_config.optimal_pca_dims) |
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"π Salesforce-specific distillation: {teacher_model} β {output_dir}") |
|
logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
logger.info("π§ Loading model with enhanced device settings...") |
|
|
|
|
|
try: |
|
logger.info("π Attempting with to_empty() method...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True) |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
teacher_model, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
device_map="meta", |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
|
|
model = model.to_empty(device=device) |
|
else: |
|
device = torch.device("cpu") |
|
model = model.to_empty(device=device) |
|
|
|
|
|
model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32) |
|
|
|
logger.info("β
Successfully loaded with to_empty() method") |
|
|
|
except Exception as e: |
|
logger.warning(f"to_empty() method failed: {e}") |
|
|
|
|
|
logger.info("π Falling back to SentenceTransformer method...") |
|
sentence_model = load_model_with_flash_attention( |
|
teacher_model, |
|
device="cpu", |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
sentence_model = sentence_model.to("cuda") |
|
|
|
|
|
model = sentence_model[0].auto_model |
|
tokenizer = sentence_model.tokenizer |
|
|
|
logger.info("β
Successfully loaded with SentenceTransformer method") |
|
|
|
|
|
from distiller.model2vec.distill.distillation import distill_from_model |
|
|
|
distilled_model = distill_from_model( |
|
model=model, |
|
tokenizer=tokenizer, |
|
pca_dims=int(pca_dims), |
|
apply_zipf=bool(distillation_config.apply_zipf), |
|
sif_coefficient=float(distillation_config.sif_coefficient), |
|
) |
|
|
|
logger.info("β
Core distillation completed successfully") |
|
|
|
|
|
distilled_model.save_pretrained(str(output_path)) |
|
logger.info(f"πΎ Model saved to {output_path}") |
|
|
|
|
|
logger.info(f"Model type: {type(distilled_model)}") |
|
if hasattr(distilled_model, "embedding"): |
|
logger.info(f"Embedding shape: {distilled_model.embedding.shape}") |
|
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}") |
|
|
|
total_time = time.time() - start_time |
|
logger.info(f"π Salesforce distillation completed in {total_time:.2f} seconds") |
|
|
|
|
|
if "sentence_model" in locals(): |
|
del sentence_model |
|
del model |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return distilled_model |
|
|
|
except Exception: |
|
logger.exception(f"β Salesforce-specific distillation failed for {teacher_model}") |
|
return None |
|
|
|
|
|
def baai_bge_model_distillation( |
|
teacher_model: str, |
|
output_dir: str, |
|
pca_dims: int | None = None, |
|
) -> Any: |
|
"""Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues.""" |
|
if pca_dims is None: |
|
pca_dims = int(distillation_config.optimal_pca_dims) |
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"π BAAI BGE-specific distillation: {teacher_model} β {output_dir}") |
|
logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
logger.info("π§ Loading BAAI model with tokenizer workaround...") |
|
|
|
|
|
success = False |
|
|
|
|
|
try: |
|
logger.info("π Attempting with SentenceTransformer wrapper...") |
|
sentence_model = load_model_with_flash_attention(teacher_model) |
|
|
|
|
|
model = sentence_model[0].auto_model |
|
tokenizer = sentence_model.tokenizer |
|
|
|
|
|
test_encoding = tokenizer.encode("test", return_tensors="pt") |
|
logger.info("β
SentenceTransformer method successful") |
|
success = True |
|
|
|
except Exception as e: |
|
logger.warning(f"SentenceTransformer method failed: {e}") |
|
|
|
|
|
try: |
|
logger.info("π Attempting with tokenizer replacement...") |
|
from transformers import BertTokenizerFast |
|
|
|
|
|
model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True) |
|
|
|
|
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True) |
|
except Exception: |
|
|
|
logger.info("π Falling back to BERT tokenizer...") |
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") |
|
|
|
logger.info("β
Tokenizer replacement method successful") |
|
success = True |
|
|
|
except Exception as e2: |
|
logger.warning(f"Tokenizer replacement method failed: {e2}") |
|
|
|
if not success: |
|
logger.error("β All BAAI model loading methods failed") |
|
return None |
|
|
|
|
|
from distiller.model2vec.distill.distillation import distill_from_model |
|
|
|
distilled_model = distill_from_model( |
|
model=model, |
|
tokenizer=tokenizer, |
|
pca_dims=int(pca_dims), |
|
apply_zipf=bool(distillation_config.apply_zipf), |
|
sif_coefficient=float(distillation_config.sif_coefficient), |
|
) |
|
|
|
logger.info("β
Core distillation completed successfully") |
|
|
|
|
|
distilled_model.save_pretrained(str(output_path)) |
|
logger.info(f"πΎ Model saved to {output_path}") |
|
|
|
|
|
logger.info(f"Model type: {type(distilled_model)}") |
|
if hasattr(distilled_model, "embedding"): |
|
logger.info(f"Embedding shape: {distilled_model.embedding.shape}") |
|
logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}") |
|
|
|
total_time = time.time() - start_time |
|
logger.info(f"π BAAI BGE distillation completed in {total_time:.2f} seconds") |
|
|
|
|
|
if "sentence_model" in locals(): |
|
del sentence_model |
|
del model |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return distilled_model |
|
|
|
except Exception: |
|
logger.exception(f"β BAAI BGE-specific distillation failed for {teacher_model}") |
|
return None |
|
|
|
|
|
def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None: |
|
"""Clear tokenlearn checkpoint markers to force re-execution of steps.""" |
|
featurization_marker = features_dir / ".featurization_complete" |
|
training_marker = trained_dir / ".training_complete" |
|
|
|
if featurization_marker.exists(): |
|
featurization_marker.unlink() |
|
logger.info(f"ποΈ Cleared featurization checkpoint: {featurization_marker}") |
|
|
|
if training_marker.exists(): |
|
training_marker.unlink() |
|
logger.info(f"ποΈ Cleared training checkpoint: {training_marker}") |
|
|
|
|
|
def verify_featurization_output(features_dir: Path) -> bool: |
|
"""Verify that featurization output files actually exist and are valid.""" |
|
if not features_dir.exists(): |
|
return False |
|
|
|
|
|
|
|
|
|
return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"]) |
|
|
|
|
|
def verify_training_output(trained_dir: Path) -> bool: |
|
"""Verify that training output files actually exist and are valid.""" |
|
if not trained_dir.exists(): |
|
return False |
|
|
|
|
|
model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"] |
|
for model_file in model_files: |
|
if (trained_dir / model_file).exists(): |
|
return True |
|
|
|
|
|
for subdir in ["model", "model_weighted"]: |
|
subdir_path = trained_dir / subdir |
|
if subdir_path.exists(): |
|
for model_file in model_files: |
|
if (subdir_path / model_file).exists(): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]: |
|
""" |
|
Prepare dataset for tokenlearn featurization. |
|
|
|
Returns: |
|
Tuple of (dataset_path, dataset_name, text_key) for tokenlearn |
|
""" |
|
if distillation_config.use_optimized_dataset: |
|
return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir) |
|
return _prepare_original_dataset_for_tokenlearn() |
|
|
|
|
|
def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]: |
|
"""Prepare custom optimized dataset for tokenlearn featurization.""" |
|
logger.info("π― Preparing custom optimized dataset for tokenlearn...") |
|
|
|
|
|
from .dataset import create_optimized_dataset, load_optimized_dataset |
|
|
|
|
|
custom_dataset_dir = ( |
|
Path(distillation_config.custom_dataset_path) |
|
if distillation_config.custom_dataset_path |
|
else Path("code_model2vec/dataset") |
|
) |
|
tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset" |
|
|
|
|
|
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists(): |
|
logger.info("π Custom dataset not found - creating optimized dataset...") |
|
create_optimized_dataset( |
|
max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, |
|
output_dir=custom_dataset_dir, |
|
create_multiple_formats=False, |
|
) |
|
|
|
|
|
logger.info(f"π Loading custom dataset from {custom_dataset_dir}") |
|
train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train") |
|
|
|
|
|
tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
train_json_path = tokenlearn_dataset_dir / "train.json" |
|
|
|
|
|
import json |
|
|
|
with train_json_path.open("w") as f: |
|
for text in train_df["text"]: |
|
json.dump({"text": text}, f) |
|
f.write("\n") |
|
|
|
logger.info(f"β
Prepared custom dataset with {len(train_df)} samples for tokenlearn") |
|
logger.info(f"πΎ Saved JSON dataset to {train_json_path}") |
|
|
|
|
|
return str(train_json_path), None, "text" |
|
|
|
|
|
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]: |
|
"""Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach).""" |
|
logger.info("π Using C4 dataset for tokenlearn (following POTION approach)...") |
|
return ( |
|
str(distillation_config.tokenlearn_dataset), |
|
str(distillation_config.tokenlearn_dataset_name), |
|
str(distillation_config.tokenlearn_text_key), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
typer.run(main) |
|
|