wise-water's picture
init commit
13aa528
import torch
def determine_accelerator():
"""
Determine the accelerator to be used based on the environment.
"""
# Check for CUDA availability
if torch.cuda.is_available():
return "cuda"
# Check for MPS (Metal Performance Shaders) availability on macOS
if torch.backends.mps.is_available():
return "mps"
# Default to CPU if no accelerators are available
return "cpu"