Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

What and why is self-attention?

In the previous chapter, I described how to turn input text into a list of vectors. In the next section, we’ll be using those vectors in a feedforward network, which will make various inferences on them. But first, we’re going to use a process called self-attention to determine how each token draws information from the tokens around it.

Self-attention sits between tokenization and the feedforward network

When I described the token embeddings in the previous chapter, I mentioned that they’re combined with position embeddings to produce the final input embedding. This lets us differentiate between “have” as the first token in a sentence and “have” as the third token. This is a decent first step, but it’s not enough: we want to know that it means something different in “we’ll always have” as compared to “Houston, we have”.

In other words, we want to learn what “have” means in the context of the specific sentence we see it in, factoring in the tokens around it. In the lingo of LLMs, we want to know how “have” attends to each of those other tokens. This attention is the crucial innovation that GPT-style LLMs introduced over previous ML models.

Since the attention layer sits between the tokenization/embedding component and the feedforward network, I find it useful to be explicit about its inputs and outputs:

In the rest of this chapter, I’ll explain what this attention is concretely, and how we compute it. I’ll start by building an intuition and motivation around what we’re building, and then go into the details of how it works. Finally, the last part of this chapter will introduce some important real-world refinements.

We’re going to be making extensive use of matrix math in this chapter. Make sure you remember how that works, and in particular the shapes of the matrices when they’re multiplied. It’s covered in the earlier chapter on matrix math.

Building a high-level intuition of what we need

What if we had infinite compute power?

As I mentioned above, what we really want to answer is: “for every input token, what information does it draw from every other token?” That question has nuance, and as you’ll recall from the earlier overview, nuance means vectors — or matrices.

If we have nn input embeddings, and each can attend to every other (including itself), we can visualize attention as an n×nn \times n grid:

5 by 5 attention weight grid with "Houston we have a problem" as both rows and columns. Each cell shows how the row word attends to the column word.

So, what’s in each of these cells?

Each cell tells us what input token A draws from input token B. Since that translates an input dd-vector (the input embedding) into an output δ\delta-vector (the attention output), we can use a d×δd \times \delta matrix:

Input1×dTransformationd×δ=Output1×δ\text{Input}_{1 \times d} \cdot \underbrace{\text{Transformation}_{d \times \delta}} = Output_{1 \times \delta}

That means each element in the grid has to be a d×δd \times \delta matrix:

n-by-n grid, where each cell is a delta-sized vector

The problem is that this is an n×n×d×δn \times n \times d \times \delta tensor, which would be far too large to reasonably store and train on. Worse yet, it grows by the square of the input size! Ideally, we’d like to have something that only grows as a function of the dd hyperparameter.

Focusing our attention

Instead of trying to learn everything about the relationships between inputs, we’ll have the attention layer focus on just one or two kinds of relationships. For example, an attention layer may focus on understanding how parentheticals fit into a sentence, or learning about subject-verb agreement.

Let’s take subject-verb agreement as an example. Of course, not all tokens are nouns or verbs. To learn subject-verb agreement, the attention mechanism needs to first focus on subjects-verb pairs, and mostly ignore the others. Otherwise, the model will train on noise, or even contradictory information. For example, the suffix “-s” usually marks singular verbs but plural nouns, and conjunctions don’t have the concept of pluralization at all.

Once it finds the relevant token pairs, the attention mechanism needs to extract the information. For example, a layer learning subject-verb agreement needs to extract whether the subject is singular or plural.

In practice, this information is almost never neatly packaged into a single dimension in the token’s embedding; it’s spread out and entangled across several dimensions. So, the attention layer needs to learn how to extract and recombine those distributed properties into a useful representation.

With all that, now we have a more tractable problem than just “find the relationships between tokens”. We need to:

  1. Learn which token pairings matter to the specific relationship that this layer is learning.

  2. Learn how to extract and combine the relevant parts of the token embeddings to produce the right output

Let’s take a look at what structures could let us answer these two questions.

Breaking down the problem

As I covered above, instead of asking generally “how does input A attend to input B”, we’ll approximate that question by asking two simpler ones:

Note that our two questions involved three usages of input tokens:

The LLM needs to learn something about each of these usages, so we’ll use a learned transformation for each one. We’ll call these transformations WqW_q, WkW_k, and WvW_v (you’ll see why in just a moment). Now we have:

Let’s think about what shapes these transformations should be.

The “how much” question is a scalar (think, “on a scale from 1 to 100, how much does...?”). We have two input tokens in this question, so if we turned each into a vector of equal length, we could calculate their dot product to get that scalar. Let’s translate them into δ\delta-sized vectors, which we can do via a d×δd \times \delta matrix:

Input1×dTransformationd×δ=Output1×δ\underbrace{Input}_{1 \times d} \cdot \underbrace{Transformation}_{d \times \delta} = \underbrace{Output}_{1 \times \delta}

Since we have two separate transformations — one for input A and one for input B — we’ll define two such matrices:

Wqd×δWkd×δ\begin{align} W_q & \Rightarrow d \times \delta \\ W_k & \Rightarrow d \times \delta \\ \end{align}

Note that while “how much” score is a single number, the transformation matrices that produce it encode quite a bit of nuance: each has d×δd \times \delta learned parameters!

The answer to “how should it express its information” is different — we’re not asking for a score, but for the actual content. We’ll represent this output as a δ\delta-dimensional vector (which will determine our attention output’s dimension). Again, this means we need a d×δd \times \delta matrix to transform the input dd-vector to a δ\delta-vector:

Wvd×δW_v \Rightarrow d \times \delta

Crucially, because this approach is just an approximation of the n×n×d×δn \times n \times d \times \delta matrix, we don’t need a separate set of WqW_q / WkW_k / WvW_v for each cell in the n×nn \times n grid. Instead, we just have three weights total for the whole attention layer: one WqW_q, one WkW_k, and one WvW_v.

The “query / key / value” terminology comes from an analogy to database lookups:

Computing attention weights

The next sections will describe the mechanics of calculating attention. If the above doesn’t make sense, it may be useful to move on for now, and then re-read it once you understand how the weight matrices actually get used.

Overview

For each token within the input, we’ll focus on that token and do a bunch of calculations centered on it. We’ll call that token the query token: it’s the token for which we want to ask (that is, “query”) how it attends to each token in the input.

With that query token in mind, we’ll look at each token in the input, treating each one in turn as a key token.

  1. First, calculate the query token’s query vector, using WqW_q and a bias called bqb_q (more in this below).

  2. Then calculate attention scores for each key token. These are scalars that tell us how much the query token should care about each key token. There is one per key token. To calculate these scores:

    1. Use WkW_k to calculate the key vector per key. We’ll also add a bias, called bkb_k, just as we did for the query vector.

    2. Take the dot product of the query and key vectors to get the attention score for each key.

  3. Next, normalize these attention scores into attention weights (still one scalar per key token).

  4. Next, compute value vectors for each key, weighted by their respective attention weights:

    1. Use WvW_v and a bvb_v bias to transform each key into a δ\delta-vector

    2. Multiply each of those δ\delta-vectors by its respective attention weight to compute the weighted value vectors, again one per key.

  5. Finally, sum the weighted values to get the context vector, which is the output for this query token. Since this is the sum of δ\delta-vectors, it is also a δ\delta-vector.

visual representation of the overall flow described above

Again, all of that work is just for a single query token. We’ll repeat it for each token in the input to produce nn context vectors of size δ\delta. This is our attention layer’s output.

Let’s walk through the specifics.

WqW_q and bqb_q → query vector

This just transforms the query token by the weight matrix WqW_q:

That’s it!

Why the bias? We can think of the query vector as analogous to a linear function:

query vector=XWq+bqax+b\text{query vector} = XW_q + b_q \Longleftrightarrow ax + b

Just as you’d need the bb coefficient to match data to a linear function in standard Cartesian math, you need bqb_q to match the input to the query vector.

Query vector, WkW_k and bkb_k → attention scores

This step happens for each key token (that is, each embedding in the input vector).

First, we’ll calculate the key vector for each key token. This is a δ\delta-sized vector. Similar to how we calculated the query vector, this is just key embeddingWk+bkkey\ embedding \cdot W_k + b_k.

Now we have two δ\delta-sized vectors: the query (from the previous step) and the key. We compute their dot-product-math to combine them into a scalar.

We call this dot product the raw attention score for this key.

Attention scores → attention weights

At this point, we have nn raw attention scores, each corresponding to a key token.

These scores can be all over the place — positive, negative, and at vastly different scales — so we normalize them to a probability distribution. This distribution is an nn-vector called the attention weights. Its values are all between 0 and 1, and they all sum to 1.

Normalizing the attention scores to attention weights improves the learning process by making the attention more differentiable and keeping the scales of the values more stable.

This normalization happens in two steps:

  1. First, we divide each attention score by δ\sqrt{\delta} (the square root of the output embedding size)

  2. Then, we apply a function called softmax, which takes a vector of scalars and normalizes them to a probability distribution.

I’ll explain these backwards: first softmax, then the scaling.

Softmax is a function that converts a vector of numbers into a probability distribution. You don’t actually need to know its definition, but what is important is that it’s sensitive to the scale of its inputs: The larger the scale, the more softmax magnifies differences in probabilities.

To keep softmax from becoming too extreme, we first divide the attention scores by δ\sqrt{\delta}. This factor comes from statistics. Remember that the raw attention score is a dot product that’s the sum of δ\delta terms, one per dimension. These terms are roughly independent, so the standard deviation of their sum grows as δ\sqrt{\delta} (this is standard statistics, which we don’t need to get into the details of here). By dividing by δ\sqrt{\delta}, we keep the typical magnitude of attention scores consistent regardless of δ\delta. This ensures that softmax operates in a reasonable range, and doesn’t get thrown off by large scales.

In other words, as δ\delta grows, so do the dot products’ variance. This growth happens by a factor of δ\sqrt{\delta}, and left unchecked it would cause softmax to lose nuance between values that are actually fairly close. Dividing by δ\sqrt{\delta} lets softmax keep that nuance. Note that I wrote above that the terms are roughly independent, but of course they’re not actually independent: the whole point of training is to find patterns in them. Still, the δ\sqrt{\delta} scaling has empirically been found to work, so that’s what people use.

This “scaling plus softmax” is called, appropriately enough, the scaled dot-product attention. When it’s applied to the raw attention scores we calculated earlier, the result is the normalized attention weights.

WvW_v and bvb_v → weighted value vectors

All of the work until now has been to calculate the attention weights, which are an nn-sized vector of scalars that answer the first component of attention: “for each input A, how much does it care about input B?” Now we’ll answer the second component: how should input B express its information?

We’ll start with familiar ground, by turning our dd-sized key embeddings into δ\delta-sized vectors by multiplying them by a weight matrix. This time we’ll use the WvW_v weight matrix and bvb_v bias, and the result is a value vector. As with the key vector, we have one such value vector per input token.

From here, we calculate intermediate “weighted values” by multiplying each value vector by its corresponding attention weight. For example, let’s say:

In this case, the weighted value vector for input 4 is:

0.27[6.2,1.4,7.9]=  [(0.276.2),(0.271.4),(0.277.9)]=  [1.67,0.378,2.13]\begin{align} & 0.27 \cdot [6.2, 1.4, 7.9] \\ = \; & [(0.27 \cdot 6.2), (0.27 \cdot 1.4), (0.27 \cdot 7.9)] \\ = \; & [1.67, 0.378, 2.13] \end{align}

We do this calculation for each key embedding to get nn weighted value vectors.

weighted value vectors → context vector

At this point, we have nn weighted values. Each is a δ\delta-sized vector that represents its respective key embedding, scaled by how much the query token cares about it and projected to represent what that key means in the context of the relationship this attention layer is learning.

We simply sum those vectors to get the context vector, which is also a δ\delta-sized vector. This represents the attention layer’s output for this query token.

Repeat to get the full layer output

Recall that all of this happened from the perspective of a single input, which we called the query token. For this token, we:

  1. Calculated a query vector (query tokenWqquery\ token \cdot W_q).

  2. Dot-producted that query vector against every input’s key vector (key tokenWkkey\ token \cdot W_k) to get, for each input, a raw attention score.

  3. Normalized those attention scores into probabilities, one per input; we called these the attention weights.

  4. Applied each of those probabilities to a corresponding value vector (key tokenWvkey\ token \cdot W_v) to get weighted value vectors, one per input

  5. Summed up those weighted value vectors to get a single context vector

All of this gives us the δ\delta-dimensional context vector for that one query token.

We then repeat this for each of the nn inputs, treating each as the query token in turn. The result is our attention layer’s output: the full attention output matrix, or just attention output for short. This has one context vector for each input, so it’s nn vectors of size δ\delta.

attention weights combine with values to form the context vector

Causal attention mask

In all of the above, we’ve been calculating the full n×nn \times n attention grid, as we saw above:

the full n times n attention grid

The problem is that at inference, we’re going to be predicting one token at a time. This means that when we predict token, we won’t yet know the tokens after it — and thus can’t know how they’ll attend to it:

The same n times n grid as before, but with the top-right crossed out to show that we don't know those pairs.

To account for this, we’ll fill the top-right portion of the attention scores with -\infty, right before applying softmax. When we apply softmax to each row, the -\infty will turn to 0, and they’ll be disregarded as we normalize the rest of the values.

This triangle of -\inftys is called the causal attention mask.

Real-world improvements

The above covers the fundamental aspects of how self-attention works, but there are several crucial ways that it’s augmented in real-world LLMs. Don’t worry: the hardest part is behind us. Still, it’s important to know about these if you want to understand how real LLMs work.

Multi-head attention and WoW_o

When I wrote above that there’s only one each of WqW_q, WkW_k, and WvW_v, that was a bit of a simplification. Everything I’ve described above — the weight matrices, vectors, etc — forms a unit called an attention head.

The problem is that a single attention head can get somewhat myopic, focusing primarily on just one aspect of the input tokens. For example, a head may end up focusing just on semantic interactions between tokens, or just on their grammatical relationships. (The actual relationships it learns are more abstract than that, but I’m “translating” the properties it learns into more intuitive relationships).

To solve this, LLMs actually use multiple heads, each with their own WqW_q / WkW_k, / WvW_v matrices and biases. Each one of these heads acts independently, finding its own relationship to learn.

In this multi-head arrangement, each head’s output has δh\frac{\delta}{h} dimensions, where δ\delta is the attention layer output’s dimensionality (as we’ve been using it all along) and hh is the number of heads. For example, if we want the attention output to have 720 dimensions, and we want 12 heads (these are both hyperparameters the model designer picks), each head would have dimensionality 60. This then determines how big each head’s weight matrices are: each will be d×δhd \times \frac{\delta}{h}.

Each head’s output is nn rows of size δh\frac{\delta}{h}, and we can think these as n×δhn \times \frac{\delta}{h} matrices. We then concatenate them to get our desired shape, an n×δn \times \delta matrix.

The only gotcha in this process is that when we scale down the attention scores, we now have to scale them down by d/h\sqrt{d/h} instead of d\sqrt{d}. Remember that the reason we did that scaling was to account for the dot products growing as the embedding dimension grows; but within each head, that embedding dimension is now dh\frac{d}{h}.

You may be thinking that it seems odd to just concatenate matrices that don’t necessarily have much to do with each other, and the borders of which are essentially “jumps” between differently-learned relationships. How would the layers that consume this matrix know how to make sense of them and combine them into a single, coherent input?

To solve that problem, multi-head models introduce one more matrix, WoW_o (for “output”), along with its bias. This is a δ×δ\delta \times \delta learned matrix that encodes how to combine all the heads into a single, appropriately blended result.

concatenated headsn×δWoδ×δ+bo=layer outputn×δ\underbrace{concatenated\ heads}_{n \times \delta} \cdot \underbrace{W_o}_{\delta \times \delta} + b_o = \underbrace{layer\ output}_{n \times \delta}
In multi-head attention, each head produces its own output, and the W_o matrix combines them into a single output for the layer

Multiple layers

In all of the above, we’ve been talking about “the” self-attention layer, as if there’s only one. In practice, an LLM will have many attention layers.

In the next section, I’ll describe the LLM’s feedforward network, which makes inferences about the attention output matrix we’ve been developing in this chapter. The attention layer and feedforward network together form a transformer block. Modern LLMs stack several of these blocks together, with each block’s output feeding into the next’s attention.

I’ll describe this in more detail in Putting it all together. For now, just know that the description of “the” attention feeding into “the” feedforward network is a simplification.

RoPE

As I mentioned in the previous chapter, modern LLMs don’t add positional encoding to the input embeddings. Instead, they use something called RoPE, which gets applied in the attention layer.

For now, I’ll just mention that this exists. I’ll describe it more in Beyond the toy LLM.

Dropout

There’s one more aspect of attention, which is something called dropout.

This aren’t a part of the attention’s fundamental architecture: it’s only applied at training. As such, I’m going to put it off until the later chapters on training.

If you haven’t heard about dropout yet, you can forget I mentioned it. I only bring them up in case you’re also using another resource (another book, or asking an LLM questions) and it mentions it. It’s often taught as part of attention, but I think they’re best held off until we give training its full treatment.

“The context is full”

If you’ve used LLMs, you may have heard about “the context” as an almost mythical thing to be kept safe. The context can’t get too full; you can’t let it get too confused with bad prompts or intermediate results; some parts of it belong to the tooling and some belong to you.

If you read about “context vectors” above and wondered if these are related: good news, they are! In fact, you now have enough to build a solid understanding of what this all-important context is.

In short, “the context is full” means that the input is as long as the LLM will allow. This is primarily driven by two factors:

LLM designers must balance the cost of training on long sequences against the usefulness of longer context windows.

Note that the learned weight matrices (WqW_q, WkW_k, WvW_v, and WoW_o) are not what limits context length. These matrices have fixed sizes based purely on the model’s hyperparameters (dd and δ\delta), not on input length.

Next up

As I’ve mentioned already, attention is the first part of an LLM’s transformer block. In the next chapters, I’ll explain the second part of the transformer block — the feedforward network — and then how all the pieces fit together to form a full LLM.