GPT-2 in One Function

GPT-2 is not a complicated model, in fact we can code golf it down into a single Python function using nothing more than numpy. This is everything inlined in a single function, the layernorm, attention, mlp, token and position embeddings, and the final logits in 70 lines of code if you don't count whitespace.

import numpy as np

def gpt2(inputs: list[int], params: ModelParams, n_head: int) -> np.ndarray:
    x = params.wte[inputs] + params.wpe[range(len(inputs))]

    seq_len = x.shape[0]
    embedding_dim = x.shape[-1]
    head_size = embedding_dim // n_head

    for block_params in params.blocks:
        ln1_mean = np.mean(x, axis=-1, keepdims=True)
        ln1_variance = np.var(x, axis=-1, keepdims=True)
        ln1_normalized = (x - ln1_mean) / np.sqrt(ln1_variance + 1e-5)
        ln1_output = block_params.ln_1.g * ln1_normalized + block_params.ln_1.b

        qkv_proj = ln1_output @ block_params.attn.c_attn.w + block_params.attn.c_attn.b
        q_proj, k_proj, v_proj = np.split(qkv_proj, 3, axis=-1)

        q_heads = q_proj.reshape(seq_len, n_head, head_size)
        k_heads = k_proj.reshape(seq_len, n_head, head_size)
        v_heads = v_proj.reshape(seq_len, n_head, head_size)

        q_heads_t = q_heads.transpose(1, 0, 2)
        k_heads_t = k_heads.transpose(1, 0, 2)
        v_heads_t = v_heads.transpose(1, 0, 2)

        attention_scores = (q_heads_t @ k_heads_t.transpose(0, 2, 1)) / np.sqrt(
            head_size
        )

        causal_mask = np.triu(np.ones((seq_len, seq_len), dtype=x.dtype) * -np.inf, k=1)
        attention_scores = attention_scores + causal_mask

        exp_scores = np.exp(
            attention_scores - np.max(attention_scores, axis=-1, keepdims=True)
        )
        attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

        weighted_values = attention_weights @ v_heads_t

        merged_heads = weighted_values.transpose(1, 0, 2).reshape(
            seq_len, embedding_dim
        )

        mha_output = (
            merged_heads @ block_params.attn.c_proj.w + block_params.attn.c_proj.b
        )

        x = x + mha_output

        ln2_mean = np.mean(x, axis=-1, keepdims=True)
        ln2_variance = np.var(x, axis=-1, keepdims=True)
        ln2_normalized = (x - ln2_mean) / np.sqrt(ln2_variance + 1e-5)
        ln2_output = block_params.ln_2.g * ln2_normalized + block_params.ln_2.b

        fc_output = ln2_output @ block_params.mlp.c_fc.w + block_params.mlp.c_fc.b

        gelu_output = (
            0.5
            * fc_output
            * (1 + np.tanh(np.sqrt(2 / np.pi) * (fc_output + 0.044715 * fc_output**3)))
        )

        ffn_output = gelu_output @ block_params.mlp.c_proj.w + block_params.mlp.c_proj.b

        x = x + ffn_output

    lnf_mean = np.mean(x, axis=-1, keepdims=True)
    lnf_variance = np.var(x, axis=-1, keepdims=True)
    lnf_normalized = (x - lnf_mean) / np.sqrt(lnf_variance + 1e-5)
    x_normalized_final = params.ln_f.g * lnf_normalized + params.ln_f.b

    logits = x_normalized_final @ params.wte.T

    return logits

Of course, here I'm cheating a little bit because we have a helper module to load the weights into a dataclass ModelParams, but the logic for the forward pass is all contained within this single function. Here are the files for the full implementation: