0%

深度解析反向传播:全连接层梯度的两种推导方法

🎓 深度解析反向传播:全连接层梯度的两种推导方法

在深度学习中,全连接层(Linear Layer)是最基础的模块。理解其反向传播中的梯度计算是掌握整个神经网络训练流程的关键。本文将以 $Y = WX + b$ 为例,演示两种核心梯度推导方法:维度检查法(Shape Check)全微分法(Matrix Differential)


场景设定与变量定义

假设我们在计算图中的某一层进行以下运算:

变量定义与维度(Shape):

  • $X$:输入向量,维度 $(n \times 1)$。
  • $W$:权重矩阵,维度 $(m \times n)$。
  • $b$:偏置向量,维度 $(m \times 1)$。
  • $Y$:输出向量,维度 $(m \times 1)$。
  • $L$:最终的损失函数(Scalar,标量)。

已知条件(上游传回来的梯度):

我们已算出 Loss 对本层输出 $Y$ 的梯度:

  • $G$ 的维度必须和 $Y$ 一致,也是 $(m \times 1)$。

我们的目标:
求出 $L$ 关于 $W$ 的梯度 $\frac{\partial L}{\partial W}$ 和 关于 $X$ 的梯度 $\frac{\partial L}{\partial X}$(用于继续往前传)。


方法一:维度检查法(Shape Check)

核心思想: 梯度的形状必须与其对应的变量形状一致。通过线性代数知识,用已知变量拼凑出正确的形状。

1. 求 $\frac{\partial L}{\partial W}$(权重 $W$ 的梯度)

  • 目标形状: $\frac{\partial L}{\partial W}$ 的形状必须和 $W$ 一致,即 $(m \times n)$
  • 素材: $G$ (上游梯度,$(m \times 1)$) 和 $X$ (输入,$(n \times 1)$)。
  • 拼凑: 想要得到 $(m \times n)$,只能将 $(m \times 1)$ 乘以 $(1 \times n)$。
  • 结论:

2. 求 $\frac{\partial L}{\partial X}$(输入 $X$ 的梯度,传给上一层)

  • 目标形状: $\frac{\partial L}{\partial X}$ 的形状必须和 $X$ 一致,即 $(n \times 1)$
  • 素材: $G$ (上游梯度,$(m \times 1)$) 和 $W$ (权重,$(m \times n)$)。
  • 拼凑: 想要得到 $(n \times 1)$,需要将 $W$ 转置为 $(n \times m)$,再与 $G$ 相乘。
  • 结论:

点评:维度检查法速度极快,是日常代码实现和面试推导的首选。


方法二:全微分法(Matrix Differential)

核心思想: 利用标量函数 $L$ 的微分 $dL$ 与梯度的关系,通过迹(Trace)的性质进行严谨推导。这是处理复杂矩阵运算的终极武器。

核心公式(迹与梯度)

对于标量函数 $L$,其微分 $dL$ 与梯度 $\nabla A$ 的关系是:

推导过程

第一步:对正向公式求微分

第二步:建立 $dL$ 方程

我们已知 $dL = \text{tr}(G^T dY)$。将 $dY$ 代入:

利用迹的线性性质:

第三步:逐项对比求梯度(利用迹的循环性质 $\text{tr}(ABC) = \text{tr}(BCA)$)

1. 求 $\frac{\partial L}{\partial W}$: 关注 $\text{Term}_W$
将 $dW$ 挪到最后面:

对比核心公式 $dL = \text{tr}((\frac{\partial L}{\partial W})^T dW)$ 可得:

两边转置:

2. 求 $\frac{\partial L}{\partial X}$: 关注 $\text{Term}_X$

对比核心公式 $dL = \text{tr}((\frac{\partial L}{\partial X})^T dX)$ 可得:

两边转置:

3. 求 $\frac{\partial L}{\partial b}$: 关注 $\text{Term}_b$

对比可得:


总结与建议:何时使用哪种方法

方法 优势 劣势 适用场景
维度检查法 (Shape Check) 速度极快,无需微积分知识,直观。 缺乏严谨性,在复杂运算(如带 Hadamard 乘积)时容易出错。 日常编码、调试、面试、标准层(全连接/卷积)
全微分法 (Matrix Differential) 严谨可靠,可以处理任何复杂的矩阵运算,确保 100% 正确。 涉及矩阵微积分和迹的性质,需要一定的数学基础。 推导新型网络层、复杂 Loss 函数、进行学术研究

最佳实践:
在日常工作中,应熟练使用维度检查法快速推导和验证;在遇到非标准或复杂的矩阵运算时,则使用全微分法确保推导的准确性。

更多公式

类型 性质/公式 描述
迹的性质 $\text{tr}(A^T) = \text{tr}(A)$ 迹的转置不变性。
$\text{tr}(A B C) = \text{tr}(B C A) = \text{tr}(C A B)$ 迹的循环不变性 (Trace Cycling Property)。这是推导的核心工具。
$\text{tr}(A + B) = \text{tr}(A) + \text{tr}(B)$ 迹的线性性质。
微分性质 $d(A B) = dA \cdot B + A \cdot dB$ 矩阵乘积的微分规则。
$d(A^T) = (dA)^T$ 矩阵转置的微分。
常用微分公式 $d(X^n) = \sum_{i=1}^n X^{i-1} dX X^{n-i}$ 矩阵幂的微分(非循环)。
$d(\text{tr}(A)) = \text{tr}(dA)$ 迹函数的微分。
特殊公式 (针对迹) $\text{tr}(A) = A$ (当 $A$ 为标量时) 标量是它自己的迹。
$\text{tr}(X^T A) = \text{tr}(A X^T)$ 迹的循环性质应用。

二次型求导示例

二次型函数 $f(\mathbf{x}) = \mathbf{x}^T \mathbf{Ax}$ 对向量 $\mathbf{x}$ 求导为例,这是机器学习(如最小二乘法、正态分布)中最经典的推导。

我们可以通过之前提到的“迹法则(Trace Trick)”和“微分法”三步走来实现:

第一步:写出函数的微分

由于二次型 $f(\mathbf{x})$ 的结果是一个标量,标量的迹等于其自身,即 $f = \text{tr}(\mathbf{x}^T \mathbf{Ax})$。
根据矩阵乘法的微分法则 $d(\mathbf{UV}) = (d\mathbf{U})\mathbf{V} + \mathbf{U}(d\mathbf{V})$,我们对 $\mathbf{x}$ 求微分:

第二步:利用迹的性质进行变换

我们的目标是将所有的 $d\mathbf{x}$ 统一移到表达式的右侧,并包裹在 $\text{tr}(\cdot)$ 中。

  1. 第一项:$(d\mathbf{x}^T)\mathbf{Ax}$ 是一个标量,取迹并利用 $\text{tr}(\mathbf{M}^T) = \text{tr}(\mathbf{M})$:
  1. 第二项:$\mathbf{x}^T \mathbf{A}(d\mathbf{x})$ 本身就是 $\text{tr}(\mathbf{x}^T \mathbf{A} d\mathbf{x})$ 的形式。

合并两项:

第三步:提取导数

根据定义 $df = \text{tr}((\frac{\partial f}{\partial \mathbf{x}})^T d\mathbf{x})$,对比上面的等式:

两边同时转置,得到最终导数公式:

📌 结论与应用

  • 一般情况:$\nabla_{\mathbf{x}} (\mathbf{x}^T \mathbf{Ax}) = (\mathbf{A} + \mathbf{A}^T)\mathbf{x}$。
  • 当 $\mathbf{A}$ 是对称矩阵时(即 $\mathbf{A}^T = \mathbf{A}$):

这与标量求导 $\frac{d(ax^2)}{dx} = 2ax$ 在形式上非常相似,非常好记。