The KV Cache: Memory Usage in Transformers

Efficient NLP
21 Jul 202308:33

Summary

TLDRThe video explains why Transformer language models like GPT require so much memory, especially when generating longer text sequences. The key reason is that self-attention requires computing attention between the current token and all previous tokens, which is inefficient. The solution is to cache the key and value matrices so they don't have to be recomputed for each new token. This KV cache takes up a large portion of memory since it grows with the sequence length. There is also higher latency when processing the initial prompt since the KV cache doesn't exist yet. Overall, the KV cache enables efficient autoregressive decoding but requires storing matrices that scale quadratically with sequence length.

Takeaways

  • ๐Ÿ˜ฑ Transformers require a lot of memory when generating long text because they have to recompute embeddings for all previous tokens on each new token
  • ๐Ÿค“ The memory usage comes mostly from the key-value (KV) cache that stores embeddings for previous tokens
  • ๐Ÿ” The key matrix represents previous context the model should attend to
  • ๐Ÿ’ก The value matrix represents previous context applied as a weighted sum after softmax
  • โœ๏ธ On each new token, only the query vector is computed while K and V matrices are cached
  • ๐Ÿš€ Using a KV cache reduces computations from quadratic to constant as sequence length grows
  • ๐Ÿ“Š For a 30B parameter model, the KV cache takes 180GB for a sequence of 102 tokens - 3X the model size!
  • โšก The KV cache dominates memory usage during inference
  • ๐Ÿข There is higher latency when processing the initial prompt without a KV cache
  • ๐Ÿ‘ The KV cache allows lower latency when generating each new token after the prompt

Q & A

  • Why do transformer language models require so much memory when generating long text?

    -As more text is generated, more key and value vectors need to be computed and cached, taking up more and more GPU memory. This quadratic growth in computations is inefficient and consumes a large amount of memory.

  • What are the key, value and query vectors in the self-attention mechanism?

    -The query vector represents the current token, the key matrix represents the previous context tokens, and the value matrix also represents the previous contexts but is applied as a weighted sum after softmax.

  • How does the KV cache help reduce memory usage and computations?

    -The KV cache stores previously computed key and value matrices so they don't need to be recomputed for every new token. Only the key and value for the new token is computed, significantly reducing computations.

  • Where in the transformer architecture is the KV cache used?

    -The KV cache is used in the self-attention layer. The cache and the current token embedding are passed in, and the new key and value vectors are computed and appended to the cache.

  • What factors contribute to the memory usage of the KV cache?

    -The factors are: 2 matrices K and V, precision (bytes per parameter), number of layers, embedding dimension per layer, maximum sequence length including prompt, and batch size.

  • Why does processing the initial prompt have higher latency?

    -For the initial prompt, there is no KV cache yet so the key and value matrices need to be computed for every prompt token. Subsequent tokens have lower latency since only the new token's KV is computed.

  • How large can the KV cache get for a typical transformer model?

    -For a 30 billion parameter model, the KV cache can be 180GB while the model itself is 60GB, so over 3x larger.

  • Why does the KV cache dominate memory usage during inference?

    -During inference, the KV cache holds the previously generated key and value matrices so they don't get recomputed. As the sequence grows longer, this cache grows much faster than the model size.

  • Does the KV cache reduce computational complexity for autoregressive decoding?

    -Yes, without the KV cache, a quadratic number of matrix-vector multiplications would be needed. The KV cache reduces this to a constant amount of work per token, independent of past sequence length.

  • Are there other memory optimizations used with transformers?

    -Yes, other optimizations include using lower-precision data types, gradient checkpointing, and knowledge distillation to compress models. But the KV cache addresses a key scalability issue.

Outlines

00:00

๐Ÿค”Why Transformers require so much memory

This paragraph explains why Transformer language models require a lot of memory as they generate more text. As more text is generated, more GPU memory is used up to store the key and value matrices representing previous context. This results in quadratic growth in computations. A key-value (KV) cache is introduced to reduce this computational complexity.

05:03

๐Ÿ˜ฎHow the KV cache works

This paragraph provides details on how the KV cache works with the Transformer model during text generation. For each new token, the KV cache stores previous key and value matrices so they don't have to be recomputed. Only the keys and values for the new token are computed. This greatly reduces the computations per token as the sequence grows longer.

Mindmap

Keywords

๐Ÿ’กTransformer

Transformers are a type of neural network architecture used in natural language processing models like GPT. They are able to generate coherent text but have high memory usage when generating longer text sequences, which is a key limitation.

๐Ÿ’กGPU memory

As Transformer models generate more text, they use up more GPU memory to store the internal representations and previous context. Eventually the GPU memory is exhausted, causing the program to crash.

๐Ÿ’กKV cache

The key-value (KV) cache stores previous hidden states in the Transformer so each new token only needs to interact with the latest state. This reduces redundant computations and memory usage.

๐Ÿ’กkey matrix

The key matrix in Transformer attention represents previous context tokens. It interacts with the query vector via dot product to determine relevance for the current token.

๐Ÿ’กvalue matrix

The value matrix contains representations of previous tokens. It is weighted and summed using the attention distribution to propagate relevant context.

๐Ÿ’กquery vector

The query vector represents the current decoded token in Transformer auto-regressive generation. It attends to the key matrix to focus on relevant context.

๐Ÿ’กattention mechanism

The Transformer attention mechanism compares the query vector and key matrix via dot product. The result is softened into a distribution and applied to the value matrix.

๐Ÿ’กdecoding

During Transformer decoding, tokens are generated auto-regressively one by one. The KV cache stores previous states so each step only attends to the latest.

๐Ÿ’กmemory usage

Storing previous key and value matrices for the full sequence being generated requires large memory size. The KV cache optimization reduces this.

๐Ÿ’กlatency

Processing the initial prompt has higher latency since no cache exists yet. Subsequent tokens have lower latency by only computing keys and values for the latest token.

Highlights

Transformers require a lot of memory when generating long text

OpenAI charges more for models that can handle longer context

Most memory is used by the key-value (KV) cache

The query, key and value vectors represent the current token and previous context

As the sequence grows, the K and V matrices stay mostly the same

Without caching, each token would recompute all previous tokens

The KV cache stores previous key and value matrices in memory

Only the self-attention layer interacts with the cache

KV cache dominates memory usage during inference

First token has higher latency since no cache exists yet

Subsequent tokens have lower latency by using the cache

KV cache equation shows memory depends on precision, layers, etc.

A 30B parameter model needs 180GB for a KV cache

KV cache takes 3x more memory than model parameters

Difference in latency for first token vs subsequent tokens

Transcripts

play00:00

Transformers are used almost everywhere

play00:02

in natural language processing models

play00:06

like GPT can write pages of coherent

play00:09

text something that is really impressive

play00:12

but one fundamental limitation of

play00:14

Transformer language models is that as

play00:16

you generate more and more text you will

play00:19

use up more and more GPU memory and

play00:22

eventually you reach a point where your

play00:24

GPU runs out of memory your program

play00:27

crashes and you cannot generate any more

play00:29

text

play00:30

my name is bai I'm a machine learning

play00:33

engineer and a PhD in natural language

play00:34

processing and today I will explain why

play00:38

is this why is it that Transformer

play00:40

language models require so much more

play00:42

memory when it deals with longer text

play00:45

actually if you look at open ai's API

play00:48

pricing for GPT you see something

play00:51

interesting they charge you twice as

play00:54

much per input token to use the longer

play00:56

context model

play00:57

this is one of The Economic Consequences

play01:00

of the high memory usage when you need

play01:03

to handle large context links

play01:05

most of the memory usage is taken up by

play01:08

the KV cache or the key value cache in

play01:11

this video I explain exactly what this

play01:13

is and why we need it

play01:15

before I explain the KV cache let's

play01:18

quickly go over what happens in the

play01:21

self-attention mechanism when we

play01:23

generate a sentence

play01:24

at the beginning of a transformer layer

play01:26

each token corresponds to an embedding

play01:30

Vector X

play01:31

the first thing that happens is X is

play01:34

multiplied by three different matrices

play01:36

to generate the query key and value

play01:39

vectors these three matrices denoted by

play01:43

WQ w k and WV are learned from data

play01:47

during decoding these three are not the

play01:50

same size in fact the query q is usually

play01:54

a vector but the K and V are matrices

play01:57

this is how I like to think about it the

play02:01

query Vector represents the new token in

play02:04

this decoder step and since there is

play02:06

only one token this is a vector instead

play02:09

of a matrix

play02:10

the key Matrix represents all the

play02:13

previous contexts that the model should

play02:15

attend to and finally the value Matrix

play02:18

also represents all the previous context

play02:20

but is applied after softmax as a

play02:24

weighted sum

play02:25

during their attention mechanism we

play02:28

first take a DOT product between the

play02:30

query vector and the key Matrix then we

play02:33

take a soft Max and apply that as a

play02:36

weighted sum over the value Matrix in

play02:39

Auto regressive decoding we are

play02:41

generating one word at a time given all

play02:44

of the previous context so the K and V

play02:47

matrices contain information about the

play02:50

entire sequence but the query Vector

play02:52

only contains information about the last

play02:54

token that we have seen

play02:56

you can think of the dot product between

play02:58

q and K as doing attention between the

play03:02

current token that we care about and all

play03:04

of the previous tokens at the same time

play03:07

as we generate a sequence one token at a

play03:09

time the K and V matrices actually don't

play03:12

change very much this token corresponds

play03:15

to a column of the K Matrix and a row of

play03:18

the V Matrix and The crucial thing is

play03:21

that once we've computed the embedding

play03:23

for this word it's not going to change

play03:26

again no matter how many more words we

play03:29

generate but the model still has to do

play03:31

the heavy work of computing the key and

play03:34

the value vectors for this word on all

play03:37

subsequent steps this results in a

play03:40

quadratic number of Matrix Vector

play03:42

multiplications which is going to be

play03:44

really slow

play03:45

as an analogy imagine if you are a model

play03:48

writing a sentence one word at a time

play03:50

but each word you write you have to read

play03:53

every word that you've written before

play03:55

and then use that information to

play03:58

generate the next word obviously this is

play04:01

extremely inefficient and it will be

play04:03

much better if you could somehow

play04:04

remember what you wrote as you're

play04:07

writing it

play04:08

now we're finally ready to explain how

play04:10

the KV cache works when the model reads

play04:13

a new word it generates the query Vector

play04:16

as before but we cached the previous

play04:19

values for the key and value matrices so

play04:22

we no longer have to compute these

play04:24

vectors for the previous context instead

play04:27

we only have to compute one new column

play04:30

for the key Matrix and one new row for

play04:33

the value Matrix and then we proceed

play04:35

with the dot product and soft Max as

play04:38

usual to compute the scaled dot product

play04:41

attention

play04:42

by the way if you like this video so far

play04:44

please give me a thumbs up to feed the

play04:46

YouTube algorithm And subscribe to my

play04:49

channel now let's talk about how the KV

play04:52

cache fits in with the rest of the

play04:54

Transformer

play04:55

here we have the self-attention layer

play04:57

and instead of passing in a whole

play04:59

sequence of embeddings now we only pass

play05:02

in the previous knv cache and the

play05:06

embedding for the current token the

play05:08

self-attention layer computes to the new

play05:10

key and value vectors for the current

play05:14

token and appends them to the KV cache

play05:17

we will then need to store these key and

play05:20

value matrices somewhere in the gpu's

play05:23

memory so that we can retrieve them

play05:25

later when we're working on the next

play05:27

token notice that the only part of the

play05:30

model where the current token interacts

play05:33

with the previous token is the

play05:34

self-attention layer in every other

play05:36

layer such as the positional embedding

play05:38

the layer norm and the feed 4 neural

play05:41

network there is no interaction between

play05:43

the current token and the previous

play05:45

context so when we're using the KV cache

play05:48

we only have to do a constant amount of

play05:51

work for each new token and this work

play05:53

does not get bigger when the sequence

play05:55

gets longer now let's look at how much

play05:58

memory it takes to store the KV cache

play06:01

here is the equation for the memory

play06:03

usage first we have two because there

play06:07

are two matrices K and V that we need to

play06:09

explore Precision is the number of bytes

play06:13

per parameter for example in fp32 there

play06:17

are four bytes per parameter

play06:19

and layers is the number of layers in

play06:22

the model

play06:23

the model is the dimension of the

play06:26

embeddings in each layer

play06:27

the sequence length is the length that

play06:30

we need to generate at the end including

play06:32

all of the prompt tokens and everything

play06:34

that we generate

play06:35

and finally batch is the batch size

play06:38

we multiply these all together to get

play06:41

the total memory usage of the KVA cache

play06:44

let's walk through an example involving

play06:46

a 30 billion parameter model which

play06:48

nowadays is considered like medium large

play06:51

we have to store Two for the matrices K

play06:55

and V

play06:56

typically the position is 2 because

play06:59

influence is done in 16 bits and not 32.

play07:03

the number of layers in this model is 48

play07:05

and the dimension size of this model is

play07:08

around 7000 and let's say that we capped

play07:11

the max sequence length at 102 full and

play07:14

we use a batch size of 128.

play07:17

if we multiply everything together we

play07:20

get that the KV cache for this model is

play07:23

180 gigabytes

play07:25

and the model itself is 2 times 30

play07:28

billion which is 60 gigabytes so you can

play07:31

see that the KV cache takes up three

play07:33

times as much memory as the model itself

play07:36

and this sort of ratio is pretty typical

play07:39

for inference scenarios and KV cache

play07:42

tends to be the dominant factor in

play07:44

memory usage during inference one more

play07:47

thing to be aware of is the difference

play07:49

in latency in processing that prompt

play07:52

versus subsequent tokens

play07:54

when the model is given the prompt and

play07:56

is deciding the first token to generate

play07:58

this has higher latency because there is

play08:01

no KV cache yet so it has to compute the

play08:04

K and V matrices for every token in the

play08:06

prompt but after this has been done each

play08:09

subsequent token will have lower latency

play08:12

because it only has to compute K and V

play08:14

for one token

play08:16

that's it for the KVA cache if you have

play08:18

any questions please don't hesitate to

play08:21

leave a comment below and if you found

play08:23

this content helpful please like And

play08:25

subscribe to my channel so you can get

play08:27

notified when I make new machine

play08:29

learning videos it will help me out a

play08:31

lot goodbye