Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

A working example

Now that we’ve worked through the conceptual workings of the LLM as well as the reformulations to make it efficient, let’s see what all of this actually looks like in code.

There’s a fully working implementation of an LLM at GitHub logoyshavit/llm-notes://simpllm (which is in the same repository that hosts this book’s source). That repository has a few modules:

Of these, simpllm-core is the most interesting. Let’s take a look at some of the highlights.

Tokenization

If you recall, the three steps for tokenization are:

  1. Convert the input text to UTF-8 bytes

  2. Merge the bytes using configured merge-pairs

  3. Look up the resulting sequences to find their token IDs

Steps 1 and 3 are trivial, so let’s take a look at step 2. It’s not too bad!

let mut merge_rules = self.merge_rules.iter();
while let Some(merge_rule) = merge_rules.next() {
    let mut look_starting_at_idx = 0;
    loop {
        let segment = &encoded[look_starting_at_idx..];
        let Some(idx_within_segment) = Self::find_match(segment, merge_rule) else {
            // Break out of the `loop`: we're done with this merge rule.
            // We'll pick up at the next iteration of `while let Some(merge_rule)`.
            break;
        };
        let index_within_full = idx_within_segment + look_starting_at_idx;

        // Merge the words by removing the next word and adding it to this one.
        let _ = encoded.remove(index_within_full + 1);
        encoded[index_within_full].extend(merge_rule.1.clone());

        // Search for more matches of the merge_rule within the input, starting at the
        // index we just merged.
        look_starting_at_idx = index_within_full;

        // Reset the merge rules, since the newly merged string may have been one of
        // the rules we've already passed.
        // Note that this won't affect the current loop; it'll just affect the next
        // iteration of while let Some(..) = merge_rules.next()`.
        merge_rules = self.merge_rules.iter();
    }
}

Attention

Attention is a bit involved because of KV caching and the various details like scaling, softmax, and combined QKVQKV weights; but even so, it’s not too bad:

// Calculate QKV at once:
// - Each is size [h x n x d/h ]
// - KV caching applies, so n = 1 after the prefill phase.
let combined = input.matmul(self.w_qkv.weights()).add(self.w_qkv.bias());
let [queries, keys, values] = combined.split::<3>(1);

// Add the keys and values to the KV cache.
let (keys, values) = cache.extend(AttentionCache {
    keys: Some(keys),
    values: Some(values),
});

// K and V are the [N x d] from above, where N is the total sequence size,
// including cached.
// Reshape each to [N x h x d/h], then transpose to [h x N x d/h]
let [keys, values] = [keys, values].map(|kv| {
    let full_n = kv.shape()[0];
    kv.clone().reshape([full_n, h, d / h]).transposed(0, 1)
});
// Similarly, reshape and transpose Q. Note that queries are [n x d], not
// [N x d], where n is just this input's size (not including the cache).
let queries = queries.reshape([input_n, h, d / h]).transposed(0, 1);

// Attention scores: QK^T
let mut a = queries.matmul(&keys.transposed(1, 2));

// Causal attention during prefill (but not decode); then scaling and softmax
if input_n > 1 {
    let causal_mask = B::lower_triangle(input_n);
    a = a.add(&causal_mask);
}
let dim_per_head = d / h;
a.multiply_scalar(1.0 / (dim_per_head as f32).sqrt());
a = a.softmax();

// Attention = AV; then transpose from [h x n x d/h] to [n x h x d/h]
// and reshape to [n x d].
let attn = a.matmul(&values).transposed(0, 1).reshape([input_n, d]);

// Finally, apply W_o's weights and bias
attn.matmul(self.w_o.weights()).add(self.w_o.bias())

FFN

The FFNs in an LLM always have two layers, but it’s not hard to write a generic FFN that takes an arbitrary number of layers:

let mut transforms = self.layers_transforms.iter().peekable();
while let Some(transform) = transforms.next() {
    // Apply the transformation's weights and bias
    let mut output = input.matmul(transform.weights()).add(transform.bias());

    // If this isn't the last transformation, also apply the activation function.
    if transforms.peek().is_some() {
        output = output.gelu();
    }
    input = output;
}

Note that this FFN uses GELU for the activation function, not the ReLU that I described previously. GELU is similar to ReLU, but smoother; this improves both training and inference by removing sharp spikes in neuron activation.

ReLU and GELU superimposed

ReLU and GELU; image was generated by Claude, but I’ve eyeballed it and it looks right. :-)

GELU is often approximated in LLMs, including my implementation:

0.5 * x * (1. + f32::tanh((2. / PI).sqrt() * (x + 0.044715 * x.powi(3))))

The rest of the owl

After attention and FFN, everything else is either glue or relatively simple stuff like normalization or logit sampling.

I won’t go into those here. If you’re interested, I invite you to take a look at the project itself.