Spaces:
Paused
Paused
""" | |
External module for creating PR-based cache synchronization. | |
""" | |
import os | |
import threading | |
from datetime import datetime | |
from pathlib import Path | |
import torch_xla.core.xla_model as xm | |
import torch_xla.runtime as xr | |
from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy | |
from optimum.neuron.utils.cache_utils import get_hf_hub_cache_repo | |
from optimum.neuron.utils.require_utils import requires_torch_neuronx | |
from optimum.neuron.utils.version_utils import get_neuronxcc_version | |
from optimum.neuron.utils.import_utils import is_neuronx_available | |
from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3 | |
def synchronize_hub_cache_with_pr( | |
cache_path: str | Path | None = None, | |
cache_repo_id: str | None = None, | |
commit_message: str | None = None, | |
commit_description: str | None = None, | |
token: str | None = None, | |
non_blocking: bool = False, | |
): | |
"""Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request. | |
Args: | |
cache_path (`str | Path | None`, defaults to `None`): | |
The path of the folder to use for synchronization. | |
cache_repo_id (`str | None`, defaults to `None`): | |
The id of the HuggingFace cache repository, in the form 'org|user/name'. | |
non_blocking (`bool`, defaults to `False`): | |
If `True`, the synchronization will be done in a non-blocking way. | |
Yields: | |
Status messages about the synchronization process. | |
Returns: | |
The URL of the created pull request or None if non_blocking=True. | |
""" | |
# Validate cache path if provided | |
if cache_path is not None: | |
cache_path = Path(cache_path) | |
cache_path_str = cache_path.as_posix() | |
if not cache_path.is_dir(): | |
raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.") | |
cache_url = CacheUrl(cache_path_str, url_type="fs") | |
else: | |
cache_url = None | |
# Get default cache repo if not provided | |
if cache_repo_id is None: | |
cache_repo_id = get_hf_hub_cache_repo() | |
# Create the hub cache proxy using the existing function | |
hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id) | |
# Check if S3 cache (not supported for PR workflow) | |
if isinstance(hub_cache_proxy.default_cache, CompileCacheS3): | |
raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.") | |
def _create_pr(): | |
"""Internal function to create the PR""" | |
try: | |
pr_url = hub_cache_proxy.api.upload_folder( | |
repo_id=cache_repo_id, | |
folder_path=hub_cache_proxy.default_cache.cache_path, | |
commit_message=commit_message, | |
commit_description=commit_description, | |
ignore_patterns="lock", | |
create_pr=True, | |
token=token | |
) | |
yield f"Pull request created successfully: {pr_url}" | |
return pr_url | |
except Exception as e: | |
yield f"Error: Failed to create PR for cache synchronization: {e}" | |
raise | |
# Handle distributed training scenario | |
if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None: | |
# Multi-process execution | |
pr_url = None | |
if xr.local_ordinal() == 0: | |
# Only the first process creates the PR | |
if non_blocking: | |
def sync_thread(): | |
try: | |
for status in _create_pr(): | |
yield status | |
except Exception as e: | |
yield f"Error: Background sync failed: {e}" | |
thread = threading.Thread(target=sync_thread) | |
thread.start() | |
yield "Cache synchronization started in background thread" | |
else: | |
for status in _create_pr(): | |
yield status | |
if "Pull request created successfully:" in status: | |
pr_url = status.split(": ", 1)[1] | |
# Synchronize all processes | |
xm.rendezvous("synchronize_hub_cache_pr") | |
return pr_url if not non_blocking else None | |
# Single process execution | |
if non_blocking: | |
def sync_thread(): | |
try: | |
for status in _create_pr(): | |
yield status | |
except Exception as e: | |
yield f"Error: Background sync failed: {e}" | |
thread = threading.Thread(target=sync_thread) | |
thread.start() | |
yield "Cache synchronization started in background thread" | |
return None | |
else: | |
pr_url = None | |
for status in _create_pr(): | |
yield status | |
if "Pull request created successfully:" in status: | |
pr_url = status.split(": ", 1)[1] | |
return pr_url |