0%

LLM 训练/微调/推理加速与省内存技术:原理拆解 + 实战技巧(读完能真正用起来)

写在前面:为什么要“把原理讲透”

很多优化技巧你可能已经听过:FP16、INT8、LoRA、Checkpointing、FSDP、CPU Offload……
但真正落地时常见的问题是:

  • “我知道它能省显存,但为什么?省在哪一块?”
  • “为什么 INT8 有时反而更慢?”
  • “LoRA 的 r、alpha、target_modules 到底怎么选?”
  • “Checkpointing 到底 checkpoint 哪些层才划算?”
  • “FSDP 开了就炸 OOM/变慢/不收敛,怎么定位?”

这篇文章会把每项技术拆成三部分:
1) 它解决的瓶颈是什么(显存/吞吐/稳定性)
2) 它的关键原理细节(读完能自己推导/解释)
3) 实践技巧与坑点(能直接照着做)

内容基于你提供的原文要点(显存增长原因、FP16、INT8-bitsandbytes、LoRA、Gradient Checkpointing、FSDP+CPU offload 等),并在这些点上做“原理级扩展与工程化落地” 。


0. 先建立统一的显存“账本”:显存到底花在哪里?

很多人一上来就调 dtype / 开 LoRA,但没有“账本”,永远是玄学。

0.1 训练阶段显存的三大头(原文给的框架)

训练一个 Transformer,显存主要由三类组成:

  1. 模型参数(Weights)
  2. 梯度(Gradients)
  3. 优化器状态(Optimizer states)
    以 AdamW 为例,会存一阶/二阶动量(m/v),相当于额外多一大坨参数级别的内存

此外还有:

  • 激活值(Activations):前向为了反向要存中间结果(这是 Checkpointing 主要省的)
  • CUDA kernel & allocator 额外开销:原文示例提到 kernel 大约会占一部分固定显存

0.2 推理阶段显存“为什么越跑越大”

原文给了关键原因:长序列带来大量 Q/K/V,以及 解码时 KV cache 会持续累积

你可以把推理显存理解为:

  • 模型权重(固定)
  • KV Cache(随 batch_size * seq_len * num_layers 线性增长)
  • 临时 buffer(attention、matmul、softmax 的 workspace)

结论:推理时最值得优化的通常是 KV cache,而不是权重。


1) FP16 / 混合精度(Mixed Precision):为什么能省显存又能加速?

原文给了混合精度的大思路:forward/梯度用 fp16,更新用 fp32。我们把细节讲透。

1.1 原理:训练为什么不能“全程 FP16”

FP16 的动态范围小,梯度容易:

  • 下溢(underflow):变成 0
  • 上溢(overflow):变成 inf/nan

所以常见做法是:

  • 计算图(forward/backward)用 FP16/BF16 提速省显存
  • 维护一份 FP32 master weights 用于稳定更新(权重和优化器用FP32)(这就是“混合精度”的本质)

1.2 Loss Scaling:混合精度稳定性的关键机制

为了避免 FP16 梯度下溢,训练时会做:

  • 把 loss 乘一个 scale(比如 2^16)
  • 反向得到更“大的”梯度
  • 更新前再除回来(或在优化器里处理)

并且通常使用 dynamic loss scaling:一旦发现 overflow,就把 scale 降低。

1.3 FP16 vs BF16:工程上怎么选

  • BF16 动态范围接近 FP32,更稳定(但 mantissa 少)
  • FP16 在部分硬件上更快,但更依赖 loss scaling

经验:

  • A100/H100 这类:优先 BF16(稳定、少折腾)
  • 消费级显卡:FP16 + AMP 更常见

1.4 实战技巧

技巧 A:推理用 model.half() 并不总是最优
原文提到推理可以直接 model.half(),但注意:

  • 某些层(LayerNorm、softmax)在 FP16 下更易数值问题
  • 更安全的做法是:让框架自动做 autocast

技巧 B:别忘了把激活也纳入考虑
混合精度不仅减少权重/梯度内存,也会减少大量中间激活的存储。


2) INT8 / bitsandbytes 量化:为什么能省显存,但不保证更快?

原文给了两个重要事实:

  • INT8 表达范围极端(-128~127)
  • 实践经验:某些 HuggingFace INT8 推理可能明显变慢

我们解释“为什么”。

2.1 INT8 的核心矛盾:省内存 vs 额外开销

权重 INT8 的确能把权重显存降到 1/4(相对 FP32)。
但推理速度是否更快取决于两件事:

1) 算子是否真的走 int8 kernel
很多场景会出现“看似 int8,实际混合精度”的情况:

  • 激活仍是 fp16/bf16
  • 某些层退化为 fp16 matmul
  • 量化/反量化开销抵消了收益

2) 硬件对 int8 的吞吐是否足够友好
不同 GPU 对 int8 的优化程度差很多。

2.2 bitsandbytes 为什么能“尽量不掉点”

原文提到 bitsandbytes 通过两种手段降低误差:
1) vector-wise quantization
2) mixed precision decomposition

把它理解成一句话:

它不会傻乎乎把所有通道都硬量化,而是对“离群值/outlier”更谨慎,必要时局部回退到更高精度。

这类策略的典型动机是:Transformer 的权重/激活在某些通道会出现极端值,直接 int8 会严重损失精度。

2.3 实战技巧:什么时候用 INT8,什么时候不要用

适合 INT8:

  • 单卡显存紧张,首先要“跑起来”
  • 你更关心 cost,而不是极限吞吐
  • QPS 不高,但需要部署更大模型

不适合 INT8:

  • 追求极限 tokens/s(尤其是小 batch、短上下文的场景)
  • 你的服务端已经能放下 fp16/bf16 权重,此时瓶颈更多在 KV cache/attention

定位是否“真加速”的方法:

  • 用 profiler 看 matmul kernel 是否是 int8 kernel
  • 看 GPU utilization:如果低且 CPU 很忙,可能被量化/搬运开销拖慢

3) LoRA:低秩适配为什么成立?“低秩”到底是什么意思?

原文对 LoRA 的描述非常关键:微调时 update matrix 往往 sparse/低秩,于是把更新重参数化为两个低秩矩阵的乘积。

我们把数学直觉讲透。

3.1 从“全参微调”到“只学习 ΔW”

全参微调等价于学习:

  • 新权重:W' = W + ΔW

如果你直接学 ΔW,它是一个大矩阵(例如 attention 的投影层通常是 d×d),参数量巨大。

3.2 低秩假设:ΔW 其实不需要满秩

秩(rank)可以理解为:

这个矩阵能表达多少“独立方向”的变化。

LoRA 的假设是:

  • 微调数据往往只需要在少数“方向”上调整模型行为
  • 所以 ΔW 的有效秩远小于 d

于是用:

  • ΔW = A B
  • A: d×r, B: r×d
  • r 很小(比如 4/8/16/32)

参数量从 d×d 变成 2×d×r,当 r << d 时,巨大节省。

3.3 为什么 LoRA 对显存特别友好?

训练显存里最贵的是:

  • 权重梯度
  • 优化器状态(AdamW 的 m/v)

LoRA 冻结了 W,只训练 A/B:

  • 梯度规模缩小到原来的 ~(2r/d)
  • Adam 状态也同等缩小

这就是为什么 LoRA 常常是“单卡微调 7B/13B”的入场券。

3.4 实战参数怎么选(能直接抄的经验)

(1) r(rank)怎么选

  • 从 8 或 16 开始
  • 任务很复杂/风格差异大 → 32
  • 只是轻量指令对齐/格式对齐 → 4~8

(2) alpha(缩放)是什么
LoRA 通常会在前向里加一个缩放系数,让 ΔW 的幅度可控。
经验:alpha ≈ 2ralpha = r 起步都行,主要看训练是否不稳定/过拟合。

(3) target_modules 选哪些层最划算
常见选择:

  • attention 的 q_proj / k_proj / v_proj / o_proj
  • 或者只上 q/v(更省、效果有时也够用)

(4) dropout

  • 数据很少/容易过拟合:加一点(0.05~0.1)
  • 数据足够:可以 0

(5) 合并与部署
部署时你可以:

  • merge 权重(把 ΔW 融回 W,推理更简单)
  • 或保持 adapter(多任务热插拔更灵活)

4) Gradient Checkpointing:它到底省的是什么?为什么会变慢?

原文给出:在 torch/hf 里可用 gradient checkpointing。但很多人不知道它省的是哪块内存。

4.1 核心原理:省的是 Activations(激活值)

反向传播要用到前向的一些中间值。默认策略是:

  • 前向把大量激活存下来
  • 反向直接用

Checkpointing 改成:

  • 前向只保存“检查点”
  • 反向时在区间内 重新跑一遍前向 来恢复中间激活

所以它的本质是:

用额外计算换显存。

4.2 为什么会变慢:多了一次(或多次)前向

慢多少取决于:

  • checkpoint 的粒度(存得越少,重算越多)
  • 模型结构(越深越吃亏)

4.3 实战技巧:checkpoint 粒度怎么选

一个非常实用的经验:

  • 优先 checkpoint Transformer block 的内部(尤其是 attention + MLP 中间激活)
  • 但不要 checkpoint 太碎,否则 overhead 太大

如果你只是为了把 batch size 提上去:

  • 开 checkpointing 往往是最直接、最稳定的“救命开关”

5) FSDP / ZeRO:为什么能把显存“按卡数”切开?

原文提到:FSDP 类似 DeepSpeed,通过 ZeRO 等思想把参数/梯度/优化器状态分布到多卡,而不是每卡保留完整副本。

5.1 DDP 的问题:每张卡都有一份完整模型

DDP(Data Parallel)的经典模式:

  • 每卡一份完整 weights
  • 每卡一份完整 grads
  • 每卡一份完整 optimizer states

这对大模型来说显存爆炸。

5.2 ZeRO/FSDP 的核心:把“三大头”都分片

以最理想的 fully shard 为例:

  • weights 分片
  • grads 分片
  • optimizer states 分片

理论上显存可以接近按卡数线性下降。

5.3 训练时怎么还能算?—— all-gather 与 reduce-scatter

你会问:既然每卡只有一片参数,怎么做前向?

答案是运行时通信:

  • 需要某层参数时:把该层参数 all-gather 到本卡
  • 用完后:释放/丢回分片状态
  • 梯度聚合:用 reduce-scatter 直接聚合成分片梯度

5.4 实战技巧:稳定性与速度取舍

原文提到某些 issue 中 shard_grad_op 模式可能更稳定 。工程上常见现象:

  • fully_shard 更省显存,但通信更频繁,可能更慢/更复杂
  • shard_grad_op(更像 ZeRO-2)在一些组合上更稳

建议:

  • 第一次上 FSDP:先用更“温和”的分片策略跑通
  • 再逐步加深分片程度

6) CPU Offload:为什么它能“救活”显存,但也可能让你哭

原文解释了 CPU offload:在一次反传中参数动态 GPU->CPU->GPU 转移来节省显存。

6.1 原理:把显存当“缓存”,把 CPU 内存当“主存”

这和操作系统很像:

  • GPU 显存太小,放不下全模型/优化器状态
  • 把不常用的部分放到 CPU RAM
  • 需要时再搬上来

Offload 的成败取决于:

  • 互联带宽(PCIe vs NVLink)
  • 你搬运的频率与粒度
  • batch size / seq len 是否让 compute 足够“覆盖”通信

典型现象:

  • 能跑起来,但 tokens/s 断崖式下跌
  • GPU 利用率很低,因为在等数据搬运

6.3 实战建议:什么时候值得用

  • 你真的“差一点显存”,否则模型根本跑不了
  • 你愿意用速度换可行性
  • 或者你在做离线训练,不追求极致吞吐

7) 推理阶段:KV Cache、吞吐与采样参数的“工程真相”

7.1 为什么推理显存持续增长:KV Cache 的账

原文已经点明:逐 token 解码要缓存 K/V。

更具体一点:

  • 每一层都要为历史 token 存 K/V
  • cache 大小约与:layers * batch * seq_len * hidden 成正比

所以推理优化的第一原则:

你能减少的不是“权重”,而是“cache 的增长速度”。

7.2 推理速度:CPU vs GPU

原文给出经验:

  • CPU 推理约 10 token/s
  • 单卡 GPU 相对 CPU 大约 10:1

工程上要注意:

  • 小 batch 下 GPU 可能吃不满
  • 增大 batch 可提升吞吐,但 KV cache 会更快爆显存

7.3 采样参数:为什么它也影响速度与显存

原文给出调参建议(top_p、num_beams、temperature、repetition_penalty 等)。

补充一个工程视角:

  • beam search(num_beams>1)几乎必然更慢更吃显存
    因为同时维护多个候选序列、cache 也会膨胀
  • top_p / temperature 主要影响质量与分布,不会像 beam 那样显著放大计算图
  • no_repeat_ngram_size 会引入额外约束检查,也会带来一些开销

如果你要服务端极致吞吐:

  • 尽量避免 beam
  • 让 decoding 更“单路径”(greedy 或轻采样)

8) 组合拳:不同场景怎么选方案(直接可用的决策表)

8.1 你要“单卡跑起来”(显存紧)

优先级:

  1. FP16/BF16(混合精度)
  2. LoRA(微调场景)
  3. Gradient Checkpointing(训练场景)
  4. INT8(权重放不下时)
  5. CPU offload(最后底牌)

8.2 你要“更快”(吞吐/QPS)

优先级:

  1. 先别量化:确认瓶颈是否在 KV cache / attention
  2. 减少 beam、优化 batch
  3. 用更合适精度(BF16/FP16)
  4. INT8 只有在 kernel 真正走 int8 且硬件支持强时才可能更快(否则只是省显存)

8.3 你要“多卡训练更大模型”

优先级:

  1. FSDP/ZeRO(分片参数/梯度/优化器)
  2. 必要时叠加 checkpointing
  3. 实在不行再叠 CPU offload

结语:把每个技术当作“显存账本上的一行”

  • Mixed precision:减少 weights/grads/activations 的 dtype 成本
  • INT8:主要减少 weights 成本,但可能引入额外计算/搬运
  • LoRA:把训练参数从大矩阵变成小低秩矩阵,连同 optimizer states 一起缩小
  • Checkpointing:省 activations,但变慢
  • FSDP/ZeRO:把 weights/grads/optim states 按卡数切开,通信换显存
  • CPU offload:用主存换显存,带宽/延迟换可行性
  • 推理 KV cache:推理显存大头,理解它才能真正优化推理

如果你愿意,我可以继续把这篇文章升级到“更硬核”的版本:

  • 给出每种方法的 显存公式估算(按参数量、层数、seq_len、batch 计算 KV cache/activation)
  • 给出 HuggingFace/torch 的 最小可运行配置模板(FSDP、bnb int8、LoRA、checkpointing 的组合示例)
  • 增补你原文没覆盖但推理非常关键的:FlashAttention、PagedAttention、Speculative Decoding(这些才是推理加速的主战场)

你更想先补哪一块:推理加速(tokens/s) 还是 训练显存(batch size 上去)