dw是loss对w求偏导,而不是out对w。db同理。 主要针对CS231n作业1-3 layers.py
x_row = x.reshape(x.shape[0],-1)
dx_row = dout.dot(w.T)
dx = dx_row.reshape(x.shape)
dw = x_row.T.dot(dout)
db = np.sum(dout,axis=0)直观理解
我们可以把反向传播想象成一个**“责任分配”的过程。前向传播算出了一个结果 out,然后我们得到了一个“误差信号” dout**,这个信号告诉我们 out 的每个部分“错”了多少。现在,我们要把这个“错误”的责任,分配给导致这个结果的 w 和 b。
1. 如何计算 dw (给权重 w 分配责任)
核心思想:一个权重 w 的责任大小,取决于它在“犯错”时有多“活跃”。
这份“活跃度”取决于两件事:
- 它所连接的“输入信号”有多强? (
x_row)- 如果一个输入信号
x_row本身就是 0,那连接它的权重w就算再大,对结果也没贡献,自然不应该承担责任。所以,dw应该和x_row成正比。
- 如果一个输入信号
- 它所影响的“输出错误”有多大? (
dout)- 如果一个权重
w连接到的输出out最终被证明是“完美”的(即dout中对应的部分是 0),那说明这个权重干得不错,也不需要承担责任。所以,dw也应该和dout成正比。
- 如果一个权重
结论:dw 的责任 = 输入信号的强度 × 输出错误的程度
现在我们把它变成矩阵操作:
- “输入信号”是
x_row(形状N, D)。 - “输出错误”是
dout(形状N, M)。 - 我们需要得到的
dw形状是(D, M)。
怎么通过 (N, D) 和 (N, M) 凑出 (D, M) 呢? 唯一的组合方式就是:x_row 转置后,再乘以 dout。
dw = x_row.T .dot(dout) (D, M) = (D, N) . (N, M)
直观记忆:在前向传播 out = x_row @ w 中,x_row 是“左乘” w 的。那么在反向传播时,轮到 x_row 来“修正”w 了,它就要站到 dout 的左边,并且需要转置一下来匹配维度。
2. 如何计算 db (给偏置 b 分配责任)
核心思想:偏置 b 是一个“全局加成”,它不偏袒任何一个输入,而是平等地影响了每一个样本的输出。
在前向传播 out = ... + b 时,b 被广播(可以想象成被复制了 N 次),加到了每一个样本上。
那么在反向传播分配责任时,我们就需要把每一个样本的“输出错误” dout 公平地收集起来,全部加到 b 的头上。
dout(形状N, M) 包含了N个样本各自的输出错误。- 我们需要把这
N个样本的错误沿着批次(N)的维度全部加起来。
db = np.sum(dout, axis=0)
这个操作会把一个 (N, M) 的矩阵压缩成一个 (M,) 的向量,这正好是 b 的形状。
直观记忆:b 在前向传播时是“广播相加”,那么在反向传播时就是“求和”。这是一个非常对称、好记的操作。
总结
| 参数 | 责任分配原则 | 计算公式 (numpy 风格) | 记忆技巧 |
|---|---|---|---|
| dw | 责任 = 输入信号强度 × 输出错误程度 | dw = x_row.T.dot(dout) | 前向左乘,反向时也站左边 (需要转置)。 |
| db | 责任 = 所有样本的输出错误之和 | db = np.sum(dout, axis=0) | 前向广播相加,反向时就要求和。 |
好的,我们来走一个折中的路线:使用链式法则,但保持在矩阵和向量的宏观层面,不深入到单个元素。这能帮助我们理解公式的结构。
结合一点数学
背景:链式法则的宏观视角
我们的目标是计算损失 L 对 w 和 b 的偏导数。链式法则告诉我们,要计算一个变量的梯度,你需要:
- 找到这个变量对它的直接输出的局部偏导数。
- 将这个局部偏导数乘以最终损失对这个直接输出的偏导数(也就是从后面传回来的梯度)。
已知条件:
- 前向传播公式:
out = x_row @ w + b - 传回来的梯度:
dout = ∂L/∂out(损失L对out的梯度,形状为(N, M))
1. 计算 dw (∂L/∂w)
第一步:应用链式法则
损失 L 通过 out 间接依赖于 w。所以梯度路径是 L → out → w。
∂L/∂w = (∂L/∂out) * (∂out/∂w)
我们已经知道 ∂L/∂out就是 dout,所以问题变成了:
dw = dout * (∂out/∂w)
第二步:计算局部偏导数 ∂out/∂w
我们关注 out = x_row @ w 这部分。这是一个关于 w 的线性变换。
- 在普通的微积分里,如果
y = a*x,那么dy/dx = a。 - 在矩阵微积分里,情况类似但要考虑矩阵乘法的顺序。当函数是
out = A @ w的形式时,out对w的偏导数和A(即x_row) 有关。
一个宏观的规则是:当求导的变量 w 在矩阵乘法的右边时,其左边的矩阵 x_row 在求导后需要转置。
所以,∂out/∂w 在行为上可以看作是 x_row.T。
第三步:组合并检查维度
现在我们把 dout 和 x_row.T 组合起来。这一步不是简单的乘法,而是需要通过矩阵乘法,将 dout 这个“误差信号”应用到 x_row.T 这个“局部关系”上。
- 我们有
dout,形状为(N, M)。 - 我们有
x_row.T,形状为(D, N)。 - 我们想要的
dw,形状必须和w一样,是(D, M)。
如何通过 (N, M) 和 (D, N) 得到 (D, M)? 唯一的矩阵乘法顺序是:
dw = x_row.T @ dout (D, M) = (D, N) @ (N, M)
这个维度的匹配验证了我们的直觉:梯度 dw 是通过将转置的输入 x_row.T 左乘上传回来的梯度 dout 计算得到的。
2. 计算 db (∂L/∂b)
第一步:应用链式法则
梯度路径是 L → out → b。
∂L/∂b = (∂L/∂out) * (∂out/∂b) db = dout * (∂out/∂b)
第二步:计算局部偏导数 ∂out/∂b
我们关注 out = ... + b 这部分。
- 在前向传播时,
b(形状(1, M)) 被广播到x_row @ w(形状(N, M)) 的每一行上。可以想象成b被一个全为1的、形状为(N, 1)的列向量I左乘,然后相加 (I @ b)。 out = x_row @ w + I @ bout对b的偏导数∂out/∂b,其行为就类似于这个全1的列向量I。它表示b的每一个元素都以1的权重影响了每一行的输出。
第三步:组合
现在我们需要将 dout 这个“误差信号”应用到这个“全1”的局部关系上。
这意味着 dout 中的每一行(代表每个样本的误差)都需要被计入 db 的计算中。将来自所有 N 个样本的梯度贡献加起来,最直接的操作就是求和。
db = np.sum(dout, axis=0)
dout 的形状是 (N, M),沿着 axis=0 (批次维度) 求和后,得到 (M,),这正好是 b 的形状。这个求和操作完美地体现了“收集所有广播出去的路径传回来的梯度”这一过程。
宏观总结
| 参数 | 数学路径 (L → out → var) | 局部导数 ∂out/∂var 的宏观行为 | 最终公式 (@ 代表矩阵乘法) |
|---|---|---|---|
| dw | (∂L/∂out) * (∂out/∂w) | 与 x_row 的转置 x_row.T 相关 | dw = x_row.T @ dout |
| db | (∂L/∂out) * (∂out/∂b) | 与一个全1的矩阵相关(广播操作) | db = np.sum(dout, axis=0) |