Page Status: Stub ☆☆☆
Introduction to training¶
Training considerations¶
Causal masking¶
This improvement only applies during training.
Remember that our LLM will ultimately be used to auto-complete prompts. That means that at the point where it’s making predictions, it won’t have access to words after the input: they haven’t been written yet!
In any machine learning model, it’s important that the model trains the same way it’ll be used during inference. This means that the attention weight for any word after the query token should always be 0. For example, given the training input “Houston, we have a problem”, if our query token is “have”, it shouldn’t attend to “a” or “problem” at all; when it comes time for inference, it won’t have access to them.
To do this, we just need to zero out all tokens after the query token. In our attention weight matrix, each row belongs to a corresponding query token (first row for the first token, etc.), so to zero out all tokens after each row’s query token, we just need to zero out the upper-right triangle of the matrix. We’ll then need to renormalize the remaining values, to reflect that they represent the full probability distribution that we care about:
In practice, we can do this more easily by applying the mask a bit earlier. Rather than setting the appropriate attention weights to 0 and then renormalizing, we can set the attention scores (before softmax) to . Softmax handles by (a) transforming it to 0 and (b) ignoring it when calculating the other values. This is exactly the result we want, so applying this causal masking to the attention scores instead of weights lets us skip the post-mask renormalization.
Dropout¶
Like causal masking, this improvement only applies during training.
The problem this improvement solves is one of over-fitting: learning parameters that are too tightly bound to the data we train on, and thus don’t generalize well. Since the ultimate goal of our LLM is to generate new, and ideally unique text, over-fitting is a real danger. We don’t want “To be” to always complete as Hamlet’s soliloquy.
Dropout solves this by reducing the model’s dependency on specific attention patterns during training.
The approach is simple: randomly zero out some of the attention weights during each training step. This essentially deactivates those attentions, causing the model to lean more on others. Each training round picks a different set of weights to drop, so over time, all of the weights get trained without ever over-depending on any particular one.
The only gotcha is that we still want the weights to be properly scaled. In particular, we want each weight’s expected value to be unchanged by dropout. To accomplish this, we multiply the surviving weights by a compensating factor based on the dropout rate.
Expected value
This is a statistical term that basically means, for a randomized value, what would its average be after an infinite number of iterations? For example, if you were to roll a 6-sided die forever, the average of those rolls would converge to .
If you randomly zero half of the rolls but double the others, these two effects cancel each other out. Even though individual rolls can now have values that the original couldn’t (like 0 or 10), the expected value remains 3.5.
For example, let’s say we set dropout to 50% (this is a hyperparameter that’s set during training; it’s typically closer to the 10% - 20% range in the real world). This means:
Each element has a 50% chance of being dropped
Each surviving item will be doubled
Note that:
After the dropping and compensating, each row no longer adds up to 1. The third row, for example, adds up to 1.74! This is fine: what’s important is that each weight’s expected value stays the same whether we do or don’t use dropout.
We’re not dropping half of the elements in any particular row, or even in the matrix. Instead, each element independently gets dropped or not. In the example above, only one row had exactly half its weights dropped, and overall we dropped 9 elements instead of 8. In practice, there are enough training rounds, and the matrices are large enough, that the randomness averages out.