0%

DoRA详解:从LoRA到Weight-Decomposed Low-Rank Adaptation

DoRA(Weight-Decomposed Low-Rank Adaptation)是对LoRA的一种重要改进方法。它的核心思想是:将权重更新分解为“方向 + 幅度”两个部分分别建模,从而提升表达能力,同时保持参数高效性。

本文将围绕论文中的关键内容进行系统介绍。


1. 3.2 Weight Decomposition Analysis:Analysis Results

论文提出一个关键视角:将权重分解为 方向(direction)+ 幅度(magnitude)

其中:

  • $V$:方向(未归一化)
  • $\frac{V}{|V|}$:单位方向
  • $m$:幅度(norm)

Analysis Results(核心结论)

论文通过实验对比 LoRA 和 Full Fine-Tuning(FT)发现:

1️⃣ 幅度(Magnitude)变化

  • FT:幅度变化显著
  • LoRA:幅度变化非常有限

👉 说明 LoRA 无法有效调整权重的scale


2️⃣ 方向(Direction)变化

  • FT:方向变化更加多样
  • LoRA:方向变化较为单一(受低秩限制)

👉 LoRA 的更新空间被限制在低秩子空间


3️⃣ 本质结论

模型能力差异 ≠ 参数量
而是来源于 更新模式(update pattern)

即:

  • FT:同时自由调整 magnitude + direction
  • LoRA:受限地调整(尤其是 magnitude)

2. 4.1 Intuition

论文给出了两个关键直觉(原文核心思想):


Intuition 1

Firstly, the magnitude and direction of weights play different roles in model adaptation.

解释:

  • magnitude($m$):控制“强度”
  • direction($V$):控制“语义/表示方向”

👉 两者作用不同,不应该混在一起学习


Intuition 2

Secondly, LoRA mainly restricts the update of direction, while magnitude is under-explored.

解释:

  • LoRA 本质是在更新:

  • 但这种方式:

    • 对 direction 有限制(低秩)
    • 对 magnitude 几乎没有单独建模能力

👉 导致表达能力不足


3. 方法(DoRA)

DoRA 的核心思想:

解耦权重:方向用LoRA学习,幅度单独学习


参数化方式

DoRA 将权重写为:

其中:

  • $W_0$:预训练权重
  • $BA$:LoRA低秩更新
  • $m’$:可学习幅度

方法拆解

组件 学习方式
direction LoRA(低秩)
magnitude full parameter

对比总结

方法 magnitude direction
FT
LoRA ❌(弱) ✅(低秩)
DoRA ✅(低秩)

优势

  • 更接近 FT 的学习行为
  • 仍保持参数高效
  • 无推理额外开销(可merge)
  • 更稳定训练

4. 简单实现 + 张量运算讲解

我们来看一个简化实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class DoRALinear(nn.Module):
def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.alpha = alpha

# Base weight V (direction)
self.W = nn.Parameter(torch.empty(out_features, in_features))
nn.init.kaiming_uniform_(self.W)

# Magnitude vector (initialized from W's norm)
self.m = nn.Parameter(self.W.norm(dim=1))

# LoRA matrices
self.A = nn.Parameter(torch.empty(rank, in_features))
self.B = nn.Parameter(torch.zeros(out_features, rank))
nn.init.kaiming_uniform_(self.A)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute effective weight: m * (V + BA) / ||V + BA||
V = self.W
BA = self.B @ self.A * (self.alpha / self.rank) if self.rank > 0 else 0
W_eff = V + BA
W_eff = self.m.view(-1, 1) * W_eff / W_eff.norm(dim=1, keepdim=True)
return x @ W_eff.T

4.1 核心计算解析

1️⃣ LoRA部分

1
BA = self.B @ self.A
  • B: (out, r)
  • A: (r, in)
  • 结果:(out, in)

2️⃣ 方向更新

👉 表示新的“方向向量”


3️⃣ 归一化(关键)

1
W_eff.norm(dim=1, keepdim=True)
  • dim=1:对每一行(每个输出通道)求norm
  • keepdim=True:保持 shape = (out, 1)

4️⃣ 广播(Broadcasting)

1
self.m.view(-1, 1) * W_eff

shape:

  • m: (out,)
  • → view(-1,1) → (out,1)
  • W_eff: (out, in)

👉 广播规则:

  • (out,1) × (out,in)
  • 自动扩展为 (out,in)

广播成立条件

两个张量可以广播,当:

  1. 从右往左维度匹配
  2. 或其中一个为1

例如:

张量A 张量B 是否可广播
(out,1) (out,in)
(1,in) (out,in)
(out,) (out,in) ❌(需reshape)

5️⃣ 最终公式


6️⃣ 前向传播

1
return x @ W_eff.T
  • x: (batch, in)
  • W_eff.T: (in, out)
  • 输出: (batch, out)

5. 其它重要理解


5.1 本质:DoRA = LoRA + Weight Normalization

DoRA 非常类似:

👉 与 WeightNorm 的思想一致


5.2 为什么有效?

👉 关键点:

  • direction:决定“表示空间”
  • magnitude:决定“激活强度”

LoRA:

  • 两者耦合 → 学习受限

DoRA:

  • 解耦 → 表达能力提升

5.3 为什么只让 LoRA 控制 direction?

因为:

  • direction 是高维结构(更重要)
  • magnitude 是标量(容易学)

👉 用有限参数优先建模“复杂部分”


5.4 推理效率

DoRA 和 LoRA 一样:

  • 可 merge:
  • inference 无额外开销

6. 总结

DoRA 的核心贡献可以一句话概括:

通过“方向(LoRA)+ 幅度(全参数)”解耦,使参数高效微调更接近全量微调的表达能力。


关键 takeaway

  • LoRA 的瓶颈:更新模式受限
  • DoRA 的突破:结构性解耦
  • 本质思想:表示分解(representation decomposition)

如果你读到这里,其实已经抓住 DoRA 的精髓了。

如果你还想更进一步,可以继续问我:

👉「DoRA 为什么在某些任务上接近 FT?」
👉「DoRA 和 QLoRA 能不能结合?」

RoPE 推导详解:从向量旋转到复数表示

Rotary Positional Embedding(RoPE)是目前大语言模型中最常用的位置编码方式之一,被 GPT、LLaMA、ChatGLM 等模型广泛使用。

相比传统 绝对位置编码(Absolute Position Encoding),RoPE 的核心特点是:

在 Attention 中自然引入相对位置(Relative Position)。

本文将从 两个角度推导 RoPE

  1. 二维向量旋转(几何视角)
  2. 复数乘法(数学视角)

同时明确写出推导过程中使用的关键公式。


1 Transformer 中的位置问题

Transformer 的 Attention 计算为:

问题是:

Attention 本身不包含位置信息。

如果输入序列顺序改变,模型无法感知。

因此需要 Positional Encoding


2 RoPE 的核心思想

RoPE 的思想非常简单:

将 embedding 的每两个维度视为一个二维向量,并根据 token 位置对其进行旋转。

假设 embedding 向量:

RoPE 会把它拆成二维对:

每一对进行 二维旋转变换


3 二维向量旋转推导

3.1 二维旋转矩阵

在二维空间中,向量旋转角度 $\theta$ 的公式为:

展开可得:

这就是 RoPE 的基本变换公式


3.2 位置依赖的旋转角度

RoPE 的旋转角度依赖于 token 位置 $p$维度频率 $\theta_k$

其中:

参数说明:

符号 含义
$p$ token 位置
$k$ 维度 index
$d$ embedding 维度

因此旋转角为:


4 RoPE 在 Attention 中的效果

Attention 计算核心是:

如果对 Query 和 Key 同时进行旋转:

其中:

则内积为:

利用旋转矩阵性质:

可得:

关键结论:

Attention 只依赖于位置差 $(i-j)$

因此 RoPE 天然编码相对位置


5 复数视角推导

上面的推导基于 二维向量
但 RoPE 论文使用 复数形式,因为表达更简洁。


5.1 向量与复数的对应

定义:

则二维向量:

可以表示为复数:


5.2 欧拉公式

复数旋转使用 欧拉公式


5.3 复数旋转

二维旋转等价于:

展开:

计算得到:

因此:

实部:

虚部:

完全等价于二维旋转矩阵


6 RoPE 的复数形式(完整推导)

上一节说明了:

二维旋转可以表示为复数乘法:

RoPE 论文就是利用这个性质,将二维向量旋转写成复数形式。


6.1 向量与复数表示

对于 embedding 的一对维度:

定义复数表示:

因此 Query 与 Key 可以写成:


6.2 RoPE 的旋转

对于位置 (i) 的 Query:

对于位置 (j) 的 Key:

其中:


6.3 Attention 内积

Attention 的核心是:

在复数形式中,对应 复数内积

为了得到实数结果,需要使用 共轭复数

其中:

  • (\overline{k}) 为复数共轭
  • (\text{Re}(\cdot)) 表示取实部

6.4 代入 RoPE

首先计算:

代入旋转:

则:

因此:

整理得到:


6.5 取实部

Attention score 为:

因此:

关键点在于:

旋转只依赖于 ( \theta_i-\theta_j )


6.6 相对位置出现

由于:

因此:

所以最终 Attention 为:


6.7 关键结论

Attention score 只依赖:

即:

相对位置

而不是绝对位置。

因此 RoPE 天然实现 relative position encoding


6.8 总结

复数形式下 RoPE 的推导可以总结为:

1️⃣ 二维旋转

2️⃣ Query / Key 旋转

3️⃣ Attention 内积

4️⃣ 位置关系

最终得到:

Attention 只依赖 token 的相对位置。


7 为什么不同维度使用不同频率

RoPE 中:

这与 Transformer 原始位置编码一致。

原因是:

不同维度代表不同频率。

类似 傅里叶基底(Fourier Basis)

  • 低维 → 低频
  • 高维 → 高频

这样模型能够同时表示:

  • 长距离关系
  • 短距离关系

8 RoPE 总结

RoPE 的核心可以总结为一句话:

将 embedding 每两个维度看作二维向量,根据 token 位置进行旋转。

关键数学工具包括:

使用到的公式

1️⃣ 二维旋转矩阵

2️⃣ 欧拉公式

3️⃣ 复数旋转

4️⃣ RoPE 频率公式


9 直观理解

RoPE 可以用一句非常直观的话理解:

Token 的表示在 embedding 空间中随着位置不断旋转。

当两个 token 计算 attention 时:

它们的相对旋转角度决定注意力强度。

因此模型可以自然建模:

  • 相对距离
  • 顺序关系

pytorch实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def apply_rotary_pos_emb(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
# x: (batch, num_heads, seq_len, head_dim)
# pos: (batch, seq_len) position indices

batch, num_heads, seq_len, head_dim = x.shape

# Compute frequencies: theta_i = 1 / (10000^(2i/d))
freqs = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=x.device).float() / head_dim))

# pos: (batch, seq_len) -> (batch, seq_len, 1)
pos = pos.unsqueeze(-1).float()

# angles: (batch, seq_len, head_dim/2)
angles = pos * freqs.unsqueeze(0).unsqueeze(0)

# cos and sin
cos = torch.cos(angles) # (batch, seq_len, head_dim/2)
sin = torch.sin(angles)

# Reshape x for rotation: split into pairs
x1 = x[..., 0::2] # (batch, num_heads, seq_len, head_dim/2)
x2 = x[..., 1::2]

# Apply rotation
cos = cos.unsqueeze(1) # (batch, 1, seq_len, head_dim/2)
sin = sin.unsqueeze(1)

x_rotated_1 = x1 * cos - x2 * sin
x_rotated_2 = x1 * sin + x2 * cos

# Interleave back
x_out = torch.stack([x_rotated_1, x_rotated_2], dim=-1).flatten(-2)

return x_out

算法

trie, LCA, topo(DFS, BFS), qsort v.s. merge sort
二分答案

Cpp < 的重载

如果要自定义比较语义,只需要重载<(不需要重载<=, = 或者>)。对于lower_bound,重载<内部实现不需要<=;你可能会问,你不告诉它=的语义,它怎么知道?其实它是通过”=” —> not “<” and not “>”来判断的,而”>”是可以通过”<”的实现通过某种技巧推出来的。
例子:

1
2
3
4
// 返回值it是迭代器
// 注意第四个参数,函数的参数得是const + &
// 函数的第一个参数表示target
auto it = lower_bound(mono_stack.begin(), mono_stack.end(), pair<int, int>{height[i], i}, [](const pair<int, int> &a, const pair<int, int> &b) {return a.first < b.first;});

升序排序/大根堆时要用<。底层逻辑是:如果 a < b 为真,则 a 排在 b 前面。


precision

在算法竞赛中,输出精度(尤其是浮点数)往往是决定 Accepted 还是 Wrong Answer 的关键细节。C++ 和 Python 在处理逻辑上有很大差异,这里为你梳理了核心用法和避坑指南。


一、 C++ 精度处理

在 C++ 中,<iomanip> 头文件是你的好伙伴。

1. 核心控制语句

  • 固定小数位数cout << fixed << setprecision(n) << val;
    这是竞赛中最常用的方式,确保小数点后正好有 $n$ 位。
  • 有效数字cout << setprecision(n) << val;(不加 fixed
    控制总的有效数字位数,通常用于科学计数法要求。
  • printf 风格printf("%.nf", val);
    简单直接,在处理大规模数据输出时性能优于 cout

2. 常见坑点

  • 四舍五入的玄学printfsetprecision 在处理“刚好在 0.5”的情况时,并不总是稳定的“四舍五入”,有时会受编译器和浮点数底层表示(IEEE 754)的影响(即“舍入到最接近的偶数”)。
    • 对策:如果对舍入要求极其严格,建议手动加上一个极小值 eps(如 $10^{-9}$),即 printf("%.2f", val + 1e-9);
  • 输出 -0.00:当计算结果是一个极小的负数(如 -0.0000001)且保留两位小数输出时,可能会得到 -0.00
    • 对策:在输出前进行判断:if (abs(val) < eps) val = 0;

二、 Python 精度处理

Python 的 float 默认就是双精度(相当于 C++ 的 double),语法更简洁。

1. 核心控制语句

  • f-string (推荐)print(f"{val:.nf}")
  • format 函数print("{:.nf}".format(val))
  • % 运算符print("%.nf" % val)

2. 常见坑点

  • round() 函数的背叛:Python 的 round(x, n) 采用的是 “四舍六入五成双”(Round half to even)。

    例如:round(2.5) 得到 2round(3.5) 得到 4。这在很多竞赛题目要求严格四舍五入时会白给。

    • 对策:使用 Decimal 模块或手动处理:int(val * 10**n + 0.5) / 10**n
  • 浮点数运算误差:Python 虽然动态类型很爽,但浮点数依然存在 $0.1 + 0.2 \neq 0.3$ 的问题。

三、 综合对比与算法竞赛建议

特性 C++ (double/long double) Python (float)
精度上限 long double 可达 80/128 位 默认 64 位,需更高精度用 decimal 模块
默认舍入 趋向四舍五入(受环境影响) 四舍六入五成双
速度 极快 较慢

💡 竞赛实战建议:

  1. 统一精度:除非内存极其紧张,否则 C++ 一律用 double,高精度需求用 long double
  2. 避免过早舍入:在所有计算完成之前,不要进行取整或保留小数操作,减少累积误差。
  3. 比较浮点数:永远不要用 if (a == b),要用 if (abs(a - b) < eps),其中 eps 通常取 $10^{-8}$ 或 $10^{-9}$。
  4. 长整型转换:如果题目要求输出整数且涉及浮点运算,建议先 +eps 再强转 int,防止因为 0.99999999 被截断成 0

树状数组,线段树与分块

在算法竞赛(ACM/ICPC, NOI, CCPC)中,树状数组 (Binary Indexed Tree)线段树 (Segment Tree)分块 (Square Root Decomposition) 是处理区间问题的“三剑客”。

以下是这三种数据结构的详细对比:

数据结构特性对比表

维度 树状数组 (BIT) 线段树 (Segment Tree) 分块 (SQRT Decomposition)
实现难度 极低。核心代码仅需几行(lowbit)。 中等。需要递归构建、PushDown/PushUp 操作。 低到中。逻辑直观,但边界处理(散块)需细心。
适用范围 较窄。主要用于点修区间查,且操作需满足可减性(如和、异或)。 极广。几乎能处理所有区间问题(最值、GCD、复杂懒标记)。 万能。不仅能处理区间,还能处理非结合律问题或配合莫队算法。
时间复杂度 $O(\log n)$。常数极小,执行速度最快。 $O(\log n)$。常数较大(递归开销)。 $O(\sqrt{n})$。常数较小,但在 $10^5$ 以上数据量级可能超时。
空间复杂度 $O(n)$。仅需一个原数组大小的额外空间。 $O(4n)$。通常需要开 4 倍空间的数组。 $O(n)$。仅需记录块信息,空间开销小。
常用程度 极高。基础区间问题的首选方案。 极高。竞赛中区间问题的“标准答案”。 。作为暴力优化的最后手段或处理黑科技题目。
扩展性 较差。虽然能改区间修改,但逻辑复杂。 极强。可扩展为动态开点、持久化(主席树)。 。对数据分布和操作类型几乎没有限制。

算法竞赛中的“避坑”与选择建议

1. 什么时候选树状数组?

  • 追求速度:当题目时限非常紧(如 $10^6$ 数据量),树状数组的常数优势巨大。
  • 求逆序对:这是树状数组最经典的用法。
  • 维度增加:在处理二维或三维空间问题时,树状数组的实现远比多维线段树简单。

2. 什么时候选线段树?

  • 区间修改 + 区间查询:这是线段树的统治区,通过 Lazy Tag(懒标记)可以高效处理。
  • 非可减性维护:比如维护区间的最大值(RMQ),由于最大值不满足可减性,树状数组处理起来很麻烦,而线段树游刃有余。
  • 复杂信息合并:如维护区间最长连续上升子序列,线段树的节点合并逻辑非常清晰。

3. 什么时候选分块?

  • “强制在线”且逻辑复杂:有些操作线段树难以通过标记下传实现(如区间内大于 $x$ 的数有多少个),此时分块配合块内排序或桶可以解决。
  • 时间换空间:如果内存限制极其严格(如 64MB),线段树的 4 倍空间可能炸内存,分块更节省。
  • 作为保底方案:在考场上如果想不出线段树的标记合并逻辑,写一个 $O(n\sqrt{n})$ 的分块通常能拿到 70%~100% 的分数。

总结建议

  • 入门阶段:先练熟树状数组,再死磕线段树。
  • 实战策略:能用树状数组不用线段树(快且好写);能用线段树不用分块(稳且不被卡常);如果题目要求太奇葩,果断上分块或莫队。

Tarjan

Tarjan 算法的核心在于通过一次 DFS 利用栈和两个关键数组 $dfn$ 和 $low$ 来挖掘图的结构特征。

1. 核心定义:$dfn$ 与 $low$

这是 Tarjan 算法的灵魂,理解了这两个数组,算法就理解了一半:

  • $dfn[u]$时间戳。节点 $u$ 在 DFS 过程中被访问的先后顺序(从 1 开始递增)。
  • $low[u]$追溯值。节点 $u$ 通过树枝边特定的反向边能够回溯到的 $dfn$ 最小的节点编号。
    • 它代表了节点 $u$ 所在的连通结构的“最高点”。

2. 三个 Tarjan 算法对比总结

A. 有向图:强连通分量 (SCC)

  • 目的:找极大的子图,使得其中任意两点可互相到达。
  • $low$ 更新时机
    1. 遇到未访问节点 $v$:递归后,$low[u] = \min(low[u], low[v])$。
    2. 遇到已访问且在栈中的节点 $v$:$low[u] = \min(low[u], dfn[v])$。
  • 判断时机
    • 当 DFS 回溯时,如果 $dfn[u] == low[u]$,说明 $u$ 是该 SCC 的“根”。此时将栈中 $u$ 及其上方的节点全部弹出,这些节点构成一个 SCC。

B. 无向图:割点与点双连通分量 (v-BCC)

  • 目的:找删去后会使原图不连通的点。
  • $low$ 更新时机
    1. 遇到未访问节点 $v$:递归后,$low[u] = \min(low[u], low[v])$。
    2. 遇到已访问且不是父节点的节点 $v$:$low[u] = \min(low[u], dfn[v])$。
  • 判断条件
    • 非根节点 $u$:存在一个子节点 $v$,使得 $low[v] \ge dfn[u]$(说明 $v$ 无法绕过 $u$ 回到更高处)。
    • 根节点 $u$:在 DFS 树上有两个或更多独立的子树。

C. 无向图:割边(桥)与边双连通分量 (e-BCC)

  • 目的:找删去后会使原图不连通的边。
  • $low$ 更新时机:与割点基本一致,但必须严格限制不通过当前输入的这条边回到父节点(处理重边时需记录边编号)。
  • 判断条件
    • 对于边 $(u, v)$,如果 $low[v] > dfn[u]$,则该边为桥(说明从 $v$ 出发无论如何也回不到 $u$ 或 $u$ 以上的点)。

3. 实现中的关键细节

细节 有向图 (SCC) 无向图 (割点/桥)
栈的作用 记录当前可能属于同一个 SCC 的节点。 割点通常不强制用栈,点双/边双才需要栈存边或点。
父节点回溯 不需要考虑父节点。 必须跳过父节点(或进来的那条边),否则 $low$ 永远等于 $dfn$。
更新逻辑 只有在栈中的点才能更新 $low$。 只要不是父节点就能更新 $low$(表示存在反向边)。
判等逻辑 $dfn == low$ 是判定 SCC 结束的标志。 $low[v] \ge dfn[u]$ 是判定 $u$ 是割点的标志。

🛠️ 避坑指南:

  1. 重边问题:在求无向图桥时,不能简单判断 v != fa,而要记录进入 $u$ 的边 ID,防止通过重边直接回到父节点导致 $low$ 更新错误。
  2. 根节点特判:求割点时,DFS 的起点(根)不能用 $low[v] \ge dfn[u]$ 判断,必须看它是否有两个以上的子树。
  3. 多连通图:如果图不连通,需要对每个未访问的点跑一遍 Tarjan。

python自定义排序

在 Python 中,自定义排序的核心逻辑是将“复杂的对象”映射为“可比较的键”。以下是算法竞赛和日常开发中最常用的三种方式:


1. 使用 key=lambda(最常用)

这是最快捷的方法,适用于大多数基础的多维列表或对象排序。

  • 基本语法list.sort(key=lambda x: x[0])
  • 多关键字排序:如果想先按第一列升序,第一列相同时按第二列降序,可以利用元组:
    1
    2
    3
    4
    5
    # 示例:学生列表 (名字, 分数, 年龄)
    # 要求:分数降序 (-x[1]),年龄升序 (x[2])
    data = [("Alice", 90, 20), ("Bob", 90, 18), ("Charlie", 85, 19)]
    data.sort(key=lambda x: (-x[1], x[2]))
    # 结果: [('Bob', 90, 18), ('Alice', 90, 20), ('Charlie', 85, 19)]

    注意:负号 - 仅适用于数值类型。如果是字符串想降序,通常需要后面提到的方法。


2. 使用 functools.cmp_to_key(最强大)

当你无法简单地通过一个 key 映射来决定顺序(例如:排序规则取决于两个元素之间的逻辑关系)时,这个方法是救星。它模拟了 C++ 中 bool cmp(a, b) 的逻辑。

  • 逻辑:返回负数表示 a 在前,正数表示 b 在前,0 表示相等。
  • 典型场景最大数问题(给定一组数字,拼成一个最大的数)。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    from functools import cmp_to_key

    def my_cmp(a, b):
    if str(a) + str(b) > str(b) + str(a):
    return -1 # a 应该排在前面
    else:
    return 1

    nums = [3, 30, 34, 5, 9]
    nums.sort(key=cmp_to_key(my_cmp))
    # 结果: [9, 5, 34, 3, 30] -> 拼接后最大

3. 重载类的比较运算符(工程常用)

如果你定义了一个类,并希望它在任何地方(sort, min, heapq)都能自动排序,可以直接在类内部定义 __lt__ (less than)。

  • 实现细节:Python 的排序只要求实现 __lt__
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    class Node:
    def __init__(self, x, y):
    self.x, self.y = x, y

    def __lt__(self, other):
    # 逻辑:x 越小越靠前;x 相等则 y 越大越靠前
    if self.x != other.x:
    return self.x < other.x
    return self.y > other.y

    nodes = [Node(1, 2), Node(1, 5), Node(0, 3)]
    nodes.sort() # 自动调用 __lt__

💡 竞赛避坑指南

  1. 稳定性 (Stability):Python 的 sort (Timsort) 是稳定的。这意味着如果两个元素 key 相同,它们的相对顺序会保持不变。你可以利用这一点进行多次排序(先排次要关键字,再排主要关键字)。
  2. 性能key 函数在每个元素上只调用一次,而 cmp_to_key 会在每次比较时调用。在大数据量下,优先使用 lambda
  3. 原地修改 vs 返回新列表
    • list.sort():原地修改,无返回值,省内存。
    • sorted(list):返回新列表,原列表不变。

在推导深度学习模型反向传播时,矩阵求导(Matrix Calculus)是最常见的数学工具之一。

很多复杂的梯度推导,本质上都依赖一些非常基础的矩阵微分公式。本文总结几个深度学习中最常见的公式,并通过 Attention 的反向传播展示这些公式如何实际使用。


一、常用矩阵求导公式

1 Frobenius 内积

含义

矩阵的 Frobenius 内积定义为:

等价写法:

为什么重要

在深度学习中:

梯度通常写成 Frobenius inner product 形式。

例如 loss 的微分可以写为

这样可以统一标量对矩阵的求导表达。


2 标量函数的矩阵微分

含义

如果

其中

  • $L$ 是标量
  • $X$ 是矩阵

那么微分展开为

用 Frobenius 内积表示就是

上式也可以自然推广到 多元变量的情况。如果损失函数依赖于多个矩阵变量,例如

那么它的全微分可以写为

也就是

这实际上就是多元微积分中全微分公式

在矩阵情形下的推广,其中内积由普通乘法推广为 Frobenius 内积。在深度学习的反向传播推导中,我们通常通过观察 $dL$ 中与某个 $dX_i$ 对应的项,直接读出对应变量的梯度 $\frac{\partial L}{\partial X_i}$。

作用

这是 反向传播链式法则的核心表达方式


3 Jacobian 线性近似

含义

如果

其中

则 Jacobian 定义为

于是微分关系为

在深度学习中的意义

在反向传播中:


4 矩阵乘法求导

含义

这是矩阵版本的 乘法求导法则(Product Rule)

对应标量:

使用场景

神经网络中最常见形式:


5 Trace 循环性质

含义

Trace 具有 循环不变性(cyclic property)

例如

但需要注意:

  • 只能循环
  • 不能改变矩阵顺序

为什么重要

在梯度推导中,经常需要将

变换为

这实际上利用了 trace 的循环性质。


二 Attention 反向传播例子

下面通过 Transformer Attention 展示这些公式如何使用。


1 Attention 前向传播

标准 Attention 计算:

其中


2 输出层梯度

假设

已知


Step 1 对 V 求导

因为

根据矩阵乘法求导

只看 $dV$ 项:

loss 微分:

代入

利用 trace 循环性质

得到


Step 2 对 P 求导

同样

loss 微分

使用 trace 循环

所以


Step 3 对 S 求导

因为

使用 Jacobian:

因此


Step 4 对 Q 和 K 求导

因为

根据矩阵乘法求导:


对 Q 求导

代入

使用 trace 循环

因此


对 K 求导

因为

得到


三 Attention 反向传播总结

最终梯度结果:


四 总结

深度学习中的复杂梯度推导,其实大量依赖几个简单规则:

核心四件套:

  1. Frobenius 内积
  2. trace 循环性质
  3. 矩阵乘法求导
  4. Jacobian

只要熟练掌握这些公式,像

  • Attention
  • Transformer
  • LayerNorm
  • BatchNorm

等模块的梯度推导都会变得非常清晰。


FlashAttention 笔记(FA1 & FA2)

本文记录 FlashAttention 的核心要点,重点用于 复习时快速回忆框架


重要参考文章

核心是理解原论文的伪代码
https://zhuanlan.zhihu.com/p/669926191
https://zhuanlan.zhihu.com/p/691067658
下面是一些要点总结

一、FlashAttention v1 的核心目标

FlashAttention v1 (2022) 的核心目标:

设计 IO-aware attention algorithm,减少 GPU memory IO

关键问题:

  • 标准 attention 需要构建 N × N attention matrix
  • GPU bottleneck 不是 FLOPs,而是 memory bandwidth
  • HBM ↔ SRAM 访问成本非常高

因此 FA1 的核心思想:

通过 tiling + online softmax 避免 materialize attention matrix


二、FlashAttention v1 的核心特点(复习框架)

看到下面这些关键词基本可以回忆出 FA1。

1 IO-aware algorithm

核心思想(maybe hallucination):

  • attention 计算瓶颈是 memory IO
  • 优化目标:
    减少 HBM ↔ SRAM 访问次数

FA1 论文重点:

1
2
3

minimize HBM reads/writes


2 不显式存储 attention matrix

传统 attention:

1
2
3
4

S = QK^T
softmax(S)

需要:

1
2
3

O(N²) memory

FA1:

1
2
3

blockwise compute attention

特点:

  • attention matrix 不落显存
  • 只在 SRAM 中短暂存在

3 Block tiling

计算方式:

1
2
3

Q block × KV block

流程:

1
2
3
4
5

for each KV block:
for each Q block:
compute partial attention

特点:

  • KV 按 block streaming
  • 避免一次性加载全部 KV

4 Online softmax

由于 attention matrix 不完整存在,需要:

1
2
3

online softmax

关键思想:

逐块维护:

1
2
3
4

running max
running sum

保证 softmax 数值稳定。


5 Memory complexity 从 O(N²) 降到 O(N)

标准 attention:

1
2
3

memory = O(N²)

FlashAttention:

1
2
3

memory = O(N)

原因:

  • 不 materialize attention matrix
  • 只保存最终 output

6 GPU kernel 特点

FA1 的 GPU kernel:

并行维度:

1
2
3

parallel over (batch, head)

一个 thread block:

1
2
3

负责一个 attention head

特点:

  • 简单并行结构
  • 但 GPU occupancy 不高

7 FA1 的主要贡献总结

FA1 解决的是:

1
2
3

attention 的 memory IO 问题

主要贡献:

  • IO-aware algorithm
  • online softmax
  • block tiling
  • O(N²) → O(N) memory
  • attention kernel 融合

三、FlashAttention v1 的局限

FA1 虽然很快,但仍有问题:

1 GPU 利用率不高

GPU 计算层级(理解这一节需要)

GPU 的执行结构大致如下:

1
2
3
4
5
6

GPU
└── SM (Streaming Multiprocessor)
└── warp (32 threads)
└── thread

关键点:

  • SM 是 GPU 的主要计算单元
  • 一个 GPU 通常有几十到上百个 SM
    例如 A100 ≈ 108 SM

线程调度关系:

1
2
3
4

thread block → 分配到某个 SM
warp → SM 上真正执行的调度单位

重要性质:

  • 一个 thread block 只会运行在一个 SM 上
  • 一个 SM 可以同时运行多个 block
  • 但如果 block 资源占用很大(shared memory / register),
    SM 可能只能放一个 block

FA1 的并行维度

FlashAttention v1 的 grid 通常是:

1
2
3

(batch, head)

也就是说:

1
2
3

1 thread block ≈ 1 attention head

block 内部线程负责:

1
2
3

一个 Q block × 所有 KV blocks

逐块计算 attention。

为什么会导致 GPU 利用率不高

假设:

1
2
3
4

batch = 2
heads = 8

那么:

1
2
3

total blocks = 16

而 GPU:

1
2
3

A100 ≈ 108 SM

结果:

1
2
3

只有 16 个 SM 有工作

剩余 SM 处于空闲。

同时 FlashAttention v1 的 kernel 使用了较多:

  • shared memory
  • registers

这可能导致:

1
2
3

1 SM 只能运行 1 个 block

因此:

1
2
3

1 SM ≈ 1 head

当 head / batch 数量较少时:

1
2
3

GPU occupancy 不足

这就是 FA1 中常说的:

并行粒度过粗,SM 利用率不高。


2 warp communication 较多

warp 是 GPU 的基本执行单位

GPU 线程以 warp 为单位执行

1
2
3

1 warp = 32 threads

SM 实际调度的不是 thread,而是:

1
2
3

warp

因此:

  • warp 内通信非常快
  • warp 之间通信成本较高

FA1 的 warp 分工方式

FlashAttention v1 的 kernel 使用的是:

1
2
3

split-K

也就是:

1
2
3

不同 warp 负责不同 KV block

示意:

1
2
3
4
5
6

warp0 → KV block 0
warp1 → KV block 1
warp2 → KV block 2
warp3 → KV block 3

每个 warp 计算:

1
2
3

Q × KV_block

得到部分 attention 结果。

为什么需要 reduce

attention 的输出是:

1
2
3

softmax(QKᵀ)V

由于 KV 被拆成多个 block:

每个 warp 只计算了 部分贡献

因此最后需要:

1
2
3

partial results → 合并

这就需要:

1
2
3

shared memory reduce

流程:

1
2
3
4
5
6
7
8

warp0 写 shared memory
warp1 写 shared memory
warp2 写 shared memory
warp3 写 shared memory

→ reduction

warp sync 的原因

在 reduction 之前需要保证:

1
2
3

所有 warp 都完成计算

因此需要:

1
2
3

warp sync

例如:

1
2
3

__syncthreads()

这些通信带来的问题

这些操作:

1
2
3
4
5

shared memory access
warp synchronization
reduction

都会带来:

1
2
3

额外 latency

而且这些操作:

1
2
3

不是 TensorCore matmul

无法充分利用 GPU 的高吞吐计算单元。

因此:

FA1 kernel 中有较多 warp communication overhead


为什么 FA2 会改善这个问题(顺带理解)

FlashAttention v2 改为:

1
2
3

split-Q

warp 分工:

1
2
3
4
5
6

warp0 → Q rows 0–15
warp1 → Q rows 16–31
warp2 → Q rows 32–47
warp3 → Q rows 48–63

每个 warp:

1
2
3

独立计算自己的 output

因此:

1
2
3

不需要跨 warp reduce

好处:

1
2
3
4

减少 shared memory
减少 warp sync

kernel 更接近:

1
2
3

GEMM-style pipeline

从而提高 GPU 利用率。


3 非 matmul 操作较多

例如:

  • normalization
  • softmax rescale

这些操作:

1
2
3

不是 TensorCore friendly

影响吞吐量。

四、FlashAttention v2 的改进

FlashAttention v2 (2023) 的核心目标:

提高 GPU FLOPs utilization,使 attention kernel 接近 GEMM 性能

FA2 不是新的 attention 算法,而是:

1
2
3

更好的 GPU work partition


五、FlashAttention v2 相比 FA1 的关键改进

1 增加 sequence 维度并行

FA1 并行维度:

1
2
3

(batch, head)

FA2:

1
2
3

(batch, head, seq_block)

效果:

  • 一个 head 可拆成多个 block
  • GPU occupancy 提高

2 改变 loop ordering

FA1:

1
2
3
4

for KV block
for Q block

FA2:

1
2
3
4

for Q block
for KV block

好处:

  • softmax row-wise 更自然
  • 数据局部性更好

3 warp partition 从 split-K 改为 split-Q

FA1:

1
2
3

warp -> KV blocks

需要:

1
2
3
4

reduction
synchronization

FA2:

1
2
3

warp -> Q rows

特点:

  • warp 输出独立
  • 不需要 reduce

4 减少 shared memory 使用

FA2 设计目标:

1
2
3

减少 warp communication

带来:

  • 更少 shared memory
  • 更少同步

5 减少非 matmul 操作

FA2 尽量让 kernel:

1
2
3

接近 GEMM pipeline

优化:

  • normalization 延迟
  • 减少 scalar ops

六、FA2 的效果

FA2 相比 FA1:

指标 FA1 FA2
GPU FLOPs 利用率 25–40% 50–73%
速度 baseline ≈2×
并行粒度 batch/head batch/head/seq

FA2 的 attention kernel:

1
2
3

性能接近 GEMM


七、为什么 2022 年没有直接提出 FA2?(下面回答未经验证)

核心原因:

FA1 和 FA2 解决的是不同层级的问题


1 FA1 解决的是算法层问题

2022 年 attention 的主要瓶颈是:

1
2
3

O(N²) memory

研究重点:

1
2
3

IO complexity

FA1 的贡献:

1
2
3

IO-aware attention algorithm


2 FA2 主要是 GPU kernel 优化

FA2 的改进本质是:

1
2
3
4
5

work partition
thread mapping
occupancy optimization

属于:

1
2
3

GPU kernel engineering

而不是新的 attention 算法。


3 FA1 本身已经是巨大突破

当时 baseline:

1
2
3

PyTorch attention

FA1:

1
2
3
4

2–4x speedup
10x memory reduction

已经是革命级改进。


4 FA2 是在大量实践后出现的优化

FA1 被广泛使用后发现:

1
2
3

GPU utilization 不高

进一步 profiling 后:

1
2
3

kernel work partition 可以优化

才提出 FA2。


八、一句话总结

FlashAttention v1:

1
2
3
4

IO-aware attention algorithm
目标:减少 memory IO

FlashAttention v2:

1
2
3
4

GPU kernel optimization
目标:提高 FLOPs utilization

两者解决的是:

1
2
3

不同层级的性能瓶颈

最近读到一篇非常好的博客:

lips.cs.princeton.edu: The Gumbel-Max Trick for Discrete Distributions

强烈推荐直接阅读原文。这篇文章非常简洁地解释了 Gumbel-Max trick。很多论文会引用这个技巧,但真正讲清楚的资料并不多。这里简单记录几个关键点,并补充一点理解。


一句话核心思想

从 softmax 分布采样,可以通过给每个 logit 加一个 Gumbel 噪声,然后取最大值来实现。

其中

重复很多次后,第 i 个类别被选中的概率正好是

也就是 softmax。


原文中的变量含义

原文定义:

于是采样规则可以写成:

这里:

  • x_i:logit / score / log-probability
  • g_i:Gumbel 噪声
  • z_i:加入随机扰动后的 score

因此整篇文章实际上是在计算:

这和计算

是同一件事,因为原文只是把

换了个记号。


为什么 z 仍然是 Gumbel

原文有一个容易忽略的点:

如果

那么

仍然服从 Gumbel 分布,只是 location 参数变成了 x:

这就是为什么原文后面可以直接把 z 当成一个 Gumbel 随机变量来处理。

一句话说,这里用到的是:

Gumbel 分布在“加一个常数”之后,还是 Gumbel 分布。


一个直觉理解

可以把 Gumbel-Max 理解为:

每个类别先有一个固定分数 x_i,再额外加上一点随机扰动 g_i,最后看谁最大。

也就是说,我们比较的是:

如果重复很多次,原本分数更高的类别更容易赢,但又不是每次都赢。最终它被选中的概率恰好就是 softmax:


从 Gumbel-Max 到 Gumbel-Softmax

Gumbel-Max 的问题在于:

不可导,因此不能直接用于神经网络训练。

一个自然的想法是:用 softmax 去近似 argmax。于是得到

当温度 tau 趋近于 0 时,这个分布会越来越接近 one-hot,也就越来越像 argmax 的结果。

这就是 Gumbel-Softmax 的来历:

  • Gumbel-Max:加 Gumbel 噪声后做 argmax
  • Gumbel-Softmax:加 Gumbel 噪声后做 softmax 近似

它的意义是:既保留了离散采样的味道,又让整个过程变得可导。

参考文章

https://spaces.ac.cn/archives/10592


recap: AdamW

AdamW(Adam with Weight Decay)是目前深度学习中最常用的优化器之一,它是对经典 Adam 算法的一个重要改进,解决了权重衰减(Weight Decay)在自适应梯度算法中失效的问题。

1. AdamW 的更新公式

AdamW 的核心思想是将权重衰减从梯度更新中分离出来。其更新步骤如下:

第一步:计算一阶和二阶矩(动量)

其中 $g_t$ 是当前步的梯度,$m_t$ 是梯度的指数移动平均(一阶矩),$v_t$ 是梯度平方的指数移动平均(二阶矩)。

第二步:偏差纠正(Bias Correction)

第三步:参数更新(关键差异点)

注意:$\lambda \theta_{t-1}$ 这一项是直接作用在参数更新上的,而不是加在梯度 $g_t$ 里。


2. 直觉解释

你可以把 AdamW 想象成一个带刹车的智能小球

  • $m_t$ (一阶矩):就像惯性。球不会因为路面一个小坑就立刻停下,而是会根据之前的速度继续滑动,这有助于跳出局部最优。
  • $v_t$ (二阶矩):就像路面的粗糙度自适应。如果某个方向坡度一直很大(梯度大),它会增加“摩擦力”让更新变稳;如果坡度很小,它会放大更新,确保能走得动。
  • Weight Decay ($\lambda$):就像向心力。它不断地把参数往原点(0)拉,防止某些参数数值爆炸,从而起到正则化、减少过拟合的作用。

3. 为什么一定要进行偏差纠正?

这是许多初学者容易忽略的地方。我们需要证明:如果不纠正,初始阶段的估算是有偏的。

数学证明

以一阶矩 $m_t$ 为例,其展开式为:

我们希望 $m_t$ 是梯度期望 $E[g_t]$ 的无偏估计。对等式两边取期望:

假设真实梯度分布相对稳定,$E[g_i] \approx E[g_t]$,则:

利用等比数列求和公式 $\sum_{i=0}^{n-1} r^i = \frac{1-r^n}{1-r}$,我们可以得到:

结论

显而易见,$E[m_t] \neq E[g_t]$,而是多了一个因子 $(1-\beta_1^t)$。

  • 在训练初期($t$ 较小):由于 $m_0$ 通常初始化为 0,$\beta_1$ 又接近 1(如 0.9),导致 $m_t$ 会严重向 0 偏移,取值远小于真实期望。
  • 随着训练进行($t \to \infty$):$\beta_1^t$ 趋于 0,偏移逐渐消失。

偏差纠正的目的就是通过除以 $(1-\beta_t^t)$,抵消掉这个初始阶段的缩小效果,使优化器从第一步起就能准确反映梯度的真实规模。


4. 常用超参数设置

在处理 LLM、Transformer 或视觉模型时,以下是业界公认的“黄金标准”设置:

参数 常用值 说明
Learning Rate ($\eta$) $10^{-3}$ ~ $10^{-4}$ 核心参数,通常配合 Linear Warmup 使用。
$\beta_1$ $0.9$ 控制一阶动量,几乎不需要改动。
$\beta_2$ $0.999$ 或 $0.95$ 控制二阶动量。在大模型训练中,有时调低至 0.95 可增加稳定性。
$\epsilon$ $10^{-8}$ 防止除零的小数,通常保持默认。
Weight Decay ($\lambda$) $0.01$ ~ $0.1$ AdamW 的精髓,负责正则化。

使用建议

  • 搭配调度器:AdamW 必须搭配 Learning Rate Scheduler(如 Cosine Decay)。
  • 解耦学习率:如果你调大了 $\lambda$,通常不需要剧烈改动 $\eta$,这是 AdamW 优于 Adam(加 L2 正则)的地方。

recap: 正交化

正交矩阵的定义以及精确对矩阵正交化的方法(gram-schmidz)

Gemini小解释

这是一个非常精妙的对比视角。你的理解触及了“局部几何”与“全局几何”的差异,但我们可以把这个直觉再推敲得更准一点:

1. 你的直觉对了一半:Adam 的“逐元素独立性”

实际上,Adam 比你想象的更“分裂”。它并没有把所有参数视为一个整体空间中的大向量,而是把每一个标量参数(每一个权重数字)都视为一个独立的维度

  • Adam 的视角:如果你的模型有 10 亿个参数,Adam 就像是在一个 10 亿维 的空间里。它给每一个维度(每一颗螺丝钉)都装了一个独立的自适应减震器(即 $v_t$ 项,二阶动量)。
  • 这种做法的问题:它完全无视了参数之间的结构关系。它不知道某 100 万个数字其实构成了一个“线性层矩阵”,它只把它们当成 100 万个互不相干的点。

2. Muon 的视角:尊重“结构”

Muon 的“聪明”之处在于它意识到:神经网络不是数字的堆砌,而是变换的组合。

  • 矩阵作为整体:Muon 认为,一个线性层的 $d \times d$ 矩阵不应该是 $d^2$ 个独立的维度,而是一个整体的变换算子
  • 空间独立:Muon 确实为每个矩阵实施正交化。它在每个层的“局部空间”内,确保该层的变换方向(奇异向量)是均匀且正交的。

3. “整体”与“局部”的降维打击

我们可以用一个形象的比喻来总结 Adam 和 Muon 的空间观:

优化器 空间观 行为比喻
Adam 极其微观:每个参数都是一个独立的轴。 像是一个巨大的合唱团,指挥要求每个人根据自己的音量(梯度历史)独立调整声音大小,不管邻居唱什么。
Muon 结构宏观:每个层(矩阵)是一个独立的空间。 指挥将合唱团按声部(矩阵)分组。要求每个声部内部音域必须横跨高低音(正交化),且组与组之间独立推进。

4. 为什么 Muon 的“局部正交”比 Adam 的“全局自适应”快?

因为 Adam 在每个坐标轴上做的缩放,实际上可能破坏矩阵的性质。
例如,Adam 可能会放大梯度中原本就很小的噪声方向,缩小重要的主成分方向,因为它试图让每个轴的更新步长都变成一样的。

Muon 在局部矩阵空间里做 $U V^T$,它保证了:

  1. 不改变能量:正交变换不改变向量的范数,梯度的能量在通过层时不会被扭曲。
  2. 最大的探索率:因为它强迫每个奇异值都为 1,这相当于强迫模型在每一个可能的特征方向上都用相同的力度去“试探”,而不是只盯着那一两个已经很明显的特征(大奇异值方向)。

总结

  • Adam:是坐标轴级别的自适应(Coordinate-wise)。它把所有参数拉平,无视矩阵结构。
  • Muon:是算子级别的自适应(Operator-wise)。它承认矩阵的几何结构,并在每个矩阵内部寻求最优的正交骨架。

既然你提到了“整体空间”,你是否想了解在分布式训练中,Muon 是如何利用这种矩阵独立性来做 Distributed Data Parallel (DDP) 下的通信优化的?(这正是 Muon 能在大型集群上跑得飞快的原因之一)。


GPT解释(超级无敌详细)

可以。下面我按这篇文章的主线,把核心公式一步一步推导出来,并在每一步都说明“为什么要这么做”。我会集中讲三条最关键的线:

  1. Muon 到底在算什么:
  2. 为什么它等于
  3. 为什么可以用 Newton–Schulz 迭代逼近它
    以及这个迭代本质上是在“把奇异值推向 1”。 (科学空间)

0. 先把目标说清楚

文章里 Muon 的更新是

所以真正关键是:给定一个矩阵 $M$,什么叫 $\mathrm{msign}(M)$。文章给出的定义是:若

是 $M$ 的 SVD,那么

其中 $r=\mathrm{rank}(M)$。直觉上,这一步是保留左右奇异向量定义的“方向结构”,但把奇异值统一处理掉,不再让不同方向的尺度差异主导更新。(科学空间)


1. 先复习 SVD,因为后面全靠它

任意矩阵 $M\in\mathbb{R}^{n\times m}$ 都可以写成

这里:

  • $U$ 的列是左奇异向量;
  • $V$ 的列是右奇异向量;
  • $\Sigma$ 是“对角形”的,主对角线上是非负奇异值

更具体地,若只看秩 $r$ 的有效部分,可以写成

其中

  • $U_r\in\mathbb{R}^{n\times r}$,
  • $\Sigma_r=\mathrm{diag}(\sigma_1,\dots,\sigma_r)$,
  • $V_r\in\mathbb{R}^{m\times r}$。

这个写法最重要的直觉是:

也就是说,$M$ 是若干个 rank-1 方向 $u_i v_i^\top$ 的叠加,而 $\sigma_i$ 控制每个方向的强弱。Muon 的核心想法就是:我想保留这些方向 $u_i,v_i$,但不想让 $\sigma_i$ 的不均衡破坏更新。 (科学空间)


2. 为什么 $\mathrm{msign}(M)=UV^\top$ 很自然

先看标量的 sign:

它做了什么?
保留符号方向,但去掉绝对值大小

矩阵版想做类似的事:
保留矩阵的“方向信息”,但去掉奇异值大小

因为标量的“大小”是 $|x|$,矩阵最自然的“大小分解”就是 SVD 里的奇异值 $\sigma_i$。所以如果

那去掉大小、只保留方向,最自然就是把 $\Sigma_r$ 里的每个正奇异值都换成 1:

于是得到

这一步的 intuition 很重要:

  • Adam 类方法是按坐标缩放;
  • Muon是按矩阵的奇异方向来规整更新;
  • 它不是“逐元素地把大数变小、小数变大”,而是“把不同奇异方向的尺度统一”。 (科学空间)

3. 推导文章中的恒等式

$\mathrm{msign}(M)=M(M^\top M)^{-1/2}$

这是全文最关键的恒等式之一。文章直接给了结果,但我们可以完整推出来。(科学空间)

第一步:从 SVD 出发

那么

因为 $U_r^\top U_r=I_r$,所以

这里 $\Sigma_r^2=\mathrm{diag}(\sigma_1^2,\dots,\sigma_r^2)$。


第二步:开平方再取逆

因为

所以它的平方根是

其逆平方根(若有零奇异值则理解为伪逆)就是

这里


第三步:右乘回去

现在计算

代入上面的表达式:

因为 $V_r^\top V_r=I_r$,所以

而这正是定义中的

所以我们得到


4. 同理推导

$\mathrm{msign}(M)=(MM^\top)^{-1/2}M$

同样从

出发,有

所以

然后左乘 $M$:

因此

这条公式的直觉是:
$\mathrm{msign}(M)$ 就像对 $M$ 做一种“矩阵归一化”。
标量里是

矩阵里对应地变成

文章也明确指出这是它作为 sign 的矩阵推广的关键理解。(科学空间)


5. 为什么对向量时它会退化成 $l_2$ 归一化

文章说,如果把向量 $m\in\mathbb{R}^n$ 看成 $n\times 1$ 矩阵,那么

我们现在直接推。(科学空间)

因为 $m$ 是列向量,所以

是一个 $1\times 1$ 矩阵,也就是标量:

于是

代回恒等式

就得到

这特别值得记住,因为它说明:

  • 标量情形:sign 是“除以绝对值”;
  • 向量情形:msign 是“除以 $l_2$ 范数”;
  • 矩阵情形:msign 是“除以矩阵意义下的尺度”,即奇异值。 (科学空间)

6. 为什么对对角矩阵时它退化成逐元素 sign

所以

因此

也就是

这正是逐元素 sign。文章借此说明:Muon 可以看成把 SignSGD/Tiger 的逐元素规整化,推广成了矩阵级规整化。(科学空间)


7. 推导“最优正交近似”公式

$\mathrm{msign}(M)=\arg\min_{O^\top O=I}|M-O|_F^2$

这也是文章的核心结论之一,而且非常有解释力:
Muon 不是随便把奇异值变成 1,而是在 Frobenius 范数下,找离 $M$ 最近的正交矩阵。 (科学空间)

我们一步一步来。

第一步:展开平方

目标函数是

利用 Frobenius 范数恒等式:

所以

Frobenius 内积满足

于是


第二步:利用 $O$ 是正交矩阵

若 $O^\top O=I$ 且 $O$ 是 $n\times n$ 正交矩阵,那么

所以目标变成

注意这里前两项与 $O$ 无关,所以最小化距离等价于最大化

这一步 intuition 很关键:
“离 $M$ 最近”这件事,最后变成了“让 $O$ 与 $M$ 的对齐程度最大”。(科学空间)


第三步:代入 SVD

利用 trace 的循环不变性:

可写成

因为 $U,V,O$ 都是正交矩阵,$Q$ 也是正交矩阵。于是


第四步:把 trace 展开

由于 $\Sigma$ 是对角的,

而正交矩阵每个对角元都满足

又因为 $\sigma_i\ge 0$,所以要让

最大,就应该让每个 $Q_{ii}$ 尽量大,也就是取

最理想的情况就是

于是

所以最优解是

这条结论给了一个非常漂亮的几何解释:

  • $M$ 可能“歪歪扭扭”、不同方向尺度不同;
  • $\mathrm{msign}(M)$ 是最接近它的“纯正交版本”;
  • Muon 其实是在拿这个“最接近的正交更新”来替代原始矩阵更新。(科学空间)

8. 现在解释它为什么像“自适应学习率”

文章指出,Muon 与 Adam 类似,也有“尺度不敏感”和“更各向同性”的性质。(科学空间)

8.1 为什么损失整体缩放不影响方向

若把损失乘上常数 $\lambda$,梯度矩阵也乘上 $\lambda$,于是动量矩阵 $M$ 也会乘上 $\lambda$。

那么

于是

因为奇异值全被“置一”了,所以尺度因子 $\lambda$ 被消掉了。

这和 Adam 用归一化来减少尺度敏感性是同一种精神,只不过 Adam 在坐标维度上做,Muon 在奇异方向上做。(科学空间)


8.2 为什么说它让更新更“各向同性”

如果

那 $\Sigma$ 里的不同奇异值,代表不同奇异方向上的拉伸强弱不同。
奇异值差异越大,矩阵越“各向异性”。

Muon 做的是把

变成

也就是把所有奇异方向的尺度统一。于是更新不再偏向某几个奇异值特别大的方向,而是在有效子空间里更均匀。这就是文章说的“更各向同性”。(科学空间)


9. 为什么可以不用 SVD,改用 Newton–Schulz 迭代

SVD 太贵,所以文章接下来从恒等式

出发,想逼近矩阵函数 $X\mapsto X^{-1/2}$。(科学空间)

核心想法是:先把标量函数

在 $t=1$ 附近做泰勒展开,然后把这个多项式“搬到矩阵上”。

文章给出:

保留到二阶后可整理成

(科学空间)

我们把这一步完整展开一下。


10. 从泰勒展开推到多项式近似

开始,逐项展开。

先算

再算

三项加起来:

把常数项合并:

把 $t$ 项合并:

保留 $t^2$ 项:

所以

于是

再左乘 $M$:

这就是文章里的近似公式。(科学空间)


11. 为什么这会导出迭代格式

文章接着说:如果 $X_t$ 已经是 $\mathrm{msign}(M)$ 的一个近似,那把同样的多项式作用到 $X_t$ 上,可能得到更好的近似,于是写成

(科学空间)

这一步的 intuition 是:

  • 真正想要的结果 $Y=\mathrm{msign}(M)$ 满足(若满列秩且 $n\ge m$,则有 $Y^\top Y=I$);
  • 所以如果当前 $X_t$ 还没做到 $X_t^\top X_t=I$,我们就用一个多项式,把它朝这个条件“推过去”。

换句话说,这个迭代不是在直接逼近矩阵元素,而是在逼近“正交性条件”。(科学空间)


12. 文章最关键的 insight:

这个迭代本质上是在迭代奇异值

这是全文最值得吃透的一步。文章把一般迭代写成

然后说明它等价于对每个奇异值单独做一个标量迭代。(科学空间)

我们完整推一次。

第一步:假设 $X_t$ 的 SVD

那么

进一步


第二步:分别代入各项

第一项:

第二项:

第三项:


第三步:合并

所以

因为 $\Sigma_t$ 是对角矩阵,对角矩阵的幂就是把每个对角元单独取幂。
如果某个奇异值是 $\sigma_t$,那么它的更新就是

这就是文章那句“问题简化成单个奇异值的迭代”的完整原因。(科学空间)


13. 为什么目标是把奇异值迭代到 1

因为如果

最后我们希望 $X_t$ 收敛到

那就意味着希望

也就是说,每个奇异值都要满足

所以整个矩阵迭代的本质,就是设计一个标量函数 $g$,让所有初始奇异值都尽快被送到 1 附近。这就是文章后面把系数选择写成优化问题的原因。(科学空间)


14. 为什么初始化要除以 Frobenius 范数

文章取

理由是:缩放不会改变 $U,V$,只会缩放奇异值;而除以 $|M|_F$ 后,初始奇异值会被压到更标准的范围内,文章表述为落在 $[0,1]$ 之间,从而更适合迭代。(科学空间)

更直观地说:

  • 我们的多项式近似是围绕 $t=1$ 展开的;
  • 所以希望 $X_t^\top X_t$ 的谱不要太离谱;
  • 先把 $M$ 缩到一个温和尺度,迭代更稳。

15. 把整件事串成一句真正“懂了”的话

Muon 的核心不是“神秘地做正交化”,而是:

  1. 把梯度/动量矩阵 $M$ 做 SVD:
  2. 只保留方向 $U,V$,把尺度 $\Sigma$ 统一成 $I$:
  3. 这等价于求离 $M$ 最近的正交矩阵;
  4. 实现时不直接做 SVD,而是通过迭代不断把奇异值 $\sigma_i$ 推向 1。

所以 Muon 的“正交化”本质上不是你在线代课上学的 Gram–Schmidt 那种“把一组向量一个个变垂直”,而是:

把整个梯度矩阵投到“正交矩阵流形”附近,得到一个最接近原矩阵、但奇异值全被标准化的更新方向。 这正是文章强调的“从向量到矩阵的本质跨越”。(科学空间)

你要是愿意,我下一条可以继续做两件事里的任意一个:

  1. 用一个 $(2\times 2)$ 的具体数值例子,把上面每一步真的算出来;
  2. 解释 这和你熟悉的 Gram–Schmidt / QR 正交化到底是什么关系、又有什么本质区别

GPU权威学习资料

https://skindhu.github.io/How-To-Scale-Your-Model-CN/article-trans/gpus.html

triton权威学习资料

triton tutorials

用flash attention深入理解GPU和triton

https://github.com/hkproj/triton-flash-attention
youtube和bilibili也有视频讲解

补充讲解

Triton 基础:从 GPU 到 Triton 的一条主线理解

1️⃣ GPU 的核心执行模型(理解 Triton 的前提)

GPU 的计算模型可以概括为一句话:

GPU 会并行地运行大量“相同的小程序(kernel)”,
每个程序负责处理一小块数据。

这几个事实是理解 Triton 的根基:

  1. GPU 的并行单位不是“一次算一个大矩阵”,而是成千上万个小任务并行
  2. 每个小任务在硬件上对应一个 thread block(CUDA 里叫 CTA)
  3. 每个 thread block:

    • 有自己的寄存器 / shared memory
    • 不和别的 block 共享状态
  4. GPU 的性能瓶颈通常是:

    • 内存访问(慢)
    • 而不是浮点运算(快)

因此,高性能 GPU 程序的目标永远是:

让每个 block 把数据一次性搬进来,
在本地尽量多算几次,再写回。


2️⃣ Triton 的抽象层次:只让你写“block 级逻辑”

Triton 的设计思想是:

你只需要写“一个 block 要算什么”,
不需要写 thread / warp / shared memory。

因此:

  • 一个 @triton.jit 函数 ≈ 一个 GPU kernel
  • kernel 的一次并行实例 ≈ 一个 Triton program ≈ 一个 GPU block
  • Triton 自动:

    • 把 block 映射成 warps / threads
    • 决定寄存器 / shared memory 的使用

你写的不是“线程程序”,而是“tile 程序”


3️⃣ Grid:GPU 并行的全局结构

GPU 启动 kernel 时,需要告诉它:

要启动多少个 block?这些 block 怎么排?

在 Triton 中,这由 grid 决定:

1
kernel[grid](...)

如果:

1
grid = (X, Y, Z)

那么 GPU 会并行启动 X × Y × Z 个 Triton program。

每个 program 可以通过:

1
2
3
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)

知道自己在 grid 中的“坐标”。


4️⃣ Program 内部:用 block 索引切出自己的数据

一个 Triton program 的第一件事永远是:

根据自己的 program_id,算出“我该处理哪一块数据”。

典型模式是:

1
2
start = pid * BLOCK_SIZE
offs = start + tl.arange(0, BLOCK_SIZE)

在 GPU 语义上,这等价于:

“这个 block 负责第 pid 个 tile。”

这一步完全是索引计算,不涉及任何计算。


5️⃣ Triton 的内存访问:一次 load 一个 tile

在 GPU 上,逐元素 for-loop 访问内存是灾难性的。

Triton 的内存访问模型是:

一次 tl.load,读进一个完整的二维块。

例如:

1
2
3
tl.load(base
+ row_offsets[:, None] * stride_row
+ col_offsets[None, :] * stride_col)

这在语义上是:

  • 构造一个 二维地址矩阵
  • GPU 并行地把整块数据搬进寄存器

从 GPU 角度看:

这是一次 向量化、连续、可合并的 global memory 访问


6️⃣ 计算:block 级数学,而不是线程级指令

在 Triton 中:

1
C = tl.dot(A, B)

不是 Python 层的矩阵乘,而是:

  • 编译器识别为 block-level GEMM
  • 自动映射到:

    • Tensor Core(如果尺寸和 dtype 合适)
    • 或普通 FMA

你表达的是:

“这个 tile 做一次矩阵乘”

而不是:

“某个线程算第几个元素”。


7️⃣ Program 内的循环:为什么它是“安全的”

Triton 允许:

1
2
for i in range(num_steps):
...

这并不是展开并行,而是:

  • 单个 block 内的顺序循环
  • 常用于:

    • 遍历 K/V blocks
    • streaming reduction
    • softmax

GPU 视角是:

很多 block 并行跑,
每个 block 内部顺序推进。

这正是 FlashAttention 的计算结构。


8️⃣ 写回:每个 block 负责一块输出

Triton 的设计原则是:

一个输出元素只由一个 program 写。

因此常见模式是:

1
tl.store(ptr + offsets, block)

对应 GPU 语义:

  • 无原子操作
  • 无写冲突
  • 写回的是该 block 独占的 tile

9️⃣ 把 Triton 与 FlashAttention 串起来

FlashAttention 的核心是:

  1. 把巨大 attention 矩阵切成 tile
  2. 每个 tile:

    • load Q/K/V
    • 算 softmax
    • 累积 O / 梯度
  3. 从不存整张 attention matrix

这与 Triton 的模型完全同构

FlashAttention Triton
attention tile Triton program
streaming softmax program 内 for-loop
不存 P 用寄存器 / SRAM
Q/K/V block tl.load
matmul tl.dot

10️⃣ 最终一句话总结

Triton 是一种“以 GPU block 为基本单位”的编程模型:
你描述 block 的数学计算,
Triton 负责把它高效地跑在 GPU 上。

本文面向:
只有简单线性代数基础(矩阵乘法、逆矩阵)
但想真正理解「位姿变换 / delta_pose 从何而来」的读者。


1. 问题从哪里来?

在机械臂、具身智能中,我们经常遇到这样的问题:

已知末端在两个时刻的绝对位姿
[
T_1,\; T_2
]

如何计算末端从 (T_1) 运动到 (T_2) 的 delta_pose

你可能见过一个“标准答案”:

[
\boxed{
T_{\Delta} = T_1^{-1} T_2
}
]

但问题是:

  • 这是定义吗?
  • 为什么可以这样乘?
  • 它和“坐标系”“点的位置”有什么关系?

这篇文章就来把这件事从地基开始讲清楚


2. 什么是“位姿”?先别急着上公式

先不用矩阵,先说人话。

一个刚体(比如机械臂末端)的位姿,只包含两件事:

  1. 它的原点在哪里?(平移)
  2. 它的坐标轴朝哪?(旋转)

3. 一个关键直觉:点的位置 = 原点 + 方向

设:

  • 世界坐标系:(W)
  • 末端坐标系:(E)
  • 某个物理点:(P)

已知:

  • 末端原点在世界中的位置:
    [
    p_E^W
    ]
  • 点 (P) 在末端坐标系下的坐标:
    [
    p^E
    ]
  • 末端坐标系相对于世界的旋转:
    [
    R_E^W
    ]

那么,点在世界坐标系中的位置一定是:

[
\boxed{
p^W = p_E^W + R_E^W \, p^E
}
]

这不是公式技巧,而是几何事实:

先走到末端原点
再沿着末端坐标轴的方向走到点 P
但方向要先旋转到世界系


4. 齐次矩阵不是魔法,只是“打包”

上面的式子有个问题:

每次都要写
[
p^W = R p^E + t
]
很麻烦,而且不好连着算。

解决办法:加一个 1

把点写成齐次坐标:

[
\bar p^E =
\begin{bmatrix}
p^E \
1
\end{bmatrix}
]

把旋转 + 平移打包成一个矩阵:

[
{}^{W}T_E =
\begin{bmatrix}
R_E^W & p_E^W \
0 & 1
\end{bmatrix}
]

算一下矩阵乘法:

[
{}^{W}T_E \bar p^E
=
\begin{bmatrix}
R_E^W p^E + p_E^W \
1
\end{bmatrix}
]

第一行正好就是上面的几何公式。

所以:

[
\boxed{
\bar p^W = {}^{W}T_E \, \bar p^E
}
]

这不是新物理定律
只是把「旋转 + 平移」写成了一次矩阵乘法


5. 为什么逆矩阵代表“反向变换”?

如果:

[
\bar p^W = T \bar p^E
]

那根据线性代数最基本的事实:

[
\bar p^E = T^{-1} \bar p^W
]

没有新含义,只是:

正向变换 → 逆矩阵
世界系 → 末端系


6. delta_pose 到底在算什么?

现在进入正题。

已知两个时刻末端的绝对位姿(都在世界系):

[
T1 = {}^{W}T{E1}, \quad
T_2 = {}^{W}T
{E_2}
]

我们想要的 delta_pose 是:

在 (E_1) 坐标系中看,(E_2) 在哪里?

即:

[
T\Delta = {}^{E_1}T{E_2}
]


7. 用“同一个点”推导 delta_pose

取一个物理点 (P),它在 (E_2) 坐标系中的坐标是 (\bar p^{E_2})。

第一步:从 (E_2) 到世界

[
\bar p^W = T_2 \bar p^{E_2}
]

第二步:从世界到 (E_1)

[
\bar p^{E_1} = T_1^{-1} \bar p^W
]

合在一起:

[
\bar p^{E_1} = T_1^{-1} T_2 \bar p^{E_2}
]

而 delta_pose 的定义正是:

[
\bar p^{E1} = T\Delta \bar p^{E_2}
]

于是只能是:

[
\boxed{
T_\Delta = T_1^{-1} T_2
}
]


8. 用 2D 平面做一个具体例子(非常重要)

设定

  • 2D 平面
  • 末端在时刻 1:
    • 位置:(1, 0)
    • 朝向:90°(y 轴向上)
  • 时刻 2:
    • 位置:(1, 1)
    • 朝向:90°

问题

世界系看:

  • 位移是:(0, 1)

但在末端系看呢?

由于末端坐标轴已经转了 90°:

  • 世界 y 方向 = 末端 x 方向

👉 所以 delta_pose 在末端系中是:

[
\Delta p = (1, 0)
]

这正是公式中:

[
\Delta p = R_1^T (p_2 - p_1)
]

在干的事。


9. 常见误区(非常容易踩)

❌ 直接用:

[
p_2 - p_1
]

这只是在世界系下的差值
而控制、学习需要的是:

末端自己“感觉”到的运动


10. 总结一句话(可以背下来)

delta_pose 不是两个位姿的“差”

它是:
先从世界回到 t1 的末端,再走到 t2 的末端

数学上就是:

[
\boxed{
T_\Delta = T_1^{-1} T_2
}
]


11. 写在最后

齐次变换矩阵不是玄学,也不是死记硬背的公式。

它只是把一句非常朴素的话:

“先旋转,再平移”

用线性代数统一、优雅地表达了出来。

如果你能理解这一点,
后面的 IK、控制、模仿学习、diffusion policy,
都会顺很多。


如果你之后想继续往下,我很推荐的顺序是:

  1. 2D → 3D 的完全类比
  2. 为什么学习算法一定喜欢用 delta_pose
  3. delta_pose 和 twist / se(3) 的关系(不硬上李代数版)

哪一篇你想先写?我可以继续帮你整理。

在大语言模型的 RLHF(Reinforcement Learning from Human Feedback) 体系中,
PPO(Proximal Policy Optimization) 一直是事实标准。

但随着模型规模增大、工程复杂度提升,PPO 的一些问题也逐渐暴露出来:

  • 训练流程复杂(Reward Model + PPO Loop)
  • 超参数敏感(KL、clip range、value loss)
  • 训练不稳定、成本高

因此,近两年社区开始大量使用 PPO 的简化或替代方案,其中最常见的就是:

  • DPO(Direct Preference Optimization)
  • GRPO(Group Relative Policy Optimization)

本文从 PPO 出发,介绍这两种常用变体的核心思想、公式直觉和适用场景。


一、PPO 在 LLM 对齐中的基本范式

在 RLHF 中,PPO 的目标可以概括为:

在不偏离参考模型太多的前提下,
最大化模型生成内容的期望 reward

其核心目标函数通常写作:

[
\mathcal{L}{\text{PPO}} =
\mathbb{E}\Big[
\min(r_t(\theta) A_t,\;
\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)
\Big]
\- \beta \,\mathrm{KL}(\pi
\theta || \pi_{\text{ref}})
]

其中:

  • Reward 由 Reward Model 给出
  • KL 用于约束策略不偏离 reference model

PPO 的现实问题

在 LLM 微调中,PPO 常见痛点包括:

  • 需要额外训练一个 reward model
  • PPO loop 本身实现复杂
  • KL / value loss 权重极其敏感
  • 容易出现 reward hacking 或模式坍塌

这直接催生了 “不用 PPO 的 PPO” —— DPO 和 GRPO。


二、DPO:把 PPO 变成一个“分类问题”

1️⃣ 核心思想

DPO(Direct Preference Optimization) 的出发点非常简单:

如果人类更喜欢 A 而不是 B,
那模型就应该给 A 更高的概率

DPO 完全绕过了 reward model 和 PPO loop
只使用 成对偏好数据(chosen / rejected)


2️⃣ DPO 的目标函数

DPO 的优化目标为:

[
\mathcal{L}{\text{DPO}} =
\- \log \sigma\Big(
\beta \big[
(\log \pi
\theta(y^+|x) - \log \pi\theta(y^-|x))
\-
(\log \pi
{\text{ref}}(y^+|x) - \log \pi_{\text{ref}}(y^-|x))
\big]
\Big)
]

直觉解释:

  • 模型不仅要更偏向 chosen
  • 而且要 比 reference model 更偏向 chosen

可以理解为一种 “相对于 reference 的偏好分类”


3️⃣ DPO 的优点与局限

优点:

  • 不需要 reward model
  • 训练稳定、实现简单
  • 类似 SFT 的训练流程(cross-entropy 风格)
  • 非常适合中小规模对齐任务

局限:

  • 只能利用 pairwise preference
  • 对数据质量高度敏感
  • 无法自然表达强度不同的 reward

4️⃣ 适用场景

  • Chat / Instruction 对齐
  • 安全对齐(harmlessness)
  • 偏好排序类任务
  • PPO 训练不稳定或成本过高时

三、GRPO:介于 PPO 与 DPO 之间的折中方案

1️⃣ 核心思想

GRPO(Group Relative Policy Optimization) 的关键假设是:

不需要一个绝对 reward,
只需要知道 同一组输出中谁更好

它通过 组内相对优势(relative advantage)
避免了 value function 和复杂的 PPO clip。


2️⃣ GRPO 的基本形式

给定同一个 prompt 下的多个采样结果:

[
{y_1, y_2, \dots, y_n}
]

使用 reward(或打分函数)计算组内标准化优势:

[
A_i = \frac{r_i - \mu(r)}{\sigma(r)}
]

然后直接优化:

[
\mathcal{L}{\text{GRPO}}
= - \mathbb{E}\big[ \log \pi
\theta(yi|x) \cdot A_i \big]
\- \beta \,\mathrm{KL}(\pi
\theta || \pi_{\text{ref}})
]


3️⃣ GRPO 的特点

  • 不需要 value model
  • Advantage 来自组内对比
  • 保留了 PPO 的 policy gradient 直觉
  • 比 DPO 更“RL”,比 PPO 更简单

4️⃣ GRPO vs DPO

维度 DPO GRPO
数据形式 成对偏好 同 prompt 多采样
Reward 隐式(偏好) 显式或打分
是否 RL 更像分类 更像 policy gradient
稳定性 非常高
表达能力

四、如何选择:PPO / DPO / GRPO?

一个实用的经验法则:

  • DPO
    👉 数据是偏好对,追求稳定、简单、快速对齐
  • GRPO
    👉 有打分函数或多样本对比,希望保留 RL 表达能力
  • PPO
    👉 需要精细控制 reward、做复杂行为塑形

在很多实际项目中:
DPO / GRPO 已经可以替代 80% 的 PPO 使用场景


五、总结

  • PPO 是 RLHF 的理论起点,但工程成本高
  • DPO 用“偏好分类”绕过 PPO
  • GRPO 用“组内相对优势”简化 PPO
  • 三者并非对立,而是 复杂度与表达能力的权衡

LLM 对齐训练,正在从“重 RL”走向“轻 RL”。


如果你正在做 LLM 微调,不妨问自己一句:
“我真的需要 PPO 吗?”