import torch from einops import rearrange try: from flash_attn_interface import flash_attn_varlen_func print("Using FlashAttention v3.") except ImportError: print("FlashAttention v3 not found, falling back to v2.") from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_qkvpacked_func from flash_attn.bert_padding import pad_input, unpad_input def get_cu_seqlens(text_mask: torch.Tensor, img_len: int): """ Compute cumulative sequence lengths (cu_seqlens) for FlashAttention. Args: text_mask (torch.Tensor): Boolean mask of shape (batch_size, text_seq_len). img_len (int): Length of image sequence. Returns: cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths for each segment. max_len (int): Maximum sequence length (text + image). """ batch_size = text_mask.shape[0] text_len = text_mask.sum(dim=1) max_len = text_mask.shape[1] + img_len cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) for i in range(batch_size): s = text_len[i] + img_len s1 = i * max_len + s s2 = (i + 1) * max_len cu_seqlens[2 * i + 1] = s1 cu_seqlens[2 * i + 2] = s2 return cu_seqlens, max_len def flash_attn_v3( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, max_s: int, causal: bool = False, deterministic: bool = False, ): """ FlashAttention v3 wrapper. Args: q, k, v (torch.Tensor): Query, key, value tensors of shape (batch, seq, nheads, head_dim). cu_seqlens (torch.Tensor): Cumulative sequence lengths. max_s (int): Maximum sequence length. causal (bool): Whether to apply causal masking. deterministic (bool): Deterministic computation. Returns: torch.Tensor: Output tensor of shape (batch, seq, nheads, head_dim). """ batch_size, seqlen = q.shape[:2] q = q.reshape(-1, *q.shape[2:]) k = k.reshape(-1, *k.shape[2:]) v = v.reshape(-1, *v.shape[2:]) output = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, causal=causal, deterministic=deterministic ) output = output.view(batch_size, seqlen, *output.shape[-2:]) return output def flash_attn_no_pad( qkv: torch.Tensor, key_padding_mask: torch.Tensor, causal: bool = False, dropout_p: float = 0.0, softmax_scale=None, deterministic: bool = False, ): """ FlashAttention for packed QKV input without padding. Args: qkv (torch.Tensor): Input tensor of shape (batch, seq, 3, nheads, head_dim). key_padding_mask (torch.Tensor): Boolean mask of shape (batch, seq). causal (bool): Whether to apply causal masking. dropout_p (float): Dropout probability. softmax_scale (float, optional): Softmax scaling factor. deterministic (bool): Deterministic computation. Returns: torch.Tensor: Output tensor of shape (batch, seq, nheads, head_dim). """ batch_size, seqlen, _, nheads, head_dim = qkv.shape x = rearrange(qkv, "b s three h d -> b s (three h d)") # Unpad input for FlashAttention, drop `used_seqlens_in_batch` for version compatibility x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)[:4] x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_seqlens, max_s, dropout_p, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic, ) if isinstance(output_unpad, tuple): output_unpad = output_unpad[0] # Pad output back to original shape output = pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen, ) output = rearrange(output, "b s (h d) -> b s h d", h=nheads) return output