""" External module for creating PR-based cache synchronization. """ import os import threading from datetime import datetime from pathlib import Path import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy from optimum.neuron.utils.cache_utils import get_hf_hub_cache_repo from optimum.neuron.utils.require_utils import requires_torch_neuronx from optimum.neuron.utils.version_utils import get_neuronxcc_version from optimum.neuron.utils.import_utils import is_neuronx_available from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3 @requires_torch_neuronx def synchronize_hub_cache_with_pr( cache_path: str | Path | None = None, cache_repo_id: str | None = None, commit_message: str | None = None, commit_description: str | None = None, token: str | None = None, non_blocking: bool = False, ): """Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request. Args: cache_path (`str | Path | None`, defaults to `None`): The path of the folder to use for synchronization. cache_repo_id (`str | None`, defaults to `None`): The id of the HuggingFace cache repository, in the form 'org|user/name'. non_blocking (`bool`, defaults to `False`): If `True`, the synchronization will be done in a non-blocking way. Yields: Status messages about the synchronization process. Returns: The URL of the created pull request or None if non_blocking=True. """ # Validate cache path if provided if cache_path is not None: cache_path = Path(cache_path) cache_path_str = cache_path.as_posix() if not cache_path.is_dir(): raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.") cache_url = CacheUrl(cache_path_str, url_type="fs") else: cache_url = None # Get default cache repo if not provided if cache_repo_id is None: cache_repo_id = get_hf_hub_cache_repo() # Create the hub cache proxy using the existing function hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id) # Check if S3 cache (not supported for PR workflow) if isinstance(hub_cache_proxy.default_cache, CompileCacheS3): raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.") def _create_pr(): """Internal function to create the PR""" try: pr_url = hub_cache_proxy.api.upload_folder( repo_id=cache_repo_id, folder_path=hub_cache_proxy.default_cache.cache_path, commit_message=commit_message, commit_description=commit_description, ignore_patterns="lock", create_pr=True, token=token ) yield f"Pull request created successfully: {pr_url}" return pr_url except Exception as e: yield f"Error: Failed to create PR for cache synchronization: {e}" raise # Handle distributed training scenario if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None: # Multi-process execution pr_url = None if xr.local_ordinal() == 0: # Only the first process creates the PR if non_blocking: def sync_thread(): try: for status in _create_pr(): yield status except Exception as e: yield f"Error: Background sync failed: {e}" thread = threading.Thread(target=sync_thread) thread.start() yield "Cache synchronization started in background thread" else: for status in _create_pr(): yield status if "Pull request created successfully:" in status: pr_url = status.split(": ", 1)[1] # Synchronize all processes xm.rendezvous("synchronize_hub_cache_pr") return pr_url if not non_blocking else None # Single process execution if non_blocking: def sync_thread(): try: for status in _create_pr(): yield status except Exception as e: yield f"Error: Background sync failed: {e}" thread = threading.Thread(target=sync_thread) thread.start() yield "Cache synchronization started in background thread" return None else: pr_url = None for status in _create_pr(): yield status if "Pull request created successfully:" in status: pr_url = status.split(": ", 1)[1] return pr_url