Spaces:
Paused
Paused
File size: 5,014 Bytes
8a2a989 7a345a2 21c37fe 8a2a989 dd77e27 126eae7 2c6f123 8a2a989 dd77e27 8a2a989 2c6f123 8a2a989 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""
External module for creating PR-based cache synchronization.
"""
import os
import threading
from datetime import datetime
from pathlib import Path
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy
from optimum.neuron.utils.cache_utils import get_hf_hub_cache_repo
from optimum.neuron.utils.require_utils import requires_torch_neuronx
from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils.import_utils import is_neuronx_available
from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3
@requires_torch_neuronx
def synchronize_hub_cache_with_pr(
cache_path: str | Path | None = None,
cache_repo_id: str | None = None,
commit_message: str | None = None,
commit_description: str | None = None,
token: str | None = None,
non_blocking: bool = False,
):
"""Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request.
Args:
cache_path (`str | Path | None`, defaults to `None`):
The path of the folder to use for synchronization.
cache_repo_id (`str | None`, defaults to `None`):
The id of the HuggingFace cache repository, in the form 'org|user/name'.
non_blocking (`bool`, defaults to `False`):
If `True`, the synchronization will be done in a non-blocking way.
Yields:
Status messages about the synchronization process.
Returns:
The URL of the created pull request or None if non_blocking=True.
"""
# Validate cache path if provided
if cache_path is not None:
cache_path = Path(cache_path)
cache_path_str = cache_path.as_posix()
if not cache_path.is_dir():
raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.")
cache_url = CacheUrl(cache_path_str, url_type="fs")
else:
cache_url = None
# Get default cache repo if not provided
if cache_repo_id is None:
cache_repo_id = get_hf_hub_cache_repo()
# Create the hub cache proxy using the existing function
hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id)
# Check if S3 cache (not supported for PR workflow)
if isinstance(hub_cache_proxy.default_cache, CompileCacheS3):
raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.")
def _create_pr():
"""Internal function to create the PR"""
try:
pr_url = hub_cache_proxy.api.upload_folder(
repo_id=cache_repo_id,
folder_path=hub_cache_proxy.default_cache.cache_path,
commit_message=commit_message,
commit_description=commit_description,
ignore_patterns="lock",
create_pr=True,
token=token
)
yield f"Pull request created successfully: {pr_url}"
return pr_url
except Exception as e:
yield f"Error: Failed to create PR for cache synchronization: {e}"
raise
# Handle distributed training scenario
if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None:
# Multi-process execution
pr_url = None
if xr.local_ordinal() == 0:
# Only the first process creates the PR
if non_blocking:
def sync_thread():
try:
for status in _create_pr():
yield status
except Exception as e:
yield f"Error: Background sync failed: {e}"
thread = threading.Thread(target=sync_thread)
thread.start()
yield "Cache synchronization started in background thread"
else:
for status in _create_pr():
yield status
if "Pull request created successfully:" in status:
pr_url = status.split(": ", 1)[1]
# Synchronize all processes
xm.rendezvous("synchronize_hub_cache_pr")
return pr_url if not non_blocking else None
# Single process execution
if non_blocking:
def sync_thread():
try:
for status in _create_pr():
yield status
except Exception as e:
yield f"Error: Background sync failed: {e}"
thread = threading.Thread(target=sync_thread)
thread.start()
yield "Cache synchronization started in background thread"
return None
else:
pr_url = None
for status in _create_pr():
yield status
if "Pull request created successfully:" in status:
pr_url = status.split(": ", 1)[1]
return pr_url |