0%

GPU与triton学习资料(结合flash attention的思想和实现,更有助于深入理解)

GPU权威学习资料

https://skindhu.github.io/How-To-Scale-Your-Model-CN/article-trans/gpus.html

triton权威学习资料

triton tutorials

用flash attention深入理解GPU和triton

https://github.com/hkproj/triton-flash-attention
youtube和bilibili也有视频讲解

补充讲解

Triton 基础:从 GPU 到 Triton 的一条主线理解

1️⃣ GPU 的核心执行模型(理解 Triton 的前提)

GPU 的计算模型可以概括为一句话:

GPU 会并行地运行大量“相同的小程序(kernel)”,
每个程序负责处理一小块数据。

这几个事实是理解 Triton 的根基:

  1. GPU 的并行单位不是“一次算一个大矩阵”,而是成千上万个小任务并行
  2. 每个小任务在硬件上对应一个 thread block(CUDA 里叫 CTA)
  3. 每个 thread block:

    • 有自己的寄存器 / shared memory
    • 不和别的 block 共享状态
  4. GPU 的性能瓶颈通常是:

    • 内存访问(慢)
    • 而不是浮点运算(快)

因此,高性能 GPU 程序的目标永远是:

让每个 block 把数据一次性搬进来,
在本地尽量多算几次,再写回。


2️⃣ Triton 的抽象层次:只让你写“block 级逻辑”

Triton 的设计思想是:

你只需要写“一个 block 要算什么”,
不需要写 thread / warp / shared memory。

因此:

  • 一个 @triton.jit 函数 ≈ 一个 GPU kernel
  • kernel 的一次并行实例 ≈ 一个 Triton program ≈ 一个 GPU block
  • Triton 自动:

    • 把 block 映射成 warps / threads
    • 决定寄存器 / shared memory 的使用

你写的不是“线程程序”,而是“tile 程序”


3️⃣ Grid:GPU 并行的全局结构

GPU 启动 kernel 时,需要告诉它:

要启动多少个 block?这些 block 怎么排?

在 Triton 中,这由 grid 决定:

1
kernel[grid](...)

如果:

1
grid = (X, Y, Z)

那么 GPU 会并行启动 X × Y × Z 个 Triton program。

每个 program 可以通过:

1
2
3
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)

知道自己在 grid 中的“坐标”。


4️⃣ Program 内部:用 block 索引切出自己的数据

一个 Triton program 的第一件事永远是:

根据自己的 program_id,算出“我该处理哪一块数据”。

典型模式是:

1
2
start = pid * BLOCK_SIZE
offs = start + tl.arange(0, BLOCK_SIZE)

在 GPU 语义上,这等价于:

“这个 block 负责第 pid 个 tile。”

这一步完全是索引计算,不涉及任何计算。


5️⃣ Triton 的内存访问:一次 load 一个 tile

在 GPU 上,逐元素 for-loop 访问内存是灾难性的。

Triton 的内存访问模型是:

一次 tl.load,读进一个完整的二维块。

例如:

1
2
3
tl.load(base
+ row_offsets[:, None] * stride_row
+ col_offsets[None, :] * stride_col)

这在语义上是:

  • 构造一个 二维地址矩阵
  • GPU 并行地把整块数据搬进寄存器

从 GPU 角度看:

这是一次 向量化、连续、可合并的 global memory 访问


6️⃣ 计算:block 级数学,而不是线程级指令

在 Triton 中:

1
C = tl.dot(A, B)

不是 Python 层的矩阵乘,而是:

  • 编译器识别为 block-level GEMM
  • 自动映射到:

    • Tensor Core(如果尺寸和 dtype 合适)
    • 或普通 FMA

你表达的是:

“这个 tile 做一次矩阵乘”

而不是:

“某个线程算第几个元素”。


7️⃣ Program 内的循环:为什么它是“安全的”

Triton 允许:

1
2
for i in range(num_steps):
...

这并不是展开并行,而是:

  • 单个 block 内的顺序循环
  • 常用于:

    • 遍历 K/V blocks
    • streaming reduction
    • softmax

GPU 视角是:

很多 block 并行跑,
每个 block 内部顺序推进。

这正是 FlashAttention 的计算结构。


8️⃣ 写回:每个 block 负责一块输出

Triton 的设计原则是:

一个输出元素只由一个 program 写。

因此常见模式是:

1
tl.store(ptr + offsets, block)

对应 GPU 语义:

  • 无原子操作
  • 无写冲突
  • 写回的是该 block 独占的 tile

9️⃣ 把 Triton 与 FlashAttention 串起来

FlashAttention 的核心是:

  1. 把巨大 attention 矩阵切成 tile
  2. 每个 tile:

    • load Q/K/V
    • 算 softmax
    • 累积 O / 梯度
  3. 从不存整张 attention matrix

这与 Triton 的模型完全同构

FlashAttention Triton
attention tile Triton program
streaming softmax program 内 for-loop
不存 P 用寄存器 / SRAM
Q/K/V block tl.load
matmul tl.dot

10️⃣ 最终一句话总结

Triton 是一种“以 GPU block 为基本单位”的编程模型:
你描述 block 的数学计算,
Triton 负责把它高效地跑在 GPU 上。