
LLM 最近很热门,但了解这些模型背后的工作原理总是很有意思的。可能有些人还不了解,LLM 自 2017 年著名的论文《注意力机制就是一切》(Attention is all you need)发表以来就一直在发展。但早期的基于 Transformer 的模型由于内部数学运算繁重,需要大量的内存,因此存在不少缺陷。
随着 LLM 生成的文本越来越多,GPU 内存消耗也会越来越高。当达到一定程度时,GPU 会出现内存溢出(Out of Memory)问题,导致整个程序崩溃,LLM 也无法继续生成文本。键值缓存(Key-Value Cacheing)是一种可以缓解这个问题的技术。它本质上是记住之前步骤中的重要信息。模型无需从头开始重新计算所有内容,而是重用已计算的内容,从而大大提高文本生成速度和效率。这项技术已被应用于多个模型,例如 Mistral、Llama 2 和 Llama 3 模型。
那么,让我们来了解一下为什么 KV 缓存对 LLM 如此重要。
注意力机制
我们通常会生成三个不同的权重矩阵(W_q、W_k、W_v)来生成 Q、K 和 V 向量。这些权重矩阵源自数据。

我们可以将 Q 视为一个向量,将 K 和 V 视为二维矩阵。这是因为 K 和 V 分别存储每个先前词元的向量,堆叠起来就形成了矩阵。
Q 向量表示解码器步骤中输入的新词元。
同时,K 矩阵表示新词元可以查询的所有先前词元的信息或“键”,以确定其相关性。
每当输入一个新的查询向量时,它都会将其自身与所有键向量进行比较。它本质上是找出哪些先前词元最为重要。
这种相关性通过加权平均值来表示,即软注意力得分。
V 矩阵表示每个先前词元的内容/含义。注意力得分计算完成后,会用于对 V 向量进行加权求和。最终的上下文输出将作为模型后续步骤的依据。
简单来说,K 矩阵决定关注哪些内容,而 V 矩阵则决定从中提取哪些信息。
预填充和解码过程中的注意力机制

Source: Doubleworld
问题是什么?
在典型的因果 Transformer 模型中,由于采用了自回归解码,我们每次生成一个词,前提是我们拥有所有先前的上下文信息。随着我们不断生成新的词元,K 矩阵和 V 矩阵会不断更新。一旦计算出该词元的嵌入向量,其对应的词元值就不会再改变。但是,模型需要在每个步骤中对该词元对应的 K 矩阵和 V 矩阵进行大量的计算。这会导致矩阵乘法运算次数呈平方级增长。这是一个非常繁重且耗时的任务。
为什么我们需要键值缓存?
让我们通过一个实际例子来理解——考试准备场景
我们一周后就要期末考试了,需要在7天内学习20个模块。我们从第一个模块开始,可以把它看作是第一个词元。学完之后,我们开始学习第二个模块,并自然而然地尝试将它与之前学过的内容联系起来。这个新模块(第二个词元)与之前模块的交互,类似于序列中新词元与之前词元的交互。
这种关系对于理解至关重要。当我们继续学习更多模块时,我们仍然记得之前模块的内容,我们的记忆就像一个缓存。因此,在学习一个新模块时,我们只需要思考它如何与之前的模块联系起来,而无需每次都从头开始学习所有内容。
这就是键值缓存的核心思想。键值对从一开始就存储在内存中,每当一个新的词元到来时,我们只需要计算它与已存储上下文的交互。

键值缓存的作用
正如您所见,键值缓存节省了大量的计算资源,因为所有之前的键值对(K 和 V)都存储在内存中。因此,当一个新的词元到达时,模型只需要计算这个新词元如何与已存储的词元交互。这避免了从头开始重新计算所有内容。
这样一来,模型的瓶颈就从计算密集型转变为内存密集型,显存成为主要的限制因素。在传统的多头自注意力机制中,每个新词元都需要进行完整的张量运算,其中 Q、K 和 V 都与输入相乘,导致计算复杂度为二次方。但使用键值缓存,我们只需对存储的 K 和 V 矩阵进行一次行(向量-矩阵)乘法运算即可获得注意力分数。这大大减少了计算量,使得该过程更多地受限于内存带宽,而不是原始的 GPU 计算。
键值缓存
键值缓存的核心思想很简单:我们能否避免对模型在推理过程中已经处理过的词元重复计算?

Source: YouTube (Tensordroid)
在上图中,我们可以看到红色区域是冗余的,因为我们只关心 Q-K(T) 中下一行的计算。
这里,K 对应于键矩阵,V 对应于值矩阵。
在因果转换器中计算下一个词元时,我们只需要最近的词元来预测下一个词元。这是下一个词预测的基础。由于我们只查询第 n 个词元来生成第 n+1 个词元,因此我们不需要旧的 Q 值,所以不会存储它们。但在标准的多头注意力机制中,由于我们没有缓存键和值,所有词元的 Q、K 和 V 都会被重复计算。

Source: YouTube (Tensordroid)
在上图中,K 和 V 代表了之前所有词元累积的上下文信息。在 KV 缓存出现之前,模型在每次解码步骤中都会重新计算从头到尾所有词元的 K 和 V 矩阵。这也意味着需要反复重新计算所有软注意力分数和最终注意力输出,计算量非常大。然而,有了 KV 缓存,我们存储了序列长度减 1 之前的所有 K 向量,并对从 1 到 n-1 的 V 矩阵执行相同的操作。
由于我们只需要查询最新的词元,因此不再需要重新计算之前的词元。第 n 个词元已经包含了生成第 (n+1) 个词元所需的所有信息。

KV缓存:核心思想
KV 缓存的核心思想由此开始……

每当一个新的词元到达时,我们计算其新的 K 向量,并将其作为附加列添加到现有的 K 矩阵中,形成 K’。类似地,我们计算最新标记的 V 向量,并将其作为新行添加到 V’ 中。这是一个轻量级的张量操作。
示例
如果这让你感到困惑,让我们用一个简单的示例来解释一下。

如果没有 KV 缓存,模型每次都会重新计算整个 K 和 V 矩阵。这会导致一次完整的二次方规模的注意力计算,以生成软注意力分数,然后将其乘以整个 V 矩阵。这在内存和计算方面都非常耗费资源。

有了 KV 缓存,我们只使用最新标记的查询向量 (q_new)。将 q_new 与更新后的 K 矩阵相乘,该矩阵包含 K_prev(已缓存)和 K_new(刚刚计算)。同样的逻辑也适用于 V_prev,它会获得一个新的 V_new 行。这通过将大型矩阵乘法转换为小得多的向量乘法,大大减少了计算量。生成的注意力分数被添加到注意力矩阵中,并在 Transformer 流程中正常使用。
然后,这第四个注意力值(在本例中)用于计算下一个标记。使用键值缓存(KV 缓存)时,我们不会存储或重新计算之前的 Q 值,但会存储所有之前的 K 向量。新的 Q 值乘以 Kᵀ 得到 Q·Kᵀ 向量(软注意力分数)。这些分数乘以 V 矩阵,生成存储上下文的加权平均值,该平均值即为用于预测的第四个标记的输出。
注意力仅依赖于前面的标记。此操作依赖于 GPU 的显存(VRAM),因此我们需要存储这些 K 和 V 矩阵。因此,实际上,这就是我们内存容量的上限。
在 Llama 2 等研究论文中也多次提到了 KV 头,因为我们知道 Transformer 模型有多个层,每个层有多个头。每个头都会应用 KV 缓存,因此得名 KV 头。
键值缓存的挑战
键值缓存面临的挑战包括:
- GPU 利用率低:尽管 GPU 性能强大,但键值缓存通常仅使用 20-40% 的 GPU 显存,这主要是由于内存分配方式无法充分利用。
- 连续内存需求:键值缓存块必须存储在连续的内存中,这使得内存分配更加严格,并导致许多小的空闲空间无法使用。
即使 GPU 非常擅长处理快速迭代操作,但使用键值缓存时,GPU 内存也无法得到充分利用。问题在于内存的划分和预留方式,即如何用于生成令牌。
内部碎片
内部碎片是指分配的内存超过实际使用的内存量。例如,如果一个模型支持的最大序列长度为 4096 个标记,则必须预先分配所有 4096 个位置的空间。但实际上,大多数迭代可能只处理 200-500 个标记。因此,即使预留了 4096 个标记的内存,也只有一小部分被实际使用。剩余的空间仍然空置。这种浪费的空间就是内部碎片。
外部碎片
外部碎片是指内存可用,但并非以连续的块形式存在。不同的请求开始和结束的时间不同。这会在内存中留下太小的“空洞”,无法有效地重用。即使 GPU 有剩余的显存,也无法容纳新的键值缓存块,因为空闲内存不连续。这会导致内存利用率低,即使理论上内存是存在的。
内存预留问题
内存预留问题本质上是预先规划出错导致的。我们预留内存时,假设模型会始终生成最长的序列。但实际上,生成过程可能很早就停止(例如,EOS 标记或达到提前停止标准)。这意味着预留的键值缓存内存中很大一部分从未被使用,却仍然无法供其他请求使用。

Source: Efficient Memory Management for Large Language Model Serving with PagedAttention
分页注意力机制
分页注意力机制(vLLM 中使用的机制)借鉴了操作系统管理内存的思路,解决了内存碎片问题。就像操作系统进行分页一样,vLLM 将键值缓存分割成固定大小的小块,而不是预留一个连续的大块。
- 逻辑块:每个请求在逻辑上看到的都是一个干净、连续的键值内存序列。它认为其标记是按顺序存储的,就像普通的注意力机制一样。
- 物理块:实际上,这些“逻辑块”分散在 GPU 显存 (VRAM) 中许多不连续的物理内存页上。vLLM 使用块表(类似于操作系统中的页表)来映射逻辑位置和物理位置。

Source: Efficient Memory Management for Large Language Model Serving with PagedAttention
因此,PagedAttention 不需要一个 4096 格的大型连续缓冲区,而是将标记存储在小块(类似于页)中,并在内存中灵活地排列它们。
这意味着:
- 不再有内部碎片
- 无需连续的 VRAM
- 可以立即重用空闲块
- GPU 利用率从约 20-40% 跃升至约 96%
但我们不会在这里深入探讨这个概念。也许会在以后的博客文章中介绍。
键值缓存的内存消耗
要了解特定模型的键值缓存消耗多少内存,我们需要了解以下四个变量:
- num_layers(Transformer 模型中的层数)
- num_heads(多头自注意力层中每个头的键值头数量)
- head_dim(每个键值头的维度)
- precision_in_bytes(假设精度为 16 位,即 2 字节)(FP16 => 16 位 = 2 字节)
假设我们的用例使用以下值:num_layers = 32,num_heads = 32,head_dim = 128,batch_size = 1(实际部署通常会使用更大的批处理大小)。
现在,让我们计算每个 token 的键值缓存。
每个 token 的 KV 缓存大小 = 2 * (层数) * (头数 * 头维度) * 精度(字节) * 批次大小
为什么是 2?
因为每个 token 需要存储两个矩阵——K 矩阵和 V 矩阵。
KV cache per token = 2 * 32 * (32 * 128) * 2 * 1= 524288 B= 0.5 MB
我们需要 0.5 MB 的空间来存储每个 token 在所有层和头部的键值对信息。真是令人震惊!
以 Llama 2 模型为例,我们知道,如果只发送一个请求,序列长度为 4096。如果我们每次请求都提前发送,并且使用键值缓存,那么我们需要存储所有这些信息。
Total KV cache per request = seq_len * KV cache per token= 4096 * 524288= 2147483648 / (1024 * 1024 * 1024)= 2 GB
所以基本上,每个请求需要 2 GB 内存。KV 缓存用于处理单个请求。
小结
在本文章中,我们讨论了 KV 缓存的底层工作原理及其用途。通过讨论,我们了解了为什么它对于现代 LLM 来说是一项如此重要的优化。我们探讨了 Transformer 最初如何应对高计算量和内存需求,以及 KV 缓存如何通过存储先前计算的键值向量来显著减少冗余工作。
然后,我们研究了 KV 缓存的实际问题,例如 GPU 利用率低、对连续内存的要求,以及诸如内存预留、内部碎片和外部碎片等问题如何导致大量的显存浪费。
我们还简要讨论了分页注意力机制 (vLLM) 如何利用操作系统风格的分页逻辑来解决这些问题。它允许 KV 缓存以小的、可重用的块而不是大的、固定的块来增长,从而显著提高 GPU 内存利用率。
最后,我们分析了 KV 缓存内存消耗的计算。我们观察到,在像 Llama-2 这样的模型上,单个请求很容易就会占用近 2 GB 的显存用于键值存储。真是惊人。
如果您想让我更详细地讲解分页注意力机制、块表以及 vLLM 快速遍历的原因……请在下方评论区留言,我会在下一篇博客中尝试解答。下次见!


评论留言