要理解 KV Cache,首先必须理解其服务的对象:Transformer 模型中的自注意力机制。自注意力的核心思想是为输入序列中的每个 Token 计算其与其他所有 Token 的关联程度(注意力分数)。
对于每个输入的 Token 嵌入向量 x,模型会通过三个独立的权重矩阵 Wq, Wk, Wv 来生成三个向量:
注意力计算的简化公式如下:
其中 d_k 是 Key 向量的维度。这个公式的本质是:用每个 Q 去和所有的 K 计算相似度,将得到的权重应用到对应的 V 上,最后加权求和。
在自回归生成(Autoregressive Generation)中,模型逐个 Token 生成输出。在第 t 步,模型需要利用前面所有 t-1 个 Token 的信息来生成第 t 个 Token。这意味着,每一步的 K 和 V 矩阵都在变大。
X = [x_1, x_2, ..., x_{t-1}]。这是一个形状为 [t-1, d_model] 的张量。t-1 个输入 Token 执行矩阵乘法,重新计算所有 Q, K, V。
Q = X @ Wq # Shape: [t-1, d_k]
K = X @ Wk # Shape: [t-1, d_k]
V = X @ Wv # Shape: [t-1, d_v]
Q 的最后一行 q_t。
q_t = Q[-1, :] # Shape: [1, d_k]
# 注意力计算
scores = q_t @ K.T # Shape: [1, t-1]
probs = softmax(scores / sqrt(d_k))
output = probs @ V # Shape: [1, d_v]
核心问题:在生成第 t 个 Token 时,我们被迫重新计算了 x_1 到 x_{t-1} 的所有 K 和 V 向量。而这些向量在之前的步骤中其实已经被计算过了,造成了巨大的计算浪费。
KV Cache 的思想非常直接:既然过去的 K 和 V 向量在后续步骤中是固定不变的,我们只需将它们缓存起来即可。
x_{t-1}。这是一个形状为 [1, d_model] 的张量。
K_cache = [k_1, k_2, ..., k_{t-2}] # Shape: [t-2, d_k]
V_cache = [v_1, v_2, ..., v_{t-2}] # Shape: [t-2, d_v]
x_{t-1} 计算其 q, k, v。
q_t = x_{t-1} @ Wq
k_t = x_{t-1} @ Wk
v_t = x_{t-1} @ Wv
K_new = concat([K_cache, k_t]) # Shape: [t-1, d_k]
V_new = concat([V_cache, v_t]) # Shape: [t-1, d_v]
scores = q_t @ K_new.T
probs = softmax(scores / sqrt(d_k))
output = probs @ V_new
K_new 和 V_new 写回缓存,供下一步使用。KV Cache 的优化效果主要体现在降低计算复杂度和减少内存带宽压力上。
我们主要关注 Attention 层的矩阵乘法。设序列长度为 N,模型维度为 d。
| 阶段 | 无 KV Cache | 使用 KV Cache |
|---|---|---|
| QKV 计算 | 在第 t 步,需要对 t 个 Token 计算,复杂度为 O(t * d^2)。总复杂度约为 O(N^2 * d^2)。 |
在第 t 步,只需对 1 个 Token 计算,复杂度为 O(d^2)。总复杂度为 O(N * d^2)。 |
| 注意力计算 (Score @ V) | 在第 t 步,qK.T 复杂度为 O(t*d),scores@V 复杂度为 O(t*d)。总复杂度约为 O(N^2 * d)。 |
与无缓存版本相同,总复杂度约为 O(N^2 * d)。 |
| 总体复杂度 | O(N² * d²) | O(N² * d + N * d²) |
结论:KV Cache 将 QKV 计算的复杂度从平方级降低到了线性级,这是最主要的性能提升来源。
LLM 推理通常是内存带宽密集型 (Memory-Bound) 任务。瓶颈不在于 GPU 的计算速度,而在于从 HBM (高带宽内存) 加载巨大的模型权重到计算核心的速度。
t 步,需要加载 Wq, Wk, Wv 权重,并与 t 个 Token 的 embedding 进行计算。权重矩阵被反复从内存中读取。t 步,只需加载权重与 1 个 Token 的 embedding 进行计算。访问主权重矩阵的次数大大减少。虽然需要读写 KV Cache 本身,但 Cache 的体积远小于模型总权重,且访问模式更友好。这极大地降低了推理延迟(Time per token),因为等待数据加载的时间显著缩短。
KV Cache 并非没有代价,它需要占用显存。其大小可以精确计算:
这个公式解释了为什么长上下文(Large Context Window)对显存要求极高,因为 Cache 的大小与序列长度 S 成正比。