import torch | |
def infer_device(): | |
""" | |
Get current device name based on available devices | |
""" | |
if torch.cuda.is_available(): # Works for both Nvidia and AMD | |
return "cuda" | |
elif torch.xpu.is_available(): | |
return "xpu" | |
else: | |
return None | |
def supports_bfloat16(): | |
device = infer_device() | |
if device == "cuda": | |
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer | |
elif device == "xpu": | |
return True | |
else: | |
return False |