Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Introduction to training

Training considerations

Causal masking

This improvement only applies during training.

Remember that our LLM will ultimately be used to auto-complete prompts. That means that at the point where it’s making predictions, it won’t have access to words after the input: they haven’t been written yet!

At "Houston we have", we don't yet know "a problem"

In any machine learning model, it’s important that the model trains the same way it’ll be used during inference. This means that the attention weight for any word after the query token should always be 0. For example, given the training input “Houston, we have a problem”, if our query token is “have”, it shouldn’t attend to “a” or “problem” at all; when it comes time for inference, it won’t have access to them.

To do this, we just need to zero out all tokens after the query token. In our attention weight matrix, each row belongs to a corresponding query token (first row for the first token, etc.), so to zero out all tokens after each row’s query token, we just need to zero out the upper-right triangle of the matrix. We’ll then need to renormalize the remaining values, to reflect that they represent the full probability distribution that we care about:

[0.380.320.140.160.090.370.240.300.490.090.040.380.510.250.060.18][0.380000.090.37000.490.090.0400.510.250.060.18][1.000000.200.80000.790.150.0600.510.250.060.18]\begin{align} \begin{bmatrix} 0.38 & 0.32 & 0.14 & 0.16 \\ 0.09 & 0.37 & 0.24 & 0.30 \\ 0.49 & 0.09 & 0.04 & 0.38 \\ 0.51 & 0.25 & 0.06 & 0.18 \end{bmatrix} & \rightarrow \begin{bmatrix} 0.38 & \textcolor{gray}{0} & \textcolor{gray}{0} & \textcolor{gray}{0} \\ 0.09 & 0.37 & \textcolor{gray}{0} & \textcolor{gray}{0} \\ 0.49 & 0.09 & 0.04 & \textcolor{gray}{0} \\ 0.51 & 0.25 & 0.06 & 0.18 \end{bmatrix} \\ & \rightarrow \begin{bmatrix} 1.00 & \textcolor{gray}{0} & \textcolor{gray}{0} & \textcolor{gray}{0} \\ 0.20 & 0.80 & \textcolor{gray}{0} & \textcolor{gray}{0} \\ 0.79 & 0.15 & 0.06 & \textcolor{gray}{0} \\ 0.51 & 0.25 & 0.06 & 0.18 \end{bmatrix} \end{align}

In practice, we can do this more easily by applying the mask a bit earlier. Rather than setting the appropriate attention weights to 0 and then renormalizing, we can set the attention scores (before softmax) to -\infty. Softmax handles -\infty by (a) transforming it to 0 and (b) ignoring it when calculating the other values. This is exactly the result we want, so applying this causal masking to the attention scores instead of weights lets us skip the post-mask renormalization.

Dropout

Like causal masking, this improvement only applies during training.

The problem this improvement solves is one of over-fitting: learning parameters that are too tightly bound to the data we train on, and thus don’t generalize well. Since the ultimate goal of our LLM is to generate new, and ideally unique text, over-fitting is a real danger. We don’t want “To be” to always complete as Hamlet’s soliloquy.

Dropout solves this by reducing the model’s dependency on specific attention patterns during training.

The approach is simple: randomly zero out some of the attention weights during each training step. This essentially deactivates those attentions, causing the model to lean more on others. Each training round picks a different set of weights to drop, so over time, all of the weights get trained without ever over-depending on any particular one.

The only gotcha is that we still want the weights to be properly scaled. In particular, we want each weight’s expected value to be unchanged by dropout. To accomplish this, we multiply the surviving weights by a compensating factor based on the dropout rate.

For example, let’s say we set dropout to 50% (this is a hyperparameter that’s set during training; it’s typically closer to the 10% - 20% range in the real world). This means:

[0.380.320.140.160.090.370.240.300.490.090.040.380.510.250.060.18][0.7600000.740.480.600.98000.76000.120]\begin{bmatrix} 0.38 & 0.32 & 0.14 & 0.16 \\ 0.09 & 0.37 & 0.24 & 0.30 \\ 0.49 & 0.09 & 0.04 & 0.38 \\ 0.51 & 0.25 & 0.06 & 0.18 \end{bmatrix} \rightarrow \begin{bmatrix} \textcolor{blue}{0.76} & \textcolor{gray}{0} & \textcolor{gray}{0} & \textcolor{gray}{0} \\ \textcolor{gray}{0} & \textcolor{blue}{0.74} & \textcolor{blue}{0.48} & \textcolor{blue}{0.60} \\ \textcolor{blue}{0.98} & \textcolor{gray}{0} & \textcolor{gray}{0} & \textcolor{blue}{0.76} \\ \textcolor{gray}{0} & \textcolor{gray}{0} & \textcolor{blue}{0.12} & \textcolor{gray}{0} \end{bmatrix}

Note that: