0%

FlashAttention v1 & v2 笔记

FlashAttention 笔记(FA1 & FA2)

本文记录 FlashAttention 的核心要点,重点用于 复习时快速回忆框架


重要参考文章

核心是理解原论文的伪代码
https://zhuanlan.zhihu.com/p/669926191
https://zhuanlan.zhihu.com/p/691067658
下面是一些要点总结

一、FlashAttention v1 的核心目标

FlashAttention v1 (2022) 的核心目标:

设计 IO-aware attention algorithm,减少 GPU memory IO

关键问题:

  • 标准 attention 需要构建 N × N attention matrix
  • GPU bottleneck 不是 FLOPs,而是 memory bandwidth
  • HBM ↔ SRAM 访问成本非常高

因此 FA1 的核心思想:

通过 tiling + online softmax 避免 materialize attention matrix


二、FlashAttention v1 的核心特点(复习框架)

看到下面这些关键词基本可以回忆出 FA1。

1 IO-aware algorithm

核心思想(maybe hallucination):

  • attention 计算瓶颈是 memory IO
  • 优化目标:
    减少 HBM ↔ SRAM 访问次数

FA1 论文重点:

1
2
3

minimize HBM reads/writes


2 不显式存储 attention matrix

传统 attention:

1
2
3
4

S = QK^T
softmax(S)

需要:

1
2
3

O(N²) memory

FA1:

1
2
3

blockwise compute attention

特点:

  • attention matrix 不落显存
  • 只在 SRAM 中短暂存在

3 Block tiling

计算方式:

1
2
3

Q block × KV block

流程:

1
2
3
4
5

for each KV block:
for each Q block:
compute partial attention

特点:

  • KV 按 block streaming
  • 避免一次性加载全部 KV

4 Online softmax

由于 attention matrix 不完整存在,需要:

1
2
3

online softmax

关键思想:

逐块维护:

1
2
3
4

running max
running sum

保证 softmax 数值稳定。


5 Memory complexity 从 O(N²) 降到 O(N)

标准 attention:

1
2
3

memory = O(N²)

FlashAttention:

1
2
3

memory = O(N)

原因:

  • 不 materialize attention matrix
  • 只保存最终 output

6 GPU kernel 特点

FA1 的 GPU kernel:

并行维度:

1
2
3

parallel over (batch, head)

一个 thread block:

1
2
3

负责一个 attention head

特点:

  • 简单并行结构
  • 但 GPU occupancy 不高

7 FA1 的主要贡献总结

FA1 解决的是:

1
2
3

attention 的 memory IO 问题

主要贡献:

  • IO-aware algorithm
  • online softmax
  • block tiling
  • O(N²) → O(N) memory
  • attention kernel 融合

三、FlashAttention v1 的局限

FA1 虽然很快,但仍有问题:

1 GPU 利用率不高

GPU 计算层级(理解这一节需要)

GPU 的执行结构大致如下:

1
2
3
4
5
6

GPU
└── SM (Streaming Multiprocessor)
└── warp (32 threads)
└── thread

关键点:

  • SM 是 GPU 的主要计算单元
  • 一个 GPU 通常有几十到上百个 SM
    例如 A100 ≈ 108 SM

线程调度关系:

1
2
3
4

thread block → 分配到某个 SM
warp → SM 上真正执行的调度单位

重要性质:

  • 一个 thread block 只会运行在一个 SM 上
  • 一个 SM 可以同时运行多个 block
  • 但如果 block 资源占用很大(shared memory / register),
    SM 可能只能放一个 block

FA1 的并行维度

FlashAttention v1 的 grid 通常是:

1
2
3

(batch, head)

也就是说:

1
2
3

1 thread block ≈ 1 attention head

block 内部线程负责:

1
2
3

一个 Q block × 所有 KV blocks

逐块计算 attention。

为什么会导致 GPU 利用率不高

假设:

1
2
3
4

batch = 2
heads = 8

那么:

1
2
3

total blocks = 16

而 GPU:

1
2
3

A100 ≈ 108 SM

结果:

1
2
3

只有 16 个 SM 有工作

剩余 SM 处于空闲。

同时 FlashAttention v1 的 kernel 使用了较多:

  • shared memory
  • registers

这可能导致:

1
2
3

1 SM 只能运行 1 个 block

因此:

1
2
3

1 SM ≈ 1 head

当 head / batch 数量较少时:

1
2
3

GPU occupancy 不足

这就是 FA1 中常说的:

并行粒度过粗,SM 利用率不高。


2 warp communication 较多

warp 是 GPU 的基本执行单位

GPU 线程以 warp 为单位执行

1
2
3

1 warp = 32 threads

SM 实际调度的不是 thread,而是:

1
2
3

warp

因此:

  • warp 内通信非常快
  • warp 之间通信成本较高

FA1 的 warp 分工方式

FlashAttention v1 的 kernel 使用的是:

1
2
3

split-K

也就是:

1
2
3

不同 warp 负责不同 KV block

示意:

1
2
3
4
5
6

warp0 → KV block 0
warp1 → KV block 1
warp2 → KV block 2
warp3 → KV block 3

每个 warp 计算:

1
2
3

Q × KV_block

得到部分 attention 结果。

为什么需要 reduce

attention 的输出是:

1
2
3

softmax(QKᵀ)V

由于 KV 被拆成多个 block:

每个 warp 只计算了 部分贡献

因此最后需要:

1
2
3

partial results → 合并

这就需要:

1
2
3

shared memory reduce

流程:

1
2
3
4
5
6
7
8

warp0 写 shared memory
warp1 写 shared memory
warp2 写 shared memory
warp3 写 shared memory

→ reduction

warp sync 的原因

在 reduction 之前需要保证:

1
2
3

所有 warp 都完成计算

因此需要:

1
2
3

warp sync

例如:

1
2
3

__syncthreads()

这些通信带来的问题

这些操作:

1
2
3
4
5

shared memory access
warp synchronization
reduction

都会带来:

1
2
3

额外 latency

而且这些操作:

1
2
3

不是 TensorCore matmul

无法充分利用 GPU 的高吞吐计算单元。

因此:

FA1 kernel 中有较多 warp communication overhead


为什么 FA2 会改善这个问题(顺带理解)

FlashAttention v2 改为:

1
2
3

split-Q

warp 分工:

1
2
3
4
5
6

warp0 → Q rows 0–15
warp1 → Q rows 16–31
warp2 → Q rows 32–47
warp3 → Q rows 48–63

每个 warp:

1
2
3

独立计算自己的 output

因此:

1
2
3

不需要跨 warp reduce

好处:

1
2
3
4

减少 shared memory
减少 warp sync

kernel 更接近:

1
2
3

GEMM-style pipeline

从而提高 GPU 利用率。


3 非 matmul 操作较多

例如:

  • normalization
  • softmax rescale

这些操作:

1
2
3

不是 TensorCore friendly

影响吞吐量。

四、FlashAttention v2 的改进

FlashAttention v2 (2023) 的核心目标:

提高 GPU FLOPs utilization,使 attention kernel 接近 GEMM 性能

FA2 不是新的 attention 算法,而是:

1
2
3

更好的 GPU work partition


五、FlashAttention v2 相比 FA1 的关键改进

1 增加 sequence 维度并行

FA1 并行维度:

1
2
3

(batch, head)

FA2:

1
2
3

(batch, head, seq_block)

效果:

  • 一个 head 可拆成多个 block
  • GPU occupancy 提高

2 改变 loop ordering

FA1:

1
2
3
4

for KV block
for Q block

FA2:

1
2
3
4

for Q block
for KV block

好处:

  • softmax row-wise 更自然
  • 数据局部性更好

3 warp partition 从 split-K 改为 split-Q

FA1:

1
2
3

warp -> KV blocks

需要:

1
2
3
4

reduction
synchronization

FA2:

1
2
3

warp -> Q rows

特点:

  • warp 输出独立
  • 不需要 reduce

4 减少 shared memory 使用

FA2 设计目标:

1
2
3

减少 warp communication

带来:

  • 更少 shared memory
  • 更少同步

5 减少非 matmul 操作

FA2 尽量让 kernel:

1
2
3

接近 GEMM pipeline

优化:

  • normalization 延迟
  • 减少 scalar ops

六、FA2 的效果

FA2 相比 FA1:

指标 FA1 FA2
GPU FLOPs 利用率 25–40% 50–73%
速度 baseline ≈2×
并行粒度 batch/head batch/head/seq

FA2 的 attention kernel:

1
2
3

性能接近 GEMM


七、为什么 2022 年没有直接提出 FA2?(下面回答未经验证)

核心原因:

FA1 和 FA2 解决的是不同层级的问题


1 FA1 解决的是算法层问题

2022 年 attention 的主要瓶颈是:

1
2
3

O(N²) memory

研究重点:

1
2
3

IO complexity

FA1 的贡献:

1
2
3

IO-aware attention algorithm


2 FA2 主要是 GPU kernel 优化

FA2 的改进本质是:

1
2
3
4
5

work partition
thread mapping
occupancy optimization

属于:

1
2
3

GPU kernel engineering

而不是新的 attention 算法。


3 FA1 本身已经是巨大突破

当时 baseline:

1
2
3

PyTorch attention

FA1:

1
2
3
4

2–4x speedup
10x memory reduction

已经是革命级改进。


4 FA2 是在大量实践后出现的优化

FA1 被广泛使用后发现:

1
2
3

GPU utilization 不高

进一步 profiling 后:

1
2
3

kernel work partition 可以优化

才提出 FA2。


八、一句话总结

FlashAttention v1:

1
2
3
4

IO-aware attention algorithm
目标:减少 memory IO

FlashAttention v2:

1
2
3
4

GPU kernel optimization
目标:提高 FLOPs utilization

两者解决的是:

1
2
3

不同层级的性能瓶颈