第2章 MLA 进阶:head_dim 512 与 grouped O 投影

“Each head should be a small but complete world.” —— 引自 V4 团队某次开源 talk 的开场

V3 的 MLA 把”小世界”压到 latent space 里;V4 的 MLA 把”小世界”做大了,每个 head 都自带 512 维。


2.1 引子:V3 的 MLA 走到了什么尽头

V2 和 V3 的 MLA(Multi-head Latent Attention)有一个非常优雅的设计:把 KV 压到一个低秩的潜在空间(kv_lora_rank=512),存的是 latent,用的时候再升维。这套设计在 128K context 以内非常成功,KV cache 体积几乎砍到原始 MHA 的 7%。

但当 context 从 128K 推到 1M 时,V3 的 MLA 暴露出三个工程伤疤:

伤疤一:升维矩阵 wkv_b 在 prefill 阶段的算力消耗超线性增长

V3 在 attention 计算前要把 latent KV 升回 head_dim,这一步是 kv_lora_rank → n_heads × head_dim 的矩阵乘。1M context 下,仅这一步就要跑 1048576 × 512 × 128 × 192 ≈ 1.3e16 FLOPs / layer。乘以 61 层和 batch_size,prefill 时 wkv_b 占了总 FLOPs 的相当一部分。

伤疤二:KV cache 的”低秩压缩”假设在长上下文上变弱

短上下文里,512 维 latent 存得下绝大多数 KV 信息。但长上下文里,序列经历的”语义切换”次数大幅增加(一篇 1M token 的论文集可能横跨几十个领域),低秩假设的成立基础动摇——开始出现”latent 表达不足”导致的细节丢失。

伤疤三:升维矩阵的 grad 与 act memory 在训练时不可忽略

V3 训练时,wkv_b 的 backward 需要 latent 的 forward activation 和 head_dim 的 grad,二者乘积是个相当大的中间张量。在 FP8 训练里,这部分激活的精度损失也在累积。

V4 的工程师们做了一个几乎完全相反的选择:抛弃 latent KV,每个 head 自己长成 512 维

flowchart LR
  subgraph V3MLA["V3 MLA"]
    direction LR
    XV3[x] --> WkvA[wkv_a] --> Latent["KV latent<br/>kv_lora_rank=512"] --> WkvB[wkv_b] --> Hkv["K, V 升回<br/>n_heads × 192"]
  end
  subgraph V4MLA["V4 MLA"]
    direction LR
    XV4[x] --> Wkv[wkv] --> KV["KV 直接到<br/>head_dim=512"]
  end

V3 是”压缩 → 升维 → 用”,V4 是”直接到位 → 用”。看起来 V4 用了更大的 KV,但因为加了滑窗 + Compressor + Indexer 稀疏选取,整体 KV cache 反而比 V3 还小(对照 §1.9·补·补 的 KV 公式)。

这章拆 V4 的 MLA 部分(不含 Compressor / Indexer / sparse_attn 那三个独立模块——分别在第 3、4、5 章展开),重点回答四个问题:

  1. 为什么 head_dim 从 192 跳到 512?
  2. grouped O 投影(o_groups=16)到底解决了什么?
  3. attn_sink 这个新增参数是干嘛的?
  4. Q 路径上的二次归一化(q *= rsqrt(...))为什么必须?

2.2 V4 的 Attention 类全景

先把 V4 的 Attention 类(inference/model.py)的 __init__ 完整摆出来——

class Attention(nn.Module):
    """Multi-head Latent Attention (MLA) with sliding window + optional KV compression."""
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.layer_id = layer_id
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = args.n_heads // world_size
        self.q_lora_rank = args.q_lora_rank
        self.o_lora_rank = args.o_lora_rank
        self.head_dim = args.head_dim
        self.rope_head_dim = args.rope_head_dim
        self.nope_head_dim = args.head_dim - args.rope_head_dim
        self.n_groups = args.o_groups
        self.n_local_groups = self.n_groups // world_size
        self.window_size = args.window_size
        self.compress_ratio = args.compress_ratios[layer_id]
        self.eps = args.norm_eps

        self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))

        # Q path
        self.wq_a = Linear(self.dim, self.q_lora_rank)
        self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
        self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)

        # KV path
        self.wkv = Linear(self.dim, self.head_dim)
        self.kv_norm = RMSNorm(self.head_dim, self.eps)

        # O path (grouped low-rank)
        self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups,
                                          self.n_groups * args.o_lora_rank,
                                          dtype=torch.bfloat16)
        self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)

        self.softmax_scale = self.head_dim ** -0.5

        # Compressor / Indexer (per-layer)
        if self.compress_ratio:
            self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
            if self.compress_ratio == 4:
                self.indexer = Indexer(args, self.compress_ratio)
            else:
                self.indexer = None

        kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
        self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim))

把这 30 多行拆成五块:

  • Q 路径wq_a → q_norm → wq_b(与 V3 相同)
  • KV 路径wkv → kv_normV3 是 wkv_a + kv_norm + wkv_b 三段,V4 砍成一段
  • O 路径wo_a → wo_bV4 新增 grouped low-rank
  • 稀疏路径compressor + indexerV4 新增,第 3、4 章)
  • 运行时状态attn_sink + kv_cache + freqs_cis

本章聚焦前三块。


2.3 head_dim=512 的几何与代数

V4 把 head_dim 从 V3 的 192(其中 nope=128, rope=64)拉到 512(nope=448, rope=64)。这是一次”维度大改造”——但它的代数动机非常清晰:

2.3.1 KV 不再被压缩到 latent,head_dim 必须扛起”语义带宽”

V3 时代,每个 head 的语义负载主要由 latent space 承担,所以 head_dim 可以小(192)。V4 抛弃 latent 后,每个 head 必须独自承载完整的 K/V 语义——head_dim 必须显著扩大。

512 这个具体数字的来源大概是:

  • V3 的 kv_lora_rank=512——V4 的 head_dim 取相同数字,意味着”V3 时代用整段 latent 承载的语义,现在被 V4 的每个 head 单独承载”
  • 512 = 4 × 128——便于 FP8 的 128 块状量化
  • 512 是一个”大但不离谱”的数——在保持 head 数 128 不变的情况下,单层 KV 维度从 V3 的 128×192 变成 V4 的 1×512(注意 V4 的 num_key_value_heads=1)

2.3.2 num_key_value_heads=1:所有 Q head 共享一组 KV

V4 的 config.json 里有一个容易被忽略的字段:num_key_value_heads=1。这意味着 V4 的 KV 头数只有 1——128 个 Q head 共享同一组 K/V

这是 GQA 思路推到极致的版本:

  • MHA:128 Q head + 128 KV head
  • GQA-8:128 Q head + 8 KV head
  • MQA:128 Q head + 1 KV head

V4 在结构上接近 MQA,但 head_dim 大到 512——每个 token 的 KV 总占用是 1 × 512 = 512。对照 V3 的 GQA + latent:每 token KV 总占用是 128 × 192 / kv_lora_rank ≈ 512(latent 形式)。两者总占用接近,但 V4 不再需要”升维矩阵”。

2.3.3 head_dim 拆分:rope=64 + nope=448

V4 的 head_dim 内部按 V3 一样拆成两段:

  • rope 部分(最后 64 维):参与 RoPE 旋转,承载位置信息
  • nope 部分(前 448 维):不参与 RoPE,承载语义信息

源码里这样体现(Attention.forward):

apply_rotary_emb(q[..., -rd:], freqs_cis)           # 只对最后 rd=64 维做 RoPE
apply_rotary_emb(kv[..., -rd:], freqs_cis)
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)  # 只对前 448 维做 FP8 量化

这种”语义 / 位置分离”的拆法在 V2 就引入了,V4 沿用——位置信息不参与 FP8 量化(保 BF16 精度),语义信息可以容忍量化损失。


2.4 grouped O 投影:V4 的”O LoRA” 创新

V3 的 O 投影是单矩阵 wo: [n_heads × v_dim → dim]。V4 改成 grouped low-rank

# V4
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups,
                                  self.n_groups * args.o_lora_rank,
                                  dtype=torch.bfloat16)
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)

让我们一步一步算这个矩阵的形状变换:

  • 原始 attention 输出:o: [B, S, n_heads × head_dim] = [B, S, 128 × 512] = [B, S, 65536]
  • 按 16 组拆分:每组 n_heads × head_dim / n_groups = 128 × 512 / 16 = 4096
  • wo_a[group_dim=4096 → 16 × o_lora_rank=16384](每组独立做低秩投影)
  • wo_b[16384 → dim=7168](合并 16 组并投回 hidden)

对比 V3 的单矩阵:

  • V3:wo: [128 × 128 → 7168](参数量 16384 × 7168 ≈ 117 M)
  • V4:wo_a + wo_b:参数量 4096 × 16384 + 16384 × 7168 ≈ 184 M

V4 的参数量反而更多?是的——但 V4 的”代数表达力”也更强:

  • V3 的 wo 是一个全连接:每个输出维度都看所有 head 的所有维度
  • V4 的 wo_a 是 grouped:每组只看本组内的 head 维度(local mixing),然后 wo_b 再做”组间合并”(global mixing)

这种”局部 + 全局”的两段式结构在表达力上严格强于单矩阵 + 同等参数量——这是 grouped LoRA 论文的主要论点。V4 把它工业化用到了 attention 的 O 投影上。

2.4.1 为什么是 16 组而不是 8 / 32 / 64

n_groups=16 这个数字的工程权衡:

  • 太少(< 8):组内还是太大,没有从分组中拿到表达力红利
  • 太多(> 32):组间合并的 wo_b 矩阵太大(与 group 数线性增长),失去参数效率
  • 16 = n_heads / 8:每组刚好对应 8 个 head,是”组内能做多 head 局部混合”的最小可用配置

n_groups=16 + o_lora_rank=1024 的组合,让 V4 的 O 投影的总参数量约为单矩阵版本的 1.6 倍,但表达力提升远超 1.6 倍——这是个非常划算的工程交易。

2.4.2 forward 里的 einsum

源码里 O 投影的 forward 部分:

o = o.view(bsz, seqlen, self.n_local_groups, -1)
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
x = self.wo_b(o.flatten(2))

注意 wo_a.weight.view(n_local_groups, o_lora_rank, -1)——把 wo_a 的权重按组 reshape,让每组拥有独立的低秩投影矩阵。einsum 实现了”每个 group 独立做 [group_dim → o_lora_rank]“的并行计算。

最后 o.flatten(2) 把 16 组的 o_lora_rank 拼起来变成 [B, S, 16 × 1024],再过 wo_b 投回 hidden。


2.4·补 grouped O 投影的几何直觉

如果觉得 grouped LoRA 的代数形式不直观,我们可以从几何角度给一个直觉解释。

考虑一个简化场景:假设你有 8 个 head 的输出向量,每个 head 16 维,你要把它们投回 32 维的 hidden。

单矩阵方案(V3)

o_total: [128 维] = [8 head × 16 维]
W: [128 → 32],参数量 = 4096

每个 hidden 输出维度都看完所有 8 个 head 的所有 16 维。

grouped LoRA 方案(V4 简化版,2 组 × 4 head/组)

o_grouped: [2 组 × 64 维],每组包含 4 个 head 共 64 维
W_a: [64 → 16] × 2 组(每组独立),共参数量 = 64 × 16 × 2 = 2048
W_b: [32 → 32],参数量 = 32 × 32 = 1024
总参数 = 3072

参数量更少,但表达力更结构化:

  • 第一阶段(W_a):每个 group 内部做”局部混合”——4 个 head 的信息被压成 16 维
  • 第二阶段(W_b):在 group 间做”全局混合”——2 个 group 的信息被混回 32 维

这种”先局部、再全局”的两段结构,比单矩阵的”一步全连接”更符合 attention head 的实际语义结构:相似 head(比如同一个语法关系的 head)天然会聚成一组,组内可以共享更多投影方向。

V4 的 n_groups=16 + o_lora_rank=1024 + n_heads=128 这套配置,相当于:

  • 8 个 head 一组(相邻 head 倾向于学到相似的 attention pattern)
  • 每组独立做低秩投影到 1024 维
  • 16 组的 1024 维拼成 16384 维,再映射回 7168 维 hidden

这种结构既省了参数(相对单矩阵的 65536 × 7168 ≈ 470M),又给 attention head 的”自然分组”留了空间。

flowchart LR
  subgraph V3O["V3 单矩阵 O"]
    direction TB
    OA1["o: 128×128"] --> WO["W_o: 16384→7168"]
    WO --> XV3["x: 7168"]
  end
  subgraph V4O["V4 grouped O"]
    direction TB
    OA2["o: 128×512<br/>= 65536"]
    OA2 --> Reshape["reshape 16 组 × 4096"]
    Reshape -->|每组独立| Wa["wo_a: 4096→1024<br/>×16 组"]
    Wa -->|拼接| Cat["16×1024<br/>= 16384"]
    Cat --> Wb["wo_b: 16384→7168"]
    Wb --> XV4["x: 7168"]
  end

2.5 attn_sink:稀疏注意力的”兜底”

V4 在 Attention.__init__ 里有这一行:

self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))

每个 head 一个 float32 的可学习标量。这个参数在 sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) 调用里被传进去——它扮演的角色是”当稀疏选取找不到合适的 KV 时,给 attention 一个安全兜底”。

具体的数学:

  • 标准 softmax attention:P = softmax(Q K^T / √d)
  • 加 sink:P_sink = softmax([Q K^T / √d, sink_bias])——把 sink 当作一个虚拟的 KV,与真实 KV 一起参与 softmax 归一化

这种”attention sink”的概念最早来自 Streaming LLM 论文,原本是为长上下文 streaming 时的”开头 token 异常吸收注意力”提出的。V4 把它用在稀疏 attention 上——当 Indexer 没有选到任何”合适的”KV 时,attn_sink 接住了该 head 的注意力质量,避免数值崩溃。

这个参数的存在是 V4 团队”工程纪律”的一个具体体现——他们没有依赖”稀疏选取永远能找到合适的 KV”这种乐观假设,而是给数值稳定性做了一个 fallback。


2.5·补 attn_sink 的训练动力学

attn_sink 这个参数在训练里的角色是什么?为什么它必须是可学习的 float32 标量而不是 0 或硬编码常数?

数值角度

在标准的 softmax attention 里,给每个 KV 位置一个 logit s_i,softmax 输出 exp(s_i) / sum_j exp(s_j)。如果所有 s_i 都是负数(极端情况:稀疏选取选错了,剩下的 KV 都和 query 不相关),那 softmax 仍然会强行把 weight 分到这些”不相关”位置上——这是 softmax 的质量守恒性

加 attn_sink 后,softmax 变成 exp(s_i) / (sum_j exp(s_j) + exp(sink))。当所有 s_i 都很负、而 sink 是正数时,sink 项会主导分母——意味着所有真实位置的 weight 都被压低,等效于”这一步 attention 选择了不输出”

这是稀疏注意力数值稳定性的关键——给模型一个”我可以什么都不选”的安全出口。

学习角度

attn_sink 必须可学习,因为不同 layer / 不同 head 对”不输出”的偏好是不同的:

  • 早期层(更接近 raw token):sink 应该偏低,因为每个位置都”应该有事可做”
  • 中期层(语义抽象):sink 应该中等,给”想跳过这一层 attention” 的可能性
  • 晚期层(表征整合):sink 应该偏低,因为最终的 logits 必须依赖具体 KV

V4 的 attn_sink 是 [n_local_heads] 维度的,意味着每个 head 独立学习自己的 sink——这给了模型最大的灵活性。

为什么是 float32

源码里 self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) ——明确指定 float32。原因是 sink 在 softmax 里参与的是与 s_i / √d 同量级的数值比较,而 s_i / √d 在长上下文下可能跨越很大的动态范围(log 概率差异可能达到 30-40)。BF16 的有效精度(约 7 bit mantissa)在这个量级下就不够了——必须用 float32 保证比较的稳定。


2.5·补·补 attn_sink 与 “softmax + 1” 的关系

attn_sink 概念在学术界有几个孪生兄弟,把它们摆在一起更容易看清 V4 的选择来自哪里:

Streaming LLM 的 sink token:在序列开头保留 4 个 token 不参与滑窗淘汰,让它们持续吸收”溢出注意力”。这是位置层面的 sink。

softmax + 1Off-by-One Errors 提出):把 softmax 的分母多加一个 exp(0) = 1,等价于给一个虚拟的”什么都不做”位置。这是数值层面的 sink。

learnable sink scalar(V4 的方案):每 head 一个可学习 float32 标量,作为虚拟 KV 的 logit 值参与 softmax。这是参数层面的 sink。

V4 选 learnable scalar 而非 softmax+1 的原因是层间差异——不同 layer / 不同 head 对”该不该输出”的偏好不同。一个固定的 +1 在所有 head 上等价,而 learnable scalar 让每个 head 自己决定 sink 的强度。

V4 选 scalar 而非 sink token 的原因是位置无关——稀疏 attention 选取的是动态的 KV 位置集合,没有”固定开头几个 token”的概念。scalar 不依赖位置,更适合稀疏架构。

这种”工程纪律”——选最简单但同时最灵活的形式——在 V4 源码里反复出现。


2.6 Q 路径的二次归一化

Attention.forward 里有这样一行(容易被忽略):

q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)

第一行是常规的 wq_b 投影 + reshape。第二行做了一个针对 q 的 inline RMSNorm——这个归一化在 q_norm(RMSNorm(q_lora_rank),作用在 q 投影前)之外额外做的一次

为什么需要二次归一?有两个原因:

  1. 稀疏 attention 对 Q/K 数值范围敏感:稀疏选取依赖 Q · K 的 dot product 排序——如果 Q 的 norm 在不同 head 间差异很大,softmax 会被某些 head 主导,稀疏选取会偏。这个 inline rsqrt 把每个 head 的 q 拉回到单位球面附近,让所有 head 在 score 比较时是公平的。
  2. 配合 attn_sink 的数值稳定:sink 是 float32 标量,q 也需要保持稳定的数值范围才能与 sink 在同一 softmax 里比较——inline 归一化保证了这一点。

这是一个”小代码、大影响”的细节——不写它,稀疏注意力的精度可能会出现非预期的层间差异。


2.7 KV 路径的 RoPE 与 FP8 量化分离

V4 在 KV 路径上的处理也很讲究——

kv = self.wkv(x)
kv = self.kv_norm(kv)
apply_rotary_emb(kv[..., -rd:], freqs_cis)              # 只对最后 64 维做 RoPE
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)  # 只对前 448 维做 FP8 量化

注意 act_quant 的第二个参数是 64——这是 quantization 的块大小。前 448 维被分成 7 个 64-维 block,每个 block 一个 FP8 scale。

**为什么 RoPE 部分不做 FP8?**因为 RoPE 的旋转矩阵是浮点旋转,FP8 的精度对它太低——会引入 phase error。所以 V4 把 RoPE 部分留在 BF16,只对语义部分做 FP8。

这种”位置 / 语义分离 → 位置走 BF16、语义走 FP8”的设计,是 V4 整个混合精度策略的一个缩影。第 12 章会从全局角度解释。


2.7·补 V3 与 V4 的 KV 数学对比

把 V3 和 V4 的 KV 处理逐步骤摆出来,可以更清楚地看到 V4 抛弃 latent 之后的代数节省:

V3 KV 路径(per-layer, per-token)

1. 投影:x ∈ R^7168  --wkv_a-->  latent ∈ R^576 (kv_lora_rank=512 + qk_rope=64)
   FLOPs: 7168 × 576 = 4.13M

2. 归一化:latent --kv_norm-->  latent_normed ∈ R^576
   FLOPs: O(576)

3. 升维:latent --wkv_b-->  k_v ∈ R^(n_heads × (qk_nope + v_dim)) = R^(128 × 320) = R^40960
   FLOPs: 512 × 40960 = 20.97M

4. 应用 RoPE:rope 部分独立处理
   FLOPs: O(64) per head

5. attention 计算:q · k 内积、softmax、加权求和
   FLOPs: n_heads × seqlen × head_dim

总计 KV 侧 FLOPs(不含 attention 本身):约 25M / token / layer
KV cache 存储:latent 形式 → 576 bytes (BF16) per token per layer

V4 KV 路径(per-layer, per-token, ratio=4 层)

1. 投影:x ∈ R^7168  --wkv-->  kv ∈ R^512
   FLOPs: 7168 × 512 = 3.67M

2. 归一化:kv --kv_norm-->  kv_normed ∈ R^512
   FLOPs: O(512)

3. (无升维步骤)

4. 应用 RoPE:rope 部分独立处理(最后 64 维)
   FLOPs: O(64)

5. FP8 量化:前 448 维做 act_quant
   FLOPs: O(448)

6. (Compressor 路径,每 4 token 一次):把 4 个 token 的 KV 压成 1 组
   FLOPs: 摊销下来约 0.3 × (3.67M / 4) = 0.27M

7. attention 计算:滑窗 + Indexer top-1024
   FLOPs: n_heads × (window_size + 1024) × head_dim   ← 不再线性增长

总计 KV 侧 FLOPs(不含 attention 本身):约 4M / token / layer
KV cache 存储:每 ratio 个 token 存一组 → 512 / 4 = 128 bytes/token/layer

V4 在 KV 侧的 FLOPs 大约只有 V3 的 1/6,KV cache 存储约 V3 的 1/4。两者乘起来再加上 attention 本身的 O(n × topk) 成本(V3 是 O(n²)),整体砍到 V3.2 的 27% 就不再神秘。


2.8 与 GQA / MQA 的横向对比

把 V4 的 attention 与同代的 GQA / MQA 摆在一起:

维度MHAGQA-8MQAV3 MLAV4 MLA
Q head 数128128128128128
KV head 数12881latent (单组)1
head_dim128128128192 (nope+rope)512 (nope+rope)
KV / token128 × 128 × 28 × 128 × 21 × 128 × 2512 (latent)1 × 512
升维需求wkv_b 升维
是否参数共享-KV 跨 head 共享KV 全 head 共享latent 共享KV 全 head 共享
工程兼容性FlashAttnFlashAttnFlashAttn需特化 kernel需 sparse_attn
长上下文成本极高中(升维代价)低 + 稀疏
O 投影单矩阵单矩阵单矩阵单矩阵grouped LoRA

V4 的 attention 相当于:MQA 的 KV 节省 + MLA 的语义带宽 + grouped O 的表达力——在三条路上同时拿。其代价是需要一个特化的 sparse_attn kernel,这个 kernel 的存在让 V4 不能直接用 FlashAttention 跑——必须用 FlashMLA。第 5 章会展开。


2.8·补 prefill 与 decode 的两条 codepath

V4 的 Attention.forward 在 prefill(首次 forward 整段 prompt)与 decode(自回归一次只送一个 token)两个阶段走的是完全不同的 codepath。这一点在源码 Attention.forward 里通过 if start_pos == 0: 的分支显式写出。

2.8·补.1 prefill 阶段(start_pos == 0

  • q / kv 一次性算完整段:q 形状 [B, S, n_heads, 512],kv 形状 [B, S, 512]
  • KV cache 一次性写入:滑窗 KV 写入 kv_cache[:, :seqlen](如 seqlen ≤ window_size)或环形写入(seqlen > window_size 时只保留最后一窗)
  • Compressor 一次性压缩:把整段 KV 按 ratio 压缩成 seqlen / ratio
  • Indexer 一次性算 score:score [B, S, n_heads, end_pos/ratio] 一次性 topk
  • sparse_attn 一次性算完整段输出:通过 cat 后的 topk_idxs 索引到 KV,做稀疏 attention

prefill 的特点是所有计算都是矩阵化的,能充分利用 GPU 的 TensorCore;瓶颈通常在 KV 投影和 sparse_attn kernel。

2.8·补.2 decode 阶段(start_pos > 0

  • q / kv 只算 1 个 token:q [B, 1, n_heads, 512],kv [B, 1, 512]
  • KV cache 增量写入kv_cache[:, start_pos % window_size] = kv.squeeze(1)——按环形 buffer 覆盖最早的 token
  • Compressor 增量更新:把新 KV 累积到 kv_state buffer,每 ratio 步触发一次”压缩落盘”
  • Indexer 增量打分:query 一行、对全部已存压缩 KV 打分、topk
  • sparse_attn 增量算 1 个 token 的输出:query 单 token,KV 是窗 + 压缩区合集

decode 的特点是每步只算 1 token,但要查所有历史 KV——瓶颈在内存带宽。V4 的”窗 + 压缩”配合 Indexer 的 top-1024,使得 decode 时只需查 ~1024 + 128 个 KV 位置而非全部历史,这是 V4 在长上下文 decode 阶段保持 throughput 的关键。

2.8·补.3 两条 codepath 的协调

V4 的 Compressor 和 Indexer 在两条 codepath 上用同一个权重矩阵——一组训练好的 wkv / wgate / weights_proj 必须同时支持”批量压缩整段”和”增量压缩单 token”。源码里的 if start_pos == 0 分支就是为了让这两种行为产出数学等价的结果

flowchart TB
  subgraph Prefill["Prefill (start_pos == 0)"]
    P1["seqlen=S 个 token 一起算"]
    P2["KV cache 一次性写入"]
    P3["Compressor 批量压 S/ratio 组"]
    P4["Indexer 批量打分 + topk"]
    P5["sparse_attn 算整段输出"]
  end
  subgraph Decode["Decode (start_pos > 0)"]
    D1["seqlen=1 个 token"]
    D2["KV cache 环形增量"]
    D3["Compressor 增量缓存,<br/>每 ratio 步压一组"]
    D4["Indexer 增量打分"]
    D5["sparse_attn 算 1 个 token"]
  end
  Prefill -.数学等价.-> Decode

这种”一份权重、两条 codepath、数学等价”是 V4 工程纪律的另一个体现——它意味着任何一个 layer 的内部计算都被精心设计成”prefill / decode 不出现数值漂移”。


2.9 一段 V3 → V4 的迁移练习

如果你已经熟悉 V3 的 MLA 实现,下面这个练习能让你”用最小修改”把 V3 的 attention 改造成 V4 风格的雏形:

# 假设你有一个 V3 attention 的简化版本
class V3Attention(nn.Module):
    def __init__(self, dim, n_heads, kv_lora_rank, head_dim):
        ...
        self.wkv_a = Linear(dim, kv_lora_rank)
        self.kv_norm = RMSNorm(kv_lora_rank)
        self.wkv_b = Linear(kv_lora_rank, n_heads * head_dim * 2)
        self.wo = Linear(n_heads * head_dim, dim)

# 改造为 V4 雏形(不含 Compressor/Indexer/sparse_attn)
class V4AttentionMinimal(nn.Module):
    def __init__(self, dim, n_heads, head_dim_v4=512, n_groups=16, o_lora_rank=1024):
        ...
        # 1. 砍掉 wkv_a / wkv_b,只留单矩阵 wkv,输出维度直接到 head_dim_v4
        self.wkv = Linear(dim, head_dim_v4)
        self.kv_norm = RMSNorm(head_dim_v4)
        # 2. 把 wo 替换为 grouped LoRA
        self.wo_a = Linear(n_heads * head_dim_v4 // n_groups, n_groups * o_lora_rank)
        self.wo_b = Linear(n_groups * o_lora_rank, dim)
        # 3. 加 attn_sink
        self.attn_sink = nn.Parameter(torch.zeros(n_heads))

    def forward(self, x):
        ...
        kv = self.kv_norm(self.wkv(x))           # [B, S, head_dim_v4]
        # KV 不再升维——所有 head 共享这一组 KV
        # ... attention ...
        o = ...                                  # [B, S, n_heads, head_dim_v4]
        # grouped O
        o = o.view(B, S, n_groups, -1)
        o = torch.einsum("bsgd,grd->bsgr", o, self.wo_a.weight.view(n_groups, o_lora_rank, -1))
        x = self.wo_b(o.flatten(2))
        return x

这个雏形只缺三件事:滑窗 KV cache、Compressor、Indexer + sparse_attn——分别是接下来三章的内容。


2.9·补 V4 attention 实现的三个工程坑

读 V4 的 Attention.forward 源码,至少有三个细节如果没注意会在自己的实现里翻车。这些都是任何人对照源码实现时都会撞到的——不需要内部资料。

坑一:unflatten 的维度顺序

V4 用了多次 q.unflatten(-1, (self.n_local_heads, self.head_dim))。这一步把最后一维拆成 (n_heads, head_dim)——但默认是先 n_heads 后 head_dim。如果你写成 unflatten(-1, (head_dim, n_heads)),整个 attention 的语义就反了。

V4 的源码在 unflatten 之后立刻接 apply_rotary_emb(q[..., -rd:], freqs_cis)——这步只对最后 64 维做 RoPE。如果 unflatten 维度顺序写错,RoPE 会作用到错误的维度上,模型 forward 不会报错但输出全乱。

坑二:滑窗的环形写入索引

V4 的滑窗 KV cache 在 prefill 时的写入逻辑:

if seqlen <= win:
    self.kv_cache[:bsz, :seqlen] = kv
else:
    cutoff = seqlen % win
    self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)

这段代码处理的是”prefill 完一段 prompt 后,window_size=128 的环形 buffer 状态”。如果直接 self.kv_cache[:, :win] = kv[:, -win:],当 seqlen > win 时确实保留了最后一窗的 KV,但环形索引是错的——decode 阶段写入第 (start_pos % win) 个位置时会覆盖错位置。

V4 的处理是:把最后一窗 KV 按”prefill 之后 next decode position 应该从哪开始”切成两段,分别写入 cache 的两段。这样 decode 阶段 kv_cache[:, start_pos % win] = new_kv 写入的位置就是正确的。

坑三:Indexer 的 offset 参数

V4 的 Indexer.forward(x, qr, start_pos, offset) 的第四个参数 offset 不太显眼,但极其关键:

# Attention.forward 中
if self.compress_ratio:
    offset = kv.size(1) if start_pos == 0 else win
    if self.indexer is not None:
        compress_topk_idxs = self.indexer(x, qr, start_pos, offset)

offset 是稀疏 KV 索引相对于 kv_cache 起始位置的偏移。在 prefill 时 offset = kv.size(1)(所有滑窗 KV 占据的位置);在 decode 时 offset = win(滑窗已固定占满前 win 个位置)。如果 offset 算错,sparse_attn 索引到的 KV 位置就是错的——读到滑窗 KV 当压缩 KV 用,输出就会乱。

这三个坑都是索引 / 顺序 / 偏移类的细节,没有任何源码注释专门说明,但它们是 V4 attention 正确工作的隐性合约。


2.10 动手实验:手算一个 V4 head_dim 的 KV cache 占用

# 给定参数(V4 Pro)
n_layers = 61
window_size = 128
head_dim = 512

# 假设所有层都是 ratio=4 的稀疏层(最 KV-密集的情况)
# n_tok = 1M
n_tok = 1_048_576
ratio = 4
kv_per_layer = (window_size + n_tok // ratio) * head_dim
total_kv_bytes = n_layers * kv_per_layer * 2  # BF16 = 2 bytes
print(f"V4 KV cache (1M, all-ratio-4): {total_kv_bytes / 1e9:.2f} GB")

# 改成混合 ratio(V4 Pro 实际配置:约 30 层 ratio=4, 30 层 ratio=128, 1 层 ratio=0)
sum_kv = 0
for ratio in [128, 128] + [4 if i % 2 == 0 else 128 for i in range(58)] + [0]:
    if ratio == 0:
        sum_kv += n_tok * head_dim   # 不压缩
    else:
        sum_kv += (window_size + n_tok // ratio) * head_dim
print(f"V4 KV cache (1M, mixed-ratio): {sum_kv * 2 / 1e9:.2f} GB")

跑完会发现”全 ratio=4”配置在 1M 下要 16+ GB,而 V4 实际混合 ratio 配置在 1M 下只要 2 GB 量级。per-layer 非均匀压缩是 V4 长上下文成本可控的真正秘诀。


2.10·补 与 vLLM 对接时的 KV cache 形状错位

V4 在 vLLM 等推理引擎里的对接难度,主要来自一个事实:V4 的 KV cache 形状不是 vLLM 默认的 PagedAttention 形状

vLLM 默认的 KV cache 是 [num_blocks, block_size, num_kv_heads, head_dim]——按 block_size(典型 16)分块、按 KV head 数分组。这种形状假设:

  • KV head 数固定且较小(典型 1 / 4 / 8)
  • 每个 KV head 维度固定(典型 64 / 128)
  • KV 不需要分层处理

V4 的 KV cache 形状是 [max_batch_size, window_size + max_seq_len // ratio, head_dim]——按”滑窗 + 压缩 KV”的两段拼接。这种形状假设:

  • 每层 KV head 数都是 1(MQA-style)
  • head_dim 极大(512)
  • KV 内部分两段:滑窗段 + 压缩段
  • ratio 逐层不同

把 V4 跑进 vLLM,至少要做三件事:

  1. 新增一个 V4 专属的 KV cache 类型(不能复用 PagedAttention 的 KVCache class)
  2. 在 BlockManager 里给”滑窗段”和”压缩段”分别分配 block——一个用环形覆盖,一个按 ratio 增长
  3. 接入 sparse_attn kernel——FlashMLA 的 V4 分支需要单独编译进 vLLM 的 wheel

第 19 章会展开这条工程接缝,并给出 vLLM 主仓库 V4 适配 PR 的详细解读。


2.10·补·补 attention 算子内部的精度链路

V4 在 attention 内部的精度切换非常密集。以一个典型的 ratio=4 层为例,逐 op 看精度:

x: BF16 [B, S, 7168]
  --wq_a (FP8)--> 
  q_lora: BF16 [B, S, 1536] (FP8 GEMM 在内部用 FP32 累加,输出回 BF16)
  --q_norm (FP32)-->
  q_normed: BF16 [B, S, 1536]
  --wq_b (FP8)-->
  q: BF16 [B, S, n_heads, 512]
  --inline rsqrt (FP32)-->
  q_normalized: BF16 [B, S, n_heads, 512]
  --apply_rotary_emb (FP32 复指数 + BF16 输出)-->
  q_with_rope: BF16

x ----wkv (FP8)--> kv: BF16 [B, S, 512]
       --kv_norm (FP32)--> kv_normed: BF16
       --apply_rotary_emb--> kv_rope: BF16
       --act_quant (FP8 e4m3 + ue8m0 scale)--> kv_quantized: FP8 [B, S, 448] + scale [B, S, 7]
       --(rope 部分留 BF16)

q · kv → score: FP32 [B, S, n_heads, K]   ← FP8 GEMM 内部 FP32 累加
softmax(score + sink): FP32
weighted sum → o: BF16 [B, S, n_heads, 512]

apply_rotary_emb(o[..., -64:], inverse=True): BF16

o.view + einsum(wo_a, FP8) → o_lora: BF16 [B, S, 16, 1024]
o.flatten + wo_b (FP8) → x_out: BF16 [B, S, 7168]

关键观察:

  • GEMM 输入是 FP8,输出是 BF16——这是 DeepGEMM 的”输入降精度、累加用 FP32、输出升 BF16” 模式
  • norm 操作内部用 FP32——RMSNorm 的 mean / rsqrt 必须 FP32 才稳定
  • softmax 和 rotary 都用 FP32——避免概率分布的尾部精度损失
  • 激活值默认 BF16,只有进入 GEMM 时才被即时量化到 FP8

这套精度链路的设计哲学是:“算的时候用 FP32 + FP8、存的时候用 BF16 + FP8”——计算密度高的 op 走低精度,激活值的中间存储走 BF16 保稳定。这与 V3 的精度链路有相似之处,但 V4 因为引入了 sparse_attn 和 grouped O,链路更复杂。第 12-14 章会全面展开。


2.10·延展 V4 attention 的长尾性能与可观测性

V4 的 attention 实现因为引入稀疏选取,性能特征与 dense attention 有本质差异——在生产中需要新的可观测性指标。

长尾现象一:稀疏命中率

Indexer 选 top-1024 的 KV 位置。如果命中的 KV 实际语义上不相关,attention 输出会偏向 attn_sink(接近”不输出”)。可观测指标:

  • 每层的 attn_sink 在 softmax 中占的比例
  • 当 attn_sink 占比突然升高时,说明 Indexer 选取出现”系统性失误”
  • 监控 attention output 的 norm 分布:如果某些 token 的输出 norm 显著低于群体均值,可能是稀疏选取失败的信号

长尾现象二:滑窗与压缩段的”接缝处”

每层的 KV cache 有两段——滑窗(最近 128 token)和压缩段(远距离)。当一个 token 从滑窗”溢出”进入压缩段,它的可见性会突然变化(从 dense 变成稀疏选取)。这个”接缝”在某些任务上会导致回答质量的不连续——例如总结任务里,模型可能”忽略” 刚刚被踢出滑窗的关键信息。

可观测指标:监控”刚溢出滑窗的 token”的 attention 权重,如果显著小于”还在滑窗里的 token”,说明 Compressor 在该层对这些 token 的语义压缩不足。

长尾现象三:per-layer ratio 的”熔点”

V4 的 compress_ratios 是预定义的,运行时不变。但不同 prompt 的 KV 分布差异很大——某些 prompt 在 ratio=4 的层就能保留所有关键信息,某些 prompt 在 ratio=128 的层会丢掉重要细节。

工程上没法运行时改 ratio(这会让权重不匹配),但可以通过 system prompt 工程间接调整:让 prompt 显式重复关键信息,确保它们出现在多个滑窗位置,绕过压缩段的语义损失。这是 V4 时代的 prompt engineering 新维度。

推荐的生产监控面板

如果你部署 V4 到生产,至少应该追踪以下 6 个指标:

  1. attn_sink_ratio(每层):sink 占 softmax 概率的比例
  2. window_to_compress_attention_ratio(每层):滑窗段获得的 attention 占比 vs 压缩段
  3. indexer_topk_overlap(每层):相邻 token 的 topk 选取重合度——重合度高说明稀疏选取稳定
  4. kv_compression_loss_proxy:用一个小的 reconstruction head 度量 Compressor 输出与原 KV 的差异
  5. mtp_consistency:MTP head 的预测与主 head 的一致率——两者都偏离表示 attention 出错
  6. end_to_end_perplexity_per_layer:每层去掉时的 perplexity 增量——找出”关键层”

这些指标在 V3 dense attention 时代不需要,但 V4 时代是生产部署的必备。


2.11 延伸阅读


2.12 本章小结

  • V4 抛弃 latent KV,把 head_dim 拉到 512——每个 head 自带完整的 KV 语义带宽
  • num_key_value_heads=1 让 V4 在结构上接近 MQA,但 head_dim 大到 512,单 token KV 占用与 V3 接近,少了 wkv_b 升维
  • grouped O 投影(n_groups=16, o_lora_rank=1024) 是 V4 的代数创新——以 1.6 倍参数换更强的局部 + 全局两段式表达力
  • attn_sink 是稀疏注意力的”兜底参数”——保证当 Indexer 找不到合适 KV 时数值不崩
  • Q 路径的 inline RMSNorm 是稀疏 attention 的精度保证,不可省
  • KV 路径的”位置 / 语义分离 → BF16 / FP8 分离”是 V4 整个混合精度策略的缩影

第 3 章我们进入 V4 注意力革命的第二站:Compressor——逐层独立配置的 KV 压缩模块,是怎么把 1M token 的 KV 压到几 GB 的。

评论 0