GPU权威学习资料
https://skindhu.github.io/How-To-Scale-Your-Model-CN/article-trans/gpus.html
triton权威学习资料
triton tutorials
用flash attention深入理解GPU和triton
https://github.com/hkproj/triton-flash-attention
youtube和bilibili也有视频讲解
补充讲解
Triton 基础:从 GPU 到 Triton 的一条主线理解
1️⃣ GPU 的核心执行模型(理解 Triton 的前提)
GPU 的计算模型可以概括为一句话:
GPU 会并行地运行大量“相同的小程序(kernel)”,
每个程序负责处理一小块数据。
这几个事实是理解 Triton 的根基:
- GPU 的并行单位不是“一次算一个大矩阵”,而是成千上万个小任务并行。
- 每个小任务在硬件上对应一个 thread block(CUDA 里叫 CTA)。
每个 thread block:
- 有自己的寄存器 / shared memory
- 不和别的 block 共享状态
GPU 的性能瓶颈通常是:
- 内存访问(慢)
- 而不是浮点运算(快)
因此,高性能 GPU 程序的目标永远是:
让每个 block 把数据一次性搬进来,
在本地尽量多算几次,再写回。
2️⃣ Triton 的抽象层次:只让你写“block 级逻辑”
Triton 的设计思想是:
你只需要写“一个 block 要算什么”,
不需要写 thread / warp / shared memory。
因此:
- 一个
@triton.jit函数 ≈ 一个 GPU kernel - kernel 的一次并行实例 ≈ 一个 Triton program ≈ 一个 GPU block
Triton 自动:
- 把 block 映射成 warps / threads
- 决定寄存器 / shared memory 的使用
你写的不是“线程程序”,而是“tile 程序”。
3️⃣ Grid:GPU 并行的全局结构
GPU 启动 kernel 时,需要告诉它:
要启动多少个 block?这些 block 怎么排?
在 Triton 中,这由 grid 决定:
1 | kernel[grid](...) |
如果:
1 | grid = (X, Y, Z) |
那么 GPU 会并行启动 X × Y × Z 个 Triton program。
每个 program 可以通过:
1 | pid0 = tl.program_id(0) |
知道自己在 grid 中的“坐标”。
4️⃣ Program 内部:用 block 索引切出自己的数据
一个 Triton program 的第一件事永远是:
根据自己的
program_id,算出“我该处理哪一块数据”。
典型模式是:
1 | start = pid * BLOCK_SIZE |
在 GPU 语义上,这等价于:
“这个 block 负责第
pid个 tile。”
这一步完全是索引计算,不涉及任何计算。
5️⃣ Triton 的内存访问:一次 load 一个 tile
在 GPU 上,逐元素 for-loop 访问内存是灾难性的。
Triton 的内存访问模型是:
一次
tl.load,读进一个完整的二维块。
例如:
1 | tl.load(base |
这在语义上是:
- 构造一个 二维地址矩阵
- GPU 并行地把整块数据搬进寄存器
从 GPU 角度看:
这是一次 向量化、连续、可合并的 global memory 访问。
6️⃣ 计算:block 级数学,而不是线程级指令
在 Triton 中:
1 | C = tl.dot(A, B) |
不是 Python 层的矩阵乘,而是:
- 编译器识别为 block-level GEMM
自动映射到:
- Tensor Core(如果尺寸和 dtype 合适)
- 或普通 FMA
你表达的是:
“这个 tile 做一次矩阵乘”
而不是:
“某个线程算第几个元素”。
7️⃣ Program 内的循环:为什么它是“安全的”
Triton 允许:
1 | for i in range(num_steps): |
这并不是展开并行,而是:
- 单个 block 内的顺序循环
常用于:
- 遍历 K/V blocks
- streaming reduction
- softmax
GPU 视角是:
很多 block 并行跑,
每个 block 内部顺序推进。
这正是 FlashAttention 的计算结构。
8️⃣ 写回:每个 block 负责一块输出
Triton 的设计原则是:
一个输出元素只由一个 program 写。
因此常见模式是:
1 | tl.store(ptr + offsets, block) |
对应 GPU 语义:
- 无原子操作
- 无写冲突
- 写回的是该 block 独占的 tile
9️⃣ 把 Triton 与 FlashAttention 串起来
FlashAttention 的核心是:
- 把巨大 attention 矩阵切成 tile
每个 tile:
- load Q/K/V
- 算 softmax
- 累积 O / 梯度
- 从不存整张 attention matrix
这与 Triton 的模型完全同构:
| FlashAttention | Triton |
|---|---|
| attention tile | Triton program |
| streaming softmax | program 内 for-loop |
| 不存 P | 用寄存器 / SRAM |
| Q/K/V block | tl.load |
| matmul | tl.dot |
10️⃣ 最终一句话总结
Triton 是一种“以 GPU block 为基本单位”的编程模型:
你描述 block 的数学计算,
Triton 负责把它高效地跑在 GPU 上。