LLM 推理中的 KV Cache 技术深度解析

技术背景:自注意力机制 (Self-Attention)

要理解 KV Cache,首先必须理解其服务的对象:Transformer 模型中的自注意力机制。自注意力的核心思想是为输入序列中的每个 Token 计算其与其他所有 Token 的关联程度(注意力分数)。

对于每个输入的 Token 嵌入向量 x,模型会通过三个独立的权重矩阵 Wq, Wk, Wv 来生成三个向量:

注意力计算的简化公式如下:

Attention(Q, K, V) = softmax( (Q @ K.T) / sqrt(d_k) ) @ V

其中 d_k 是 Key 向量的维度。这个公式的本质是:用每个 Q 去和所有的 K 计算相似度,将得到的权重应用到对应的 V 上,最后加权求和。

无 KV Cache 的自回归推理

在自回归生成(Autoregressive Generation)中,模型逐个 Token 生成输出。在第 t 步,模型需要利用前面所有 t-1 个 Token 的信息来生成第 t 个 Token。这意味着,每一步的 K 和 V 矩阵都在变大。

计算流程 (以生成第 t 个 Token 为例)

  1. 输入: 完整的输入序列 X = [x_1, x_2, ..., x_{t-1}]。这是一个形状为 [t-1, d_model] 的张量。
  2. 计算 Q, K, V: 对全部 t-1 个输入 Token 执行矩阵乘法,重新计算所有 Q, K, V。
    无KV Cache - QKV计算
    
    Q = X @ Wq  # Shape: [t-1, d_k]
    K = X @ Wk  # Shape: [t-1, d_k]
    V = X @ Wv  # Shape: [t-1, d_v]
                        
  3. 计算注意力: 我们只关心最后一个 Token (即新的 Query) 的输出,所以我们取 Q 的最后一行 q_t
    无KV Cache - 注意力计算
    
    q_t = Q[-1, :] # Shape: [1, d_k]
    # 注意力计算
    scores = q_t @ K.T  # Shape: [1, t-1]
    probs = softmax(scores / sqrt(d_k))
    output = probs @ V # Shape: [1, d_v]
                        

核心问题:在生成第 t 个 Token 时,我们被迫重新计算了 x_1x_{t-1} 的所有 K 和 V 向量。而这些向量在之前的步骤中其实已经被计算过了,造成了巨大的计算浪费。

引入 KV Cache 后的优化推理

KV Cache 的思想非常直接:既然过去的 K 和 V 向量在后续步骤中是固定不变的,我们只需将它们缓存起来即可。

计算流程 (以生成第 t 个 Token 为例)

  1. 输入: 仅需当前最新的 Token x_{t-1}。这是一个形状为 [1, d_model] 的张量。
  2. 获取缓存: 从缓存中加载之前所有步骤计算好的 K 和 V。
    使用KV Cache - 获取缓存
    
    K_cache = [k_1, k_2, ..., k_{t-2}] # Shape: [t-2, d_k]
    V_cache = [v_1, v_2, ..., v_{t-2}] # Shape: [t-2, d_v]
                        
  3. 计算新的 q, k, v: 仅对当前输入 x_{t-1} 计算其 q, k, v。
    使用KV Cache - 计算当前步QKV
    
    q_t = x_{t-1} @ Wq
    k_t = x_{t-1} @ Wk
    v_t = x_{t-1} @ Wv
                        
  4. 更新缓存: 将新计算的 k, v 追加到缓存的 K, V 后面。
    使用KV Cache - 更新缓存
    
    K_new = concat([K_cache, k_t]) # Shape: [t-1, d_k]
    V_new = concat([V_cache, v_t]) # Shape: [t-1, d_v]
                        
  5. 计算注意力: 使用新的 q_t 和更新后的 K_new, V_new 进行计算。
    使用KV Cache - 注意力计算
    
    scores = q_t @ K_new.T
    probs = softmax(scores / sqrt(d_k))
    output = probs @ V_new
                        
  6. 保存新缓存: 将 K_newV_new 写回缓存,供下一步使用。

优化效果量化分析

KV Cache 的优化效果主要体现在降低计算复杂度和减少内存带宽压力上。

1. 计算复杂度 (FLOPs)

我们主要关注 Attention 层的矩阵乘法。设序列长度为 N,模型维度为 d

阶段 无 KV Cache 使用 KV Cache
QKV 计算 在第 t 步,需要对 t 个 Token 计算,复杂度为 O(t * d^2)。总复杂度约为 O(N^2 * d^2) 在第 t 步,只需对 1 个 Token 计算,复杂度为 O(d^2)。总复杂度为 O(N * d^2)
注意力计算 (Score @ V) 在第 t 步,qK.T 复杂度为 O(t*d)scores@V 复杂度为 O(t*d)。总复杂度约为 O(N^2 * d) 与无缓存版本相同,总复杂度约为 O(N^2 * d)
总体复杂度 O(N² * d²) O(N² * d + N * d²)

结论:KV Cache 将 QKV 计算的复杂度从平方级降低到了线性级,这是最主要的性能提升来源。

2. 内存带宽 (Memory Bandwidth)

LLM 推理通常是内存带宽密集型 (Memory-Bound) 任务。瓶颈不在于 GPU 的计算速度,而在于从 HBM (高带宽内存) 加载巨大的模型权重到计算核心的速度。

这极大地降低了推理延迟(Time per token),因为等待数据加载的时间显著缩短。

3. 缓存的成本:内存占用

KV Cache 并非没有代价,它需要占用显存。其大小可以精确计算:

Cache Size (bytes) = 2 * L * H * S * D_h * P

这个公式解释了为什么长上下文(Large Context Window)对显存要求极高,因为 Cache 的大小与序列长度 S 成正比。