Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

I mentioned way back in the introduction that I find it useful to think about LLMs first in terms of the fundamental concepts, and then in terms of the algebraic reformulations of those concepts. Until now, I’ve been focusing exclusively on the conceptual layers. In this chapter, I’ll describe how those get bundled into mathematical objects that are more efficient to compute.

There are two major parts to this:

The architecture’s conceptual shape

Before we dive into the algebraic reformulations, let’s take a look at the LLM’s architecture once more, this time focusing on the shapes of the learned parameters and activations. I’ll skip the tokenization phase, since that’s effectively a preparation step that happens before the LLM itself runs.

For most of the LLM, the activations are in the form of nn vectors, each size dd. The final output is still nn vectors, but each sized vv (the vocabulary size).

An overview of the LLM architecture, showing n vectors of size d for most of the flow, and a final output of n vectors of size v

Vectors of vectors → matrices

The basic “lifting” we’ll do is to to turn vectors of vectors into matrices. This will let us turn the various “for each outer vector, do some stuff” loops that we’ve been working with into matrix multiplication (I’ll describe each of these in detail below). This doesn’t change what’s going on conceptually, but it lets us do the math on GPUs that process it much more quickly.

All we need to do is turn each “outer” vector into a row in a matrix:

[  1.32,5.91,5.71,  ][  6.16,4.81,3.62,  ][  8.27,9.53,2.44,  ][  ,,,  ]n vectors of size d[1.325.915.716.164.813.628.279.532.44]Mn×d matrix\underbrace{ \begin{array}{llll} [\; 1.32 \,, & 5.91 \,, & 5.71 \,, & \dots \;] \\[0.15em] [\; 6.16 \,, & 4.81 \,, & 3.62 \,, & \dots \;] \\[0.15em] [\; 8.27 \,, & 9.53 \,, & 2.44 \,, & \dots \;] \\[0.15em] [\; \dots\,, & \dots\,, & \dots\,, & \dots \;] \\[0.50em] \end{array} }_{\vphantom{\big|}n \text{ vectors of size } d} \quad \Longrightarrow \quad \underbrace{ \begin{bmatrix} 1.32 & 5.91 & 5.71 & \dots \\ 6.16 & 4.81 & 3.62 & \dots \\ 8.27 & 9.53 & 2.44 & \dots \\ \vdots & \vdots & \vdots & \ddots \end{bmatrix} \rule[-2.75em]{0pt}{0pt} }_{\vphantom{M}n \times d \text{ matrix}}

Let’s work through what that means for the calculations I’ve described in the previous chapters.

Calculating attention

Recall that we calculated attention by doing a nested loop over the input embeddings:

  1. For each input embedding tqt_q (there are nn of them):

    1. Calculate the query vector q=tq×Wqq = t_q \times W_q. This vector has size dd.

    2. For each input embedding tkt_k (the same nn embeddings as for the query vector), calculate the attention score of qq against tkt_k:

      1. Calculate the key vector k=tk×Wkk = t_k \times W_k. This vector has size dd.

      2. Calculate the dot product qkq \cdot k to get the attention score (a scalar).

    3. Treat those nn attention scores as a vector; scale and softmax that vector to get the attention weight vector (size nn).

    1. Calculate value vectors:

      1. For every input embedding tvt_v, calculate a value vector v=tv×Wvv = t_v \times W_v. There are nn such vectors, each size dd.

      2. Multiply each value vector by the corresponding attention weight (the nn scalars from the previous step). The result is still nn vectors, each size dd.

    2. Sum the value vectors to get the context vector. This vector has size dd.

There are nn inputs (that is, nn iterations of the tqt_q loop), so we ended up with nn context vectors, each of size dd.

Let’s see how much of this we can turn into matrix math. (Spoiler alert: almost all of it.) Instead of a nested loop that generates nn vectors of size dd, we’ll use matrix math to generate an n×dn \times d matrix.

Calculating the query matrix

I’ll start with step 1.1 above. We’ll focus on one iteration of the loop — call it ii — and calculate the key vector qiq_i for input tit_i. (Remember that tit_i is a dd-sized embedding vector.)

qi=ti×Wk=[ti,1ti,2]1×dWkd×d=[ti,1ti,2][w1,1w1,2w2,1w2,2]=[([ti,1ti,2][w1,1w2,1])  ([ti,1ti,2][w1,2w2,2])  ]\begin{align} q_i & = t_i \times W_k \\ & = \underbrace{\begin{bmatrix} t_{i,1} & t_{i,2} & \dots \end{bmatrix}}_{1 \times d} \cdot \underbrace{W_k}_{d \times d} \\ & = \begin{bmatrix} t_{i,1} & t_{i,2} & \dots \end{bmatrix} \cdot \begin{bmatrix} w_{1,1} & w_{1,2} & \dots \\ w_{2,1} & w_{2,2} & \dots \\ \vdots & \vdots & \ddots \end{bmatrix} \\ & = \begin{bmatrix} \left( \begin{bmatrix} t_{i,1} & t_{i,2} & \dots \end{bmatrix} \begin{bmatrix} w_{1,1} \\ w_{2,1} \\ \vdots \end{bmatrix} \right) \; & \left( \begin{bmatrix} t_{i,1} & t_{i,2} & \dots \end{bmatrix} \begin{bmatrix} w_{1,2} \\ w_{2,2} \\ \vdots \end{bmatrix} \right) \; & \dots \end{bmatrix} \end{align}

If we do this for each embedding, we get a matrix that we’ll call QQ. This matrix represents the 1.1 step executed across all of the top-level iterations:

1. For each input embedding tq1. Calculate the query vector q=tq×Wq}Q(n×d)\left. \begin{array}{l} \text{1. For each input embedding } t_q\\ \quad \text{1. Calculate the query vector } q = t_q \times W_q \end{array} \right\} Q_{(n \times d)}

Let’s put each of those iterations into a row of a matrix:

Q=[t1×Wqt2×Wq]}n rows=[([t1,1t1,2][w1,1w2,1])([t1,1t1,2][w1,2w2,2])([t2,1t2,2][w1,1w2,1])([t2,1t2,2][w1,2w2,2])]d elements\begin{align} Q & = \left. \begin{bmatrix} t_1 \times W_q \\ t_2 \times W_q \\ \vdots \end{bmatrix} \right\} n \text{ rows} \\[2.5em] & = \underbrace{ \begin{bmatrix} \left( \begin{bmatrix} t_{1,1} & t_{1,2} & \dots \end{bmatrix} \begin{bmatrix} w_{1,1} \\ w_{2,1} \\ \dots \end{bmatrix} \right) & \left( \begin{bmatrix} t_{1,1} & t_{1,2} & \dots \end{bmatrix} \begin{bmatrix} w_{1,2} \\ w_{2,2} \\ \dots \end{bmatrix} \right) & \dots \\[2.5em] \left( \begin{bmatrix} t_{2,1} & t_{2,2} & \dots \end{bmatrix} \begin{bmatrix} w_{1,1} \\ w_{2,1} \\ \dots \end{bmatrix} \right) & \left( \begin{bmatrix} t_{2,1} & t_{2,2} & \dots \end{bmatrix} \begin{bmatrix} w_{1,2} \\ w_{2,2} \\ \dots \end{bmatrix} \right) & \dots \\ \vdots & \vdots &\ddots \end{bmatrix} \rule[-5.25em]{0pt}{0pt} }_{d \text{ elements} } \end{align}

This looks like matrix multiplication — and it is! Specifically, the ti,jt_{i,j} elements make up a TT matrix whose rows are the nn inputs and whose columns are each input’s dd embedding dimensions; and the wk,lw_{k,l} elements represent the d×dd \times d weight matrix.

This means we can calculate QQ with just one matrix multiplication:

Q=TWqn×dQ = \underbrace{TW_q}_{n \times d}

This is really powerful! It means the first part of the nested loop (steps 1 → 1.1) can be reduced to a single matrix multiplication, which GPUs are extremely efficient at processing. We’ll be doing similar things for the key and value vectors, so I’d suggest taking the time to work through the above and make sure it makes sense to you.

Calculating attention scores matrix

Now, we can move onto the raw attention scores. This corresponds to step 1.2 above.

First, let’s calculate the key matrix KK. This is exactly the same as the query matrix QQ, except that it uses WkW_k instead of WqW_q. Because the progression from vectors-of-vectors to matrix is the same, I won’t spell it out in full.

K=TWkn×dK = \underbrace{TW_k}_{n \times d}

Next, we’ll calculate all the attention scores as a matrix. Each row will correspond to a query token, and each column will be the attention score between that query token and the corresponding key token:

attention scores=[q1k1q1k2q2k1q2k2]=[Q1(key vector 1)Q1(key vector 2)Q2(key vector 1)Q2(key vector 2)]\begin{align} \text{attention scores} & = \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \dots \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \dots \\ \vdots & \vdots & \ddots \end{bmatrix} \\ & = \begin{bmatrix} Q_1 \cdot \text{(key vector 1)} & Q_1 \cdot \text{(key vector 2)} & \dots \\ Q_2 \cdot \text{(key vector 1)} & Q_2 \cdot \text{(key vector 2)} & \dots \\ \vdots & \vdots & \ddots \end{bmatrix} \end{align}

Once again, this looks like matrix multiplication! The one problem is the key vectors. In that matrix multiplication for the attention scores, each (key vector i)\textit{(key vector } i \textit{)} needs to be a dd-sized vector corresponding to a horizontal row within KK. But, if we calculated this matrix as attention scores=QK\textbf{attention scores} = QK, then the thing that should be dd-sized key vectors would instead be the nn-sized vertical slices of KK:

=[(Q1[K1,1K2,1])(Q1[K1,2K2,2])(Q2[K1,1K2,1])(Q2[K1,2K2,2])]= \begin{bmatrix} \left( Q_1 \begin{bmatrix}K_{1,1} \\ K_{2,1} \\ \vdots \end{bmatrix} \right) & \left( Q_1 \begin{bmatrix}K_{1,2} \\ K_{2,2} \\ \vdots \end{bmatrix} \right) & \dots \\[2.5em] \left( Q_2 \begin{bmatrix}K_{1,1} \\ K_{2,1} \\ \vdots \end{bmatrix} \right) & \left( Q_2 \begin{bmatrix}K_{1,2} \\ K_{2,2} \\ \vdots \end{bmatrix} \right) & \dots \\ \vdots & \vdots & \ddots \end{bmatrix}

Not only is this not what we want, but the math isn’t even defined: we’re taking the dot products of dd-sized QiQ_i vectors and nn-sized K,jK_{\star \,,\,j} vectors.

What we need is to replace the vertical slicing of KK with horizontal slicing. To do that, we just need to transpose KK. This turns its rows into columns — meaning that when take vertical slices of the transposed KTK^T matrix during multiplication, what we actually get are the rows of KK.

Now we can just multiply QQ by KTK^T. For example, the first cell in this matrix would be:

attention scores(1,1)=Q1[KT1,1KT2,1]=Q1[K1,1K1,2]\begin{align} \text{attention scores}_{(1,1)} & = Q_1 \begin{bmatrix}{K^T}_{1,1} \\ {K^T}_{2,1} \\ \vdots \end{bmatrix} \\[2.5em] & = Q_1 \begin{bmatrix}K_{1,1} \\ K_{1,2} \\ \vdots \end{bmatrix} \end{align}

Now the math works out: we’re multiplying Qn×dQ_{n \times d} by KTd×n{K^T}_{d \times n} to get an n×nn \times n matrix, the raw attention scores:

attention scores=QKT\text{attention scores} = QK^T

Just to belabor the point: we’ve turned all of the nested looping in steps 1 → (1.1 - 1.2) into just a few matrix operations:

  1. Q=TWqQ = TW_q

  2. K=TWkK = TW_k

  3. Transpose KK (this doesn’t even require moving any memory: it’s just a bit of metadata to tell the computer to treat i,ji,j as j,ij,i)

  4. attention scores=QKT\text{attention scores} = QK^T

Causal attention, scale, and softmax

Next, we just need to apply the causal mask, scale each element in the attention scores by dividing it by d\sqrt{d}, and then apply softmax. This corresponds to step 1.3 above.

attention weights=A=softmax(QKTd)\text{attention weights} = A = \text{softmax}\left( \frac{QK^T}{\sqrt{d}} \right)

Note:

None of these operations change the dimensions of the matrix, so it’s still n×nn \times n.

Context matrix

Finally, we’ll apply our weights against the value vectors, and sum the results. This corresponds to step 1 → (1.4, 1.5) above.

First, we’ll get the value matrix VV, similar to the above. This is step 1.4.1.

V=TWvn×dV = \underbrace{TW_v}_{n \times d}

Each row in this matrix is one value vector.

Before we go further, let’s step back and compute just a single context vector (that is, just a single token’s attentions) the matrices we’ve computed so far.

Just to recap, here’s what we need to do:

  1. For each input tqt_q:

    ...

    1. Calculate value vectors:

      1. For every input embedding tvt_v, calculate a value vector v=tv×Wvv = t_v \times W_v. There are nn such vectors, each size dd.

      2. Multiply each value vector by the corresponding attention weight (the nn scalars from the previous step). The result is still nn vectors, each size dd.

    2. Sum the value vectors to get the context vector. This vector has size dd.

This means that within the context of a single query token QiQ_i, we need to:

We’ll call this vector CiC_i, the context vector for the ii-th query token. Let’s see what all the CiC_is look like stacked as rows of a matrix:

[C1C2]=[(A1,1V1,1+A1,2V2,1+)(A1,1V1,2+A1,2V2,2+)(A2,1V1,1+A2,2V2,1+)(A2,1V1,2+A2,2V2,2+)]\begin{align} \begin{bmatrix} C_1 \\ C_2 \\ \vdots \end{bmatrix} & = \begin{bmatrix} (A_{1,1}V_{1,1} + A_{1,2}V_{2,1} + \cdots) & (A_{1,1}V_{1,2} + A_{1,2}V_{2,2} + \cdots) & \cdots \\ (A_{2,1}V_{1,1} + A_{2,2}V_{2,1} + \cdots) & (A_{2,1}V_{1,2} + A_{2,2}V_{2,2} + \cdots) & \cdots \\ \vdots & \vdots & \ddots \end{bmatrix} \end{align}

Each cell i,ji,j is a sum of terms: each of row AiA_i’s columns multiplied by column V,jV_{\star,j}'s rows. In other words, each cell is a dot product:

=[[A1,1A1,2][V1,1V2,1][A1,1A1,2][V1,2V2,2][A2,1A2,2][V1,1V2,1][A2,1A2,2][V1,2V2,2]]= \begin{bmatrix} \begin{bmatrix} A_{1,1} & A_{1,2} & \cdots \end{bmatrix} \begin{bmatrix} V_{1,1} \\ V_{2,1} \\ \vdots \end{bmatrix} & \begin{bmatrix} A_{1,1} & A_{1,2} & \cdots \end{bmatrix} \begin{bmatrix} V_{1,2} \\ V_{2,2} \\ \vdots \end{bmatrix} & \cdots \\[2.5em] \begin{bmatrix} A_{2,1} & A_{2,2} & \cdots \end{bmatrix} \begin{bmatrix} V_{1,1} \\ V_{2,1} \\ \vdots \end{bmatrix} & \begin{bmatrix} A_{2,1} & A_{2,2} & \cdots \end{bmatrix} \begin{bmatrix} V_{1,2} \\ V_{2,2} \\ \vdots \end{bmatrix} & \cdots \\[2.5em] \vdots & \vdots & \ddots \end{bmatrix}

This may look familiar: it’s just the matrix multiplication AVAV.

The full attention calculation

So, attention is AVAV. If we substitute AA with the expression from Causal attention, scale, and softmax above, we get:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d}} \right)V

This is the canonical representation of attention, and is somewhat famous within the literature of LLMs.

This means we’ve now turned the all of the attention calculation — a logically multiple-nested loop — into a few matrix multiplications and a bit of parallelizable manipulation:

  1. Q=TWq+bqQ = TW_q + b_q

  2. K=TWk+bkK = TW_k + b_k

  3. V=TWv+bvV = TW_v + b_v

  4. attention scores=QKT\text{attention scores} = QK^T

  5. divide these by d\sqrt{d}

  6. apply softmax to each row to get AA, the attention weight matrix

  7. Attention=AV\text{Attention} = AV

A GPU is going to eat this for breakfast!

Multi-head attention

Back in the chapter on attention, I talked about how LLMs use multiple heads within a single attention layer, each learning a different relationship. The attention layer concatenates these heads, and then uses a final projection WoW_o to combine them.

Described as such, this would require looping over each of the heads to perform the attention function we just saw. It may not surprise you that this can be done without looping, using tensor math.

First, let’s refresh the multi-head ideas:

To illustrate everything, I’ll pick n=3n = 3, d=4d = 4, and h=2h = 2.

First, for each of WqW_q, WkW_k, and WvW_v, we’ll concatenate the heads’ weights to create a single, d×dd \times d matrix. For example:

Wq=[abefijmn]head 1[cdghklop]head 2[abcdefghijkl]W_q = \underbrace{ \begin{bmatrix} a & b \\ e & f \\ i & j \\ m & n \end{bmatrix} }_{\text{head 1}} \underbrace{ \begin{bmatrix} c & d \\ g & h \\ k & l \\ o & p \end{bmatrix} }_{\text{head 2}} \rightarrow \begin{bmatrix} a & b & c & d \\ e & f & g & h \\ i & j & k & l \end{bmatrix}

(Remember, each head is d×dhd \times \frac{d}{h} — so in our example, 4×424 \times \frac{4}{2}, a.k.a 4×24 \times 2.) Note that this doesn’t happen at runtime, during inference: these are learned parameters, so we can lay them out this way as we build the model.

At inference, we’ll multiply the input by these matrices, just as we did in the single-head description above. So for example:

Q=TWq(n×d)(d×d)=[αβγδϵζηθικλμ]n×dQ = \underbrace{TW_q}_{(n \times d)\,(d \times d)} = \underbrace{ \begin{bmatrix} \alpha & \beta & \gamma & \delta \\ \epsilon & \zeta & \eta &\theta \\ \iota & \kappa & \lambda & \mu \end{bmatrix} }_{n \times d}

Remember that in matrix multiplication, the each cell in the result combines the corresponding row from the left matrix (the input, for us) and the column from the right matrix (the weights). Since our weights were split up column-wise by heads, the corresponding matrix products are, too:

Q=[αβγδϵζηθικλμ]Q = \left[\begin{array}{cc|cc} \alpha & \beta & \gamma & \delta \\ \epsilon & \zeta & \eta &\theta \\ \iota & \kappa & \lambda & \mu \end{array}\right]

At this point, we need to do actual looping — not just clever matrix math. For each of the heads, we’ll compute attention just as we did above:

Head Attention(Qh,Kh,Vh)=softmax(QhKhTd/h)Vh\text{Head Attention}(Q_h, K_h, V_h) = \text{softmax}\left( \frac{Q_h{K_h}^T}{\sqrt{d/h}} \right)V_h

(Remember to divide the scaling factor by hh to account for each head’s smaller embedding dimension!) Let’s look at the shape of this head attention. We can disregard softmax and d/h\sqrt{d/h} (they don’t change the shape of vectors or matrices), in which case we get:

Head Attention(Qh,Kh,Vh)=softmax(QhKhTd/h)Vh=(QhKhT)Vh=[(n×dh)(n×dh)T](n×dh)=[(n×dh)(dh×n)](n×dh)=(n×n)(n×dh)=n×dh\begin{align} \text{Head Attention}(Q_h, K_h, V_h) & = \sout{\text{softmax}}\left( \frac{Q_h{K_h}^T}{\sout{\sqrt{d/h}}} \right)V_h \\ & = ( Q_h{K_h}^T )V_h \\ & = \left[ \left(n \times \frac{d}{h}\right) \left(n \times \frac{d}{h} \right)^T \right] \left(n \times \frac{d}{h} \right) \\ & = \left[ \left(n \times \frac{d}{h} \right) \left(\frac{d}{h} \times n \right) \right] \left(n \times \frac{d}{h} \right) \\ & = \left(n \times n \right) \left(n \times \frac{d}{h} \right) \\ & = n \times \frac{d}{h} \end{align}

So with all that, we now have hh head attentions, each sized n×nhn \times \frac{n}{h}. Now we reverse the process that we took with the weights: we take these head attentions, concatenate them by columns, and treat them as a single attention output:

Attention=[νξοπρστυϕχψω]\text{Attention} = \left[\begin{array}{cc|cc} \nu & \xi & \omicron & \pi \\ \rho & \sigma & \tau & \upsilon \\ \phi & \chi & \psi & \omega \\ \end{array}\right]

In this figure, each “side” of the attention represents the output from one head, sized n×dhn \times \frac{d}{h}. The concatenated heads form a single, n×dn \times d matrix.

If you recall, the last step in the multi-head process was to multiply the output by a WoW_o matrix. This is just a d×dd \times d matrix, so there’s nothing special to do here: we just apply the matrix multiplication.

Combined QKV matrix

In the above (and back in our original chapter on attention), we treated the WqW_q, WkW_k, and WvW_v weight matrices as three separate matrices. To calculate the query, key, and value matrices, we did:

  1. Q=TWq+bqQ = TW_q + b_q

  2. K=TWk+bkK = TW_k + b_k

  3. V=TWv+bvV = TW_v + b_v

In practice, these are usually concatenated into one matrix, WqkvW_{qkv}:

Wqkv=[qqqkkkvvvqqqkkkvvvqqqkkkvvv]W_{qkv} = \begin{bmatrix} q & q & q & \cdots & k & k & k & \cdots & v & v & v & \\ q & q & q & \cdots & k & k & k & \cdots & v & v & v & \\ q & q & q & \cdots & k & k & k & \cdots & v & v & v & \\ \end{bmatrix}

(Note that for brevity, I’m being a bit informal in my notation here: in particular, I’m writing the various qi,jq_{i,j} values as just qq, and similarly for kk and vv). We apply matrix multiplication and addition to this:

TWqkv=[T1q+bqT1k+bkT1v+bvT2q+bqT2k+bkT2v+bvT3q+bqT3k+bkT3v+bv]TW_{qkv} = \begin{bmatrix} T_1q + b_q & \cdots & T_1k + b_k & \cdots T_1v + b_v \\ T_2q + b_q & \cdots & T_2k + b_k & \cdots T_2v + b_v \\ T_3q + b_q & \cdots & T_3k + b_k & \cdots T_3v + b_v \\ \end{bmatrix}

... and then just split the matrix into three slices:

W,K,V=[T1q+bqT2q+bqT3q+bq],[T1k+bkT2k+bkT3k+bk],[T1v+bvT2v+bvT3v+bv]W, K, V = \begin{bmatrix} T_1q + b_q & \cdots \\ T_2q + b_q & \cdots \\ T_3q + b_q & \cdots \\ \end{bmatrix} , \begin{bmatrix} T_1k + b_k & \cdots \\ T_2k + b_k & \cdots \\ T_3k + b_k & \cdots \\ \end{bmatrix} , \begin{bmatrix} T_1v + b_v & \cdots \\ T_2v + b_v & \cdots \\ T_3v + b_v & \cdots \\ \end{bmatrix}

This lets us to do all three matrix multiplications (WW, KK, and VV) in a single operation. GPUs have some fixed overhead in any given matrix multiplication, so this optimization just amortizes that overhead across all three matrices.

KV caching

In all of the above, we’ve been calculating the full n×dn \times d attention for an input nn tokens. When we process the user’s initial prompt, this is great. But as we generate tokens, we can calculate attention incrementally.

For example, let’s say the user entered The quick brown fox. This prompt is 4 tokens, so our attention is 4×d4 \times d. We generate the next token, jumps, and then loop back for another round of inference. The naive, full-attention calculation I’ve been describing so far would require a 5×d5 \times d attention, for The quick brown fox jumps. We can short-circuit much of this calculation.

Let’s take a quick review of everything we did above. To make things concrete, I’ll pick d=2d = 2, and we’ll look at a 3-sequence input (n=3n = 3).

First, we calculate QQ, KK, VV. These are all n×dn \times d matrices (the products of the n×dn \times d input and the d×dd \times d weight matrices):

[Q1,1Q1,2Q2,1Q2,2Q3,1Q3,2][K1,1K1,2K2,1K2,2K3,1K3,2][V1,1V1,2V2,1V2,2V3,1V3,2]% Q \begin{bmatrix} Q_{1,1} & Q_{1,2} \\ Q_{2,1} & Q_{2,2} \\ Q_{3,1} & Q_{3,2} \end{bmatrix} \quad % K \begin{bmatrix} K_{1,1} & K_{1,2} \\ K_{2,1} & K_{2,2} \\ K_{3,1} & K_{3,2} \end{bmatrix} \quad % V \begin{bmatrix} V_{1,1} & V_{1,2} \\ V_{2,1} & V_{2,2} \\ V_{3,1} & V_{3,2} \end{bmatrix}

Next, we’ll calculate A=QKTA' = QK^T, which an n×nn \times n matrix:

A=QKT=[Q1,1Q1,2Q2,1Q2,2Q3,1Q3,2][K1,1K2,1K3,1K1,2K2,2K3,2]=[Q1,K1,Q1,K2,Q1,K3,Q2,K1,Q2,K2,Q2,K3,Q3,K1,Q3,K2,Q3,K3,]\begin{align} A' = QK^T & = \begin{bmatrix} Q_{1,1} & Q_{1,2} \\ Q_{2,1} & Q_{2,2} \\ Q_{3,1} & Q_{3,2} \end{bmatrix} \begin{bmatrix} K_{1,1} & K_{2,1} & K_{3,1} \\ K_{1,2} & K_{2,2} & K_{3,2} \end{bmatrix} \\[1.5em] & = \begin{bmatrix} Q_{1,\star} K_{{1,\star}} & Q_{1,\star} K_{{2,\star}} & Q_{1,\star} K_{{3,\star}} \\ Q_{2,\star} K_{{1,\star}} & Q_{2,\star} K_{{2,\star}} & Q_{2,\star} K_{{3,\star}} \\ Q_{3,\star} K_{{1,\star}} & Q_{3,\star} K_{{2,\star}} & Q_{3,\star} K_{{3,\star}} \end{bmatrix} \end{align}

(Note that I’m using informal, nonstandard notation here: M1,M_{1,\star} to represent row 1 of MM, and M,1M_{\star,1} to represent column 1. Also, for simplicity, I’m omitting causal attention, scaling, and softmax — we don’t need them right now. That means AA' isn’t quite the AA we’ve been using above.)

Finally, we calculate AVA'V, which is n×dn \times d:

AV=[A1,V,1A1,V,2A2,V,1A2,V,2A3,V,1A3,V,2]A'V = \begin{bmatrix} A'_{1,\star}V_{\star,1} & A'_{1,\star}V_{\star,2} \\[0.5em] A'_{2,\star}V_{\star,1} & A'_{2,\star}V_{\star,2} \\[0.5em] A'_{3,\star}V_{\star,1} & A'_{3,\star}V_{\star,2} \end{bmatrix}

Remember that when our round of inference is done, we’re only going to use the last logit, which will be derived just from the last row of this attention (after passing it through various FFNs and other transformer blocks). So, let’s focus on the last row of AVAV.

(AV)3=[A3,V,1A3,V,2](A'V)_3 = \begin{bmatrix} A_{3,\star}V_{\star,1} & A_{3,\star}V_{\star,2} \end{bmatrix}

As a reminder, A3A_3 is:

[Q3,K1,Q3,K2,Q3,K3,]\begin{bmatrix} Q_{3,\star} K_{{1,\star}} & Q_{3,\star} K_{{2,\star}} & Q_{3,\star} K_{{3,\star}} \end{bmatrix}

This means that (AV)3(A'V)_3 contains:

So far, this is just a reshash of everything we’ve already seen. Here’s where it gets interesting! We can make some observations:

We do still need to build the full KK and VV matrices; we just don’t need to compute most of them, since all but the last row are cached. For QQ, we don’t even need to build the full matrix.

Also, since we’re now only computing the last row of attention, we don’t need to account for causal attention. Remember that the attention mask only zeroed out weights for rows before the last row; the last row is unaffected, and that’s the only one we’re generating.

Putting it all together, we have essentially the same attention formula as before, but tweaked to only generate the last row:

Attention(Qn,K,V)=softmax(QnKTd/h)V\text{Attention}(Q_n, K, V) = \text{softmax}\left( \frac{Q_nK^T}{\sqrt{d/h}} \right)V

So:

And there we have it! We’ve calculated just the last row in attention, which will then snake through the FFN and other transformer blocks to produce a single logit, the next prediction.

Implementation details

To do all of the matrix concatenations efficiently, we need to get into the nitty-gritty of standard matrix libraries in software. This book doesn’t cover any particular library, but they’ll all work pretty similarly.

We’ll pick up where we split the QQ matrix by head:

Q=[αβγδϵζηθικλμ]Q = \left[\begin{array}{cc|cc} \alpha & \beta & \gamma & \delta \\ \epsilon & \zeta & \eta &\theta \\ \iota & \kappa & \lambda & \mu \end{array}\right]

Tensor libraries will let you reinterpret one tensor as another, differently-shaped one. In our case, we’re going to reinterpret the n×dn \times d matrix (a rank 2 tensor) into an n×h×dhn \times h \times \frac{d}{h} tensor (rank 3), which splits it up just as we’ve just visualized.

When tensor libraries perform matrix operations on higher-order tensors (rank 3 or above), they treat the leftmost dimensions as “batching dimensions” — basically, dimensions to loop over. They can load this batching to GPUs, where it happens very efficiently.

Unfortunately, this approach doesn’t quite work for our n×h×dhn \times h \times \frac{d}{h} tensor: we want to loop over each of the hh heads, not each of the nn rows. To solve this, we’ll first transpose the tensor to h×n×dhh \times n \times \frac{d}{h}. This doesn’t change its layout at all: it just changes how the library indexes into the tensor, and thus how it batches.

To summarize, we’ve done three things:

Now we just apply the attention formula:

Attention(Q,K,V)=softmax(QKTd/h)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d/h}} \right)V

This time, QQ, KK, and VV are each those 3-rank tensors, with hh as the batch dimension. When they’re multiplied together, the libraries will match each batch up. For example, to multiply QKTQK^T:

When you apply the softmax\text{softmax} function, you’ll explicitly tell the library which dimension to apply it against (in our case, the columns — that is, the last dimension). d/h\sqrt{d/h} only applies to scalars, so it doesn’t need any dimension or batch handling; it just applies independently to each of the values.

The result of all that is an attention tensor, which is h×n×dhh \times n \times \frac{d}{h}. Now we just reverse the reshaping: we transpose this to n×h×dhn \times h \times \frac{d}{h} and then reinterpret it as a rank-2, n×dn \times d matrix.

FFNs

As I mentioned in the previous chapter, each FFN in an LLM typically consists of the input sized dd, one hidden layer sized 4d4d, and an output layer sized dd. The FFN’s input and output correspond to a single token embedding; this gets evaluated separately for each token, though GPUs are able to do those separate evaluations efficiently in parallel.

Let’s look at the FFN from the perspective of one layer. Remember from the chapter on FFNs that each layer has:

We can visualize the neuron weights as doutd_{out} column vectors, each with dind_{in} elements:

din weights{[αβγδ][ϵζηθ][ικλμ]dout sets of weights,one per neurond_{in} \text{ weights} \left\{ \vphantom{\begin{matrix} \\ \\ \\ \\ \end{matrix}} \right. \underbrace{ \begin{bmatrix} \alpha \\ \beta \\ \gamma \\ \delta \end{bmatrix} \begin{bmatrix} \epsilon \\ \zeta \\ \eta \\ \theta \end{bmatrix} \begin{bmatrix} \iota \\ \kappa \\ \lambda \\ \mu \end{bmatrix} }_{ \substack{\text{$d_{out}$ sets of weights,} \\[.5em] \text{one per neuron}} }

You may already see where this is going: we can treat this as a single din×doutd_{in} \times d_{out} matrix. I’ll call this matrix WW.

We can also treat the layer’s dind_{in}-vector as a 1×din1 \times d_{in} matrix, which I’ll call X. If we do, we see that the matrix multiplication XWXW gives us the right shape:

X1×dinWdin×dout=layer1×dout matrixdout vector\underbrace{X}_{1 \times d_{in}} \cdot \underbrace{W}_{d_{in} \times d_{out}} = \underbrace{\text{layer} }_{\substack{1 \times d_{out} \text{ matrix} \\[.5em] \Downarrow \\ d_{out}\text{ vector} }}

Furthermore, each column in the output is the right value for the pre-bias neuron activation. For every column jj in XWXW, its value is:

[X1,1X1,2X1,din][W1,jW2,jWdin,j]\begin{bmatrix}X_{1,\,1} & X_{1,\,2} & \dots & X_{1,\,d_{in}} \end{bmatrix} \begin{bmatrix}W_{1,\,j} \\ W_{2,\,j} \\ \vdots \\ W_{d_{in},\,j} \end{bmatrix}

Now we need to add the biases. There are doutd_{out} of them, one per neuron. Instead of treating them as separate values and adding them one at a time,we’ll treat them as a single 1×dout1 \times d_{out} matrix, and add this to the 1×dout1 \times d_{out} result from jWjW. I’ll call this bias matrix BB.

After that, we just need to apply the activation function. This does have to be applied to each value separately, but GPUs can efficiently parallelize that work.

This gives the full representation of each FFN layer:

Layer=Activation(XW+B)\text{Layer} = \text{Activation}( XW + B )

To create the full FFN, we just apply each layer serially.

One crucial optimization we can make is to do all of the tokens’ XBXB calculation at once. Remember that XX is a 1×din1 \times d_{in} matrix, corresponding to a single token in the prompt. If we consider the whole prompt, this is an n×dinn \times d_{in} matrix:

[X1,1X1,2X1,dinX2,1X2,2X2,dinXn,1Xn,2Xn,din]\begin{bmatrix} X_{1,\,1} & X_{1,\,2} & \dots & X_{1,\,d_{in}} \\ X_{2,\,1} & X_{2,\,2} & \dots & X_{2,\,d_{in}} \\ \vdots & \vdots & \ddots & \vdots \\ X_{n,\,1} & X_{n,\,2} & \dots & X_{n,\,d_{in}} \\ \end{bmatrix}

If we multiply this by the din×doutd_{in} \times d_{out} matrix WW, the result will have nn rows, each corresponding to one row from the input XX, and representing that row multiplied by the weight matrix WW.

We still need to conceptually loop over each of those rows to add BB, and then over every value to apply the activation function. GPUs can handle both of those efficiently, though.

Normalization

Recall that for each embedding token, the normalization layer is calculated as:

scale(activationsmeanvariance+ε)+shift\text{scale} \cdot \left( \frac{\text{activations} - \text{mean}}{\sqrt{\text{variance} + \varepsilon}} \right) + \text{shift}

We need to apply this to each input embedding separately. Unfortunately, here there’s nothing tricky we can do with matrix math: not only does each embedding need to be evaluated separately, but calculating the mean and variance requires per-element calculations.

Luckily, the calculations themselves are pretty simple. And, as before, GPUs can handle these efficiently and in parallel.

The nn tokens of embedding dd do get treated as a single n×dn \times d matrix. This is partially because GPUs know how to parallelize work efficiently across rows of matrices. It’s also convenient for feeding the normalization into the attention layer, which as we’ve seen does gain from seeing the whole input as a single n×dn \times d matrix.

Batching

Up until now, we’ve been working with one input at a time. In practice, GPUs can process multiple inputs in parallel.

This doesn’t affect the learned parameters at all; just the activations. Basically, we just lift them into a tensor of 1 higher rank. Instead of representing the input as an n×dn \times d matrix, we’ll represent it as a b×n×db \times n \times d tensor.

The rest of the math is exactly the same. At the hardware level, this will result in the same operations (including the same weights) being applied to different inputs at the same time. GPUs are highly optimized for this.

The final architecture

Our LLM now has essentially the same architecture as before: the only real difference is that we’re treating the inputs not as nn vectors of size dd vectors, but a single n×dn \times d matrix. Similarly, the output is an n×vn \times v matrix. This lets us reformulate the operations we’ve already seen as matrix operations instead of logical loops, which lets us compute them far more efficiently on GPUs.

The same architecture as above, but with matrices instead of vectors-of-vectors

This diagram elides some of the complication, especially in the attention layer (and specifically, its multi-head architecture, as described above).

That’s it! You have an LLM!

If someone were to provide you good values for all the weights throughout the architecture (and a lot of AWS credits ;-)), you’d have enough to build an LLM that would have been competitive in early 2020. You’re not about to take down OpenAI or Anthropic, but that’s still pretty neat!