TransWikia.com

What's wrong in this derivation of back-propagation errors?

Cross Validated Asked on December 27, 2021

I’m trying to find a rigorous derivation for the backpropagation algorithm, and I’ve gotten myself into something of a confusion. The confusion comes from when and why people transpose the weight matrices, and how we know when to use the Hadamard product and when to use the dot product. When these things are worked through element by element, as is the case in a wonderful answer here, the arguments provided seem to make sense. That said, there’s always something a little artificial about the derivations, and people often write `we do this to make the dimensions agree’, which is of course not at all rigorous, and not really actual maths.

If I were approaching the problem without having seen the solution, I would come up with the solution below. Although I know this solution is definitely incorrect, I can’t work out why.

Beginning with
$$
a^l = sigma(z^l)\
z^l=w^lcdot a^{l-1}+b^l
$$

we want to find
$$frac{partial C}{partial z^l}.$$
Let’s assume we have
$$delta^{l+1}=frac{partial C}{partial z^{l+1}}$$
Now, via the chain rule, I would find that
$$
begin{align}
frac{partial C}{partial z^l}&=frac{partial C}{partial a^l}frac{partial a^l}{partial z^l}\
&=underbrace{frac{partial C}{partial z^{l+1}}}_Aunderbrace{frac{partial z^{l+1}}{partial a^l}}_Bunderbrace{frac{partial a^l}{partial z^l}}_C
end{align}
$$

Now each of these are simple. We have that
$$begin{align}
A&=delta^{l+1}\
B&=frac{partial}{partial a^l} w^{l+1}a^l+b^l\
&=w^l\
C&=frac{partial}{partial z^l} sigma(z^l)\
&=sigma'(z^l)
end{align}$$

So, putting these back in, I ought to get
$$
frac{partial C}{partial z^l} = delta^{l+1}cdot w^lcdotsigma'(z^l)
$$

which is of course completely wrong, the correct answer being $$delta^l=((w^{l+1})^Tcdotdelta^{l+1})odotsigma'(z^l).$$

I can see that my answer couldn’t be right anyway, since it would end up with the product of two vectors. But what I can’t see is where I’ve actually gone wrong, or done something mathematically incorrect.

Any help much appreciated!

One Answer

There are several ways of operationalizing "derivative" in the context of backprop / AD (although in the end, it's still the same mathematical object of course).

The most common is the "component-wise approach", where you unpack all the matrix/vector operations by writing out all the indices, and then you're left with elementary differentiation in one dimension. Finally at the end, you try to remove all the indices and rewrite everything into "matrix/vector form". This is error prone, lacks elegance, and often results in a lot of handwaving and confusion in that final step. Also it gives me a headache to keep track of all the indices...

Another approach is the "matrix calculus" formalism, where the derivative of a function $f: mathbb{R^n} rightarrow mathbb{R^m}$ at a point $x$ is defined as an $m times n$ matrix $J_f(x)$, known as the jacobian matrix. Then, you simply multiply all the jacobian matrices together in your chain rule, and everything is good -- no confusion about hadamard versus inner or outer products -- it's all matrix multiplication. The two difficulties with this approach is that 1. it's not reflective of how things are actually implemented -- materializing entire jacobian matrices of high dimensional functions is prohibitively expensive in real life, and 2. it gets tricky when your inputs and/or outputs are matrices or higher-dimensional arrays as opposed to just vectors or scalars. I believe this can be resolved elegantly via "tensor calculus" formalisms, but I'm not familiar enough with that area to go into more detail.

The most workable approach in my opinion, and what's actually used in real implementations of automatic differentiation, is the "vector-jacobian product" (VJP) approach. It's really just a reframing of matrix calculus -- but instead of worrying about what the jacobian matrix is (expensive!), simply think in terms of how it acts on a vector. To be more precise, for a given function $f(x)$ and another vector $g$, $text{VJP}(g,x)$ computes $J_f(x)^T g$ (If you think of a vector $v$ as representing the function $v(x) = v^T x$, then VJP composes $g(x) circ J$). Importantly, this function can often be implemented without explicitly calculating $J_f(x)$. Now in place of your chain-rule, you simply pass $g$ through each node's VJP until you obtain your desired gradient.


With that out of the way, I'll examine your example step-by-step from both the matrix calculus and VJP perspectives (the other post you linked already did it for component-wise).

$A=delta^{l+1}$

Using our jacobian convention (aka numerator-layout), $A$ is actually a $1 times n$ matrix, so keep that in mind. If you'd like to think of $delta$ as a column-vector gradient (rather than row vector), then we need to write $A = delta^T$. We'll do this, since it's pretty common (and the post you linked to also used it). From the VJP perspective, $delta$ is just $g$ here.

$B=frac{partial}{partial a^l} w^{l+1}a^l+b^l = w_l$

So far so good, I'll point out that the VJP form of this term is simply $text{VJP}_B(g) = w_l^T g$.

$C =frac{partial}{partial z^l} sigma(z^l) = sigma'(z^l)$

To be technically correct, $C$ is a jacobian matrix which is zero everywhere except for the diagonal, and the diagonal entries are filled in by $sigma'(z^l)$ -- we can write this as $text{diag}(sigma'(z^l))$. This also illustrates the point of using VJPs -- instead of computing $text{VJP}(g, z) = text{diag}(sigma'(z^l))^T g$ -- a large and expensive matrix multiplication, we know this is equivalent to $text{VJP}_C(g,z) = g odot sigma'(z)$

Now to put it all together:

$$frac{partial C}{partial z^l} = {delta^{l+1}}^T w^l text{diag}(sigma'(z^l))$$

Again, remember that since this is a $1 times n$ jacobian, we should take the transpose to recover the gradient: $delta_l = text{diag}(sigma'(z^l)) {w^l}^T delta^{l+1}$ and multiplication by a diagonal matrix can be replaced by a hadamard product, so we end up with the expected $delta_l = sigma'(w_l) odot {w^l}^T delta^{l+1}$.

Of course, it's not easy for an algorithm to make this observation (that it doesn't need to do this massive matrix multiplication), which is why we prefer VJP to blindly multiplying a bunch of jacobian matrices together. In the VJP approach, we just compute $delta_l = text{VJP}_C(text{VJP}_B(delta_{l+1}))$, and if you unpack the functions, you'll find that the answer is the same.

Answered by shimao on December 27, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP