dw是loss对w求偏导,而不是out对w。db同理。 主要针对CS231n作业1-3 layers.py

x_row = x.reshape(x.shape[0],-1) 
dx_row = dout.dot(w.T)
dx = dx_row.reshape(x.shape)
 
dw = x_row.T.dot(dout)
db = np.sum(dout,axis=0)

计算权重梯度的推导过程 · GitHub Copilot

直观理解

我们可以把反向传播想象成一个**“责任分配”的过程。前向传播算出了一个结果 out,然后我们得到了一个“误差信号” dout**,这个信号告诉我们 out 的每个部分“错”了多少。现在,我们要把这个“错误”的责任,分配给导致这个结果的 w 和 b


1. 如何计算 dw (给权重 w 分配责任)

核心思想:一个权重 w 的责任大小,取决于它在“犯错”时有多“活跃”。

这份“活跃度”取决于两件事:

  1. 它所连接的“输入信号”有多强? (x_row)
    • 如果一个输入信号 x_row 本身就是 0,那连接它的权重 w 就算再大,对结果也没贡献,自然不应该承担责任。所以,dw 应该和 x_row 成正比。
  2. 它所影响的“输出错误”有多大? (dout)
    • 如果一个权重 w 连接到的输出 out 最终被证明是“完美”的(即 dout 中对应的部分是 0),那说明这个权重干得不错,也不需要承担责任。所以,dw 也应该和 dout 成正比。

结论:dw 的责任 = 输入信号的强度 × 输出错误的程度

现在我们把它变成矩阵操作:

  • “输入信号”是 x_row (形状 N, D)。
  • “输出错误”是 dout (形状 N, M)。
  • 我们需要得到的 dw 形状是 (D, M)

怎么通过 (N, D) 和 (N, M) 凑出 (D, M) 呢? 唯一的组合方式就是:x_row 转置后,再乘以 dout

dw = x_row.T .dot(dout) (D, M) = (D, N) . (N, M)

直观记忆:在前向传播 out = x_row @ w 中,x_row 是“左乘” w 的。那么在反向传播时,轮到 x_row 来“修正”w 了,它就要站到 dout 的左边,并且需要转置一下来匹配维度。


2. 如何计算 db (给偏置 b 分配责任)

核心思想:偏置 b 是一个“全局加成”,它不偏袒任何一个输入,而是平等地影响了每一个样本的输出。

在前向传播 out = ... + b 时,b 被广播(可以想象成被复制了 N 次),加到了每一个样本上。

那么在反向传播分配责任时,我们就需要把每一个样本的“输出错误” dout 公平地收集起来,全部加到 b 的头上

  • dout (形状 N, M) 包含了 N 个样本各自的输出错误。
  • 我们需要把这 N 个样本的错误沿着批次(N)的维度全部加起来。

db = np.sum(dout, axis=0)

这个操作会把一个 (N, M) 的矩阵压缩成一个 (M,) 的向量,这正好是 b 的形状。

直观记忆b 在前向传播时是“广播相加”,那么在反向传播时就是“求和”。这是一个非常对称、好记的操作。

总结

参数责任分配原则计算公式 (numpy 风格)记忆技巧
dw责任 = 输入信号强度 × 输出错误程度dw = x_row.T.dot(dout)前向左乘,反向时也站左边 (需要转置)。
db责任 = 所有样本的输出错误之和db = np.sum(dout, axis=0)前向广播相加,反向时就要求和。

好的,我们来走一个折中的路线:使用链式法则,但保持在矩阵和向量的宏观层面,不深入到单个元素。这能帮助我们理解公式的结构。

结合一点数学

背景:链式法则的宏观视角

我们的目标是计算损失 L 对 w 和 b 的偏导数。链式法则告诉我们,要计算一个变量的梯度,你需要:

  1. 找到这个变量对它的直接输出的局部偏导数。
  2. 将这个局部偏导数乘以最终损失对这个直接输出的偏导数(也就是从后面传回来的梯度)。

已知条件

  • 前向传播公式out = x_row @ w + b
  • 传回来的梯度dout = ∂L/∂out (损失 L 对 out 的梯度,形状为 (N, M))

1. 计算 dw (∂L/∂w)

第一步:应用链式法则

损失 L 通过 out 间接依赖于 w。所以梯度路径是 L → out → w

∂L/∂w = (∂L/∂out) * (∂out/∂w)

我们已经知道 ∂L/∂out就是 dout,所以问题变成了:

dw = dout * (∂out/∂w)

第二步:计算局部偏导数 ∂out/∂w

我们关注 out = x_row @ w 这部分。这是一个关于 w 的线性变换。

  • 在普通的微积分里,如果 y = a*x,那么 dy/dx = a
  • 在矩阵微积分里,情况类似但要考虑矩阵乘法的顺序。当函数是 out = A @ w 的形式时,out 对 w 的偏导数和 A (即 x_row) 有关。

一个宏观的规则是:当求导的变量 w 在矩阵乘法的右边时,其左边的矩阵 x_row 在求导后需要转置

所以,∂out/∂w 在行为上可以看作是 x_row.T

第三步:组合并检查维度

现在我们把 dout 和 x_row.T 组合起来。这一步不是简单的乘法,而是需要通过矩阵乘法,将 dout 这个“误差信号”应用到 x_row.T 这个“局部关系”上。

  • 我们有 dout,形状为 (N, M)
  • 我们有 x_row.T,形状为 (D, N)
  • 我们想要的 dw,形状必须和 w 一样,是 (D, M)

如何通过 (N, M) 和 (D, N) 得到 (D, M)? 唯一的矩阵乘法顺序是:

dw = x_row.T @ dout (D, M) = (D, N) @ (N, M)

这个维度的匹配验证了我们的直觉:梯度 dw 是通过将转置的输入 x_row.T 左乘上传回来的梯度 dout 计算得到的。


2. 计算 db (∂L/∂b)

第一步:应用链式法则

梯度路径是 L → out → b

∂L/∂b = (∂L/∂out) * (∂out/∂b) db = dout * (∂out/∂b)

第二步:计算局部偏导数 ∂out/∂b

我们关注 out = ... + b 这部分。

  • 在前向传播时,b (形状 (1, M)) 被广播到 x_row @ w (形状 (N, M)) 的每一行上。可以想象成 b 被一个全为1的、形状为 (N, 1) 的列向量 I 左乘,然后相加 (I @ b)。
  • out = x_row @ w + I @ b
  • out 对 b 的偏导数 ∂out/∂b,其行为就类似于这个全1的列向量 I。它表示 b 的每一个元素都以 1 的权重影响了每一行的输出。

第三步:组合

现在我们需要将 dout 这个“误差信号”应用到这个“全1”的局部关系上。

这意味着 dout 中的每一行(代表每个样本的误差)都需要被计入 db 的计算中。将来自所有 N 个样本的梯度贡献加起来,最直接的操作就是求和

db = np.sum(dout, axis=0)

dout 的形状是 (N, M),沿着 axis=0 (批次维度) 求和后,得到 (M,),这正好是 b 的形状。这个求和操作完美地体现了“收集所有广播出去的路径传回来的梯度”这一过程。

宏观总结

参数数学路径 (L → out → var)局部导数 ∂out/∂var 的宏观行为最终公式 (@ 代表矩阵乘法)
dw(∂L/∂out) * (∂out/∂w)与 x_row 的转置 x_row.T 相关dw = x_row.T @ dout
db(∂L/∂out) * (∂out/∂b)与一个全1的矩阵相关(广播操作)db = np.sum(dout, axis=0)