Page Status: Final draft ★★★
A working example¶
Now that we’ve worked through the conceptual workings of the LLM as well as the reformulations to make it efficient, let’s see what all of this actually looks like in code.
There’s a fully working implementation of an LLM at yshavit/llm-notes://simpllm (which is in the same repository that hosts this book’s source). That repository has a few modules:
simpllm-core: A complete, from-scratch implementation of an LLM. This includes the tokenizer, inference, tensor math, and logit sampling. The only external dependencies arerand(for randomization in the logit sampler) andrayon(for parallelization).simpllm: An executable that takessimpllm-coreand wraps it into a nice TUI.fasterllm: An executable that replaces the tensor math insimpllm-corewith a real, production implementation.A few other helper modules.
Of these, simpllm-core is the most interesting. Let’s take a look at some of the highlights.
Tokenization¶
If you recall, the three steps for tokenization are:
Convert the input text to UTF-8 bytes
Merge the bytes using configured merge-pairs
Look up the resulting sequences to find their token IDs
Steps 1 and 3 are trivial, so let’s take a look at step 2. It’s not too bad!
let mut merge_rules = self.merge_rules.iter();
while let Some(merge_rule) = merge_rules.next() {
let mut look_starting_at_idx = 0;
loop {
let segment = &encoded[look_starting_at_idx..];
let Some(idx_within_segment) = Self::find_match(segment, merge_rule) else {
// Break out of the `loop`: we're done with this merge rule.
// We'll pick up at the next iteration of `while let Some(merge_rule)`.
break;
};
let index_within_full = idx_within_segment + look_starting_at_idx;
// Merge the words by removing the next word and adding it to this one.
let _ = encoded.remove(index_within_full + 1);
encoded[index_within_full].extend(merge_rule.1.clone());
// Search for more matches of the merge_rule within the input, starting at the
// index we just merged.
look_starting_at_idx = index_within_full;
// Reset the merge rules, since the newly merged string may have been one of
// the rules we've already passed.
// Note that this won't affect the current loop; it'll just affect the next
// iteration of while let Some(..) = merge_rules.next()`.
merge_rules = self.merge_rules.iter();
}
}Attention¶
Attention is a bit involved because of KV caching and the various details like scaling, softmax, and combined weights; but even so, it’s not too bad:
// Calculate QKV at once:
// - Each is size [h x n x d/h ]
// - KV caching applies, so n = 1 after the prefill phase.
let combined = input.matmul(self.w_qkv.weights()).add(self.w_qkv.bias());
let [queries, keys, values] = combined.split::<3>(1);
// Add the keys and values to the KV cache.
let (keys, values) = cache.extend(AttentionCache {
keys: Some(keys),
values: Some(values),
});
// K and V are the [N x d] from above, where N is the total sequence size,
// including cached.
// Reshape each to [N x h x d/h], then transpose to [h x N x d/h]
let [keys, values] = [keys, values].map(|kv| {
let full_n = kv.shape()[0];
kv.clone().reshape([full_n, h, d / h]).transposed(0, 1)
});
// Similarly, reshape and transpose Q. Note that queries are [n x d], not
// [N x d], where n is just this input's size (not including the cache).
let queries = queries.reshape([input_n, h, d / h]).transposed(0, 1);
// Attention scores: QK^T
let mut a = queries.matmul(&keys.transposed(1, 2));
// Causal attention during prefill (but not decode); then scaling and softmax
if input_n > 1 {
let causal_mask = B::lower_triangle(input_n);
a = a.add(&causal_mask);
}
let dim_per_head = d / h;
a.multiply_scalar(1.0 / (dim_per_head as f32).sqrt());
a = a.softmax();
// Attention = AV; then transpose from [h x n x d/h] to [n x h x d/h]
// and reshape to [n x d].
let attn = a.matmul(&values).transposed(0, 1).reshape([input_n, d]);
// Finally, apply W_o's weights and bias
attn.matmul(self.w_o.weights()).add(self.w_o.bias())FFN¶
The FFNs in an LLM always have two layers, but it’s not hard to write a generic FFN that takes an arbitrary number of layers:
let mut transforms = self.layers_transforms.iter().peekable();
while let Some(transform) = transforms.next() {
// Apply the transformation's weights and bias
let mut output = input.matmul(transform.weights()).add(transform.bias());
// If this isn't the last transformation, also apply the activation function.
if transforms.peek().is_some() {
output = output.gelu();
}
input = output;
}Note that this FFN uses GELU for the activation function, not the ReLU that I described previously. GELU is similar to ReLU, but smoother; this improves both training and inference by removing sharp spikes in neuron activation.
ReLU and GELU; image was generated by Claude, but I’ve eyeballed it and it looks right. :-)
GELU is often approximated in LLMs, including my implementation:
0.5 * x * (1. + f32::tanh((2. / PI).sqrt() * (x + 0.044715 * x.powi(3))))The rest of the owl¶
After attention and FFN, everything else is either glue or relatively simple stuff like normalization or logit sampling.
I won’t go into those here. If you’re interested, I invite you to take a look at the project itself.