35 Comments
User's avatar
Michael Xie's avatar

Very well written, clear, sharp, entertaining to read, and very educational. Thank you!

Expand full comment
Sebastian Raschka, PhD's avatar

Thanks!

Expand full comment
Vivek Nayyar's avatar

One of the best articles I have read and so well written. Please continue writing more 🙏

Expand full comment
Sebastian Raschka, PhD's avatar

Thanks for the kind words!

Expand full comment
Vivek Nayyar's avatar

Please also if possible write about query, key and value tensors. The role each of them play in deciding the next token

Expand full comment
Sebastian Raschka, PhD's avatar

Thanks for the suggestion! I think this may already be covered in my previous attention mechanism article: https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention ?

Expand full comment
RR's avatar

Great to see you back in action!! I hope your recovery is proceeding well.

Expand full comment
Aldred's avatar

I tried to load the GPT2 weights but I need to comment out "persistent=False" for mask, else the weights cannot be loaded. Wondering what is the rationale of the aforementioned setting?

self.register_buffer(

"mask",

torch.triu(torch.ones(context_length, context_length), diagonal=1),

# persistent=False

)

Expand full comment
Sebastian Raschka, PhD's avatar

Good point. This flag determines whether these tensors are included when saving or loading the weights via the `state_dict`. I think you are loading the weights from a previous checkpoint / the book? I think this would be due to a discrepancy because the I didn't set this to `False` in the book. In hindsight, I think setting it to `False` is better though because these are values that are easily created and don't need to be saved in the weights file.

I should probably set it to `True` here to avoid this issue though.

Expand full comment
Aldred's avatar
Sebastian Raschka, PhD's avatar

Thanks for the note. I think it's easiest to remove the `persistent=False` so the weights can be loaded without issues. (I just updated the article.)

Expand full comment
active_sky's avatar

Mac is a magical hardware device; the acceleration effect of KV-Cache on the CPU is even better than that of the GPU.

Expand full comment
Sebastian Raschka, PhD's avatar

Ha yes, but I think that's mainly due to the small size of the model.

Expand full comment
Abhishek Sharma's avatar

I have been following your resources for almost a decade now, and you single handedly have helped me become a better engineer by leaps & bounds.

Expand full comment
Sebastian Raschka, PhD's avatar

Thanks for the kind comment! On the on hand it makes me feel a bit old but on the other hand, this is very nice to hear!

Expand full comment
Mariano Kamp's avatar

Thank you Sebastian. Qq: the sliding window would mean that alongside the cache the context window is also truncated, right?

Expand full comment
Sebastian Raschka, PhD's avatar

Yes, it would truncate the context.

Expand full comment
kevin's avatar

3kx a lot!

Expand full comment
Peter van Beek's avatar

Thank you, great tutorial! One question I had about KV caching is why it is not also applied to the query data. It seems to me that if x is the full context in the vanilla implementation, then queries = self.W_query(x) also recomputes many query tokens repeatedly. I must be missing something, but previously couldn't find a clear answer.

Expand full comment
Sebastian Raschka, PhD's avatar

That's because for the current query, you don't need past queries, only past keys and values when computing the attention score.

Expand full comment
Aldred's avatar

Is it possible to modify the vanilla implementation to only take in the current query instead of recomputing the whole context i.e. queries = self.W_query(x) where x is the full context, not only the latest token?

Expand full comment
Sebastian Raschka, PhD's avatar

yes definitely. But this is taken care of in the `generate_text_simple_cached`. I.e., after the initial prefill, it only passes the recent token, not all tokens, so the `self.W_query(x)` doesn't get the full context

Expand full comment
Aldred's avatar

Got it. So that I truly understand what you mean by "you don't need past queries", I modified the vanilla implementation by only taking into account the current query but without KV cache as the following. It's hacky and highly inefficient, but it's meant for learning purposes only. I tested it and it works exactly like the vanilla implementation. Let me know what you think.

class MultiHeadAttention(nn.Module):

def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):

super().__init__()

assert (d_out % num_heads == 0), \

"d_out must be divisible by num_heads"

self.d_out = d_out

self.num_heads = num_heads

self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)

self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)

self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs

self.dropout = nn.Dropout(dropout)

self.register_buffer(

"mask",

torch.triu(torch.ones(context_length, context_length),

diagonal=1)

)

self.ptr_current_pos = 0

def forward(self, x):

b, num_tokens, d_in = x.shape

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)

# Calculate the Q for all of the tokens for first invocation, else only the last token

if self.ptr_current_pos == 0:

queries = self.W_query(x)

else:

queries = self.W_query(x[:, -1, :])

queries = queries.unsqueeze(dim=-2)

values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension

# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)

keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)

values = values.view(b, num_tokens, self.num_heads, self.head_dim)

queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)

keys = keys.transpose(1, 2)

queries = queries.transpose(1, 2)

values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask

attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head

# Mask for first invocation, else no mask

if self.ptr_current_pos == 0:

# Original mask truncated to the number of tokens and converted to boolean

mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores

attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)

context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim

context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

context_vec = self.out_proj(context_vec) # optional projection

# print(f"context_vec shape: {context_vec.shape}")

num_tokens_Q = queries.shape[-2]

self.ptr_current_pos += num_tokens_Q

return context_vec

def reset(self):

self.ptr_current_pos = 0

Expand full comment
Sahibpreet Singh's avatar

One of the best Implementation and explanation for Kv Cache

Expand full comment
Sebastian Raschka, PhD's avatar

Thanks!

Expand full comment
Scott Gwynn's avatar

Thank you. I have a keen interest in this - especially on apple. I am beginning to learn about the constrained generation used with their on-device foundation model. This might be good content for you as well.

Expand full comment
Sebastian Raschka, PhD's avatar

Interesting! I heard they announced API access to their on-device models in iOS 26 at their recent WWDC conference. Do you know if they said something about macOS 26 and opening the API up there as well?

Expand full comment
Alpha Xiao's avatar

Thanks for sharing! I have two questions:

1) For models running on CUDA devices, does this KV cache technique still apply? If so, does the cache use GPU or CPU memory?

2) Does KV cache work across different inference sessions? e.g., the system prompt can utilize the cache.

Expand full comment
Sebastian Raschka, PhD's avatar

Yes this works for CUDA as a well. You can initialize placeholder tensors on the CUDA device directly for optimal efficiency. For this small demo there was not much of a speed benefit but I’ll add the KV cache to my Llama 3 and Qwen3 from-scratch models this weekend so the difference will be more visible.

Expand full comment
prasadraje's avatar

Great article as always Sebastian, clearly spelling out the KV cache benefits. To illustrate this further: here is a compute analysis spreadsheet that illustrates the specific KV cache total values and the compute reduction due to KV cache for Deepseek, Llama, Qwen (essentially we get GEMVs instead of GEMMs in the attention and MLP blocks). https://www.linkedin.com/posts/prasadraje_i-am-making-available-for-free-this-llm-compute-activity-7326488598705242112-LlwX

Expand full comment
Kartik Ramesh's avatar

Wow this is like exactly what I needed? Thanks Sebastian!

Expand full comment
Logan Thorneloe's avatar

Love this! Thank you for sharing. I hope your recovery is swift.

Expand full comment
Halidu Abdulai's avatar

Great tutorial! I have a question about how ambiguous words like "duck" when they appear multiple times with different meanings would be treated in this case.

For example, in the sentence:

"He had to duck when the duck flew at him."

The first "duck" is a verb, and the second is a noun. Since we make use of cached key-value (KV) pairs for previously seen tokens, what happens in this case?

If we cache the first "duck" (the verb), do we simply reuse its KV pair when we encounter the second "duck"? Shouldn't their representations be different due to their distinct roles and contexts?

Expand full comment
Sebastian Raschka, PhD's avatar

Could question, and I see what you mean. However, this would not be an issue since we retrieve the previous keys and values by position.

Expand full comment