FlashAttention 笔记(FA1 & FA2)
本文记录 FlashAttention 的核心要点,重点用于 复习时快速回忆框架。
重要参考文章
核心是理解原论文的伪代码
https://zhuanlan.zhihu.com/p/669926191
https://zhuanlan.zhihu.com/p/691067658
下面是一些要点总结
一、FlashAttention v1 的核心目标
FlashAttention v1 (2022) 的核心目标:
设计 IO-aware attention algorithm,减少 GPU memory IO
关键问题:
- 标准 attention 需要构建
N × Nattention matrix - GPU bottleneck 不是 FLOPs,而是 memory bandwidth
- HBM ↔ SRAM 访问成本非常高
因此 FA1 的核心思想:
通过 tiling + online softmax 避免 materialize attention matrix
二、FlashAttention v1 的核心特点(复习框架)
看到下面这些关键词基本可以回忆出 FA1。
1 IO-aware algorithm
核心思想(maybe hallucination):
- attention 计算瓶颈是 memory IO
- 优化目标:
减少 HBM ↔ SRAM 访问次数
FA1 论文重点:
1 |
|
2 不显式存储 attention matrix
传统 attention:
1 |
|
需要:
1 |
|
FA1:
1 |
|
特点:
- attention matrix 不落显存
- 只在 SRAM 中短暂存在
3 Block tiling
计算方式:
1 |
|
流程:
1 |
|
特点:
- KV 按 block streaming
- 避免一次性加载全部 KV
4 Online softmax
由于 attention matrix 不完整存在,需要:
1 |
|
关键思想:
逐块维护:
1 |
|
保证 softmax 数值稳定。
5 Memory complexity 从 O(N²) 降到 O(N)
标准 attention:
1 |
|
FlashAttention:
1 |
|
原因:
- 不 materialize attention matrix
- 只保存最终 output
6 GPU kernel 特点
FA1 的 GPU kernel:
并行维度:
1 |
|
一个 thread block:
1 |
|
特点:
- 简单并行结构
- 但 GPU occupancy 不高
7 FA1 的主要贡献总结
FA1 解决的是:
1 |
|
主要贡献:
- IO-aware algorithm
- online softmax
- block tiling
- O(N²) → O(N) memory
- attention kernel 融合
三、FlashAttention v1 的局限
FA1 虽然很快,但仍有问题:
1 GPU 利用率不高
GPU 计算层级(理解这一节需要)
GPU 的执行结构大致如下:
1 |
|
关键点:
- SM 是 GPU 的主要计算单元
- 一个 GPU 通常有几十到上百个 SM
例如 A100 ≈ 108 SM
线程调度关系:
1 |
|
重要性质:
- 一个 thread block 只会运行在一个 SM 上
- 一个 SM 可以同时运行多个 block
- 但如果 block 资源占用很大(shared memory / register),
SM 可能只能放一个 block
FA1 的并行维度
FlashAttention v1 的 grid 通常是:
1 |
|
也就是说:
1 |
|
block 内部线程负责:
1 |
|
逐块计算 attention。
为什么会导致 GPU 利用率不高
假设:
1 |
|
那么:
1 |
|
而 GPU:
1 |
|
结果:
1 |
|
剩余 SM 处于空闲。
同时 FlashAttention v1 的 kernel 使用了较多:
- shared memory
- registers
这可能导致:
1 |
|
因此:
1 |
|
当 head / batch 数量较少时:
1 |
|
这就是 FA1 中常说的:
并行粒度过粗,SM 利用率不高。
2 warp communication 较多
warp 是 GPU 的基本执行单位
GPU 线程以 warp 为单位执行:
1 |
|
SM 实际调度的不是 thread,而是:
1 |
|
因此:
- warp 内通信非常快
- warp 之间通信成本较高
FA1 的 warp 分工方式
FlashAttention v1 的 kernel 使用的是:
1 |
|
也就是:
1 |
|
示意:
1 |
|
每个 warp 计算:
1 |
|
得到部分 attention 结果。
为什么需要 reduce
attention 的输出是:
1 |
|
由于 KV 被拆成多个 block:
每个 warp 只计算了 部分贡献。
因此最后需要:
1 |
|
这就需要:
1 |
|
流程:
1 |
|
warp sync 的原因
在 reduction 之前需要保证:
1 |
|
因此需要:
1 |
|
例如:
1 |
|
这些通信带来的问题
这些操作:
1 |
|
都会带来:
1 |
|
而且这些操作:
1 |
|
无法充分利用 GPU 的高吞吐计算单元。
因此:
FA1 kernel 中有较多 warp communication overhead。
为什么 FA2 会改善这个问题(顺带理解)
FlashAttention v2 改为:
1 |
|
warp 分工:
1 |
|
每个 warp:
1 |
|
因此:
1 |
|
好处:
1 |
|
kernel 更接近:
1 |
|
从而提高 GPU 利用率。
3 非 matmul 操作较多
例如:
- normalization
- softmax rescale
这些操作:
1 |
|
影响吞吐量。
四、FlashAttention v2 的改进
FlashAttention v2 (2023) 的核心目标:
提高 GPU FLOPs utilization,使 attention kernel 接近 GEMM 性能
FA2 不是新的 attention 算法,而是:
1 |
|
五、FlashAttention v2 相比 FA1 的关键改进
1 增加 sequence 维度并行
FA1 并行维度:
1 |
|
FA2:
1 |
|
效果:
- 一个 head 可拆成多个 block
- GPU occupancy 提高
2 改变 loop ordering
FA1:
1 |
|
FA2:
1 |
|
好处:
- softmax row-wise 更自然
- 数据局部性更好
3 warp partition 从 split-K 改为 split-Q
FA1:
1 |
|
需要:
1 |
|
FA2:
1 |
|
特点:
- warp 输出独立
- 不需要 reduce
4 减少 shared memory 使用
FA2 设计目标:
1 |
|
带来:
- 更少 shared memory
- 更少同步
5 减少非 matmul 操作
FA2 尽量让 kernel:
1 |
|
优化:
- normalization 延迟
- 减少 scalar ops
六、FA2 的效果
FA2 相比 FA1:
| 指标 | FA1 | FA2 |
|---|---|---|
| GPU FLOPs 利用率 | 25–40% | 50–73% |
| 速度 | baseline | ≈2× |
| 并行粒度 | batch/head | batch/head/seq |
FA2 的 attention kernel:
1 |
|
七、为什么 2022 年没有直接提出 FA2?(下面回答未经验证)
核心原因:
FA1 和 FA2 解决的是不同层级的问题
1 FA1 解决的是算法层问题
2022 年 attention 的主要瓶颈是:
1 |
|
研究重点:
1 |
|
FA1 的贡献:
1 |
|
2 FA2 主要是 GPU kernel 优化
FA2 的改进本质是:
1 |
|
属于:
1 |
|
而不是新的 attention 算法。
3 FA1 本身已经是巨大突破
当时 baseline:
1 |
|
FA1:
1 |
|
已经是革命级改进。
4 FA2 是在大量实践后出现的优化
FA1 被广泛使用后发现:
1 |
|
进一步 profiling 后:
1 |
|
才提出 FA2。
八、一句话总结
FlashAttention v1:
1 |
|
FlashAttention v2:
1 |
|
两者解决的是:
1 |
|