from __future__ import annotations from logging import getLogger import torch logger = getLogger(__name__) def select_optimal_device(device: str | None) -> str: """ Guess what your optimal device should be based on backend availability. If you pass a device, we just pass it through. :param device: The device to use. If this is not None you get back what you passed. :return: The selected device. """ if device is None: if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" logger.info(f"Automatically selected device: {device}") return device