File size: 5,014 Bytes
8a2a989
 
 
 
 
 
 
 
 
 
 
 
7a345a2
 
 
 
 
21c37fe
8a2a989
 
 
 
 
dd77e27
126eae7
2c6f123
8a2a989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd77e27
 
8a2a989
 
2c6f123
8a2a989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
External module for creating PR-based cache synchronization.
"""

import os
import threading
from datetime import datetime
from pathlib import Path

import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy
from optimum.neuron.utils.cache_utils  import get_hf_hub_cache_repo
from optimum.neuron.utils.require_utils  import requires_torch_neuronx
from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils.import_utils import is_neuronx_available
from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3 

@requires_torch_neuronx
def synchronize_hub_cache_with_pr(
    cache_path: str | Path | None = None,
    cache_repo_id: str | None = None,
    commit_message: str | None = None,
    commit_description: str | None = None,
    token: str | None = None,
    non_blocking: bool = False,
):
    """Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request.

    Args:
        cache_path (`str | Path | None`, defaults to `None`):
            The path of the folder to use for synchronization.
        cache_repo_id (`str | None`, defaults to `None`):
            The id of the HuggingFace cache repository, in the form 'org|user/name'.
        non_blocking (`bool`, defaults to `False`):
            If `True`, the synchronization will be done in a non-blocking way.
    
    Yields:
        Status messages about the synchronization process.
        
    Returns:
        The URL of the created pull request or None if non_blocking=True.
    """
    # Validate cache path if provided
    if cache_path is not None:
        cache_path = Path(cache_path)
        cache_path_str = cache_path.as_posix()
        if not cache_path.is_dir():
            raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.")
        cache_url = CacheUrl(cache_path_str, url_type="fs")
    else:
        cache_url = None
    
    # Get default cache repo if not provided
    if cache_repo_id is None:
        cache_repo_id = get_hf_hub_cache_repo()
    
    # Create the hub cache proxy using the existing function
    hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id)
    
    # Check if S3 cache (not supported for PR workflow)
    if isinstance(hub_cache_proxy.default_cache, CompileCacheS3):
        raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.")
    
    
    def _create_pr():
        """Internal function to create the PR"""
        try:
            pr_url = hub_cache_proxy.api.upload_folder(
                repo_id=cache_repo_id,
                folder_path=hub_cache_proxy.default_cache.cache_path,
                commit_message=commit_message,
                commit_description=commit_description,
                ignore_patterns="lock",
                create_pr=True,
                token=token
            )
            yield f"Pull request created successfully: {pr_url}"
            return pr_url
        except Exception as e:
            yield f"Error: Failed to create PR for cache synchronization: {e}"
            raise
    
    # Handle distributed training scenario
    if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None:
        # Multi-process execution
        pr_url = None
        if xr.local_ordinal() == 0:
            # Only the first process creates the PR
            if non_blocking:
                def sync_thread():
                    try:
                        for status in _create_pr():
                            yield status
                    except Exception as e:
                        yield f"Error: Background sync failed: {e}"
                
                thread = threading.Thread(target=sync_thread)
                thread.start()
                yield "Cache synchronization started in background thread"
            else:
                for status in _create_pr():
                    yield status
                    if "Pull request created successfully:" in status:
                        pr_url = status.split(": ", 1)[1]
        
        # Synchronize all processes
        xm.rendezvous("synchronize_hub_cache_pr")
        
        return pr_url if not non_blocking else None
    
    # Single process execution
    if non_blocking:
        def sync_thread():
            try:
                for status in _create_pr():
                    yield status
            except Exception as e:
                yield f"Error: Background sync failed: {e}"
        
        thread = threading.Thread(target=sync_thread)
        thread.start()
        yield "Cache synchronization started in background thread"
        return None
    else:
        pr_url = None
        for status in _create_pr():
            yield status
            if "Pull request created successfully:" in status:
                pr_url = status.split(": ", 1)[1]
        return pr_url