kernel
rotary / tests /utils.py
YangKai0616's picture
Add triton support(XPU)
2e98b65
raw
history blame
517 Bytes
import torch
def infer_device():
"""
Get current device name based on available devices
"""
if torch.cuda.is_available(): # Works for both Nvidia and AMD
return "cuda"
elif torch.xpu.is_available():
return "xpu"
else:
return None
def supports_bfloat16():
device = infer_device()
if device == "cuda":
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
elif device == "xpu":
return True
else:
return False