| import argparse |
| import fcntl |
| import os |
| import random |
| import sys |
| import time |
| from typing import List |
|
|
| SLEEP_BACKOFF = 5.0 |
|
|
|
|
| def main(): |
| """ |
| Remark: Can use `lslocks` to debug |
| """ |
| args = _parse_args() |
|
|
| if args.print_only: |
| _execute_print_only(args) |
| return |
|
|
| fd_locks = _try_acquire(args) |
|
|
| dev_list = ",".join(str(x.gpu_id) for x in fd_locks) |
| os.environ["CUDA_VISIBLE_DEVICES"] = dev_list |
|
|
| if args.env: |
| for env_var in args.env: |
| name, value = env_var.split("=") |
| os.environ[name] = value |
| print( |
| f"[gpu_lock_exec] Setting environment variable: {name}={value}", |
| flush=True, |
| ) |
| print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) |
|
|
| _os_execvp(args) |
|
|
|
|
| def _os_execvp(args): |
| cmd = args.cmd |
| if cmd[0] == "--": |
| cmd = cmd[1:] |
|
|
| |
| os.execvp(cmd[0], cmd) |
|
|
|
|
| def _parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument( |
| "--count", type=int, default=None, help="Acquire this many GPUs (any free ones)" |
| ) |
| p.add_argument( |
| "--devices", |
| type=str, |
| default=None, |
| help="Comma separated explicit devices to acquire (e.g. 0,1)", |
| ) |
| p.add_argument( |
| "--total-gpus", type=int, default=8, help="Total GPUs on the machine" |
| ) |
| p.add_argument( |
| "--timeout", |
| type=int, |
| default=3600, |
| help="Seconds to wait for locks before failing", |
| ) |
| p.add_argument( |
| "--env", |
| type=str, |
| default=None, |
| nargs="*", |
| help="Environment variables to set (e.g. HF_TOKEN=1234567890)", |
| ) |
| p.add_argument( |
| "--lock-path-pattern", |
| type=str, |
| default="/dev/shm/custom_gpu_lock_{gpu_id}.lock", |
| help='Filename pattern with "{gpu_id}" placeholder', |
| ) |
| p.add_argument( |
| "--print-only", |
| action="store_true", |
| help="Probe free devices and print them (does NOT hold locks)", |
| ) |
| p.add_argument( |
| "cmd", |
| nargs=argparse.REMAINDER, |
| help="Command to exec after '--' (required unless --print-only)", |
| ) |
| args = p.parse_args() |
|
|
| if "{gpu_id}" not in args.lock_path_pattern: |
| raise Exception("ERROR: --lock-path-pattern must contain '{i}' placeholder.") |
|
|
| if not args.cmd and not args.print_only: |
| raise Exception("ERROR: missing command to run. Use -- before command.") |
|
|
| return args |
|
|
|
|
| def _execute_print_only(args): |
| free = [] |
| _ensure_lock_files(path_pattern=args.lock_path_pattern, total_gpus=args.total_gpus) |
| for i in range(args.total_gpus): |
| try: |
| fd_lock = FdLock(args.lock_path_pattern, i) |
| fd_lock.open() |
| try: |
| fd_lock.lock() |
| fcntl.flock(fd_lock.fd, fcntl.LOCK_UN) |
| free.append(i) |
| except BlockingIOError: |
| pass |
| fd_lock.close() |
| except Exception as e: |
| print( |
| f"Warning: Error while probing lock: {e}", file=sys.stderr, flush=True |
| ) |
|
|
| print("Free GPUs:", ",".join(str(x) for x in free), flush=True) |
|
|
|
|
| def _try_acquire(args): |
| if args.devices: |
| devs = _parse_devices(args.devices) |
| return _try_acquire_specific(devs, args.lock_path_pattern, args.timeout) |
| else: |
| return _try_acquire_count( |
| args.count, args.total_gpus, args.lock_path_pattern, args.timeout |
| ) |
|
|
|
|
| def _try_acquire_specific(devs: List[int], path_pattern: str, timeout: int): |
| fd_locks = [] |
| start = time.time() |
| try: |
| _ensure_lock_files(path_pattern, max(devs) + 1) |
| for gpu_id in devs: |
| fd_lock = FdLock(path_pattern, gpu_id=gpu_id) |
| fd_lock.open() |
| while True: |
| try: |
| fd_lock.lock() |
| break |
| except BlockingIOError: |
| if time.time() - start > timeout: |
| raise TimeoutError(f"Timeout while waiting for GPU {gpu_id}") |
| time.sleep(SLEEP_BACKOFF * random.random()) |
| fd_locks.append(fd_lock) |
| return fd_locks |
| except Exception as e: |
| print( |
| f"Error during specific GPU acquisition: {e}", file=sys.stderr, flush=True |
| ) |
| for fd_lock in fd_locks: |
| fd_lock.close() |
| raise |
|
|
|
|
| def _try_acquire_count(count: int, total_gpus: int, path_pattern: str, timeout: int): |
| start = time.time() |
| _ensure_lock_files(path_pattern, total_gpus) |
| while True: |
| fd_locks: List = [] |
| for gpu_id in range(total_gpus): |
| fd_lock = FdLock(path_pattern, gpu_id=gpu_id) |
| fd_lock.open() |
| try: |
| fd_lock.lock() |
| except BlockingIOError: |
| fd_lock.close() |
| continue |
|
|
| fd_locks.append(fd_lock) |
| if len(fd_locks) == count: |
| return fd_locks |
|
|
| gotten_gpu_ids = [x.gpu_id for x in fd_locks] |
| for fd_lock in fd_locks: |
| fd_lock.close() |
| del fd_lock |
|
|
| if time.time() - start > timeout: |
| raise TimeoutError(f"Timeout acquiring {count} GPUs (out of {total_gpus})") |
|
|
| print( |
| f"[gpu_lock_exec] try_acquire_count failed, sleep and retry (only got: {gotten_gpu_ids})", |
| flush=True, |
| ) |
| time.sleep(SLEEP_BACKOFF * random.random()) |
|
|
|
|
| class FdLock: |
| def __init__(self, path_pattern, gpu_id: int): |
| self.gpu_id = gpu_id |
| self.path = _get_lock_path(path_pattern, self.gpu_id) |
| self.fd = None |
|
|
| def open(self): |
| assert self.fd is None |
| self.fd = open(self.path, "a+") |
| |
| os.set_inheritable(self.fd.fileno(), True) |
|
|
| def lock(self): |
| assert self.fd is not None |
| fcntl.flock(self.fd, fcntl.LOCK_EX | fcntl.LOCK_NB) |
|
|
| def close(self): |
| assert self.fd is not None |
| try: |
| self.fd.close() |
| except Exception as e: |
| print( |
| f"Warning: Failed to close file descriptor: {e}", |
| file=sys.stderr, |
| flush=True, |
| ) |
| self.fd = None |
|
|
|
|
| def _ensure_lock_files(path_pattern: str, total_gpus: int): |
| lock_dir = os.path.dirname(path_pattern) |
| if lock_dir: |
| os.makedirs(lock_dir, exist_ok=True) |
| for gpu_id in range(total_gpus): |
| p = _get_lock_path(path_pattern, gpu_id) |
| try: |
| open(p, "a").close() |
| except Exception as e: |
| print( |
| f"Warning: Could not create lock file {p}: {e}", |
| file=sys.stderr, |
| flush=True, |
| ) |
|
|
|
|
| def _get_lock_path(path_pattern: str, gpu_id: int) -> str: |
| return path_pattern.format(gpu_id=gpu_id) |
|
|
|
|
| def _parse_devices(s: str) -> List[int]: |
| return [int(x) for x in s.split(",") if x.strip() != ""] |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|