Page Status: Draft ★☆☆
Introduction¶
At the heart of our LLM’s training is backpropagation, or “backprop”. This is often described either in simple terms, as in Wikipedia’s introduction to the topic...
It is an efficient application of the chain rule to neural networks.
...or in complex math terms, as in the rest of that Wikipedia article. I’ll try to hit an in-between. In particular, this chapter will start with a quick refresher on some level-1 calculus (including what the chain rule is), and then work through how it applies to backprop.
At its core, backprop tries to answer a conceptually simple question:
We have a model with a bunch of learned parameters, each of which has some value.
We take some input for which we know the expected result, and we run that input through the model.
We compare the predicted result with the actual result, and notice that the two don’t quite match up.
Backprop then asks: how can we wiggle each parameter such that when we hold all the other parameters constant, but wiggle just that one parameter, the prediction will get closer to the expected value? (The actual mechanism is more efficient than literally wiggling each parameter and recomputing the prediction, as we’ll see below.)
To do this, we define a loss function, which takes the predicted and actual value, and compares them. We then apply a bunch of math, which does this wiggling. The loss function always produces a scalar (typically non-negative, so that zero represents a perfect prediction), and backprop will wiggle the model’s parameters to get the loss closer to zero.
To get an understanding of how backprop works, I’ll start exceedingly simple and build up from there:
Backprop on a single-layer, scalar model
Backprop on a multi-layer, but still scalar model
Backprop on a single-layer, matrix-based model
Backprop on a multi-layer, matrix-based model
By the last of those, we’ll have a “full” understanding of backprop. After that, the only difference between what we’ve built and a real model is that the real model is bigger.
The math you’ll need¶
This chapter assumes you’re decently familiar with derivatives; if you’re not, it may be tough. If you’re familiar with them but just need a quick refresher, the following sections should help. If you’re already comfortable with these, feel free to jump ahead to the meat of it.
Derivatives¶
If we have some function , then its derivative is how fast changes at any given point .
We can also express the derivative using what’s called Leibniz notation: . This notation makes it explicit that we’re differentiating with respect to .
To differentiate a polynomial, bring each exponent down as a factor and lower it by one:
Chain rule¶
The chain rule lets you deconstruct a function that’s the composition of two functions — in other words, a function that takes the output of one function and passes it to another:
To compute ’s derivative, we:
take ’s derivative at :
take ’s derivative at :
multiply them:
is the derivative of with respect to , so the Leibniz notation for that is:
Using that, we can write the chain rule in a fraction-like way:
Note that this isn’t actually a fraction; the Leibniz notation just illustrates (by way of analogy) how the elements “cancel out”.
We can intuit why the chain rule works by going back to ’s definition: evaluated at . So, to see how fast changes as changes, we take how fast changes as changes, and multiply it by how fast changes as changes.
Partial derivatives¶
In the above sections, was defined in terms of a single variable, . But what if there are two variables, or more?
To handle this, we use partial derivatives. The concept is simple: treat all but one of the variables as a constant, and then take the (ordinary) derivative with respect to that one remaining variable. The Leibniz notation for this is if is the “with-respect-to” variable. We can define as many partial derivatives as there are variables:
Backprop on a simple, scalar model¶
Now that we have our math refreshed, let’s get to the fun stuff! To start our intuition for how backprop works, let’s start with the simplest possible model: a scalar, linear function:
This is just a plain old line, like you learned about in middle school. We’re going to use machine learning to figure out its slope and -intercept. Our training data will be a bunch of pairs:
Since we’re assuming (as the model designers) that the points form a line , our job will be to figure out and from the various data points. In other words, and are the model’s learned parameters.
Our first step is to define a loss function , which defines how wrong a given prediction is from the true value. A common one is mean squared error (MSE), which we’ll adapt for our scalar model:
Our simple model will:
Take an pair from the training data
Run through the model (with whatever and we currently have) to produce a prediction,
Calculate the loss
Use the chain rule to compute the two partial derivatives, and . These give us the gradients for and (I’ll explain this in just a second)
Use the gradients to nudge and towards their true values
The gradients for each learned parameter ( and ) represent the partial derivative of the loss function with respect to that parameter. In other words, it represents just the mechanical, mathematical question of “as that parameter grows, how fast does the loss grow?” Of course, we want the loss to shrink, since it represents how wrong the prediction was. So, we just nudge the parameter in the opposite direction of the gradient.
The first three steps in the list above are trivial (remember that in this example, “run through the model” is just ). Let’s focus on the fourth step, the chain rule.
We’ll focus on first. What we want is the partial derivative of the loss with respect to :
We can think of as a composed function . That means we can use the chain rule:
Let’s start by calculating the right term, :
Now the left term, :
Putting it all together:
And here’s where the “efficient application of” starts to kick in: during our inference phase, we already calculated . If we just store that value during that forward pass, becomes a trivial calculation: comes from that stored lookup, and and were our given arguments. We call this value ’s gradient.
We can can do the same thing to calculate . I’ll go a bit faster, since it’s basically the same work.
Notice that the left term is exactly the same as it was for ’s gradient.
With that, we’ve calculated our two gradients, for and . Now we just apply each one to its respective parameter ( and ) to update them. As I mentioned before, we subtract the gradients, because we want to reduce the loss. Before we do that, we scale the gradients down by , which is a learning rate. This is some small number, like 0.01, and it means that each round of learning only nudges the values towards a 0-loss, instead of lurching them there. This prevents over-fitting any one data point, which can cause the model to overshoot and oscillate around the desired value, or worse, shoot off to infinity.
That’s all there is to it! If we churn this training through a large enough data set, and will eventually converge to the right values.
Terminology: residual and local derivative¶
Before we go further, let’s introduce two useful names for the concepts we’ve already learned. Remember that the gradients for and each used the chain rule, and in both cases their left-hand term was the same:
Let’s ask where these various terms come from, and do so within the framing of the layer that contains and (that is, ).
comes purely from the layer below us (), where can be thought of as a constant)
got stored during forward inference
The fact that we need to the value is due to the derivative of — irrespective of what anything else in the model is doing.
The and 1 each come from partial derivatives local to the layer. Again, these only depend on the layer, irrespective of what anything else is doing.
The distinction between information coming from the layer below, and information computed at this layer, is reflected in terminology:
The residual is the left-hand term in the chain rule: the signal from the layer below
The local derivatives are the right-hand term in the chain rule: the partial derivatives applied at this layer
We can think of this for any parameter as:
...where:
is a parameter defined at layer
is the residual, which comes from the layer below
Note that isn’t an equation, but an actual, concrete value. Each layer gets this value, and then uses it as-is for all of that layer’s parameters. This is a lot of what’s behind the “efficient” in “efficient application of the chain rule”.
The lowest layer, , is a special case: it doesn’t have a lower layer to provide a residual, so we need to calculate it by figuring out its derivative and plugging in and .
Backprop on a multi-layer, scalar model¶
Now that we have backprop working on a single-layer model, let’s add a second layer. For now, we won’t have an activation function between the two:
We’ll use the same loss function as before:
Let’s start by keeping in mind our objectives:
We want to figure out how much to nudge , , , and .
To do that, we need to calculate their four gradients.
Each gradient is a partial derivative: , , , .
We’ll start at the bottom of the model, the layer closest to : . This means we’ll be calculating the gradients for and , which are and . Let’s start with . As before, we’ll use the chain rule:
This turns out to be exactly the same as the single-layer example above: just add a subscript to , , and :
So far, this is all just a review of the previous two sections. Now comes the new wrinkle: calculating the gradients for the layer.
There are two ways to approach this: by working everything out piece by piece, or by relying on the residual-based pattern we established in the previous section. I’m not sure which is more helpful, so I’ll provide both. If one doesn’t make sense, try the other!
Working it out piece by piece
As before, we’ll start by writing out the loss function:
Note that even though we’re interested in the parameters at layer , the loss function is still defined in terms of . The loss function can only be defined against , because its semantic is “how wrong was the model’s ultimate prediction”; we don’t have any way of estimating how far off an intermediate value was, because our training data only has the inputs and final expected outputs.
Let’s start with . Again, we’ll use the chain rule — but what do we want to use as the chain?
If we think about what most directly impacts — that is, what changes most directly as changes — it’s just , the function that directly uses . So, let’s use that:
With that in mind, let’s take a crack at the left term: . We can’t just use a plain polynomial derivative formula as we’ve been doing so far, because isn’t “directly” in the definition for . Instead, let’s try the chain rule again:
Again we ask, what does most directly affect? Well, it’s used in the next layer:
...so most directly affects . Let’s fill that in:
Here’s where the “efficient application” kicks in again. We already computed in the previous step — it was the left-hand term of the chain rule — so we can just plug that in. For the right hand term (of our layer), we can fall back to standard derivatives:
With that, we’ve calculated the left term of our gradient:
And we can then fill in the right term straightforwardly:
The gradient for would work the same way. In the end, you’d get:
Using residuals
If we trust our understanding of the residuals pattern in the previous section:
...then we can get a shortcut for all of the above. Let’s give a name for the residual coming into : .
The layer will calculate this for us, as an extra step after it calculates its gradients. That’s because calculating requires knowing ’s definition. The calculation is cheap:
The layer knows its definition in terms of , so it’s able to determine its :
Following this through, we get a pattern for every layer :
Take the residual from the previous layer. Use it as the left-hand term of chain rule applications, with the right hand term being the local derivative of each parameter in this layer.
Calculate the local derivative of the input ( ), and pass to the layer
Note the general pattern: we start at the bottom of the model, and then work our way up, with each layer providing the residual for the one before it. This is the “back” in “backpropagation.”
Adding an activation function¶
Now that we have our nice two-layer model, we’ll want to add an activation function:
But instead of getting straight into that, I want to revisit the one-layer model. This will seem like (possibly pedantic) side discussion, but I promise, it’ll get at the activation-enabled model.
Let’s rewrite our one-layer model, with the loss function — but this time expand it into individual operations:
Just like I described in the previous section, we’ll work from the bottom of the model up: each layer will calculate the gradients for its parameter (if it has one), and then calculate its local derivative with respect to its input and pass that residual up to that input.
Hopefully this table drives home everything from the above. You can see that:
Each layer calculates its outgoing residual as the product of its incoming residual and its local derivatives
For the layers that have learned parameters, those layers separately calculate their gradients as the product of the incoming residual and , where is the layer’s learned parameter.
Crucially, note that some layers don’t calculate any gradient, and that’s fine. They still calculate an outgoing residual, and that’s all the rest of the layers need.
What would happen if we slid a function between and ? This would correspond to a model:
I’ve named the layer so that most of the other layers don’t have to change; you can see that other than referring to the new residual , everything else is exactly the same.
Note that in a real LLM, the activation function doesn’t go between the weight and the bias, as I’ve done here. But as we’re about to see, the whole point is that backprop doesn’t actually care where it goes. In fact, as far as backprop is concerned, there’s no such thing as a layer at all: everything is just a sequence of operations. Here’s the same table as above, this time with one more layer for :
And that’s it! If is our activation function — GELU, ReLU, or anything else — then that’s all there is to it. The machinery needs to know what the activation function’s local derivative with respect to its input is, but the implementation can just hard-code that. This also applies to any other operation in the real LLM, like softmax.
The actual equations for these derivatives can be a handful — for example, GELU’s is — but the implementation just hard-codes them, as it does for the functions’ forward-pass definition.
Using vectors instead of scalars¶
Until now, our simple backprop model has used scalars. In an LLM, of course, everything is a matrix.
Vectors? Tensors?
For purposes of calculating derivatives in backprop, we’ll treat everything as a matrix.
We’ll treat -vectors as matrices
We’ll treat higher-ranked tensors as having a bunch of batching dimensions, and then two matrix dimensions at the end.
The concepts are exactly the same as above: the only difference is that when we take the [partial] derivative of functions, those functions will have matrices as their inputs and parameters, instead of scalars. Likewise, the result will be a matrix:
As with the scalar model, each layer (that is, operation) in the matrix-based model will have to do two things:
Compute its outgoing residual, which is the layer’s local derivative relative to its input.
This residual will be used by the layer above this layer as the value of this layer; so it needs to be the same shape as this layer’s input.
Compute its learned parameter’s gradient, if any.
This will eventually be subtracted from the learned parameter, so it needs to be of the same shape as that parameter.
At the core of it, that’s all we need to know for now. Just to round it out, let’s look back at the -layer of our scalar model:
It’s going to be almost the same for the matrix version, with one gotcha on the gradient:
Let’s look at the gradient. Where did that transpose come from? Without getting too into the weeds of matrix derivatives, let’s at least look at the shapes of the matrices.
Let’s say is an matrix.
This means our layer (in the forward, inference pass) is:
...for some dimension .
We know also has to be , as mentioned above.
We know has the same shape this layer’s output, which is )$.
So we have:
In order for the matrix multiplication to work, the unknown bit must have shape . And wouldn’t you know it, that’s exactly the shape of .
Note that even though most of backprop works against matrices, the loss function still produces a scalar. The derivative of this function with respect to its input is a matrix of the same shape as that input. This becomes the first residual, and from there everything works as above.
Expanding to computation graphs¶
In all of the above, our model was a single sequence of operations:
But if you’ll recall, our LLM has two ways in which operations branch off and reconnect.
Residual connections (unrelated to the “residuals” in this chapter):
We take a layer , perform some operations on it to produce , and then add the original layer to get .
Attention layer:
We took our input, applied it to each of the , , matrices, and then recombined them via the attention operation.
In both cases, this branching and rejoining forms a graph, not a simple sequence. For example, residuals look like: