Context Zero Logo
Published on

Attention resolves through DNS

Authors

The attention mechanism is one of the core ideas behind Transformers and modern language models. At first, it looks like a scary pile of symbols:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

But once you break it down, the whole thing is just three steps. Here's the mnemonic:

DNS — Dot, Normalize, Sum

D = Dot product — compute similarity scores between query and keys.

N = Normalize — apply softmax to turn scores into attention weights.

S = Sum — take the weighted sum of values.

So:

Attention=DNS\text{Attention} = D \rightarrow N \rightarrow S

Or, written in matrices:

QKTsoftmaxAVQK^T \rightarrow \text{softmax} \rightarrow AV

That's the skeleton. Attention resolves through DNS. Now let's build the whole thing carefully — from token embeddings all the way to the output, with interactive demos along the way so you can play with each step.

Video overview

YouTube video placeholder — pass an id or url prop.

If you'd rather read straight through, skip the video and continue below.

1. Start with tokens

Take a sentence:

The cat drank milk because it was thirsty.

The model first splits this sentence into tokens. For simplicity, pretend each word is one token:

The, cat, drank, milk, because, it, was, thirsty

There are 8 tokens.

A Transformer cannot directly understand words as text. It needs numbers. So each token is converted into a vector called an embedding.

An embedding is a list of numbers representing a token. For example, a tiny fake embedding for cat might be:

xcat=[0.2, 1.7, 0.4, 0.9]x_{\text{cat}} = [0.2,\ 1.7,\ -0.4,\ 0.9]

Real embeddings are much larger: 512, 768, 4096, or even more dimensions depending on the model.

If we stack all token embeddings into a matrix, we get XX. For 8 tokens, each with 4-dimensional embeddings:

XR8×4X \in \mathbb{R}^{8 \times 4}

This means XX is a matrix with 8 rows and 4 columns. Rows = tokens. Columns = embedding dimensions.

X=[xThexcatxdrankxmilkxbecausexitxwasxthirsty]X = \begin{bmatrix} x_{\text{The}} \\ x_{\text{cat}} \\ x_{\text{drank}} \\ x_{\text{milk}} \\ x_{\text{because}} \\ x_{\text{it}} \\ x_{\text{was}} \\ x_{\text{thirsty}} \end{bmatrix}

Each row is a token vector.

2. What does XR8×4X \in \mathbb{R}^{8 \times 4} mean?

This notation is common in machine learning. Read it as:

X belongs to the set of real-number matrices with 8 rows and 4 columns.

Breaking it down:

  • R\mathbb{R} means real numbers: 11, 2-2, 0.50.5, 3.143.14, 2\sqrt{2}.
  • R8×4\mathbb{R}^{8 \times 4} means all matrices with 8 rows and 4 columns, where each entry is a real number.

In attention:

  • Rows usually represent tokens.
  • Columns usually represent vector dimensions.

So QR3×2Q \in \mathbb{R}^{3 \times 2} means QQ has 3 token rows, and each query vector has 2 dimensions.

This shape notation is not decoration — it tells you whether matrix multiplication is legal. For example:

(8×64)(64×8)=8×8(8 \times 64)(64 \times 8) = 8 \times 8

works because the middle dimensions match (64=6464 = 64).

3. Embeddings alone are not enough

A token embedding tells the model what the token roughly means:

  • cat = animal, noun, living thing, singular, etc.
  • milk = liquid, noun, drinkable, etc.
  • it = pronoun

But there is a problem.

A basic embedding by itself does not tell the model where the word appears in the sentence. Compare:

Dog bites man.

and:

Man bites dog.

Same words. Different order. Very different meaning.

Without word order, the model would have a mess. It would know the words dog, bites, man but not who bites whom. That is why Transformers need positional information.

4. Positional encoding: telling the model word order

Attention by itself is permutation-blind. If you shuffle the tokens, raw attention does not naturally know the original order. So we add positional encoding.

Let XX be the token embedding matrix and PP be the positional encoding matrix. Then:

Z=X+PZ = X + P

Here:

  • XX = token meaning
  • PP = token position
  • ZZ = token meaning plus position

For example:

zcat=xcat+p2z_{\text{cat}} = x_{\text{cat}} + p_2

This means "cat plus I am at position 2." And:

zit=xit+p6z_{\text{it}} = x_{\text{it}} + p_6

This means "it plus I am at position 6."

So after positional encoding, the model does not just know cat — it knows something closer to cat at position 2. And not just it, but it at position 6. This matters enormously for grammar, reference, syntax, and meaning.

5. Why positional encoding helps in pronoun resolution

Consider again:

The cat drank milk because it was thirsty.

The word it could refer to cat or milk. Both are nouns. Both came before it. But cat is more plausible because a cat can be thirsty.

However, meaning alone is not enough. The model also needs order and structure:

The(1) cat(2) drank(3) milk(4) because(5) it(6) was(7) thirsty(8)

Positional information helps the model learn patterns like:

  • Pronouns often refer to earlier nouns.
  • Nearby nouns may matter.
  • Subjects often matter.
  • Words after "because" explain a cause.
  • Adjectives often describe nearby nouns/pronouns.

Without position, the model would struggle to know which word came first, which is subject-like, which is object-like, and which phrase belongs with which. Positional encoding gives attention the sentence's geometry.

Bluntly: attention tells the model who is relevant; positional encoding tells the model where everyone is standing in line.

6. From Z to Q, K, and V

Once we have position-aware token vectors ZZ, the model creates three new matrices:

Q=ZWQ,K=ZWK,V=ZWVQ = Z W_Q, \quad K = Z W_K, \quad V = Z W_V

These are:

  • QQ: Query
  • KK: Key
  • VV: Value

The matrices WQW_Q, WKW_K, and WVW_V are learned during training. They are called projection matrices.

7. What does "projecting embeddings" mean?

Projection means: multiply a vector by a learned matrix to create a new version of the vector.

For example, suppose one token vector is x=[2,3]x = [2, 3] and the model has learned this matrix:

WQ=[1002]W_Q = \begin{bmatrix} 1 & 0 \\ 0 & 2 \end{bmatrix}

Then:

q=xWQ=[2,3][1002]=[2,6]q = x W_Q = [2, 3] \begin{bmatrix} 1 & 0 \\ 0 & 2 \end{bmatrix} = [2, 6]

So the original vector [2,3][2, 3] was transformed into [2,6][2, 6]. That is projection.

In Transformers, we project the same input ZZ three different ways:

Q=ZWQ,K=ZWK,V=ZWVQ = Z W_Q, \quad K = Z W_K, \quad V = Z W_V

The model creates three different learned views of the same tokens.

8. Why divide embeddings into Q, K, and V?

This is one of the most important ideas.

A token embedding is general-purpose. It contains lots of mixed information: meaning, grammar, position, token identity, semantic role, contextual hints.

But attention needs three different jobs:

  1. What am I looking for?
  2. How should others find me?
  3. What information should I provide if selected?

That is why attention uses:

VectorMeaningJob
QQQueryWhat am I looking for?
KKKeyHow should I be found?
VVValueWhat information do I provide?

For the word it, qitq_{\text{it}} might represent: "I am a pronoun looking for a referent."

For the word cat, kcatk_{\text{cat}} might represent: "I am a singular living noun and a possible referent."

For the word milk, kmilkk_{\text{milk}} might represent: "I am a singular noun, object, liquid."

Then it compares its query with the keys of other tokens.

9. Why exactly three: Q, K, and V?

Because attention is basically key-value retrieval.

Think of a dictionary, database, or search engine. You have:

query → compare against keys → retrieve values

Example database:

  • query: user_id = 42
  • key: 42
  • value: full user record

You do not retrieve the key itself. You use the key to find the value.

Attention works the same way:

Qcompare with Kretrieve/mix VQ \rightarrow \text{compare with } K \rightarrow \text{retrieve/mix } V
  • QQ is the search query.
  • KK is the searchable address/index.
  • VV is the content/payload.

This is why three is the natural minimum. You need one thing that searches, one thing that gets searched, and one thing that gets returned.

Could we use only two? Technically yes, but weaker. If K=VK = V, then the same vector must be used both for matching and for content. That creates a bottleneck — the features useful for being found are not always the features useful for what should be passed along.

Example:

The trophy does not fit in the suitcase because it is too large.

To resolve it, the model needs to decide whether it refers to trophy or suitcase. The key features may include: noun, singular, candidate referent, position. But the value information may include: object size, containment role, semantic plausibility. Matching and payload are related, but not identical. That is why KK and VV are separated.

Could there be 4 or 5 vectors? In principle, yes. But Q/K/V is the clean core structure. Multi-head attention already creates multiple Q/K/V sets in parallel, so the model gets many specialized attention mechanisms without adding random extra roles.

10. Q and K produce attention scores

Once we have Q, K, and V, attention starts with a dot product:

QKTQK^T

This compares every query with every key.

Suppose we have 3 tokens: cat, drank, milk. And each token has a 2D query/key vector. Let:

Q=[101101],K=[101101]Q = \begin{bmatrix} 1 & 0 \\ 1 & 1 \\ 0 & 1 \end{bmatrix}, \quad K = \begin{bmatrix} 1 & 0 \\ 1 & 1 \\ 0 & 1 \end{bmatrix}

Then:

KT=[110011]K^T = \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix}

Now:

QKT=[101101][110011]=[110121011]QK^T = \begin{bmatrix} 1 & 0 \\ 1 & 1 \\ 0 & 1 \end{bmatrix} \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 1 & 0 \\ 1 & 2 & 1 \\ 0 & 1 & 1 \end{bmatrix}

This gives a score matrix. Rows are query tokens. Columns are key tokens.

catdrankmilk
cat110
drank121
milk011

The row for drank is [1,2,1][1, 2, 1]. In this toy setup, that means drank matches cat with score 1, itself with score 2, and milk with score 1.

11. Why dot product?

A dot product measures alignment/similarity between vectors. For two vectors qiq_i and kjk_j, the score qikjq_i \cdot k_j is larger when they point in a similar direction, and smaller when they are less aligned.

So in attention, qikjq_i \cdot k_j asks: "How well does token ii's query match token jj's key?"

For the pronoun example:

  • qitkcatq_{\text{it}} \cdot k_{\text{cat}} asks: "How compatible is 'it' with 'cat' as something to attend to?"
  • qitkmilkq_{\text{it}} \cdot k_{\text{milk}} asks: "How compatible is 'it' with 'milk' as something to attend to?"

Dot product playground

Drag the Q and K vectors. When they point the same way, the dot product is large. When they're perpendicular, it's zero. When opposite, it's negative.

QK
Q = [2.0, 1.0]
K = [1.5, 1.5]
Dot product Q · K
4.50
= 2.0 × 1.5 + 1.0 × 1.5
Angle between them
18°
Strongly aligned — high attention

12. Why divide by dk\sqrt{d_k}?

The attention formula is not just QKTQK^T. It is:

QKTdk\frac{QK^T}{\sqrt{d_k}}

where dkd_k is the dimension of the key/query vectors.

Why divide? Because dot products get larger when vectors have more dimensions. If dkd_k is large, the dot-product scores can become huge — and then softmax becomes too sharp.

Example:

  • softmax([1, 2, 3]) gives a distribution where the biggest value wins, but not insanely.
  • softmax([10, 20, 30]) becomes almost all weight on the largest value.

That is bad for training because most positions get nearly zero gradient. The model becomes overconfident too early.

So we scale by dk\sqrt{d_k} to keep the score magnitudes controlled. Think of it as preventing softmax from becoming a drama queen.

13. Softmax: turning scores into weights

After dot product and scaling, we apply softmax:

A=softmax ⁣(QKTdk)A = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)

Softmax converts raw scores into normalized weights. The weights have two important properties:

  1. Each weight is positive.
  2. Each row sums to 1.

So a row becomes a distribution over tokens.

For example, suppose the scores for it are:

catdrankmilk
it5.20.71.8

After softmax:

catdrankmilk
it0.960.010.03

This means it attends 96% to cat, 1% to drank, and 3% to milk. Now we have the attention matrix AA, which tells us where each token looks.

Softmax sharpness demo

Drag the raw scores. See how softmax turns them into a probability distribution. Toggle ÷√d_k and push d_k up to see why scaling matters — without it, softmax becomes a drama queen.

Raw scores (Q · K)
cat5.2
drank0.7
milk1.8
because0.3
it0.1
d_k8 (√d_k = 2.83)
Softmax weights (sum to 1)
cat54.2%
drank11.0%
milk16.3%
because9.6%
it8.9%
Entropy1.311
Spread out — attention is uncertain.

14. The attention matrix A

If the sentence has 8 tokens, then:

AR8×8A \in \mathbb{R}^{8 \times 8}

Why? Because every token can attend to every token. Rows = the token doing the looking. Columns = the token being looked at.

So with tokens The, cat, drank, milk, because, it, was, thirsty, the row for it might be:

Thecatdrankmilkbecauseitwasthirsty
it0.000.910.010.040.010.010.010.01

This would mean the token it mostly attends to cat.

In reality, actual attention patterns are spread across many heads and layers, so one row does not always map cleanly to a human interpretation. But as a learning model, this is the right intuition.

Attention matrix

Click any row token to see which other tokens it attends to. Rows = query (the token doing the looking). Columns = key (the token being looked at). Switch between raw scores and softmax-normalized weights.

Thecatdrankmilkbecauseitwasthirsty
The6%64%6%5%5%5%5%5%
cat12%52%16%5%4%3%4%4%
drank1%53%7%32%1%2%2%2%
milk3%4%49%30%4%3%3%4%
because7%9%33%9%15%7%8%11%
it1%84%2%3%2%3%2%4%
was3%5%4%4%4%31%6%43%
thirsty2%34%3%3%3%42%8%6%
it attends most to cat — click another row to compare. Note how pronouns and verbs pull from their semantic anchors.
Note: these scores are hand-tuned for illustration. A trained model's attention is distributed across many heads and layers, but the shape of each individual head looks like this.

15. Now the final step: weighted sum

After we get attention weights AA, we use them to mix the value vectors VV:

O=AVO = AV

This is the S in DNS: Dot → Normalize → Sum.

Here:

  • AA tells us how much to take from each token.
  • VV contains what each token can provide.
  • OO is the resulting context-aware output.

For a single token ii:

oi=jAijvjo_i = \sum_j A_{ij} v_j

Read this as: the output vector for token ii is the weighted sum of all value vectors vjv_j, weighted by how much token ii attends to token jj.

This is the heart of attention.

16. Example of O = AV

Suppose for it, attention weights are:

Ait=[0.96, 0.01, 0.03]A_{\text{it}} = [0.96,\ 0.01,\ 0.03]

corresponding to cat, drank, milk. Now suppose:

Vcat=[10,2],Vdrank=[1,8],Vmilk=[3,5]V_{\text{cat}} = [10, 2], \quad V_{\text{drank}} = [1, 8], \quad V_{\text{milk}} = [3, 5]

Stacked:

V=[1021835]V = \begin{bmatrix} 10 & 2 \\ 1 & 8 \\ 3 & 5 \end{bmatrix}

Then:

Oit=AitV=[0.96, 0.01, 0.03][1021835]O_{\text{it}} = A_{\text{it}} V = [0.96,\ 0.01,\ 0.03] \begin{bmatrix} 10 & 2 \\ 1 & 8 \\ 3 & 5 \end{bmatrix} =0.96[10,2]+0.01[1,8]+0.03[3,5]= 0.96 [10, 2] + 0.01 [1, 8] + 0.03 [3, 5] =[9.6,1.92]+[0.01,0.08]+[0.09,0.15]=[9.70, 2.15]= [9.6, 1.92] + [0.01, 0.08] + [0.09, 0.15] = [9.70,\ 2.15]

So Oit=[9.70, 2.15]O_{\text{it}} = [9.70,\ 2.15].

Because cat had weight 0.960.96, the output for it is mostly based on VcatV_{\text{cat}}. That means it now carries information from cat. This is how attention makes a token context-aware.

Weighted sum: O = AV

Drag the attention weight sliders to see how the output vector is built as a weighted blend of the value vectors. Push one weight high and the output shape mirrors that token's values.

cat
weight: 88.3%
dim 0
10
dim 1
2
dim 2
5
dim 3
1
drank
weight: 4.4%
dim 0
1
dim 1
8
dim 2
2
dim 3
6
milk
weight: 7.3%
dim 0
3
dim 1
5
dim 2
8
dim 3
2
Output O = Σ wᵢ · Vᵢ
The new representation, blended from value vectors
dim 0
9.10
dim 1
2.48
dim 2
5.09
dim 3
1.29
O = [9.10, 2.48, 5.09, 1.29]

17. Full matrix shape of O = AV

Suppose AR8×8A \in \mathbb{R}^{8 \times 8} and VR8×64V \in \mathbb{R}^{8 \times 64}. Then:

O=AVO = AV

Shape: (8×8)(8×64)=8×64(8 \times 8)(8 \times 64) = 8 \times 64. So:

OR8×64O \in \mathbb{R}^{8 \times 64}

Meaning: 8 input tokens, 8 output token vectors, each output vector with 64 dimensions. Every token gets updated. The output for it is one row. The output for cat is another row. Each output row is a weighted mixture of the value vectors.

18. What attention actually accomplishes

Before attention, each token is mostly isolated:

  • it = pronoun
  • cat = animal noun
  • milk = liquid noun

After attention:

  • it = pronoun + information from cat
  • drank = verb + information from cat and milk
  • thirsty = adjective + information from it/cat

So attention turns each token into a context-aware token. That is the big idea.

The vector for it no longer just means "it." It now means something closer to "it, referring to cat in this sentence." The vector for drank may encode "drank, with subject cat and object milk." The vector for thirsty may encode "thirsty, describing the referent of it, likely cat."

Attention allows information to move between tokens.

19. Why not just average all tokens?

You might ask: why not just average all the word embeddings? Because not all tokens matter equally.

For understanding it, cat matters much more than drank or milk. A simple average would do:

13Vcat+13Vdrank+13Vmilk\frac{1}{3} V_{\text{cat}} + \frac{1}{3} V_{\text{drank}} + \frac{1}{3} V_{\text{milk}}

That treats every token equally. Attention does:

0.96Vcat+0.01Vdrank+0.03Vmilk0.96 V_{\text{cat}} + 0.01 V_{\text{drank}} + 0.03 V_{\text{milk}}

That is much smarter. The model dynamically decides what matters depending on the current token and the context.

20. Why not just pick the highest token?

Another question: why use a weighted sum instead of just choosing the best token? Because language often requires combining information from several places.

Example:

The old cat near the window drank milk because it was thirsty.

To understand it, the model may need: cat (referent), old (description), near the window (context), drank (action), thirsty (state/reason). A hard choice would lose useful information.

Attention uses soft selection: mostly cat, some old, some drank, some nearby context, little or none from irrelevant tokens. This soft blending is powerful and differentiable, which means the model can learn it through gradient descent.

21. How Q, K, V relate to the pronoun example

In the sentence:

The cat drank milk because it was thirsty.

When processing it:

  • qitq_{\text{it}} acts like: "I am looking for what this pronoun refers to."
  • kcatk_{\text{cat}} might say: "I am a living noun, singular, earlier in the sentence."
  • kmilkk_{\text{milk}} might say: "I am a liquid noun, singular, earlier in the sentence."

Then dot products produce scores:

qitkcat=5.2,qitkmilk=1.8q_{\text{it}} \cdot k_{\text{cat}} = 5.2, \quad q_{\text{it}} \cdot k_{\text{milk}} = 1.8

Since cat is a better match, it gets a higher score. Softmax turns those scores into weights: cat: 0.96, milk: 0.03. Then:

oit=0.96Vcat+0.03Vmilk+o_{\text{it}} = 0.96 V_{\text{cat}} + 0.03 V_{\text{milk}} + \cdots

So the output representation for it becomes mostly cat-informed. That is how attention can help resolve references.

22. Q and K are routing, V is payload

This sentence is worth remembering:

Q and K decide routing. V carries payload.

Routing means deciding where information should come from. Payload means the actual information being moved.

  • QQ: What am I looking for?
  • KK: Am I relevant to that search?
  • VV: Here is the information I provide if selected.

So QKTQK^T decides routing. AVAV moves payload.

Analogy: QQ = search query, KK = search index, VV = search result content. You do not return the index — you use the index to find the content. Same with attention. You do not use KK as the final content. You use VV.

23. The full attention formula again

Now the full formula should feel less scary:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Break it into parts:

  1. Dot productQKTQK^T compares every query with every key.
  2. ScalingQKTdk\frac{QK^T}{\sqrt{d_k}} keeps scores numerically stable.
  3. Normalizationsoftmax ⁣(QKTdk)\text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) turns scores into attention weights. Call this AA.
  4. Weighted sumAVAV uses attention weights to mix value vectors. Call this OO.

So O=AVO = AV where A=softmax ⁣(QKTdk)A = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right). Therefore:

O=softmax ⁣(QKTdk)VO = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

That is the whole attention operation.

24. Full pipeline from sentence to attention output

Here is the entire process:

Step 1: Tokenize sentenceThe, cat, drank, milk, because, it, was, thirsty.

Step 2: Convert tokens to embeddingsXX. Each token becomes a vector.

Step 3: Add positional encodingZ=X+PZ = X + P. Now each vector knows token meaning and position.

Step 4: Project into Q, K, V:

Q=ZWQ,K=ZWK,V=ZWVQ = Z W_Q, \quad K = Z W_K, \quad V = Z W_V

Now each token has a query, key, and value version.

Step 5: Compute similarity scoresS=QKTS = QK^T. Each token compares itself to every token.

Step 6: Scale scoresSscaled=SdkS_{\text{scaled}} = \frac{S}{\sqrt{d_k}}. Prevents softmax from becoming too sharp.

Step 7: Apply softmaxA=softmax(Sscaled)A = \text{softmax}(S_{\text{scaled}}). Scores become attention weights.

Step 8: Weighted sum of valuesO=AVO = AV. Each token pulls information from other tokens.

Here's the entire pipeline running on real (deterministic) matrices for a 9-token sentence with dmodel=768d_{\text{model}} = 768 and dk=64d_k = 64. Step through each stage and hover any cell to see its value:

Attention pipeline (end-to-end)

Single-head attention on 9 tokens with d_model = 768, d_k = 64. Step through every operation from input embeddings to output. Hover any cell to see its value. Colors: positive / negative for unbounded matrices, white → blue for [0, 1] matrices like softmax.

Step 1: Token embeddings
X ∈ ℝ⁹ˣ⁷⁶⁸

Each token starts as a 768-dim vector. Stacked together, they form a 9 × 768 matrix.

Vector form
X =
xThe[-0.48, -0.14, -0.92, -0.57, …, -0.53]
xcat[-0.46, 0.74, -1.06, -0.58, …, 0.46]
xdrank[-0.20, -0.38, 0.79, -0.09, …, -0.68]
xmilk[0.19, 0.88, -0.10, 0.29, …, -0.08]
xbecause[-0.23, 0.66, 0.26, -0.52, …, 0.45]
xit[0.20, 1.10, -0.20, -1.04, …, -0.15]
xwas[0.09, 0.65, 0.20, 0.65, …, 0.51]
xvery[-0.00, 0.55, -0.51, -0.56, …, -0.34]
xthirsty[-1.05, -0.45, -1.07, 0.03, …, -0.41]
Each row is one token's 768-dim embedding.
Heatmap form
9 × 768
showing 31 of 768 columns

25. Multi-head attention

Real Transformers usually do not use just one attention operation. They use multiple attention heads.

Each head has its own WQW_Q, WKW_K, WVW_V, so each head learns a different way to attend. One head might track pronoun → noun. Another might track verb → subject. Another might track verb → object. Another might track adjective → noun.

Each head computes:

Oh=softmax ⁣(QhKhTdk)VhO_h = \text{softmax}\!\left(\frac{Q_h K_h^T}{\sqrt{d_k}}\right) V_h

Then the outputs are concatenated:

MultiHead=Concat(O1,O2,,Oh)WO\text{MultiHead} = \text{Concat}(O_1, O_2, \ldots, O_h) W_O

This lets the model capture different relationships in parallel.

A single attention head is one lens. Multi-head attention is a committee of lenses.

26. Attention variants: MHA, MQA, GQA, MLA

So far we've described the original multi-head attention from the "Attention Is All You Need" paper. That works beautifully for training, but for inference it has a problem: the KV cache.

The KV cache problem

During autoregressive decoding, each new token has to attend to every previous token. Rather than recomputing KK and VV for every previous token on every step, we cache them. That cache is the KV cache.

KV cache size per token, per layer:

cache per token per layer=2nhdh\text{cache per token per layer} = 2 \cdot n_h \cdot d_h

The factor of 2 is because we store both KK and VV. Across all ll layers:

cache per token=2nhdhl\text{cache per token} = 2 \cdot n_h \cdot d_h \cdot l

For a model like Llama-2-70B with nh=64n_h = 64, dh=128d_h = 128, l=80l = 80 — that's 2.6 MB per token of context. At 32K context, you're holding 80 GB of KV cache per request. This is the dominant cost of inference at long context.

The variants below exist to shrink this number.

MHA — Multi-Head Attention (original)

Every head has its own WQW_Q, WKW_K, WVW_V. Maximum expressivity, maximum cache.

KV cache per token=2nhdhl\text{KV cache per token} = 2 \cdot n_h \cdot d_h \cdot l

MQA — Multi-Query Attention

One shared KK and VV across all heads. Each head still has its own WQW_Q, so queries differ — but all heads read from the same key/value pool.

KV cache per token=2dhl\text{KV cache per token} = 2 \cdot d_h \cdot l

That's an nhn_h× reduction. For Llama-2-70B's numbers, the per-token cache drops from ~2.6 MB to ~40 KB. Massive. But quality takes a real hit because the model loses head-level diversity in what it can "look at."

GQA — Grouped-Query Attention

The compromise. Group heads together; each group shares one KK and VV. If you have nh=64n_h = 64 heads and ng=8n_g = 8 groups, you have 8 KV pairs instead of 64.

KV cache per token=2ngdhl\text{KV cache per token} = 2 \cdot n_g \cdot d_h \cdot l

Llama-2-70B uses GQA with 8 groups: ~500 KB per token. Roughly 8×8\times smaller than MHA, with quality very close to MHA. This is what most production models use today (Llama, Mistral, Qwen, etc.).

MLA — Multi-Head Latent Attention (DeepSeek)

The newest of the four, introduced by DeepSeek-V2. Rather than shrinking KK and VV directly, cache a low-rank latent and reconstruct KK and VV on the fly.

The trick:

LKV=XWDKV(low-rank latent, e.g. dim 576)L_{KV} = X W_{DKV} \quad \text{(low-rank latent, e.g. dim 576)}

Then at attention time:

K=LKVWUK,V=LKVWUVK = L_{KV} W_{UK}, \quad V = L_{KV} W_{UV}

Only LKVL_{KV} is cached — not KK and VV.

KV cache per token=dll\text{KV cache per token} = d_l \cdot l

where dld_l is the latent dimension (e.g. 576 vs nhdh=16,384n_h \cdot d_h = 16{,}384 for the same model). The result: smaller cache than MQA, and quality competitive with or better than MHA. The reconstruction matrices WUK,WUVW_{UK}, W_{UV} are absorbed into the query/output projections at inference time so there's no extra compute.

Side-by-side

VariantWhat's sharedKV cache per tokenRelative sizeQuality
MHAnothing2nhdhl2 \cdot n_h \cdot d_h \cdot l1×1\times (baseline)Best, but \approx matched by others
MQAK,VK, V across all heads2dhl2 \cdot d_h \cdot l1nh×\frac{1}{n_h}\timesNoticeable drop
GQAK,VK, V within groups2ngdhl2 \cdot n_g \cdot d_h \cdot lngnh×\frac{n_g}{n_h}\timesNear MHA
MLAlatent compressed LKVL_{KV}dlld_l \cdot lSmallest in practiceMatches or beats MHA

For a concrete model:

VariantKV cache per token (Llama-70B-scale)
MHA~2.6 MB
GQA (8 groups)~500 KB
MQA~40 KB
MLA (DeepSeek-V2 scale)~70 KB

MLA isn't quite as small as MQA in absolute bytes — but the quality jump puts it on a different curve entirely. MQA gets you small cache and worse quality; MLA gets you small cache and good quality.

What changed and what stayed the same

In all four variants, the core operation is still D-N-S:

O=softmax ⁣(QKTdk)VO = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

What changes is how KK and VV are produced and stored:

  • MHA: stored directly, one per head.
  • MQA: stored directly, one set shared across heads.
  • GQA: stored directly, one set per group.
  • MLA: stored as a low-rank latent, reconstructed on the fly.

The DNS recipe doesn't change. The plumbing around it does — and that plumbing is what determines whether you can serve a 128K-context request without melting a GPU.

27. Important caveat: attention is not literally English reasoning

When we say "the query for it asks which noun it refers to," that is a human-friendly interpretation. Inside the model, there is no English sentence saying "find the referent of this pronoun." There are only vectors and learned weights.

But through training, the model learns vector patterns that often behave like this. So the language we use is metaphorical but useful.

A more precise version would be:

The learned query representation for it tends to assign higher compatibility scores to key representations of tokens that are useful for predicting or representing the pronoun's role in context.

That sentence is more accurate, but also a tiny academic sleep dart. The intuitive explanation is better for learning.

28. The core intuition in one story

Imagine each token is a person in a room. Each person has three cards:

  • Query card: What I am looking for.
  • Key card: What kind of information I have.
  • Value card: The actual information I can share.

The token it walks into the room and checks everyone's key card. It compares its query card with each key card. It sees:

  • cat: strong match
  • milk: weak match
  • drank: very weak match

Then it assigns weights: cat: 0.96, milk: 0.03, drank: 0.01.

Then it collects information from their value cards:

0.96Vcat+0.03Vmilk+0.01Vdrank0.96 V_{\text{cat}} + 0.03 V_{\text{milk}} + 0.01 V_{\text{drank}}

Now it has a new representation that mostly contains information from cat.

That is attention.

29. Final cheat sheet

ConceptDescriptionNotation
EmbeddingA vector representing a tokenxcatx_{\text{cat}}
Embedding matrixAll token vectors stackedXRn×dX \in \mathbb{R}^{n \times d}
Positional encodingAdds word-order informationZ=X+PZ = X + P
ProjectionLearned transformation into another vector spaceQ=ZWQQ = Z W_Q, K=ZWKK = Z W_K, V=ZWVV = Z W_V
QueryWhat this token is looking forQQ
KeyHow this token can be foundKK
ValueWhat this token provides if selectedVV
Dot product scoresCompare queries with keysQKTQK^T
ScalingKeep scores stableQKTdk\frac{QK^T}{\sqrt{d_k}}
SoftmaxTurn scores into weightsA=softmax ⁣(QKTdk)A = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)
Weighted sumMix values using attention weightsO=AVO = AV
Full attentionThe whole operationO=softmax ⁣(QKTdk)VO = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

30. Final mental model

A Transformer begins with token embeddings. It adds position so word order is known. It creates Q, K, and V so each token can search, be searched, and provide information. It uses dot products to score relevance. It uses softmax to normalize those scores into attention weights. It uses O=AVO = AV to mix information from value vectors.

The output is a new representation for every token, now informed by the relevant tokens around it.

The cleanest version:

Attention lets each token ask: "Who matters to me?" Then it pulls information from those tokens and updates itself.

Or using the mnemonic:

DNS: Dot product, Normalize, Sum.

And the killer one-liner:

Q and K decide where to look. V is what gets copied. Positional encoding tells the model where every word is standing.