Spaces:
Build error
Build error
File size: 5,180 Bytes
3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 8471c6c 3964763 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
'flash_attention',
'attention',
]
import xformers.ops as xops
from xformers.ops import memory_efficient_attention, fmha
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.float16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes, f"dtype must be float16 or bfloat16, got {dtype}"
assert q.device.type == "cuda" and q.size(-1) <= 256
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# 预处理查询
if q_lens is None:
q = half(q.flatten(0, 1)) # [B*Lq, Nq, C1]
q_lens = torch.full((b,), lq, dtype=torch.int32, device=q.device)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)], dim=0))
# 预处理键和值
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.full((b,), lk, dtype=torch.int32, device=k.device)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)], dim=0))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)], dim=0))
# 确保数据类型一致
q = q.to(dtype)
k = k.to(dtype)
v = v.to(dtype)
if q_scale is not None:
q = q * q_scale
# 调整键和值的头数以匹配查询
n_q_heads = q.size(1)
n_k_heads = k.size(1)
if n_k_heads != n_q_heads:
assert n_q_heads % n_k_heads == 0, "Nq must be divisible by Nk"
repeat_factor = n_q_heads // n_k_heads
k = k.repeat(1, repeat_factor, 1)
v = v.repeat(1, repeat_factor, 1)
# if window_size != (-1, -1):
# raise NotImplementedError("Sliding window attention not supported with xFormers")
window_size = (-1, -1)
# 生成块对角掩码
q_lens_list = q_lens.cpu().tolist()
k_lens_list = k_lens.cpu().tolist()
if causal:
attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(q_seqlen=q_lens_list)
else:
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(q_seqlen=q_lens_list, kv_seqlen=k_lens_list)
# 添加虚拟批次维度以适应xFormers接口
q = q.unsqueeze(0) # [1, sum_q, nh, hd]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
# 调用xFormers的高效注意力实现
x = xops.memory_efficient_attention(
q, k, v,
attn_bias=attn_bias,
p=dropout_p,
scale=softmax_scale,
# deterministic=deterministic # xFormers可能不支持此参数
)
# 移除虚拟批次维度并恢复原始形状
x = x.squeeze(0).unflatten(0, (b, lq)) # [B, Lq, Nq, C2]
return x.to(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
|