The KV Cache: Memory Usage in Transformers
Summary
TLDRThe video explains why Transformer language models like GPT require so much memory, especially when generating longer text sequences. The key reason is that self-attention requires computing attention between the current token and all previous tokens, which is inefficient. The solution is to cache the key and value matrices so they don't have to be recomputed for each new token. This KV cache takes up a large portion of memory since it grows with the sequence length. There is also higher latency when processing the initial prompt since the KV cache doesn't exist yet. Overall, the KV cache enables efficient autoregressive decoding but requires storing matrices that scale quadratically with sequence length.
Takeaways
- 😱 Transformers require a lot of memory when generating long text because they have to recompute embeddings for all previous tokens on each new token
- 🤓 The memory usage comes mostly from the key-value (KV) cache that stores embeddings for previous tokens
- 🔍 The key matrix represents previous context the model should attend to
- 💡 The value matrix represents previous context applied as a weighted sum after softmax
- ✏️ On each new token, only the query vector is computed while K and V matrices are cached
- 🚀 Using a KV cache reduces computations from quadratic to constant as sequence length grows
- 📊 For a 30B parameter model, the KV cache takes 180GB for a sequence of 102 tokens - 3X the model size!
- ⚡ The KV cache dominates memory usage during inference
- 🐢 There is higher latency when processing the initial prompt without a KV cache
- 👍 The KV cache allows lower latency when generating each new token after the prompt
Q & A
Why do transformer language models require so much memory when generating long text?
-As more text is generated, more key and value vectors need to be computed and cached, taking up more and more GPU memory. This quadratic growth in computations is inefficient and consumes a large amount of memory.
What are the key, value and query vectors in the self-attention mechanism?
-The query vector represents the current token, the key matrix represents the previous context tokens, and the value matrix also represents the previous contexts but is applied as a weighted sum after softmax.
How does the KV cache help reduce memory usage and computations?
-The KV cache stores previously computed key and value matrices so they don't need to be recomputed for every new token. Only the key and value for the new token is computed, significantly reducing computations.
Where in the transformer architecture is the KV cache used?
-The KV cache is used in the self-attention layer. The cache and the current token embedding are passed in, and the new key and value vectors are computed and appended to the cache.
What factors contribute to the memory usage of the KV cache?
-The factors are: 2 matrices K and V, precision (bytes per parameter), number of layers, embedding dimension per layer, maximum sequence length including prompt, and batch size.
Why does processing the initial prompt have higher latency?
-For the initial prompt, there is no KV cache yet so the key and value matrices need to be computed for every prompt token. Subsequent tokens have lower latency since only the new token's KV is computed.
How large can the KV cache get for a typical transformer model?
-For a 30 billion parameter model, the KV cache can be 180GB while the model itself is 60GB, so over 3x larger.
Why does the KV cache dominate memory usage during inference?
-During inference, the KV cache holds the previously generated key and value matrices so they don't get recomputed. As the sequence grows longer, this cache grows much faster than the model size.
Does the KV cache reduce computational complexity for autoregressive decoding?
-Yes, without the KV cache, a quadratic number of matrix-vector multiplications would be needed. The KV cache reduces this to a constant amount of work per token, independent of past sequence length.
Are there other memory optimizations used with transformers?
-Yes, other optimizations include using lower-precision data types, gradient checkpointing, and knowledge distillation to compress models. But the KV cache addresses a key scalability issue.
Outlines
🤔Why Transformers require so much memory
This paragraph explains why Transformer language models require a lot of memory as they generate more text. As more text is generated, more GPU memory is used up to store the key and value matrices representing previous context. This results in quadratic growth in computations. A key-value (KV) cache is introduced to reduce this computational complexity.
😮How the KV cache works
This paragraph provides details on how the KV cache works with the Transformer model during text generation. For each new token, the KV cache stores previous key and value matrices so they don't have to be recomputed. Only the keys and values for the new token are computed. This greatly reduces the computations per token as the sequence grows longer.
Mindmap
Keywords
💡Transformer
💡GPU memory
💡KV cache
💡key matrix
💡value matrix
💡query vector
💡attention mechanism
💡decoding
💡memory usage
💡latency
Highlights
Transformers require a lot of memory when generating long text
OpenAI charges more for models that can handle longer context
Most memory is used by the key-value (KV) cache
The query, key and value vectors represent the current token and previous context
As the sequence grows, the K and V matrices stay mostly the same
Without caching, each token would recompute all previous tokens
The KV cache stores previous key and value matrices in memory
Only the self-attention layer interacts with the cache
KV cache dominates memory usage during inference
First token has higher latency since no cache exists yet
Subsequent tokens have lower latency by using the cache
KV cache equation shows memory depends on precision, layers, etc.
A 30B parameter model needs 180GB for a KV cache
KV cache takes 3x more memory than model parameters
Difference in latency for first token vs subsequent tokens
Transcripts
Transformers are used almost everywhere
in natural language processing models
like GPT can write pages of coherent
text something that is really impressive
but one fundamental limitation of
Transformer language models is that as
you generate more and more text you will
use up more and more GPU memory and
eventually you reach a point where your
GPU runs out of memory your program
crashes and you cannot generate any more
text
my name is bai I'm a machine learning
engineer and a PhD in natural language
processing and today I will explain why
is this why is it that Transformer
language models require so much more
memory when it deals with longer text
actually if you look at open ai's API
pricing for GPT you see something
interesting they charge you twice as
much per input token to use the longer
context model
this is one of The Economic Consequences
of the high memory usage when you need
to handle large context links
most of the memory usage is taken up by
the KV cache or the key value cache in
this video I explain exactly what this
is and why we need it
before I explain the KV cache let's
quickly go over what happens in the
self-attention mechanism when we
generate a sentence
at the beginning of a transformer layer
each token corresponds to an embedding
Vector X
the first thing that happens is X is
multiplied by three different matrices
to generate the query key and value
vectors these three matrices denoted by
WQ w k and WV are learned from data
during decoding these three are not the
same size in fact the query q is usually
a vector but the K and V are matrices
this is how I like to think about it the
query Vector represents the new token in
this decoder step and since there is
only one token this is a vector instead
of a matrix
the key Matrix represents all the
previous contexts that the model should
attend to and finally the value Matrix
also represents all the previous context
but is applied after softmax as a
weighted sum
during their attention mechanism we
first take a DOT product between the
query vector and the key Matrix then we
take a soft Max and apply that as a
weighted sum over the value Matrix in
Auto regressive decoding we are
generating one word at a time given all
of the previous context so the K and V
matrices contain information about the
entire sequence but the query Vector
only contains information about the last
token that we have seen
you can think of the dot product between
q and K as doing attention between the
current token that we care about and all
of the previous tokens at the same time
as we generate a sequence one token at a
time the K and V matrices actually don't
change very much this token corresponds
to a column of the K Matrix and a row of
the V Matrix and The crucial thing is
that once we've computed the embedding
for this word it's not going to change
again no matter how many more words we
generate but the model still has to do
the heavy work of computing the key and
the value vectors for this word on all
subsequent steps this results in a
quadratic number of Matrix Vector
multiplications which is going to be
really slow
as an analogy imagine if you are a model
writing a sentence one word at a time
but each word you write you have to read
every word that you've written before
and then use that information to
generate the next word obviously this is
extremely inefficient and it will be
much better if you could somehow
remember what you wrote as you're
writing it
now we're finally ready to explain how
the KV cache works when the model reads
a new word it generates the query Vector
as before but we cached the previous
values for the key and value matrices so
we no longer have to compute these
vectors for the previous context instead
we only have to compute one new column
for the key Matrix and one new row for
the value Matrix and then we proceed
with the dot product and soft Max as
usual to compute the scaled dot product
attention
by the way if you like this video so far
please give me a thumbs up to feed the
YouTube algorithm And subscribe to my
channel now let's talk about how the KV
cache fits in with the rest of the
Transformer
here we have the self-attention layer
and instead of passing in a whole
sequence of embeddings now we only pass
in the previous knv cache and the
embedding for the current token the
self-attention layer computes to the new
key and value vectors for the current
token and appends them to the KV cache
we will then need to store these key and
value matrices somewhere in the gpu's
memory so that we can retrieve them
later when we're working on the next
token notice that the only part of the
model where the current token interacts
with the previous token is the
self-attention layer in every other
layer such as the positional embedding
the layer norm and the feed 4 neural
network there is no interaction between
the current token and the previous
context so when we're using the KV cache
we only have to do a constant amount of
work for each new token and this work
does not get bigger when the sequence
gets longer now let's look at how much
memory it takes to store the KV cache
here is the equation for the memory
usage first we have two because there
are two matrices K and V that we need to
explore Precision is the number of bytes
per parameter for example in fp32 there
are four bytes per parameter
and layers is the number of layers in
the model
the model is the dimension of the
embeddings in each layer
the sequence length is the length that
we need to generate at the end including
all of the prompt tokens and everything
that we generate
and finally batch is the batch size
we multiply these all together to get
the total memory usage of the KVA cache
let's walk through an example involving
a 30 billion parameter model which
nowadays is considered like medium large
we have to store Two for the matrices K
and V
typically the position is 2 because
influence is done in 16 bits and not 32.
the number of layers in this model is 48
and the dimension size of this model is
around 7000 and let's say that we capped
the max sequence length at 102 full and
we use a batch size of 128.
if we multiply everything together we
get that the KV cache for this model is
180 gigabytes
and the model itself is 2 times 30
billion which is 60 gigabytes so you can
see that the KV cache takes up three
times as much memory as the model itself
and this sort of ratio is pretty typical
for inference scenarios and KV cache
tends to be the dominant factor in
memory usage during inference one more
thing to be aware of is the difference
in latency in processing that prompt
versus subsequent tokens
when the model is given the prompt and
is deciding the first token to generate
this has higher latency because there is
no KV cache yet so it has to compute the
K and V matrices for every token in the
prompt but after this has been done each
subsequent token will have lower latency
because it only has to compute K and V
for one token
that's it for the KVA cache if you have
any questions please don't hesitate to
leave a comment below and if you found
this content helpful please like And
subscribe to my channel so you can get
notified when I make new machine
learning videos it will help me out a
lot goodbye
Weitere verwandte Videos ansehen
![](https://i.ytimg.com/vi/k1Wo6o_Wn7s/hqdefault.jpg?sqp=-oaymwEXCJADEOABSFryq4qpAwkIARUAAIhCGAE=&rs=AOn4CLCaNanpCW_sxj-lWfhWTGRsSe_bEg)
Mistral Spelled Out: Prefill and Chunking : Part 9
![](https://i.ytimg.com/vi/tMGTuhhG6xk/hq720.jpg)
DNS Cache
![](https://i.ytimg.com/vi/nJSMrpAMBNQ/hq720.jpg?sqp=-oaymwEmCIAKENAF8quKqQMa8AEB-AHUCIAC0AWKAgwIABABGGUgZShlMA8=&rs=AOn4CLCllM4u0kpcVFJEqDrZnwywDlsw3g)
Intuition Behind the Attention Mechanism from Transformers using Spreadsheets
![](https://i.ytimg.com/vi/NVxcsekcbhs/hq720.jpg)
LLMs are not superintelligent | Yann LeCun and Lex Fridman
![](https://i.ytimg.com/vi/xR7-H1EMDJc/hq720.jpg)
Google SWE teaches systems design | EP23: Conflict-Free Replicated Data Types
![](https://i.ytimg.com/vi/eMlx5fFNoYc/hq720.jpg)
Visualizing Attention, a Transformer's Heart | Chapter 6, Deep Learning
5.0 / 5 (0 votes)