neuron-export / synchronizer.py
badaoui's picture
badaoui HF Staff
Update synchronizer.py
2c6f123 verified
"""
External module for creating PR-based cache synchronization.
"""
import os
import threading
from datetime import datetime
from pathlib import Path
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy
from optimum.neuron.utils.cache_utils import get_hf_hub_cache_repo
from optimum.neuron.utils.require_utils import requires_torch_neuronx
from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils.import_utils import is_neuronx_available
from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3
@requires_torch_neuronx
def synchronize_hub_cache_with_pr(
cache_path: str | Path | None = None,
cache_repo_id: str | None = None,
commit_message: str | None = None,
commit_description: str | None = None,
token: str | None = None,
non_blocking: bool = False,
):
"""Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request.
Args:
cache_path (`str | Path | None`, defaults to `None`):
The path of the folder to use for synchronization.
cache_repo_id (`str | None`, defaults to `None`):
The id of the HuggingFace cache repository, in the form 'org|user/name'.
non_blocking (`bool`, defaults to `False`):
If `True`, the synchronization will be done in a non-blocking way.
Yields:
Status messages about the synchronization process.
Returns:
The URL of the created pull request or None if non_blocking=True.
"""
# Validate cache path if provided
if cache_path is not None:
cache_path = Path(cache_path)
cache_path_str = cache_path.as_posix()
if not cache_path.is_dir():
raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.")
cache_url = CacheUrl(cache_path_str, url_type="fs")
else:
cache_url = None
# Get default cache repo if not provided
if cache_repo_id is None:
cache_repo_id = get_hf_hub_cache_repo()
# Create the hub cache proxy using the existing function
hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id)
# Check if S3 cache (not supported for PR workflow)
if isinstance(hub_cache_proxy.default_cache, CompileCacheS3):
raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.")
def _create_pr():
"""Internal function to create the PR"""
try:
pr_url = hub_cache_proxy.api.upload_folder(
repo_id=cache_repo_id,
folder_path=hub_cache_proxy.default_cache.cache_path,
commit_message=commit_message,
commit_description=commit_description,
ignore_patterns="lock",
create_pr=True,
token=token
)
yield f"Pull request created successfully: {pr_url}"
return pr_url
except Exception as e:
yield f"Error: Failed to create PR for cache synchronization: {e}"
raise
# Handle distributed training scenario
if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None:
# Multi-process execution
pr_url = None
if xr.local_ordinal() == 0:
# Only the first process creates the PR
if non_blocking:
def sync_thread():
try:
for status in _create_pr():
yield status
except Exception as e:
yield f"Error: Background sync failed: {e}"
thread = threading.Thread(target=sync_thread)
thread.start()
yield "Cache synchronization started in background thread"
else:
for status in _create_pr():
yield status
if "Pull request created successfully:" in status:
pr_url = status.split(": ", 1)[1]
# Synchronize all processes
xm.rendezvous("synchronize_hub_cache_pr")
return pr_url if not non_blocking else None
# Single process execution
if non_blocking:
def sync_thread():
try:
for status in _create_pr():
yield status
except Exception as e:
yield f"Error: Background sync failed: {e}"
thread = threading.Thread(target=sync_thread)
thread.start()
yield "Cache synchronization started in background thread"
return None
else:
pr_url = None
for status in _create_pr():
yield status
if "Pull request created successfully:" in status:
pr_url = status.split(": ", 1)[1]
return pr_url