File size: 4,039 Bytes
43c5292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from einops import rearrange

try:
    from flash_attn_interface import flash_attn_varlen_func

    print("Using FlashAttention v3.")
except ImportError:
    print("FlashAttention v3 not found, falling back to v2.")
    from flash_attn import flash_attn_varlen_func

from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input


def get_cu_seqlens(text_mask: torch.Tensor, img_len: int):
    """
    Compute cumulative sequence lengths (cu_seqlens) for FlashAttention.

    Args:
        text_mask (torch.Tensor): Boolean mask of shape (batch_size, text_seq_len).
        img_len (int): Length of image sequence.

    Returns:
        cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths for each segment.
        max_len (int): Maximum sequence length (text + image).
    """
    batch_size = text_mask.shape[0]
    text_len = text_mask.sum(dim=1)
    max_len = text_mask.shape[1] + img_len

    cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device)
    for i in range(batch_size):
        s = text_len[i] + img_len
        s1 = i * max_len + s
        s2 = (i + 1) * max_len
        cu_seqlens[2 * i + 1] = s1
        cu_seqlens[2 * i + 2] = s2

    return cu_seqlens, max_len


def flash_attn_v3(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_s: int,
    causal: bool = False,
    deterministic: bool = False,
):
    """
    FlashAttention v3 wrapper.

    Args:
        q, k, v (torch.Tensor): Query, key, value tensors of shape (batch, seq, nheads, head_dim).
        cu_seqlens (torch.Tensor): Cumulative sequence lengths.
        max_s (int): Maximum sequence length.
        causal (bool): Whether to apply causal masking.
        deterministic (bool): Deterministic computation.

    Returns:
        torch.Tensor: Output tensor of shape (batch, seq, nheads, head_dim).
    """
    batch_size, seqlen = q.shape[:2]
    q = q.reshape(-1, *q.shape[2:])
    k = k.reshape(-1, *k.shape[2:])
    v = v.reshape(-1, *v.shape[2:])
    output = flash_attn_varlen_func(
        q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, causal=causal, deterministic=deterministic
    )
    output = output.view(batch_size, seqlen, *output.shape[-2:])
    return output


def flash_attn_no_pad(
    qkv: torch.Tensor,
    key_padding_mask: torch.Tensor,
    causal: bool = False,
    dropout_p: float = 0.0,
    softmax_scale=None,
    deterministic: bool = False,
):
    """
    FlashAttention for packed QKV input without padding.

    Args:
        qkv (torch.Tensor): Input tensor of shape (batch, seq, 3, nheads, head_dim).
        key_padding_mask (torch.Tensor): Boolean mask of shape (batch, seq).
        causal (bool): Whether to apply causal masking.
        dropout_p (float): Dropout probability.
        softmax_scale (float, optional): Softmax scaling factor.
        deterministic (bool): Deterministic computation.

    Returns:
        torch.Tensor: Output tensor of shape (batch, seq, nheads, head_dim).
    """
    batch_size, seqlen, _, nheads, head_dim = qkv.shape
    x = rearrange(qkv, "b s three h d -> b s (three h d)")

    # Unpad input for FlashAttention, drop `used_seqlens_in_batch` for version compatibility
    x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)[:4]
    x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)

    output_unpad = flash_attn_varlen_qkvpacked_func(
        x_unpad,
        cu_seqlens,
        max_s,
        dropout_p,
        softmax_scale=softmax_scale,
        causal=causal,
        deterministic=deterministic,
    )
    if isinstance(output_unpad, tuple):
        output_unpad = output_unpad[0]

    # Pad output back to original shape
    output = pad_input(
        rearrange(output_unpad, "nnz h d -> nnz (h d)"),
        indices,
        batch_size,
        seqlen,
    )
    output = rearrange(output, "b s (h d) -> b s h d", h=nheads)
    return output