Sarthak
feat: overhaul distiller package with unified CLI, enhanced evaluation, and modular structure
454e47c
""" | |
Common utilities for the distiller package. | |
This module provides shared functionality used across multiple components | |
including model discovery, result management, and initialization helpers. | |
""" | |
import json | |
import logging | |
from pathlib import Path | |
from types import TracebackType | |
from typing import Any | |
from .beam_utils import ( | |
BeamCheckpointManager, | |
BeamEvaluationManager, | |
BeamModelManager, | |
BeamVolumeManager, | |
create_beam_utilities, | |
) | |
from .config import VolumeConfig, get_safe_model_name, get_volume_config, setup_logging | |
logger = logging.getLogger(__name__) | |
# ============================================================================= | |
# BEAM UTILITIES MANAGEMENT | |
# ============================================================================= | |
class BeamContext: | |
"""Context manager for Beam utilities with consistent initialization.""" | |
def __init__(self, workflow: str, volume_config: VolumeConfig | None = None) -> None: | |
""" | |
Initialize Beam context. | |
Args: | |
workflow: Workflow type (distill, evaluate, benchmark, etc.) | |
volume_config: Optional custom volume config, otherwise inferred from workflow | |
""" | |
self.workflow = workflow | |
self.volume_config = volume_config or get_volume_config() | |
self.volume_manager: BeamVolumeManager | None = None | |
self.checkpoint_manager: BeamCheckpointManager | None = None | |
self.model_manager: BeamModelManager | None = None | |
self.evaluation_manager: BeamEvaluationManager | None = None | |
def __enter__(self) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]: | |
"""Enter context and initialize utilities.""" | |
logger.info(f"🚀 Initializing Beam utilities for {self.workflow}") | |
logger.info(f"📁 Volume: {self.volume_config.name} at {self.volume_config.mount_path}") | |
self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager = ( | |
create_beam_utilities(self.volume_config.name, self.volume_config.mount_path) | |
) | |
return self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager | |
def __exit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> None: | |
"""Exit context with cleanup if needed.""" | |
if exc_type: | |
logger.error(f"❌ Error in Beam context for {self.workflow}: {exc_val}") | |
else: | |
logger.info(f"✅ Beam context for {self.workflow} completed successfully") | |
def get_beam_utilities() -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]: | |
""" | |
Get Beam utilities for a specific workflow. | |
Returns: | |
Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager) | |
""" | |
volume_config = get_volume_config() | |
return create_beam_utilities(volume_config.name, volume_config.mount_path) | |
# ============================================================================= | |
# MODEL DISCOVERY | |
# ============================================================================= | |
def discover_simplified_models(base_path: str | Path = ".") -> list[str]: | |
""" | |
Discover simplified distillation models in the specified directory. | |
Args: | |
base_path: Base path to search for models | |
Returns: | |
List of model paths sorted alphabetically | |
""" | |
base = Path(base_path) | |
# Look for models in common locations | |
search_patterns = [ | |
"code_model2vec/final/**/", | |
"final/**/", | |
"code_model2vec_*/", | |
"*/config.json", | |
"*.safetensors", | |
] | |
discovered_models = [] | |
for pattern in search_patterns: | |
matches = list(base.glob(pattern)) | |
for match in matches: | |
if match.is_dir(): | |
# Check if it's a valid model directory | |
if (match / "config.json").exists() or (match / "model.safetensors").exists(): | |
discovered_models.append(str(match)) | |
elif match.name == "config.json": | |
# Add parent directory if config.json found | |
discovered_models.append(str(match.parent)) | |
# Remove duplicates and sort | |
unique_models = sorted(set(discovered_models)) | |
logger.info(f"🔍 Discovered {len(unique_models)} models in {base_path}") | |
for model in unique_models: | |
logger.info(f" 📁 {model}") | |
return unique_models | |
def validate_model_path(model_path: str | Path, volume_manager: BeamVolumeManager | None = None) -> str | None: | |
""" | |
Validate and resolve model path, checking local filesystem and Beam volumes. | |
Args: | |
model_path: Path to model (can be local path or HuggingFace model name) | |
volume_manager: Optional volume manager for Beam volume checks | |
Returns: | |
Resolved model path or None if not found | |
""" | |
path = Path(model_path) | |
# Check if it's a HuggingFace model name | |
if "/" in str(model_path) and not path.exists() and not str(model_path).startswith("/"): | |
logger.info(f"📥 Treating as HuggingFace model: {model_path}") | |
return str(model_path) | |
# Check local filesystem | |
if path.exists(): | |
logger.info(f"✅ Found local model: {model_path}") | |
return str(path) | |
# Check Beam volume if available | |
if volume_manager: | |
volume_path = Path(volume_manager.mount_path) / path.name | |
if volume_path.exists(): | |
logger.info(f"✅ Found model in Beam volume: {volume_path}") | |
return str(volume_path) | |
# Check volume root | |
root_path = Path(volume_manager.mount_path) | |
if (root_path / "config.json").exists(): | |
logger.info(f"✅ Found model in Beam volume root: {root_path}") | |
return str(root_path) | |
logger.warning(f"⚠️ Model not found: {model_path}") | |
return None | |
# ============================================================================= | |
# RESULT MANAGEMENT | |
# ============================================================================= | |
def save_results_with_backup( | |
results: dict[str, Any], | |
primary_path: str | Path, | |
model_name: str, | |
result_type: str = "evaluation", | |
volume_manager: BeamVolumeManager | None = None, | |
evaluation_manager: BeamEvaluationManager | None = None, | |
) -> bool: | |
""" | |
Save results with multiple backup strategies. | |
Args: | |
results: Results dictionary to save | |
primary_path: Primary save location | |
model_name: Model name for filename generation | |
result_type: Type of results (evaluation, benchmark, etc.) | |
volume_manager: Optional volume manager for Beam storage | |
evaluation_manager: Optional evaluation manager for specialized storage | |
Returns: | |
True if saved successfully to at least one location | |
""" | |
success_count = 0 | |
safe_name = get_safe_model_name(model_name) | |
# Save to primary location | |
try: | |
primary = Path(primary_path) | |
primary.mkdir(parents=True, exist_ok=True) | |
filename = f"{result_type}_{safe_name}.json" | |
filepath = primary / filename | |
with filepath.open("w") as f: | |
json.dump(results, f, indent=2, default=str) | |
logger.info(f"💾 Saved {result_type} results to: {filepath}") | |
success_count += 1 | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to save to primary location: {e}") | |
# Save to Beam volume if available | |
if volume_manager: | |
try: | |
volume_path = Path(volume_manager.mount_path) / f"{result_type}_results" | |
volume_path.mkdir(parents=True, exist_ok=True) | |
filename = f"{result_type}_{safe_name}.json" | |
filepath = volume_path / filename | |
with filepath.open("w") as f: | |
json.dump(results, f, indent=2, default=str) | |
logger.info(f"💾 Saved {result_type} results to Beam volume: {filepath}") | |
success_count += 1 | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to save to Beam volume: {e}") | |
# Save via evaluation manager if available and appropriate | |
if evaluation_manager and result_type == "evaluation": | |
try: | |
success = evaluation_manager.save_evaluation_results(model_name, results) | |
if success: | |
logger.info(f"💾 Saved via evaluation manager for {model_name}") | |
success_count += 1 | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to save via evaluation manager: {e}") | |
return success_count > 0 | |
def load_existing_results( | |
model_name: str, | |
result_type: str = "evaluation", | |
search_paths: list[str | Path] | None = None, | |
volume_manager: BeamVolumeManager | None = None, | |
evaluation_manager: BeamEvaluationManager | None = None, | |
) -> dict[str, Any] | None: | |
""" | |
Load existing results from multiple possible locations. | |
Args: | |
model_name: Model name to search for | |
result_type: Type of results to load | |
search_paths: Additional paths to search | |
volume_manager: Optional volume manager | |
evaluation_manager: Optional evaluation manager | |
Returns: | |
Results dictionary if found, None otherwise | |
""" | |
safe_name = get_safe_model_name(model_name) | |
filename = f"{result_type}_{safe_name}.json" | |
# Search in provided paths | |
if search_paths: | |
for search_path in search_paths: | |
filepath = Path(search_path) / filename | |
if filepath.exists(): | |
try: | |
with filepath.open("r") as f: | |
results = json.load(f) | |
logger.info(f"📂 Loaded existing {result_type} results from: {filepath}") | |
return results | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to load from {filepath}: {e}") | |
# Search in Beam volume | |
if volume_manager: | |
volume_path = Path(volume_manager.mount_path) / f"{result_type}_results" / filename | |
if volume_path.exists(): | |
try: | |
with volume_path.open("r") as f: | |
results = json.load(f) | |
logger.info(f"📂 Loaded existing {result_type} results from Beam volume: {volume_path}") | |
return results | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to load from Beam volume: {e}") | |
# Try evaluation manager | |
if evaluation_manager and result_type == "evaluation": | |
try: | |
results = evaluation_manager.load_evaluation_results(model_name) | |
if results: | |
logger.info(f"📂 Loaded existing {result_type} results via evaluation manager") | |
return results | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to load via evaluation manager: {e}") | |
logger.info(f"ℹ️ No existing {result_type} results found for {model_name}") | |
return None | |
# ============================================================================= | |
# WORKFLOW HELPERS | |
# ============================================================================= | |
def print_workflow_summary( | |
workflow_name: str, | |
total_items: int, | |
processed_items: int, | |
skipped_items: int, | |
execution_time: float | None = None, | |
) -> None: | |
"""Print a standardized workflow summary.""" | |
logger.info(f"\n✅ {workflow_name} complete!") | |
logger.info(f"📊 Total items: {total_items}") | |
logger.info(f"✨ Newly processed: {processed_items}") | |
logger.info(f"⏭️ Skipped (already done): {skipped_items}") | |
if execution_time: | |
logger.info(f"⏱️ Execution time: {execution_time:.2f} seconds") | |
def check_existing_results( | |
items: list[str], | |
result_type: str, | |
search_paths: list[str | Path] | None = None, | |
volume_manager: BeamVolumeManager | None = None, | |
) -> tuple[list[str], list[str]]: | |
""" | |
Check which items already have results and which need processing. | |
Args: | |
items: List of items (model names, etc.) to check | |
result_type: Type of results to check for | |
search_paths: Paths to search for existing results | |
volume_manager: Optional volume manager | |
Returns: | |
Tuple of (items_to_process, items_to_skip) | |
""" | |
to_process = [] | |
to_skip = [] | |
for item in items: | |
existing = load_existing_results(item, result_type, search_paths, volume_manager) | |
if existing: | |
to_skip.append(item) | |
else: | |
to_process.append(item) | |
return to_process, to_skip | |
# ============================================================================= | |
# INITIALIZATION | |
# ============================================================================= | |
def initialize_distiller_logging(level: int = logging.INFO) -> None: | |
"""Initialize logging for distiller package.""" | |
setup_logging(level) | |
logger.info("🚀 Distiller package initialized") | |
# Ensure logging is set up when module is imported | |
initialize_distiller_logging() | |