from __future__ import annotations | |
from logging import getLogger | |
import torch | |
logger = getLogger(__name__) | |
def select_optimal_device(device: str | None) -> str: | |
""" | |
Guess what your optimal device should be based on backend availability. | |
If you pass a device, we just pass it through. | |
:param device: The device to use. If this is not None you get back what you passed. | |
:return: The selected device. | |
""" | |
if device is None: | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
logger.info(f"Automatically selected device: {device}") | |
return device | |