File size: 517 Bytes
2e98b65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
def infer_device():
"""
Get current device name based on available devices
"""
if torch.cuda.is_available(): # Works for both Nvidia and AMD
return "cuda"
elif torch.xpu.is_available():
return "xpu"
else:
return None
def supports_bfloat16():
device = infer_device()
if device == "cuda":
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
elif device == "xpu":
return True
else:
return False |