写在前面:为什么要“把原理讲透”
很多优化技巧你可能已经听过:FP16、INT8、LoRA、Checkpointing、FSDP、CPU Offload……
但真正落地时常见的问题是:
- “我知道它能省显存,但为什么?省在哪一块?”
- “为什么 INT8 有时反而更慢?”
- “LoRA 的 r、alpha、target_modules 到底怎么选?”
- “Checkpointing 到底 checkpoint 哪些层才划算?”
- “FSDP 开了就炸 OOM/变慢/不收敛,怎么定位?”
这篇文章会把每项技术拆成三部分:
1) 它解决的瓶颈是什么(显存/吞吐/稳定性)
2) 它的关键原理细节(读完能自己推导/解释)
3) 实践技巧与坑点(能直接照着做)
内容基于你提供的原文要点(显存增长原因、FP16、INT8-bitsandbytes、LoRA、Gradient Checkpointing、FSDP+CPU offload 等),并在这些点上做“原理级扩展与工程化落地” 。
0. 先建立统一的显存“账本”:显存到底花在哪里?
很多人一上来就调 dtype / 开 LoRA,但没有“账本”,永远是玄学。
0.1 训练阶段显存的三大头(原文给的框架)
训练一个 Transformer,显存主要由三类组成:
- 模型参数(Weights)
- 梯度(Gradients)
- 优化器状态(Optimizer states)
以 AdamW 为例,会存一阶/二阶动量(m/v),相当于额外多一大坨参数级别的内存
此外还有:
- 激活值(Activations):前向为了反向要存中间结果(这是 Checkpointing 主要省的)
- CUDA kernel & allocator 额外开销:原文示例提到 kernel 大约会占一部分固定显存
0.2 推理阶段显存“为什么越跑越大”
原文给了关键原因:长序列带来大量 Q/K/V,以及 解码时 KV cache 会持续累积 。
你可以把推理显存理解为:
- 模型权重(固定)
- KV Cache(随
batch_size * seq_len * num_layers线性增长) - 临时 buffer(attention、matmul、softmax 的 workspace)
结论:推理时最值得优化的通常是 KV cache,而不是权重。
1) FP16 / 混合精度(Mixed Precision):为什么能省显存又能加速?
原文给了混合精度的大思路:forward/梯度用 fp16,更新用 fp32。我们把细节讲透。
1.1 原理:训练为什么不能“全程 FP16”
FP16 的动态范围小,梯度容易:
- 下溢(underflow):变成 0
- 上溢(overflow):变成 inf/nan
所以常见做法是:
- 计算图(forward/backward)用 FP16/BF16 提速省显存
- 维护一份 FP32 master weights 用于稳定更新(权重和优化器用FP32)(这就是“混合精度”的本质)
1.2 Loss Scaling:混合精度稳定性的关键机制
为了避免 FP16 梯度下溢,训练时会做:
- 把 loss 乘一个 scale(比如 2^16)
- 反向得到更“大的”梯度
- 更新前再除回来(或在优化器里处理)
并且通常使用 dynamic loss scaling:一旦发现 overflow,就把 scale 降低。
1.3 FP16 vs BF16:工程上怎么选
- BF16 动态范围接近 FP32,更稳定(但 mantissa 少)
- FP16 在部分硬件上更快,但更依赖 loss scaling
经验:
- A100/H100 这类:优先 BF16(稳定、少折腾)
- 消费级显卡:FP16 + AMP 更常见
1.4 实战技巧
技巧 A:推理用 model.half() 并不总是最优
原文提到推理可以直接 model.half(),但注意:
- 某些层(LayerNorm、softmax)在 FP16 下更易数值问题
- 更安全的做法是:让框架自动做 autocast
技巧 B:别忘了把激活也纳入考虑
混合精度不仅减少权重/梯度内存,也会减少大量中间激活的存储。
2) INT8 / bitsandbytes 量化:为什么能省显存,但不保证更快?
原文给了两个重要事实:
- INT8 表达范围极端(-128~127)
- 实践经验:某些 HuggingFace INT8 推理可能明显变慢
我们解释“为什么”。
2.1 INT8 的核心矛盾:省内存 vs 额外开销
权重 INT8 的确能把权重显存降到 1/4(相对 FP32)。
但推理速度是否更快取决于两件事:
1) 算子是否真的走 int8 kernel
很多场景会出现“看似 int8,实际混合精度”的情况:
- 激活仍是 fp16/bf16
- 某些层退化为 fp16 matmul
- 量化/反量化开销抵消了收益
2) 硬件对 int8 的吞吐是否足够友好
不同 GPU 对 int8 的优化程度差很多。
2.2 bitsandbytes 为什么能“尽量不掉点”
原文提到 bitsandbytes 通过两种手段降低误差:
1) vector-wise quantization
2) mixed precision decomposition
把它理解成一句话:
它不会傻乎乎把所有通道都硬量化,而是对“离群值/outlier”更谨慎,必要时局部回退到更高精度。
这类策略的典型动机是:Transformer 的权重/激活在某些通道会出现极端值,直接 int8 会严重损失精度。
2.3 实战技巧:什么时候用 INT8,什么时候不要用
适合 INT8:
- 单卡显存紧张,首先要“跑起来”
- 你更关心 cost,而不是极限吞吐
- QPS 不高,但需要部署更大模型
不适合 INT8:
- 追求极限 tokens/s(尤其是小 batch、短上下文的场景)
- 你的服务端已经能放下 fp16/bf16 权重,此时瓶颈更多在 KV cache/attention
定位是否“真加速”的方法:
- 用 profiler 看 matmul kernel 是否是 int8 kernel
- 看 GPU utilization:如果低且 CPU 很忙,可能被量化/搬运开销拖慢
3) LoRA:低秩适配为什么成立?“低秩”到底是什么意思?
原文对 LoRA 的描述非常关键:微调时 update matrix 往往 sparse/低秩,于是把更新重参数化为两个低秩矩阵的乘积。
我们把数学直觉讲透。
3.1 从“全参微调”到“只学习 ΔW”
全参微调等价于学习:
- 新权重:
W' = W + ΔW
如果你直接学 ΔW,它是一个大矩阵(例如 attention 的投影层通常是 d×d),参数量巨大。
3.2 低秩假设:ΔW 其实不需要满秩
秩(rank)可以理解为:
这个矩阵能表达多少“独立方向”的变化。
LoRA 的假设是:
- 微调数据往往只需要在少数“方向”上调整模型行为
- 所以 ΔW 的有效秩远小于 d
于是用:
ΔW = A BA: d×r, B: r×d- r 很小(比如 4/8/16/32)
参数量从 d×d 变成 2×d×r,当 r << d 时,巨大节省。
3.3 为什么 LoRA 对显存特别友好?
训练显存里最贵的是:
- 权重梯度
- 优化器状态(AdamW 的 m/v)
LoRA 冻结了 W,只训练 A/B:
- 梯度规模缩小到原来的
~(2r/d) - Adam 状态也同等缩小
这就是为什么 LoRA 常常是“单卡微调 7B/13B”的入场券。
3.4 实战参数怎么选(能直接抄的经验)
(1) r(rank)怎么选
- 从 8 或 16 开始
- 任务很复杂/风格差异大 → 32
- 只是轻量指令对齐/格式对齐 → 4~8
(2) alpha(缩放)是什么
LoRA 通常会在前向里加一个缩放系数,让 ΔW 的幅度可控。
经验:alpha ≈ 2r 或 alpha = r 起步都行,主要看训练是否不稳定/过拟合。
(3) target_modules 选哪些层最划算
常见选择:
- attention 的 q_proj / k_proj / v_proj / o_proj
- 或者只上 q/v(更省、效果有时也够用)
(4) dropout
- 数据很少/容易过拟合:加一点(0.05~0.1)
- 数据足够:可以 0
(5) 合并与部署
部署时你可以:
- merge 权重(把 ΔW 融回 W,推理更简单)
- 或保持 adapter(多任务热插拔更灵活)
4) Gradient Checkpointing:它到底省的是什么?为什么会变慢?
原文给出:在 torch/hf 里可用 gradient checkpointing。但很多人不知道它省的是哪块内存。
4.1 核心原理:省的是 Activations(激活值)
反向传播要用到前向的一些中间值。默认策略是:
- 前向把大量激活存下来
- 反向直接用
Checkpointing 改成:
- 前向只保存“检查点”
- 反向时在区间内 重新跑一遍前向 来恢复中间激活
所以它的本质是:
用额外计算换显存。
4.2 为什么会变慢:多了一次(或多次)前向
慢多少取决于:
- checkpoint 的粒度(存得越少,重算越多)
- 模型结构(越深越吃亏)
4.3 实战技巧:checkpoint 粒度怎么选
一个非常实用的经验:
- 优先 checkpoint Transformer block 的内部(尤其是 attention + MLP 中间激活)
- 但不要 checkpoint 太碎,否则 overhead 太大
如果你只是为了把 batch size 提上去:
- 开 checkpointing 往往是最直接、最稳定的“救命开关”
5) FSDP / ZeRO:为什么能把显存“按卡数”切开?
原文提到:FSDP 类似 DeepSpeed,通过 ZeRO 等思想把参数/梯度/优化器状态分布到多卡,而不是每卡保留完整副本。
5.1 DDP 的问题:每张卡都有一份完整模型
DDP(Data Parallel)的经典模式:
- 每卡一份完整 weights
- 每卡一份完整 grads
- 每卡一份完整 optimizer states
这对大模型来说显存爆炸。
5.2 ZeRO/FSDP 的核心:把“三大头”都分片
以最理想的 fully shard 为例:
- weights 分片
- grads 分片
- optimizer states 分片
理论上显存可以接近按卡数线性下降。
5.3 训练时怎么还能算?—— all-gather 与 reduce-scatter
你会问:既然每卡只有一片参数,怎么做前向?
答案是运行时通信:
- 需要某层参数时:把该层参数 all-gather 到本卡
- 用完后:释放/丢回分片状态
- 梯度聚合:用 reduce-scatter 直接聚合成分片梯度
5.4 实战技巧:稳定性与速度取舍
原文提到某些 issue 中 shard_grad_op 模式可能更稳定 。工程上常见现象:
- fully_shard 更省显存,但通信更频繁,可能更慢/更复杂
- shard_grad_op(更像 ZeRO-2)在一些组合上更稳
建议:
- 第一次上 FSDP:先用更“温和”的分片策略跑通
- 再逐步加深分片程度
6) CPU Offload:为什么它能“救活”显存,但也可能让你哭
原文解释了 CPU offload:在一次反传中参数动态 GPU->CPU->GPU 转移来节省显存。
6.1 原理:把显存当“缓存”,把 CPU 内存当“主存”
这和操作系统很像:
- GPU 显存太小,放不下全模型/优化器状态
- 把不常用的部分放到 CPU RAM
- 需要时再搬上来
6.2 代价:PCIe/NVLink 带宽与延迟
Offload 的成败取决于:
- 互联带宽(PCIe vs NVLink)
- 你搬运的频率与粒度
- batch size / seq len 是否让 compute 足够“覆盖”通信
典型现象:
- 能跑起来,但 tokens/s 断崖式下跌
- GPU 利用率很低,因为在等数据搬运
6.3 实战建议:什么时候值得用
- 你真的“差一点显存”,否则模型根本跑不了
- 你愿意用速度换可行性
- 或者你在做离线训练,不追求极致吞吐
7) 推理阶段:KV Cache、吞吐与采样参数的“工程真相”
7.1 为什么推理显存持续增长:KV Cache 的账
原文已经点明:逐 token 解码要缓存 K/V。
更具体一点:
- 每一层都要为历史 token 存 K/V
- cache 大小约与:
layers * batch * seq_len * hidden成正比
所以推理优化的第一原则:
你能减少的不是“权重”,而是“cache 的增长速度”。
7.2 推理速度:CPU vs GPU
原文给出经验:
- CPU 推理约 10 token/s
- 单卡 GPU 相对 CPU 大约 10:1
工程上要注意:
- 小 batch 下 GPU 可能吃不满
- 增大 batch 可提升吞吐,但 KV cache 会更快爆显存
7.3 采样参数:为什么它也影响速度与显存
原文给出调参建议(top_p、num_beams、temperature、repetition_penalty 等)。
补充一个工程视角:
- beam search(num_beams>1)几乎必然更慢更吃显存
因为同时维护多个候选序列、cache 也会膨胀 - top_p / temperature 主要影响质量与分布,不会像 beam 那样显著放大计算图
- no_repeat_ngram_size 会引入额外约束检查,也会带来一些开销
如果你要服务端极致吞吐:
- 尽量避免 beam
- 让 decoding 更“单路径”(greedy 或轻采样)
8) 组合拳:不同场景怎么选方案(直接可用的决策表)
8.1 你要“单卡跑起来”(显存紧)
优先级:
- FP16/BF16(混合精度)
- LoRA(微调场景)
- Gradient Checkpointing(训练场景)
- INT8(权重放不下时)
- CPU offload(最后底牌)
8.2 你要“更快”(吞吐/QPS)
优先级:
- 先别量化:确认瓶颈是否在 KV cache / attention
- 减少 beam、优化 batch
- 用更合适精度(BF16/FP16)
- INT8 只有在 kernel 真正走 int8 且硬件支持强时才可能更快(否则只是省显存)
8.3 你要“多卡训练更大模型”
优先级:
- FSDP/ZeRO(分片参数/梯度/优化器)
- 必要时叠加 checkpointing
- 实在不行再叠 CPU offload
结语:把每个技术当作“显存账本上的一行”
- Mixed precision:减少 weights/grads/activations 的 dtype 成本
- INT8:主要减少 weights 成本,但可能引入额外计算/搬运
- LoRA:把训练参数从大矩阵变成小低秩矩阵,连同 optimizer states 一起缩小
- Checkpointing:省 activations,但变慢
- FSDP/ZeRO:把 weights/grads/optim states 按卡数切开,通信换显存
- CPU offload:用主存换显存,带宽/延迟换可行性
- 推理 KV cache:推理显存大头,理解它才能真正优化推理
如果你愿意,我可以继续把这篇文章升级到“更硬核”的版本:
- 给出每种方法的 显存公式估算(按参数量、层数、seq_len、batch 计算 KV cache/activation)
- 给出 HuggingFace/torch 的 最小可运行配置模板(FSDP、bnb int8、LoRA、checkpointing 的组合示例)
- 增补你原文没覆盖但推理非常关键的:FlashAttention、PagedAttention、Speculative Decoding(这些才是推理加速的主战场)
你更想先补哪一块:推理加速(tokens/s) 还是 训练显存(batch size 上去)?