Spaces:
Running
on
Zero
Running
on
Zero
from contextlib import contextmanager | |
import torch | |
try: | |
from flash_attn import flash_attn_varlen_qkvpacked_func as flash_varlen_qkvpacked_attn | |
except ImportError: | |
flash_varlen_qkvpacked_attn = None | |
try: | |
from sageattention import sageattn as sage_attn | |
except ImportError: | |
sage_attn = None | |
try: | |
from comfy.ldm.modules.attention import comfy_optimized_attention as comfy_attn | |
except ImportError: | |
comfy_attn = None | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
backends = [] | |
if torch.cuda.get_device_properties(0).major < 7: | |
backends.append(SDPBackend.MATH) | |
if torch.cuda.get_device_properties(0).major >= 9.0: | |
backends.append(SDPBackend.CUDNN_ATTENTION) | |
else: | |
backends.append(SDPBackend.EFFICIENT_ATTENTION) | |
def sdpa_attn_ctx(): | |
with sdpa_kernel(backends): | |
yield | |