File size: 207 Bytes
f56ede2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
import torch

def setup_device(pipe):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        pipe.enable_model_cpu_offload()
    pipe.to(device)
    return device