File size: 7,038 Bytes
29658b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import argparse
import fcntl
import os
import random
import sys
import time
from typing import List

SLEEP_BACKOFF = 5.0


def main():
    """
    Remark: Can use `lslocks` to debug
    """
    args = _parse_args()

    if args.print_only:
        _execute_print_only(args)
        return

    fd_locks = _try_acquire(args)

    dev_list = ",".join(str(x.gpu_id) for x in fd_locks)
    os.environ["CUDA_VISIBLE_DEVICES"] = dev_list

    if args.env:
        for env_var in args.env:
            name, value = env_var.split("=")
            os.environ[name] = value
            print(
                f"[gpu_lock_exec] Setting environment variable: {name}={value}",
                flush=True,
            )
    print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True)

    _os_execvp(args)


def _os_execvp(args):
    cmd = args.cmd
    if cmd[0] == "--":
        cmd = cmd[1:]

    # propagate the environment variables
    os.execvp(cmd[0], cmd)


def _parse_args():
    p = argparse.ArgumentParser()
    p.add_argument(
        "--count", type=int, default=None, help="Acquire this many GPUs (any free ones)"
    )
    p.add_argument(
        "--devices",
        type=str,
        default=None,
        help="Comma separated explicit devices to acquire (e.g. 0,1)",
    )
    p.add_argument(
        "--total-gpus", type=int, default=8, help="Total GPUs on the machine"
    )
    p.add_argument(
        "--timeout",
        type=int,
        default=3600,
        help="Seconds to wait for locks before failing",
    )
    p.add_argument(
        "--env",
        type=str,
        default=None,
        nargs="*",
        help="Environment variables to set (e.g. HF_TOKEN=1234567890)",
    )
    p.add_argument(
        "--lock-path-pattern",
        type=str,
        default="/dev/shm/custom_gpu_lock_{gpu_id}.lock",
        help='Filename pattern with "{gpu_id}" placeholder',
    )
    p.add_argument(
        "--print-only",
        action="store_true",
        help="Probe free devices and print them (does NOT hold locks)",
    )
    p.add_argument(
        "cmd",
        nargs=argparse.REMAINDER,
        help="Command to exec after '--' (required unless --print-only)",
    )
    args = p.parse_args()

    if "{gpu_id}" not in args.lock_path_pattern:
        raise Exception("ERROR: --lock-path-pattern must contain '{i}' placeholder.")

    if not args.cmd and not args.print_only:
        raise Exception("ERROR: missing command to run. Use -- before command.")

    return args


def _execute_print_only(args):
    free = []
    _ensure_lock_files(path_pattern=args.lock_path_pattern, total_gpus=args.total_gpus)
    for i in range(args.total_gpus):
        try:
            fd_lock = FdLock(args.lock_path_pattern, i)
            fd_lock.open()
            try:
                fd_lock.lock()
                fcntl.flock(fd_lock.fd, fcntl.LOCK_UN)
                free.append(i)
            except BlockingIOError:
                pass
            fd_lock.close()
        except Exception as e:
            print(
                f"Warning: Error while probing lock: {e}", file=sys.stderr, flush=True
            )

    print("Free GPUs:", ",".join(str(x) for x in free), flush=True)


def _try_acquire(args):
    if args.devices:
        devs = _parse_devices(args.devices)
        return _try_acquire_specific(devs, args.lock_path_pattern, args.timeout)
    else:
        return _try_acquire_count(
            args.count, args.total_gpus, args.lock_path_pattern, args.timeout
        )


def _try_acquire_specific(devs: List[int], path_pattern: str, timeout: int):
    fd_locks = []
    start = time.time()
    try:
        _ensure_lock_files(path_pattern, max(devs) + 1)
        for gpu_id in devs:
            fd_lock = FdLock(path_pattern, gpu_id=gpu_id)
            fd_lock.open()
            while True:
                try:
                    fd_lock.lock()
                    break
                except BlockingIOError:
                    if time.time() - start > timeout:
                        raise TimeoutError(f"Timeout while waiting for GPU {gpu_id}")
                    time.sleep(SLEEP_BACKOFF * random.random())
            fd_locks.append(fd_lock)
        return fd_locks
    except Exception as e:
        print(
            f"Error during specific GPU acquisition: {e}", file=sys.stderr, flush=True
        )
        for fd_lock in fd_locks:
            fd_lock.close()
        raise


def _try_acquire_count(count: int, total_gpus: int, path_pattern: str, timeout: int):
    start = time.time()
    _ensure_lock_files(path_pattern, total_gpus)
    while True:
        fd_locks: List = []
        for gpu_id in range(total_gpus):
            fd_lock = FdLock(path_pattern, gpu_id=gpu_id)
            fd_lock.open()
            try:
                fd_lock.lock()
            except BlockingIOError:
                fd_lock.close()
                continue

            fd_locks.append(fd_lock)
            if len(fd_locks) == count:
                return fd_locks

        gotten_gpu_ids = [x.gpu_id for x in fd_locks]
        for fd_lock in fd_locks:
            fd_lock.close()
        del fd_lock

        if time.time() - start > timeout:
            raise TimeoutError(f"Timeout acquiring {count} GPUs (out of {total_gpus})")

        print(
            f"[gpu_lock_exec] try_acquire_count failed, sleep and retry (only got: {gotten_gpu_ids})",
            flush=True,
        )
        time.sleep(SLEEP_BACKOFF * random.random())


class FdLock:
    def __init__(self, path_pattern, gpu_id: int):
        self.gpu_id = gpu_id
        self.path = _get_lock_path(path_pattern, self.gpu_id)
        self.fd = None

    def open(self):
        assert self.fd is None
        self.fd = open(self.path, "a+")
        # try to avoid lock disappear when execvp
        os.set_inheritable(self.fd.fileno(), True)

    def lock(self):
        assert self.fd is not None
        fcntl.flock(self.fd, fcntl.LOCK_EX | fcntl.LOCK_NB)

    def close(self):
        assert self.fd is not None
        try:
            self.fd.close()
        except Exception as e:
            print(
                f"Warning: Failed to close file descriptor: {e}",
                file=sys.stderr,
                flush=True,
            )
        self.fd = None


def _ensure_lock_files(path_pattern: str, total_gpus: int):
    lock_dir = os.path.dirname(path_pattern)
    if lock_dir:
        os.makedirs(lock_dir, exist_ok=True)
    for gpu_id in range(total_gpus):
        p = _get_lock_path(path_pattern, gpu_id)
        try:
            open(p, "a").close()
        except Exception as e:
            print(
                f"Warning: Could not create lock file {p}: {e}",
                file=sys.stderr,
                flush=True,
            )


def _get_lock_path(path_pattern: str, gpu_id: int) -> str:
    return path_pattern.format(gpu_id=gpu_id)


def _parse_devices(s: str) -> List[int]:
    return [int(x) for x in s.split(",") if x.strip() != ""]


if __name__ == "__main__":
    main()