Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Introduction

At the heart of our LLM’s training is backpropagation, or “backprop”. This is often described either in simple terms, as in Wikipedia’s introduction to the topic...

It is an efficient application of the chain rule to neural networks.

...or in complex math terms, as in the rest of that Wikipedia article. I’ll try to hit an in-between. In particular, this chapter will start with a quick refresher on some level-1 calculus (including what the chain rule is), and then work through how it applies to backprop.

At its core, backprop tries to answer a conceptually simple question:

The model has parameters a, b, c, ... n. After it makes a prediction, backprop wiggles a to get the answer closer, then b, then c, and so on.

To do this, we define a loss function, which takes the predicted and actual value, and compares them. We then apply a bunch of math, which does this wiggling. The loss function always produces a scalar (typically non-negative, so that zero represents a perfect prediction), and backprop will wiggle the model’s parameters to get the loss closer to zero.

To get an understanding of how backprop works, I’ll start exceedingly simple and build up from there:

  1. Backprop on a single-layer, scalar model

  2. Backprop on a multi-layer, but still scalar model

  3. Backprop on a single-layer, matrix-based model

  4. Backprop on a multi-layer, matrix-based model

By the last of those, we’ll have a “full” understanding of backprop. After that, the only difference between what we’ve built and a real model is that the real model is bigger.

The math you’ll need

This chapter assumes you’re decently familiar with derivatives; if you’re not, it may be tough. If you’re familiar with them but just need a quick refresher, the following sections should help. If you’re already comfortable with these, feel free to jump ahead to the meat of it.

Derivatives

If we have some function y=f(x)y = f(x), then its derivative y=f(x)y' = f'(x) is how fast yy changes at any given point xx.

We can also express the derivative using what’s called Leibniz notation: dydx\frac{dy}{dx}. This notation makes it explicit that we’re differentiating with respect to xx.

To differentiate a polynomial, bring each exponent down as a factor and lower it by one:

y=axn+bxm+y=n  axn1+m  bxm1+\begin{array}{rccccl} y & = & ax^n & + & bx^m & + \, \dots \\[0.3em] & & \downarrow & & \downarrow & \\[0.3em] y' & = & n \; ax^{n-1} & + & m \; bx^{m-1} & + \, \dots \end{array}

Chain rule

The chain rule lets you deconstruct a function that’s the composition of two functions — in other words, a function that takes the output of one function and passes it to another:

h(x)=z(  y(x)  )h(x) = z( \; y(x) \; )

To compute hh’s derivative, we:

hh' is the derivative of zz with respect to xx, so the Leibniz notation for that is:

hdzdxh' \longleftrightarrow \frac{dz}{dx}

Using that, we can write the chain rule in a fraction-like way:

dzdx=dzdydydx“The derivativeof z wrt xisthe derivativeof z wrt ytimesthe derivativeof y wrt x.”\def\t#1{\textit{\scriptsize #1}} \def\tt#1#2{\begin{array}{c}\t{#1}\\\t{#2}\end{array}} \begin{array}{ccccc} \frac{dz}{dx} & = & \frac{dz}{dy} & \cdot & \frac{dy}{dx} \\ \tt{``The derivative}{of z wrt x} & \t{is} & \tt{the derivative}{of z wrt y} & \t{times} & \tt{the derivative}{of y wrt x.''} \end{array}

Note that this isn’t actually a fraction; the Leibniz notation just illustrates (by way of analogy) how the dydy elements “cancel out”.

We can intuit why the chain rule works by going back to hh’s definition: zz evaluated at y(x)y(x). So, to see how fast hh changes as xx changes, we take how fast y(x)y(x) changes as xx changes, and multiply it by how fast zz changes as y(x)y(x) changes.

Partial derivatives

In the above sections, yy was defined in terms of a single variable, xx. But what if there are two variables, or more?

y=f(x,u,v,)y = f(x, u, v, \dots)

To handle this, we use partial derivatives. The concept is simple: treat all but one of the variables as a constant, and then take the (ordinary) derivative with respect to that one remaining variable. The Leibniz notation for this is yx\pdv{y}{x} if xx is the “with-respect-to” variable. We can define as many partial derivatives as there are variables:

yx,yu,yv,\pdv{y}{x} \quad,\quad \pdv{y}{u} \quad,\quad \pdv{y}{v} \quad,\quad \dots

Backprop on a simple, scalar model

Now that we have our math refreshed, let’s get to the fun stuff! To start our intuition for how backprop works, let’s start with the simplest possible model: a scalar, linear function:

y=ax+by = ax + b

This is just a plain old line, like you learned about in middle school. We’re going to use machine learning to figure out its slope and yy-intercept. Our training data will be a bunch of (x,y)(x, y) pairs:

plot of data showing points more or less along a line

Since we’re assuming (as the model designers) that the points form a line y=ax+by = ax + b, our job will be to figure out aa and bb from the various (x,y)(x, y) data points. In other words, aa and bb are the model’s learned parameters.

Our first step is to define a loss function LL, which defines how wrong a given prediction is from the true value. A common one is mean squared error (MSE), which we’ll adapt for our scalar model:

L(x)=(y(x)ytrue)2L(x) = (y(x) - y_{true})^2

Our simple model will:

  1. Take an (x,ytrue)(x, y_{true}) pair from the training data

  2. Run xx through the model (with whatever aa and bb we currently have) to produce a prediction, ypredy_{pred}

  3. Calculate the loss L=(ypredytrue)2L = (y_{pred} - y_{true})^2

  4. Use the chain rule to compute the two partial derivatives, La\pdv{L}{a} and Lb\pdv{L}{b}. These give us the gradients for aa and bb (I’ll explain this in just a second)

  5. Use the gradients to nudge aa and bb towards their true values

The gradients for each learned parameter (aa and bb) represent the partial derivative of the loss function with respect to that parameter. In other words, it represents just the mechanical, mathematical question of “as that parameter grows, how fast does the loss grow?” Of course, we want the loss to shrink, since it represents how wrong the prediction was. So, we just nudge the parameter in the opposite direction of the gradient.

Visual representation of the steps described above

The first three steps in the list above are trivial (remember that in this example, “run xx through the model” is just ypred=ax+by_{pred} = ax + b). Let’s focus on the fourth step, the chain rule.

We’ll focus on aa first. What we want is the partial derivative of the loss LL with respect to aa:

La\dpdv{L}{a}

We can think of LL as a composed function L(x)=(y(x)ytrue)2L(x) = ( \, y(x) \, - y_{true} )^2. That means we can use the chain rule:

La=Lyya\dpdv{L}{a} = \dpdv{L}{y} \cdot \dpdv{y}{a}

Let’s start by calculating the right term, ya\pdv{y}{a}:

y(x)=ax+bya=xy(x) = ax + b \\[0.3em] \downarrow \\[0.3em] \pdv{y}{a} = x

Now the left term, Ly\pdv{L}{y}:

L(x)=(y(x)ytrue)2Ly=2(y(x)ytrue)L(x) = (y(x) - y_{true})^2\\[0.3em] \downarrow \\[0.3em] \pdv{L}{y} = 2(y(x) - y_{true})

Putting it all together:

La=Ly  ya=2(y(x)ytrue)  x\begin{align} \dpdv{L}{a} & = \dpdv{L}{y} & \cdot & \; \dpdv{y}{a} \\[1em] & = 2(y(x) - y_{true}) & \cdot & \; x \end{align}

And here’s where the “efficient application of” starts to kick in: during our inference phase, we already calculated y(x)=ypredy(x) = y_{pred}. If we just store that value during that forward pass, La\pdv{L}{a} becomes a trivial calculation: y(x)y(x) comes from that stored lookup, and xx and ytruey_{true} were our given arguments. We call this value aa’s gradient.

We can can do the same thing to calculate Lb\pdv{L}{b}. I’ll go a bit faster, since it’s basically the same work.

Lb=Lyybsame as Ly aboveb(ax+b)=2(y(x)ytrue)1\begin{array}{rccc} \pdv{L}{b} = & \underbrace{\pdv{L}{y}} & \cdot & \underbrace{\pdv{y}{b}} \\[1em] & \textit{\footnotesize same as $\pdv{L}{y}$ above} & & \footnotesize \pdv{}{b} (ax + b)\\[1.5em] = & 2(y(x) - y_{true}) & \cdot & 1 \end{array}

Notice that the left term is exactly the same as it was for aa’s gradient.

With that, we’ve calculated our two gradients, for aa and bb. Now we just apply each one to its respective parameter (aa and bb) to update them. As I mentioned before, we subtract the gradients, because we want to reduce the loss. Before we do that, we scale the gradients down by η\eta, which is a learning rate. This is some small number, like 0.01, and it means that each round of learning only nudges the values towards a 0-loss, instead of lurching them there. This prevents over-fitting any one data point, which can cause the model to overshoot and oscillate around the desired value, or worse, shoot off to infinity.

η=0.01aupdated=a(η  agradient)bupdated=b(η  bgradient)\eta = 0.01 \\[1.5em] a_{updated} = a - (\eta \; a_{gradient}) \\ b_{updated} = b - (\eta \; b_{gradient})

That’s all there is to it! If we churn this training through a large enough data set, aa and bb will eventually converge to the right values.

Terminology: residual and local derivative

Before we go further, let’s introduce two useful names for the concepts we’ve already learned. Remember that the gradients for aa and bb each used the chain rule, and in both cases their left-hand term was the same:

La=2(y(x)ytrue)  xLb=2(y(x)ytrue)  1\begin{align} \pdv{L}{a} & = 2(y(x) - y_{true}) & \cdot & \; x \\[1em] \pdv{L}{b} & = 2(y(x) - y_{true}) & \cdot & \; 1 \end{align}

Let’s ask where these various terms come from, and do so within the framing of the layer that contains aa and bb (that is, y=ax+by = ax + b).

The distinction between information coming from the layer below, and information computed at this layer, is reflected in terminology:

We can think of this for any parameter pp as:

Lp=(signal from lower level)  (partial derivative of p)=(residual)  (local derivative)=ryp\begin{array}{lll} \dpdv{L}{p} & = \text{(signal from lower level)} & \cdot \; \text{(partial derivative of $p$)} \\[0.5em] & = \text{(residual)} & \cdot \; \text{(local derivative)} \\[1em] & = \boxed{r \cdot \dpdv{y}{p}} & \end{array}

...where:

Note that rr isn’t an equation, but an actual, concrete value. Each layer gets this value, and then uses it as-is for all of that layer’s parameters. This is a lot of what’s behind the “efficient” in “efficient application of the chain rule”.

The lowest layer, LL, is a special case: it doesn’t have a lower layer to provide a residual, so we need to calculate it by figuring out its derivative and plugging in ypredy_{pred} and ytruey_{true}.

Backprop on a multi-layer, scalar model

Now that we have backprop working on a single-layer model, let’s add a second layer. For now, we won’t have an activation function between the two:

y1=a1x+b1y2=a2(y1)+b2y_1 = a_1x + b_1 \\ y_2 = a_2 (y_1) + b_2

We’ll use the same loss function as before:

L(x)=(y2(x)ytrue)2L(x) = (y_2(x) - y_{true})^2

Let’s start by keeping in mind our objectives:

We’ll start at the bottom of the model, the layer closest to LL: y2y_2. This means we’ll be calculating the gradients for a2a_2 and b2b_2, which are La2\pdv{L}{a_2} and Lb2\pdv{L}{b_2}. Let’s start with a2a_2. As before, we’ll use the chain rule:

L(x)=(y2(x)ytrue)2La2=Ly2y2a2L(x) = (y_2(x) - y_{true})^2 \\[0.3em] \downarrow \\[0.3em] \pdv{L}{a_2} = \pdv{L}{y_2} \cdot \pdv{y_2}{a_2}

This turns out to be exactly the same as the single-layer example above: just add a 2_2 subscript to yy, aa, and bb:

La2=2(y2(x)ytrue)xLb2=2(y2(x)ytrue)1\pdv{L}{a_2} = 2(y_2(x) - y_{true}) \cdot x \\[0.3em] \pdv{L}{b_2} = 2(y_2(x) - y_{true}) \cdot 1

So far, this is all just a review of the previous two sections. Now comes the new wrinkle: calculating the gradients for the y1y_1 layer.

There are two ways to approach this: by working everything out piece by piece, or by relying on the residual-based pattern we established in the previous section. I’m not sure which is more helpful, so I’ll provide both. If one doesn’t make sense, try the other!

Working it out piece by piece

As before, we’ll start by writing out the loss function:

L(x)=(y2(x)ytrue)2L(x) = (y_2(x) - y_{true})^2

Note that even though we’re interested in the parameters at layer y1y_1, the loss function is still defined in terms of y2y_2. The loss function can only be defined against y2y_2, because its semantic is “how wrong was the model’s ultimate prediction”; we don’t have any way of estimating how far off an intermediate value was, because our training data only has the inputs and final expected outputs.

Let’s start with La1\pdv{L}{a_1}. Again, we’ll use the chain rule — but what do we want to use as the chain?

La1=L??a1\dpdv{L}{a_1} = \dpdv{L}{\,\unknown} \cdot \dpdv{\,\unknown}{a_1}

If we think about what a1a_1 most directly impacts — that is, what changes most directly as a1a_1 changes — it’s just y1y_1, the function that directly uses a1a_1. So, let’s use that:

La1=Ly1y1a1\dpdv{L}{a_1} = \dpdv{L}{y_1} \cdot \dpdv{y_1}{a_1}

With that in mind, let’s take a crack at the left term: Ly1\pdv{L}{y_1}. We can’t just use a plain polynomial derivative formula as we’ve been doing so far, because y1y_1 isn’t “directly” in the definition for LL. Instead, let’s try the chain rule again:

Ly1=L??y1\dpdv{L}{y_1} = \dpdv{L}{\,\unknown} \cdot \dpdv{\,\unknown}{y_1}

Again we ask, what does y1y_1 most directly affect? Well, it’s used in the next layer:

y2=a2(y1)+b2y_2 = a_2 (\underline{y_1}) + b_2

...so y1y_1 most directly affects y2y_2. Let’s fill that in:

Ly1=Ly2y2y1\dpdv{L}{y_1} = \dpdv{L}{y_2} \cdot \dpdv{y_2}{y_1}

Here’s where the “efficient application” kicks in again. We already computed Ly2\pdv{L}{y_2} in the previous step — it was the left-hand term of the chain rule — so we can just plug that in. For the right hand term (of our y1y_1 layer), we can fall back to standard derivatives:

y2y1=y1(a2(y1)+b2)=a2\dpdv{y_2}{y_1} \, = \dpdv{}{y_1} (a_2 (y_1) + b_2) \, = a_2

With that, we’ve calculated the left term of our a1a_1 gradient:

La1=Ly1y1a1=Ly2y2y1y1a1=(Ly2a2)from previous layery1a1\begin{array}{lccccc} \dpdv{L}{a_1} & = & \dpdv{L}{y_1} & \cdot & \dpdv{y_1}{a_1} \\[2em] & = & \pdv{L}{y_2} \cdot \pdv{y_2}{y_1} & \cdot & \pdv{y_1}{a_1} \\[2em] & = & \underbrace{ \left( \pdv{L}{y_2} \cdot a_2 \right) }_{\text{from previous layer}} & \cdot & \pdv{y_1}{a_1} \end{array}

And we can then fill in the right term straightforwardly:

La1=(Ly2a2)y1a1a1a1x+b1=(Ly2a2)x\begin{align} \pdv{L}{a_1} & = \left( \pdv{L}{y_2} \cdot a_2 \right) \cdot & \underbrace{\pdv{y_1}{a_1}}_{\pdv{}{a_1}a_1 x + b_1} \\[2em] & = \left( \pdv{L}{y_2} \cdot a_2 \right) \cdot & x \end{align}

The gradient for b1b_1 would work the same way. In the end, you’d get:

Lb1=(Ly2a2)y1b1b1a1x+b1=(Ly2a2)1\begin{align} \pdv{L}{b_1} & = \left( \pdv{L}{y_2} \cdot a_2 \right) \cdot & \underbrace{\pdv{y_1}{b_1}}_{\pdv{}{b_1}a_1 x + b_1} \\[2em] & = \left( \pdv{L}{y_2} \cdot a_2 \right) \cdot & 1 \end{align}
Using residuals

If we trust our understanding of the residuals pattern in the previous section:

Lp=(signal from lower level)  (partial derivative of p)=(residual)  (local derivative)=ryp\begin{array}{lll} \dpdv{L}{p} & = \text{(signal from lower level)} & \cdot \; \text{(partial derivative of $p$)} \\[0.5em] & = \text{(residual)} & \cdot \; \text{(local derivative)} \\[1em] & = \boxed{r \cdot \dpdv{y}{p}} & \end{array}

...then we can get a shortcut for all of the above. Let’s give a name for the residual coming into y1y_1: r1r_1.

r1=Ly1r_1 = \dpdv{L}{y_1}

The y2y_2 layer will calculate this for us, as an extra step after it calculates its gradients. That’s because calculating Ly1\pdv{L}{y_1} requires knowing y1y_1’s definition. The calculation is cheap:

r1=Ly1chain rule=Ly2y2y1=r2(local derivative of y1 wrt y2)\begin{array}{lll} r_ 1 &=& \dpdv{L}{y_1} \\[0.5em] & & \downarrow \textit{chain rule} \\[0.5em] &=& \dpdv{L}{y_2} \cdot \dpdv{y_2}{y_1} \\[0.5em] &=& r_2 \cdot \text{(local derivative of $y_1$ wrt $y_2$)} \end{array}

The y2y_2 layer knows its definition in terms of y1y_1, so it’s able to determine its y1\pdv{}{y_1}:

y2=a2(y1)+b2rewrite to make y1 more explicitly the variabley2=y1a2+b2y2y1=a2y_2 = a_2( \, y_1 \, ) + b_2 \\[0.3em] \downarrow \\[0.3em] \textit{rewrite to make $y_1$ more explicitly the variable} \\[0.3em] \downarrow \\[0.3em] y_2 = y_1 \, a_2 + b_2 \\[0.3em] \downarrow \\[0.3em] \dpdv{y_2}{y_1} = a_2

Following this through, we get a pattern for every layer yny_n:

  1. Take the residual from the previous layer. Use it as the left-hand term of chain rule applications, with the right hand term being the local derivative of each parameter pp in this layer.

  2. Calculate the local derivative of the input ( yn1y_{n-1} ), and pass to the yn1y_{n-1} layer

Note the general pattern: we start at the bottom of the model, and then work our way up, with each layer providing the residual for the one before it. This is the “back” in “backpropagation.”

Adding an activation function

Now that we have our nice two-layer model, we’ll want to add an activation function:

y1=a1x+b1y2=GeLU(y1)y3=a3(y2)+b3y_1 = a_1x + b_1 \\ y_2 = GeLU(y1) \\ y_3 = a_3 (y_2) + b_3

But instead of getting straight into that, I want to revisit the one-layer model. This will seem like (possibly pedantic) side discussion, but I promise, it’ll get at the activation-enabled model.

Let’s rewrite our one-layer model, with the loss function — but this time expand it into individual operations:

y=ax+bL=(yytrue)2y1=axy2=y1+by3=y2ytruey4=y32\begin{align} & y = ax + b \\ & L = (y - y_{true})^2 \\[0.3em] & \downarrow \\[0.3em] & y_1 = ax \\ & y_2 = y_1 + b \\ & y_3 = y_2 - y_{true} \\ & y_4 = {y_3}^2 \end{align}

Just like I described in the previous section, we’ll work from the bottom of the model up: each layer will calculate the gradients for its parameter (if it has one), and then calculate its local derivative with respect to its input and pass that residual up to that input.

LayerLocal deriv.IncomingOutgoingGradientwrt inputresidualresidualy4=y322y3r4=2y3y3=y2ytrue1r4r3=r41y2=y1+b1r3r2=r31L/b=r31y1=axar2r1=r2aL/a=r2x\begin{array}{l|l|l|l|l} \textbf{Layer} & \textbf{Local deriv.} & \textbf{Incoming} & \textbf{Outgoing} & \textbf{Gradient} \\ & \textbf{wrt input} & \textbf{residual} & \textbf{residual} & \\ \hline y_4 = {y_3}^2 & 2 \, y_3 & \text{---} & r_4 = \underline{2 \, y_3} & \text{---} \\ y_3 = y_2-y_{true} & 1 & r_4 & r_3 = r_4 \cdot \underline{1} & \text{---} \\ y_2 = y_1+b & 1 & r_3 & r_2 = r_3 \cdot \underline{1} & \ipdv{L}{b} = r_3 \cdot 1 \\ y_1 = a \cdot x & a & r_2 & r_1 = r_2 \cdot \underline{a} & \ipdv{L}{a} = r_2 \cdot x \\ \end{array}

Hopefully this table drives home everything from the above. You can see that:

What would happen if we slid a function ff between y1y_1 and y2y_2? This would correspond to a model:

y=ax+f(b)y1=axyf=f(y1)y2=yf+by3=y2ytruey4=y32\begin{array}{cl} & y = ax + f(b) \\[0.3em] & \downarrow \\[0.3em] & y_1 = ax \\ \bigstar & y_f = f(y_1) \\ & y_2 = y_f + b \\ & y_3 = y_2 - y_{true} \\ & y_4 = {y_3}^2 \end{array}

I’ve named the layer yfy_f so that most of the other layers don’t have to change; you can see that other than y1y_1 referring to the new residual rfr_f, everything else is exactly the same.

Note that in a real LLM, the activation function doesn’t go between the weight and the bias, as I’ve done here. But as we’re about to see, the whole point is that backprop doesn’t actually care where it goes. In fact, as far as backprop is concerned, there’s no such thing as a layer at all: everything is just a sequence of operations. Here’s the same table as above, this time with one more layer for yfy_f:

LayerLocal deriv.IncomingOutgoingGradientwrt inputresidualresidualy4=y322y3r4=2y3y3=y2ytrue1r4r3=r41y2=yf+b1r3r2=r31L/b=r31yf=f(y1)f(y1)r2rf=r2fy1=axarfr1=rfaL/a=rfx\begin{array}{cl|l|l|l|l} & \textbf{Layer} & \textbf{Local deriv.} & \textbf{Incoming} & \textbf{Outgoing} & \textbf{Gradient} \\ & & \textbf{wrt input} & \textbf{residual} & \textbf{residual} & \\ \hline & y_4 = {y_3}^2 & 2 \, y_3 & \text{---} & r_4 = \underline{2 \, y_3} & \text{---} \\ & y_3 = y_2-y_{true} & 1 & r_4 & r_3 = r_4 \cdot \underline{1} & \text{---} \\ & y_2 = y_f+b & 1 & r_3 & r_2 = r_3 \cdot \underline{1} & \ipdv{L}{b} = r_3 \cdot 1 \\ \bigstar & y_f = f(y_1) & f'(y_1) & r_2 & r_f = r_2 \cdot f' & \text{---} \\ & y_1 = a \cdot x & a & r_f & r_1 = r_f \cdot \underline{a} & \ipdv{L}{a} = r_f \cdot x \\ \end{array}

And that’s it! If ff is our activation function — GELU, ReLU, or anything else — then that’s all there is to it. The machinery needs to know what the activation function’s local derivative with respect to its input is, but the implementation can just hard-code that. This also applies to any other operation in the real LLM, like softmax.

The actual equations for these derivatives can be a handful — for example, GELU’s is 0.5(1+tanh(u))+0.5xsech2(u)2π(1+0.134145x2)0.5(1 + \tanh(u)) + 0.5x \cdot \text{sech}^2(u) \cdot \sqrt{\frac{2}{\pi}}(1 + 0.134145x^2) — but the implementation just hard-codes them, as it does for the functions’ forward-pass definition.

Using vectors instead of scalars

Until now, our simple backprop model has used scalars. In an LLM, of course, everything is a matrix.

The concepts are exactly the same as above: the only difference is that when we take the [partial] derivative of functions, those functions will have matrices as their inputs and parameters, instead of scalars. Likewise, the result will be a matrix:

y1=A×X(n×v)(n×m)(m×v)\begin{array}{ccccc} y_1 & = & A & \times & X \\ \mm{n}{v} & & \mm{n}{m} & & \mm{m}{v} \end{array}

As with the scalar model, each layer (that is, operation) in the matrix-based model will have to do two things:

At the core of it, that’s all we need to know for now. Just to round it out, let’s look back at the aa-layer of our scalar model:

y(x)=axgradienta=rinxrout=x(ax)\begin{align} y(x) & = ax \\[0.3em] & \downarrow \\[0.3em] \text{gradient}_a & = r_{\text{in}} \cdot x \\[0.3em] r_{\text{out}} & = \pdv{}{x}(ax) \end{align}

It’s going to be almost the same for the matrix version, with one gotcha on the gradient:

y(X)=AXgradientA=rinXTrout=X(AX)more on this below\begin{align} y(X) & = AX \\[0.3em] & \downarrow \\[0.3em] \text{gradient}_A & = r_{\text{in}} \cdot X^T \\[0.3em] r_{\text{out}} & = \underbrace{\pdv{}{X} (AX)}_{\text{\scriptstyle more on this below}} \\ \end{align}

Let’s look at the gradient. Where did that transpose come from? Without getting too into the weeds of matrix derivatives, let’s at least look at the shapes of the matrices.

So we have:

gradientA=rin?(n×m)(n×v)(?×?)\begin{array}{ccccc} \text{gradient}_A & = & r_{\text{in}} & \, & \unknown \\ \mm{n}{m} & & \mm{n}{v} & & \mm{?}{?} \end{array}

In order for the matrix multiplication to work, the unknown bit must have shape (v×m)(v \times m). And wouldn’t you know it, that’s exactly the shape of XTX^T.

Note that even though most of backprop works against matrices, the loss function still produces a scalar. The derivative of this function with respect to its input is a matrix of the same shape as that input. This becomes the first residual, and from there everything works as above.

Expanding to computation graphs

In all of the above, our model was a single sequence of operations:

y1=axy2=y1+by3=y2ytruey4=y32\begin{align} & y_1 = ax \\ & y_2 = y_1 + b \\ & y_3 = y_2 - y_{true} \\ & y_4 = {y_3}^2 \end{align}

But if you’ll recall, our LLM has two ways in which operations branch off and reconnect.

In both cases, this branching and rejoining forms a graph, not a simple sequence. For example, residuals look like:

Layer 1 flows into layer 2, which flows into layer 3. Layer 1 also flows into layer 3 directly.