🎓 深度解析反向传播:全连接层梯度的两种推导方法
在深度学习中,全连接层(Linear Layer)是最基础的模块。理解其反向传播中的梯度计算是掌握整个神经网络训练流程的关键。本文将以 $Y = WX + b$ 为例,演示两种核心梯度推导方法:维度检查法(Shape Check)和全微分法(Matrix Differential)。
场景设定与变量定义
假设我们在计算图中的某一层进行以下运算:
变量定义与维度(Shape):
- $X$:输入向量,维度 $(n \times 1)$。
- $W$:权重矩阵,维度 $(m \times n)$。
- $b$:偏置向量,维度 $(m \times 1)$。
- $Y$:输出向量,维度 $(m \times 1)$。
- $L$:最终的损失函数(Scalar,标量)。
已知条件(上游传回来的梯度):
我们已算出 Loss 对本层输出 $Y$ 的梯度:
- $G$ 的维度必须和 $Y$ 一致,也是 $(m \times 1)$。
我们的目标:
求出 $L$ 关于 $W$ 的梯度 $\frac{\partial L}{\partial W}$ 和 关于 $X$ 的梯度 $\frac{\partial L}{\partial X}$(用于继续往前传)。
方法一:维度检查法(Shape Check)
核心思想: 梯度的形状必须与其对应的变量形状一致。通过线性代数知识,用已知变量拼凑出正确的形状。
1. 求 $\frac{\partial L}{\partial W}$(权重 $W$ 的梯度)
- 目标形状: $\frac{\partial L}{\partial W}$ 的形状必须和 $W$ 一致,即 $(m \times n)$。
- 素材: $G$ (上游梯度,$(m \times 1)$) 和 $X$ (输入,$(n \times 1)$)。
- 拼凑: 想要得到 $(m \times n)$,只能将 $(m \times 1)$ 乘以 $(1 \times n)$。
- 结论:
2. 求 $\frac{\partial L}{\partial X}$(输入 $X$ 的梯度,传给上一层)
- 目标形状: $\frac{\partial L}{\partial X}$ 的形状必须和 $X$ 一致,即 $(n \times 1)$。
- 素材: $G$ (上游梯度,$(m \times 1)$) 和 $W$ (权重,$(m \times n)$)。
- 拼凑: 想要得到 $(n \times 1)$,需要将 $W$ 转置为 $(n \times m)$,再与 $G$ 相乘。
- 结论:
点评:维度检查法速度极快,是日常代码实现和面试推导的首选。
方法二:全微分法(Matrix Differential)
核心思想: 利用标量函数 $L$ 的微分 $dL$ 与梯度的关系,通过迹(Trace)的性质进行严谨推导。这是处理复杂矩阵运算的终极武器。
核心公式(迹与梯度)
对于标量函数 $L$,其微分 $dL$ 与梯度 $\nabla A$ 的关系是:
推导过程
第一步:对正向公式求微分
第二步:建立 $dL$ 方程
我们已知 $dL = \text{tr}(G^T dY)$。将 $dY$ 代入:
利用迹的线性性质:
第三步:逐项对比求梯度(利用迹的循环性质 $\text{tr}(ABC) = \text{tr}(BCA)$)
1. 求 $\frac{\partial L}{\partial W}$: 关注 $\text{Term}_W$
将 $dW$ 挪到最后面:
对比核心公式 $dL = \text{tr}((\frac{\partial L}{\partial W})^T dW)$ 可得:
两边转置:
2. 求 $\frac{\partial L}{\partial X}$: 关注 $\text{Term}_X$
对比核心公式 $dL = \text{tr}((\frac{\partial L}{\partial X})^T dX)$ 可得:
两边转置:
3. 求 $\frac{\partial L}{\partial b}$: 关注 $\text{Term}_b$
对比可得:
总结与建议:何时使用哪种方法
| 方法 | 优势 | 劣势 | 适用场景 |
|---|---|---|---|
| 维度检查法 (Shape Check) | 速度极快,无需微积分知识,直观。 | 缺乏严谨性,在复杂运算(如带 Hadamard 乘积)时容易出错。 | 日常编码、调试、面试、标准层(全连接/卷积)。 |
| 全微分法 (Matrix Differential) | 严谨可靠,可以处理任何复杂的矩阵运算,确保 100% 正确。 | 涉及矩阵微积分和迹的性质,需要一定的数学基础。 | 推导新型网络层、复杂 Loss 函数、进行学术研究。 |
最佳实践:
在日常工作中,应熟练使用维度检查法快速推导和验证;在遇到非标准或复杂的矩阵运算时,则使用全微分法确保推导的准确性。
更多公式
| 类型 | 性质/公式 | 描述 |
|---|---|---|
| 迹的性质 | $\text{tr}(A^T) = \text{tr}(A)$ | 迹的转置不变性。 |
| $\text{tr}(A B C) = \text{tr}(B C A) = \text{tr}(C A B)$ | 迹的循环不变性 (Trace Cycling Property)。这是推导的核心工具。 | |
| $\text{tr}(A + B) = \text{tr}(A) + \text{tr}(B)$ | 迹的线性性质。 | |
| 微分性质 | $d(A B) = dA \cdot B + A \cdot dB$ | 矩阵乘积的微分规则。 |
| $d(A^T) = (dA)^T$ | 矩阵转置的微分。 | |
| 常用微分公式 | $d(X^n) = \sum_{i=1}^n X^{i-1} dX X^{n-i}$ | 矩阵幂的微分(非循环)。 |
| $d(\text{tr}(A)) = \text{tr}(dA)$ | 迹函数的微分。 | |
| 特殊公式 (针对迹) | $\text{tr}(A) = A$ (当 $A$ 为标量时) | 标量是它自己的迹。 |
| $\text{tr}(X^T A) = \text{tr}(A X^T)$ | 迹的循环性质应用。 |
二次型求导示例
以二次型函数 $f(\mathbf{x}) = \mathbf{x}^T \mathbf{Ax}$ 对向量 $\mathbf{x}$ 求导为例,这是机器学习(如最小二乘法、正态分布)中最经典的推导。
我们可以通过之前提到的“迹法则(Trace Trick)”和“微分法”三步走来实现:
第一步:写出函数的微分
由于二次型 $f(\mathbf{x})$ 的结果是一个标量,标量的迹等于其自身,即 $f = \text{tr}(\mathbf{x}^T \mathbf{Ax})$。
根据矩阵乘法的微分法则 $d(\mathbf{UV}) = (d\mathbf{U})\mathbf{V} + \mathbf{U}(d\mathbf{V})$,我们对 $\mathbf{x}$ 求微分:
第二步:利用迹的性质进行变换
我们的目标是将所有的 $d\mathbf{x}$ 统一移到表达式的右侧,并包裹在 $\text{tr}(\cdot)$ 中。
- 第一项:$(d\mathbf{x}^T)\mathbf{Ax}$ 是一个标量,取迹并利用 $\text{tr}(\mathbf{M}^T) = \text{tr}(\mathbf{M})$:
- 第二项:$\mathbf{x}^T \mathbf{A}(d\mathbf{x})$ 本身就是 $\text{tr}(\mathbf{x}^T \mathbf{A} d\mathbf{x})$ 的形式。
合并两项:
第三步:提取导数
根据定义 $df = \text{tr}((\frac{\partial f}{\partial \mathbf{x}})^T d\mathbf{x})$,对比上面的等式:
两边同时转置,得到最终导数公式:
📌 结论与应用
- 一般情况:$\nabla_{\mathbf{x}} (\mathbf{x}^T \mathbf{Ax}) = (\mathbf{A} + \mathbf{A}^T)\mathbf{x}$。
- 当 $\mathbf{A}$ 是对称矩阵时(即 $\mathbf{A}^T = \mathbf{A}$):
这与标量求导 $\frac{d(ax^2)}{dx} = 2ax$ 在形式上非常相似,非常好记。