kernel
flash-attn / README.md
drbh's picture
drbh HF Staff
fix: add warning for accuracy issue
3b8334d verified
|
raw
history blame
3.03 kB
metadata
license: bsd-3-clause
tags:
  - kernel

The latest build b58ed97 may contain an accuracy issue, which is currently being addressed. Please use with caution, and be aware that corrected outputs will be available soon.

Flash Attention

Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.

Original code here https://github.com/Dao-AILab/flash-attention.

# /// script
# dependencies = ["numpy", "torch", "kernels"]
# ///
import torch
from kernels import get_kernel

# Setup
torch.manual_seed(42)
flash_attn = get_kernel("kernels-community/flash-attn")
device = torch.device("cuda")

# Show available functions
print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])

# 1. Standard attention
print("\n1. Standard attention:")
B, S, H, D = 2, 5, 4, 8  # batch, seq_len, heads, head_dim
q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0]
print(f"Output: {out.shape}")

# 2. Variable length sequences
print("\n2. Variable length sequences:")
q_var = torch.randn(10, H, D, device=device, dtype=torch.float16)  # total_q=10
k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16)  # total_k=12
# For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k
cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32)
cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
out_var = flash_attn.mha_varlen_fwd(
    q=q_var,
    k=k_var,
    v=v_var,
    cu_seqlens_q=cu_q,
    cu_seqlens_k=cu_k,
    max_seqlen_q=4,
    max_seqlen_k=5,
)[0]
print(f"Output: {out_var.shape}")

# 3. KV-cache for autoregressive generation
print("\n3. KV-cache:")
cache_len, new_len = 10, 2
kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16)
q_new = k_new = v_new = torch.randn(
    B, new_len, H, D, device=device, dtype=torch.float16
)
seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32)
out_kv = flash_attn.mha_fwd_kvcache(
    q=q_new,
    kcache=kcache,
    vcache=vcache,
    k=k_new,
    v=v_new,
    seqlens_k=seqlens,
    is_causal=True,
)[0]
print(f"Output: {out_kv.shape}")

expected output

Fetching 3 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:00<00:00, 16384.00it/s]
Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']

1. Standard attention:
Output: torch.Size([2, 5, 4, 8])

2. Variable length sequences:
Output: torch.Size([10, 4, 8])

3. KV-cache:
Output: torch.Size([2, 2, 4, 8])