File size: 424 Bytes
13aa528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch


def determine_accelerator():
    """
    Determine the accelerator to be used based on the environment.
    """

    # Check for CUDA availability
    if torch.cuda.is_available():
        return "cuda"

    # Check for MPS (Metal Performance Shaders) availability on macOS
    if torch.backends.mps.is_available():
        return "mps"

    # Default to CPU if no accelerators are available
    return "cpu"