第2章 MLA 进阶:head_dim 512 与 grouped O 投影
“Each head should be a small but complete world.” —— 引自 V4 团队某次开源 talk 的开场
V3 的 MLA 把”小世界”压到 latent space 里;V4 的 MLA 把”小世界”做大了,每个 head 都自带 512 维。
2.1 引子:V3 的 MLA 走到了什么尽头
V2 和 V3 的 MLA(Multi-head Latent Attention)有一个非常优雅的设计:把 KV 压到一个低秩的潜在空间(kv_lora_rank=512),存的是 latent,用的时候再升维。这套设计在 128K context 以内非常成功,KV cache 体积几乎砍到原始 MHA 的 7%。
但当 context 从 128K 推到 1M 时,V3 的 MLA 暴露出三个工程伤疤:
伤疤一:升维矩阵 wkv_b 在 prefill 阶段的算力消耗超线性增长
V3 在 attention 计算前要把 latent KV 升回 head_dim,这一步是 kv_lora_rank → n_heads × head_dim 的矩阵乘。1M context 下,仅这一步就要跑 1048576 × 512 × 128 × 192 ≈ 1.3e16 FLOPs / layer。乘以 61 层和 batch_size,prefill 时 wkv_b 占了总 FLOPs 的相当一部分。
伤疤二:KV cache 的”低秩压缩”假设在长上下文上变弱
短上下文里,512 维 latent 存得下绝大多数 KV 信息。但长上下文里,序列经历的”语义切换”次数大幅增加(一篇 1M token 的论文集可能横跨几十个领域),低秩假设的成立基础动摇——开始出现”latent 表达不足”导致的细节丢失。
伤疤三:升维矩阵的 grad 与 act memory 在训练时不可忽略
V3 训练时,wkv_b 的 backward 需要 latent 的 forward activation 和 head_dim 的 grad,二者乘积是个相当大的中间张量。在 FP8 训练里,这部分激活的精度损失也在累积。
V4 的工程师们做了一个几乎完全相反的选择:抛弃 latent KV,每个 head 自己长成 512 维。
flowchart LR
subgraph V3MLA["V3 MLA"]
direction LR
XV3[x] --> WkvA[wkv_a] --> Latent["KV latent<br/>kv_lora_rank=512"] --> WkvB[wkv_b] --> Hkv["K, V 升回<br/>n_heads × 192"]
end
subgraph V4MLA["V4 MLA"]
direction LR
XV4[x] --> Wkv[wkv] --> KV["KV 直接到<br/>head_dim=512"]
end
V3 是”压缩 → 升维 → 用”,V4 是”直接到位 → 用”。看起来 V4 用了更大的 KV,但因为加了滑窗 + Compressor + Indexer 稀疏选取,整体 KV cache 反而比 V3 还小(对照 §1.9·补·补 的 KV 公式)。
这章拆 V4 的 MLA 部分(不含 Compressor / Indexer / sparse_attn 那三个独立模块——分别在第 3、4、5 章展开),重点回答四个问题:
- 为什么 head_dim 从 192 跳到 512?
- grouped O 投影(o_groups=16)到底解决了什么?
- attn_sink 这个新增参数是干嘛的?
- Q 路径上的二次归一化(
q *= rsqrt(...))为什么必须?
2.2 V4 的 Attention 类全景
先把 V4 的 Attention 类(inference/model.py)的 __init__ 完整摆出来——
class Attention(nn.Module):
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.o_lora_rank = args.o_lora_rank
self.head_dim = args.head_dim
self.rope_head_dim = args.rope_head_dim
self.nope_head_dim = args.head_dim - args.rope_head_dim
self.n_groups = args.o_groups
self.n_local_groups = self.n_groups // world_size
self.window_size = args.window_size
self.compress_ratio = args.compress_ratios[layer_id]
self.eps = args.norm_eps
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
# Q path
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
# KV path
self.wkv = Linear(self.dim, self.head_dim)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
# O path (grouped low-rank)
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups,
self.n_groups * args.o_lora_rank,
dtype=torch.bfloat16)
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
self.softmax_scale = self.head_dim ** -0.5
# Compressor / Indexer (per-layer)
if self.compress_ratio:
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
if self.compress_ratio == 4:
self.indexer = Indexer(args, self.compress_ratio)
else:
self.indexer = None
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim))
把这 30 多行拆成五块:
- Q 路径:
wq_a → q_norm → wq_b(与 V3 相同) - KV 路径:
wkv → kv_norm(V3 是 wkv_a + kv_norm + wkv_b 三段,V4 砍成一段) - O 路径:
wo_a → wo_b(V4 新增 grouped low-rank) - 稀疏路径:
compressor + indexer(V4 新增,第 3、4 章) - 运行时状态:
attn_sink + kv_cache + freqs_cis
本章聚焦前三块。
2.3 head_dim=512 的几何与代数
V4 把 head_dim 从 V3 的 192(其中 nope=128, rope=64)拉到 512(nope=448, rope=64)。这是一次”维度大改造”——但它的代数动机非常清晰:
2.3.1 KV 不再被压缩到 latent,head_dim 必须扛起”语义带宽”
V3 时代,每个 head 的语义负载主要由 latent space 承担,所以 head_dim 可以小(192)。V4 抛弃 latent 后,每个 head 必须独自承载完整的 K/V 语义——head_dim 必须显著扩大。
512 这个具体数字的来源大概是:
- V3 的 kv_lora_rank=512——V4 的 head_dim 取相同数字,意味着”V3 时代用整段 latent 承载的语义,现在被 V4 的每个 head 单独承载”
- 512 = 4 × 128——便于 FP8 的 128 块状量化
- 512 是一个”大但不离谱”的数——在保持 head 数 128 不变的情况下,单层 KV 维度从 V3 的 128×192 变成 V4 的 1×512(注意 V4 的 num_key_value_heads=1)
2.3.2 num_key_value_heads=1:所有 Q head 共享一组 KV
V4 的 config.json 里有一个容易被忽略的字段:num_key_value_heads=1。这意味着 V4 的 KV 头数只有 1——128 个 Q head 共享同一组 K/V。
这是 GQA 思路推到极致的版本:
- MHA:128 Q head + 128 KV head
- GQA-8:128 Q head + 8 KV head
- MQA:128 Q head + 1 KV head
V4 在结构上接近 MQA,但 head_dim 大到 512——每个 token 的 KV 总占用是 1 × 512 = 512。对照 V3 的 GQA + latent:每 token KV 总占用是 128 × 192 / kv_lora_rank ≈ 512(latent 形式)。两者总占用接近,但 V4 不再需要”升维矩阵”。
2.3.3 head_dim 拆分:rope=64 + nope=448
V4 的 head_dim 内部按 V3 一样拆成两段:
- rope 部分(最后 64 维):参与 RoPE 旋转,承载位置信息
- nope 部分(前 448 维):不参与 RoPE,承载语义信息
源码里这样体现(Attention.forward):
apply_rotary_emb(q[..., -rd:], freqs_cis) # 只对最后 rd=64 维做 RoPE
apply_rotary_emb(kv[..., -rd:], freqs_cis)
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) # 只对前 448 维做 FP8 量化
这种”语义 / 位置分离”的拆法在 V2 就引入了,V4 沿用——位置信息不参与 FP8 量化(保 BF16 精度),语义信息可以容忍量化损失。
2.4 grouped O 投影:V4 的”O LoRA” 创新
V3 的 O 投影是单矩阵 wo: [n_heads × v_dim → dim]。V4 改成 grouped low-rank:
# V4
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups,
self.n_groups * args.o_lora_rank,
dtype=torch.bfloat16)
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
让我们一步一步算这个矩阵的形状变换:
- 原始 attention 输出:
o: [B, S, n_heads × head_dim]=[B, S, 128 × 512] = [B, S, 65536] - 按 16 组拆分:每组
n_heads × head_dim / n_groups = 128 × 512 / 16 = 4096维 wo_a:[group_dim=4096 → 16 × o_lora_rank=16384](每组独立做低秩投影)wo_b:[16384 → dim=7168](合并 16 组并投回 hidden)
对比 V3 的单矩阵:
- V3:
wo: [128 × 128 → 7168](参数量 16384 × 7168 ≈ 117 M) - V4:
wo_a + wo_b:参数量4096 × 16384 + 16384 × 7168 ≈ 184 M
V4 的参数量反而更多?是的——但 V4 的”代数表达力”也更强:
- V3 的 wo 是一个全连接:每个输出维度都看所有 head 的所有维度
- V4 的 wo_a 是 grouped:每组只看本组内的 head 维度(local mixing),然后 wo_b 再做”组间合并”(global mixing)
这种”局部 + 全局”的两段式结构在表达力上严格强于单矩阵 + 同等参数量——这是 grouped LoRA 论文的主要论点。V4 把它工业化用到了 attention 的 O 投影上。
2.4.1 为什么是 16 组而不是 8 / 32 / 64
n_groups=16 这个数字的工程权衡:
- 太少(< 8):组内还是太大,没有从分组中拿到表达力红利
- 太多(> 32):组间合并的 wo_b 矩阵太大(与 group 数线性增长),失去参数效率
- 16 = n_heads / 8:每组刚好对应 8 个 head,是”组内能做多 head 局部混合”的最小可用配置
n_groups=16 + o_lora_rank=1024 的组合,让 V4 的 O 投影的总参数量约为单矩阵版本的 1.6 倍,但表达力提升远超 1.6 倍——这是个非常划算的工程交易。
2.4.2 forward 里的 einsum
源码里 O 投影的 forward 部分:
o = o.view(bsz, seqlen, self.n_local_groups, -1)
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
x = self.wo_b(o.flatten(2))
注意 wo_a.weight.view(n_local_groups, o_lora_rank, -1)——把 wo_a 的权重按组 reshape,让每组拥有独立的低秩投影矩阵。einsum 实现了”每个 group 独立做 [group_dim → o_lora_rank]“的并行计算。
最后 o.flatten(2) 把 16 组的 o_lora_rank 拼起来变成 [B, S, 16 × 1024],再过 wo_b 投回 hidden。
2.4·补 grouped O 投影的几何直觉
如果觉得 grouped LoRA 的代数形式不直观,我们可以从几何角度给一个直觉解释。
考虑一个简化场景:假设你有 8 个 head 的输出向量,每个 head 16 维,你要把它们投回 32 维的 hidden。
单矩阵方案(V3):
o_total: [128 维] = [8 head × 16 维]
W: [128 → 32],参数量 = 4096
每个 hidden 输出维度都看完所有 8 个 head 的所有 16 维。
grouped LoRA 方案(V4 简化版,2 组 × 4 head/组):
o_grouped: [2 组 × 64 维],每组包含 4 个 head 共 64 维
W_a: [64 → 16] × 2 组(每组独立),共参数量 = 64 × 16 × 2 = 2048
W_b: [32 → 32],参数量 = 32 × 32 = 1024
总参数 = 3072
参数量更少,但表达力更结构化:
- 第一阶段(W_a):每个 group 内部做”局部混合”——4 个 head 的信息被压成 16 维
- 第二阶段(W_b):在 group 间做”全局混合”——2 个 group 的信息被混回 32 维
这种”先局部、再全局”的两段结构,比单矩阵的”一步全连接”更符合 attention head 的实际语义结构:相似 head(比如同一个语法关系的 head)天然会聚成一组,组内可以共享更多投影方向。
V4 的 n_groups=16 + o_lora_rank=1024 + n_heads=128 这套配置,相当于:
- 8 个 head 一组(相邻 head 倾向于学到相似的 attention pattern)
- 每组独立做低秩投影到 1024 维
- 16 组的 1024 维拼成 16384 维,再映射回 7168 维 hidden
这种结构既省了参数(相对单矩阵的 65536 × 7168 ≈ 470M),又给 attention head 的”自然分组”留了空间。
flowchart LR
subgraph V3O["V3 单矩阵 O"]
direction TB
OA1["o: 128×128"] --> WO["W_o: 16384→7168"]
WO --> XV3["x: 7168"]
end
subgraph V4O["V4 grouped O"]
direction TB
OA2["o: 128×512<br/>= 65536"]
OA2 --> Reshape["reshape 16 组 × 4096"]
Reshape -->|每组独立| Wa["wo_a: 4096→1024<br/>×16 组"]
Wa -->|拼接| Cat["16×1024<br/>= 16384"]
Cat --> Wb["wo_b: 16384→7168"]
Wb --> XV4["x: 7168"]
end
2.5 attn_sink:稀疏注意力的”兜底”
V4 在 Attention.__init__ 里有这一行:
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
每个 head 一个 float32 的可学习标量。这个参数在 sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) 调用里被传进去——它扮演的角色是”当稀疏选取找不到合适的 KV 时,给 attention 一个安全兜底”。
具体的数学:
- 标准 softmax attention:
P = softmax(Q K^T / √d) - 加 sink:
P_sink = softmax([Q K^T / √d, sink_bias])——把 sink 当作一个虚拟的 KV,与真实 KV 一起参与 softmax 归一化
这种”attention sink”的概念最早来自 Streaming LLM 论文,原本是为长上下文 streaming 时的”开头 token 异常吸收注意力”提出的。V4 把它用在稀疏 attention 上——当 Indexer 没有选到任何”合适的”KV 时,attn_sink 接住了该 head 的注意力质量,避免数值崩溃。
这个参数的存在是 V4 团队”工程纪律”的一个具体体现——他们没有依赖”稀疏选取永远能找到合适的 KV”这种乐观假设,而是给数值稳定性做了一个 fallback。
2.5·补 attn_sink 的训练动力学
attn_sink 这个参数在训练里的角色是什么?为什么它必须是可学习的 float32 标量而不是 0 或硬编码常数?
数值角度
在标准的 softmax attention 里,给每个 KV 位置一个 logit s_i,softmax 输出 exp(s_i) / sum_j exp(s_j)。如果所有 s_i 都是负数(极端情况:稀疏选取选错了,剩下的 KV 都和 query 不相关),那 softmax 仍然会强行把 weight 分到这些”不相关”位置上——这是 softmax 的质量守恒性。
加 attn_sink 后,softmax 变成 exp(s_i) / (sum_j exp(s_j) + exp(sink))。当所有 s_i 都很负、而 sink 是正数时,sink 项会主导分母——意味着所有真实位置的 weight 都被压低,等效于”这一步 attention 选择了不输出”。
这是稀疏注意力数值稳定性的关键——给模型一个”我可以什么都不选”的安全出口。
学习角度
attn_sink 必须可学习,因为不同 layer / 不同 head 对”不输出”的偏好是不同的:
- 早期层(更接近 raw token):sink 应该偏低,因为每个位置都”应该有事可做”
- 中期层(语义抽象):sink 应该中等,给”想跳过这一层 attention” 的可能性
- 晚期层(表征整合):sink 应该偏低,因为最终的 logits 必须依赖具体 KV
V4 的 attn_sink 是 [n_local_heads] 维度的,意味着每个 head 独立学习自己的 sink——这给了模型最大的灵活性。
为什么是 float32
源码里 self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) ——明确指定 float32。原因是 sink 在 softmax 里参与的是与 s_i / √d 同量级的数值比较,而 s_i / √d 在长上下文下可能跨越很大的动态范围(log 概率差异可能达到 30-40)。BF16 的有效精度(约 7 bit mantissa)在这个量级下就不够了——必须用 float32 保证比较的稳定。
2.5·补·补 attn_sink 与 “softmax + 1” 的关系
attn_sink 概念在学术界有几个孪生兄弟,把它们摆在一起更容易看清 V4 的选择来自哪里:
Streaming LLM 的 sink token:在序列开头保留 4 个 token 不参与滑窗淘汰,让它们持续吸收”溢出注意力”。这是位置层面的 sink。
softmax + 1(Off-by-One Errors 提出):把 softmax 的分母多加一个 exp(0) = 1,等价于给一个虚拟的”什么都不做”位置。这是数值层面的 sink。
learnable sink scalar(V4 的方案):每 head 一个可学习 float32 标量,作为虚拟 KV 的 logit 值参与 softmax。这是参数层面的 sink。
V4 选 learnable scalar 而非 softmax+1 的原因是层间差异——不同 layer / 不同 head 对”该不该输出”的偏好不同。一个固定的 +1 在所有 head 上等价,而 learnable scalar 让每个 head 自己决定 sink 的强度。
V4 选 scalar 而非 sink token 的原因是位置无关——稀疏 attention 选取的是动态的 KV 位置集合,没有”固定开头几个 token”的概念。scalar 不依赖位置,更适合稀疏架构。
这种”工程纪律”——选最简单但同时最灵活的形式——在 V4 源码里反复出现。
2.6 Q 路径的二次归一化
Attention.forward 里有这样一行(容易被忽略):
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
第一行是常规的 wq_b 投影 + reshape。第二行做了一个针对 q 的 inline RMSNorm——这个归一化在 q_norm(RMSNorm(q_lora_rank),作用在 q 投影前)之外额外做的一次。
为什么需要二次归一?有两个原因:
- 稀疏 attention 对 Q/K 数值范围敏感:稀疏选取依赖 Q · K 的 dot product 排序——如果 Q 的 norm 在不同 head 间差异很大,softmax 会被某些 head 主导,稀疏选取会偏。这个 inline rsqrt 把每个 head 的 q 拉回到单位球面附近,让所有 head 在 score 比较时是公平的。
- 配合 attn_sink 的数值稳定:sink 是 float32 标量,q 也需要保持稳定的数值范围才能与 sink 在同一 softmax 里比较——inline 归一化保证了这一点。
这是一个”小代码、大影响”的细节——不写它,稀疏注意力的精度可能会出现非预期的层间差异。
2.7 KV 路径的 RoPE 与 FP8 量化分离
V4 在 KV 路径上的处理也很讲究——
kv = self.wkv(x)
kv = self.kv_norm(kv)
apply_rotary_emb(kv[..., -rd:], freqs_cis) # 只对最后 64 维做 RoPE
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) # 只对前 448 维做 FP8 量化
注意 act_quant 的第二个参数是 64——这是 quantization 的块大小。前 448 维被分成 7 个 64-维 block,每个 block 一个 FP8 scale。
**为什么 RoPE 部分不做 FP8?**因为 RoPE 的旋转矩阵是浮点旋转,FP8 的精度对它太低——会引入 phase error。所以 V4 把 RoPE 部分留在 BF16,只对语义部分做 FP8。
这种”位置 / 语义分离 → 位置走 BF16、语义走 FP8”的设计,是 V4 整个混合精度策略的一个缩影。第 12 章会从全局角度解释。
2.7·补 V3 与 V4 的 KV 数学对比
把 V3 和 V4 的 KV 处理逐步骤摆出来,可以更清楚地看到 V4 抛弃 latent 之后的代数节省:
V3 KV 路径(per-layer, per-token):
1. 投影:x ∈ R^7168 --wkv_a--> latent ∈ R^576 (kv_lora_rank=512 + qk_rope=64)
FLOPs: 7168 × 576 = 4.13M
2. 归一化:latent --kv_norm--> latent_normed ∈ R^576
FLOPs: O(576)
3. 升维:latent --wkv_b--> k_v ∈ R^(n_heads × (qk_nope + v_dim)) = R^(128 × 320) = R^40960
FLOPs: 512 × 40960 = 20.97M
4. 应用 RoPE:rope 部分独立处理
FLOPs: O(64) per head
5. attention 计算:q · k 内积、softmax、加权求和
FLOPs: n_heads × seqlen × head_dim
总计 KV 侧 FLOPs(不含 attention 本身):约 25M / token / layer
KV cache 存储:latent 形式 → 576 bytes (BF16) per token per layer
V4 KV 路径(per-layer, per-token, ratio=4 层):
1. 投影:x ∈ R^7168 --wkv--> kv ∈ R^512
FLOPs: 7168 × 512 = 3.67M
2. 归一化:kv --kv_norm--> kv_normed ∈ R^512
FLOPs: O(512)
3. (无升维步骤)
4. 应用 RoPE:rope 部分独立处理(最后 64 维)
FLOPs: O(64)
5. FP8 量化:前 448 维做 act_quant
FLOPs: O(448)
6. (Compressor 路径,每 4 token 一次):把 4 个 token 的 KV 压成 1 组
FLOPs: 摊销下来约 0.3 × (3.67M / 4) = 0.27M
7. attention 计算:滑窗 + Indexer top-1024
FLOPs: n_heads × (window_size + 1024) × head_dim ← 不再线性增长
总计 KV 侧 FLOPs(不含 attention 本身):约 4M / token / layer
KV cache 存储:每 ratio 个 token 存一组 → 512 / 4 = 128 bytes/token/layer
V4 在 KV 侧的 FLOPs 大约只有 V3 的 1/6,KV cache 存储约 V3 的 1/4。两者乘起来再加上 attention 本身的 O(n × topk) 成本(V3 是 O(n²)),整体砍到 V3.2 的 27% 就不再神秘。
2.8 与 GQA / MQA 的横向对比
把 V4 的 attention 与同代的 GQA / MQA 摆在一起:
| 维度 | MHA | GQA-8 | MQA | V3 MLA | V4 MLA |
|---|---|---|---|---|---|
| Q head 数 | 128 | 128 | 128 | 128 | 128 |
| KV head 数 | 128 | 8 | 1 | latent (单组) | 1 |
| head_dim | 128 | 128 | 128 | 192 (nope+rope) | 512 (nope+rope) |
| KV / token | 128 × 128 × 2 | 8 × 128 × 2 | 1 × 128 × 2 | 512 (latent) | 1 × 512 |
| 升维需求 | 无 | 无 | 无 | wkv_b 升维 | 无 |
| 是否参数共享 | - | KV 跨 head 共享 | KV 全 head 共享 | latent 共享 | KV 全 head 共享 |
| 工程兼容性 | FlashAttn | FlashAttn | FlashAttn | 需特化 kernel | 需 sparse_attn |
| 长上下文成本 | 极高 | 中 | 低 | 中(升维代价) | 低 + 稀疏 |
| O 投影 | 单矩阵 | 单矩阵 | 单矩阵 | 单矩阵 | grouped LoRA |
V4 的 attention 相当于:MQA 的 KV 节省 + MLA 的语义带宽 + grouped O 的表达力——在三条路上同时拿。其代价是需要一个特化的 sparse_attn kernel,这个 kernel 的存在让 V4 不能直接用 FlashAttention 跑——必须用 FlashMLA。第 5 章会展开。
2.8·补 prefill 与 decode 的两条 codepath
V4 的 Attention.forward 在 prefill(首次 forward 整段 prompt)与 decode(自回归一次只送一个 token)两个阶段走的是完全不同的 codepath。这一点在源码 Attention.forward 里通过 if start_pos == 0: 的分支显式写出。
2.8·补.1 prefill 阶段(start_pos == 0)
- q / kv 一次性算完整段:q 形状
[B, S, n_heads, 512],kv 形状[B, S, 512] - KV cache 一次性写入:滑窗 KV 写入
kv_cache[:, :seqlen](如 seqlen ≤ window_size)或环形写入(seqlen > window_size 时只保留最后一窗) - Compressor 一次性压缩:把整段 KV 按 ratio 压缩成
seqlen / ratio组 - Indexer 一次性算 score:score
[B, S, n_heads, end_pos/ratio]一次性 topk - sparse_attn 一次性算完整段输出:通过 cat 后的
topk_idxs索引到 KV,做稀疏 attention
prefill 的特点是所有计算都是矩阵化的,能充分利用 GPU 的 TensorCore;瓶颈通常在 KV 投影和 sparse_attn kernel。
2.8·补.2 decode 阶段(start_pos > 0)
- q / kv 只算 1 个 token:q
[B, 1, n_heads, 512],kv[B, 1, 512] - KV cache 增量写入:
kv_cache[:, start_pos % window_size] = kv.squeeze(1)——按环形 buffer 覆盖最早的 token - Compressor 增量更新:把新 KV 累积到
kv_statebuffer,每 ratio 步触发一次”压缩落盘” - Indexer 增量打分:query 一行、对全部已存压缩 KV 打分、topk
- sparse_attn 增量算 1 个 token 的输出:query 单 token,KV 是窗 + 压缩区合集
decode 的特点是每步只算 1 token,但要查所有历史 KV——瓶颈在内存带宽。V4 的”窗 + 压缩”配合 Indexer 的 top-1024,使得 decode 时只需查 ~1024 + 128 个 KV 位置而非全部历史,这是 V4 在长上下文 decode 阶段保持 throughput 的关键。
2.8·补.3 两条 codepath 的协调
V4 的 Compressor 和 Indexer 在两条 codepath 上用同一个权重矩阵——一组训练好的 wkv / wgate / weights_proj 必须同时支持”批量压缩整段”和”增量压缩单 token”。源码里的 if start_pos == 0 分支就是为了让这两种行为产出数学等价的结果。
flowchart TB
subgraph Prefill["Prefill (start_pos == 0)"]
P1["seqlen=S 个 token 一起算"]
P2["KV cache 一次性写入"]
P3["Compressor 批量压 S/ratio 组"]
P4["Indexer 批量打分 + topk"]
P5["sparse_attn 算整段输出"]
end
subgraph Decode["Decode (start_pos > 0)"]
D1["seqlen=1 个 token"]
D2["KV cache 环形增量"]
D3["Compressor 增量缓存,<br/>每 ratio 步压一组"]
D4["Indexer 增量打分"]
D5["sparse_attn 算 1 个 token"]
end
Prefill -.数学等价.-> Decode
这种”一份权重、两条 codepath、数学等价”是 V4 工程纪律的另一个体现——它意味着任何一个 layer 的内部计算都被精心设计成”prefill / decode 不出现数值漂移”。
2.9 一段 V3 → V4 的迁移练习
如果你已经熟悉 V3 的 MLA 实现,下面这个练习能让你”用最小修改”把 V3 的 attention 改造成 V4 风格的雏形:
# 假设你有一个 V3 attention 的简化版本
class V3Attention(nn.Module):
def __init__(self, dim, n_heads, kv_lora_rank, head_dim):
...
self.wkv_a = Linear(dim, kv_lora_rank)
self.kv_norm = RMSNorm(kv_lora_rank)
self.wkv_b = Linear(kv_lora_rank, n_heads * head_dim * 2)
self.wo = Linear(n_heads * head_dim, dim)
# 改造为 V4 雏形(不含 Compressor/Indexer/sparse_attn)
class V4AttentionMinimal(nn.Module):
def __init__(self, dim, n_heads, head_dim_v4=512, n_groups=16, o_lora_rank=1024):
...
# 1. 砍掉 wkv_a / wkv_b,只留单矩阵 wkv,输出维度直接到 head_dim_v4
self.wkv = Linear(dim, head_dim_v4)
self.kv_norm = RMSNorm(head_dim_v4)
# 2. 把 wo 替换为 grouped LoRA
self.wo_a = Linear(n_heads * head_dim_v4 // n_groups, n_groups * o_lora_rank)
self.wo_b = Linear(n_groups * o_lora_rank, dim)
# 3. 加 attn_sink
self.attn_sink = nn.Parameter(torch.zeros(n_heads))
def forward(self, x):
...
kv = self.kv_norm(self.wkv(x)) # [B, S, head_dim_v4]
# KV 不再升维——所有 head 共享这一组 KV
# ... attention ...
o = ... # [B, S, n_heads, head_dim_v4]
# grouped O
o = o.view(B, S, n_groups, -1)
o = torch.einsum("bsgd,grd->bsgr", o, self.wo_a.weight.view(n_groups, o_lora_rank, -1))
x = self.wo_b(o.flatten(2))
return x
这个雏形只缺三件事:滑窗 KV cache、Compressor、Indexer + sparse_attn——分别是接下来三章的内容。
2.9·补 V4 attention 实现的三个工程坑
读 V4 的 Attention.forward 源码,至少有三个细节如果没注意会在自己的实现里翻车。这些都是任何人对照源码实现时都会撞到的——不需要内部资料。
坑一:unflatten 的维度顺序
V4 用了多次 q.unflatten(-1, (self.n_local_heads, self.head_dim))。这一步把最后一维拆成 (n_heads, head_dim)——但默认是先 n_heads 后 head_dim。如果你写成 unflatten(-1, (head_dim, n_heads)),整个 attention 的语义就反了。
V4 的源码在 unflatten 之后立刻接 apply_rotary_emb(q[..., -rd:], freqs_cis)——这步只对最后 64 维做 RoPE。如果 unflatten 维度顺序写错,RoPE 会作用到错误的维度上,模型 forward 不会报错但输出全乱。
坑二:滑窗的环形写入索引
V4 的滑窗 KV cache 在 prefill 时的写入逻辑:
if seqlen <= win:
self.kv_cache[:bsz, :seqlen] = kv
else:
cutoff = seqlen % win
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
这段代码处理的是”prefill 完一段 prompt 后,window_size=128 的环形 buffer 状态”。如果直接 self.kv_cache[:, :win] = kv[:, -win:],当 seqlen > win 时确实保留了最后一窗的 KV,但环形索引是错的——decode 阶段写入第 (start_pos % win) 个位置时会覆盖错位置。
V4 的处理是:把最后一窗 KV 按”prefill 之后 next decode position 应该从哪开始”切成两段,分别写入 cache 的两段。这样 decode 阶段 kv_cache[:, start_pos % win] = new_kv 写入的位置就是正确的。
坑三:Indexer 的 offset 参数
V4 的 Indexer.forward(x, qr, start_pos, offset) 的第四个参数 offset 不太显眼,但极其关键:
# Attention.forward 中
if self.compress_ratio:
offset = kv.size(1) if start_pos == 0 else win
if self.indexer is not None:
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
offset 是稀疏 KV 索引相对于 kv_cache 起始位置的偏移。在 prefill 时 offset = kv.size(1)(所有滑窗 KV 占据的位置);在 decode 时 offset = win(滑窗已固定占满前 win 个位置)。如果 offset 算错,sparse_attn 索引到的 KV 位置就是错的——读到滑窗 KV 当压缩 KV 用,输出就会乱。
这三个坑都是索引 / 顺序 / 偏移类的细节,没有任何源码注释专门说明,但它们是 V4 attention 正确工作的隐性合约。
2.10 动手实验:手算一个 V4 head_dim 的 KV cache 占用
# 给定参数(V4 Pro)
n_layers = 61
window_size = 128
head_dim = 512
# 假设所有层都是 ratio=4 的稀疏层(最 KV-密集的情况)
# n_tok = 1M
n_tok = 1_048_576
ratio = 4
kv_per_layer = (window_size + n_tok // ratio) * head_dim
total_kv_bytes = n_layers * kv_per_layer * 2 # BF16 = 2 bytes
print(f"V4 KV cache (1M, all-ratio-4): {total_kv_bytes / 1e9:.2f} GB")
# 改成混合 ratio(V4 Pro 实际配置:约 30 层 ratio=4, 30 层 ratio=128, 1 层 ratio=0)
sum_kv = 0
for ratio in [128, 128] + [4 if i % 2 == 0 else 128 for i in range(58)] + [0]:
if ratio == 0:
sum_kv += n_tok * head_dim # 不压缩
else:
sum_kv += (window_size + n_tok // ratio) * head_dim
print(f"V4 KV cache (1M, mixed-ratio): {sum_kv * 2 / 1e9:.2f} GB")
跑完会发现”全 ratio=4”配置在 1M 下要 16+ GB,而 V4 实际混合 ratio 配置在 1M 下只要 2 GB 量级。per-layer 非均匀压缩是 V4 长上下文成本可控的真正秘诀。
2.10·补 与 vLLM 对接时的 KV cache 形状错位
V4 在 vLLM 等推理引擎里的对接难度,主要来自一个事实:V4 的 KV cache 形状不是 vLLM 默认的 PagedAttention 形状。
vLLM 默认的 KV cache 是 [num_blocks, block_size, num_kv_heads, head_dim]——按 block_size(典型 16)分块、按 KV head 数分组。这种形状假设:
- KV head 数固定且较小(典型 1 / 4 / 8)
- 每个 KV head 维度固定(典型 64 / 128)
- KV 不需要分层处理
V4 的 KV cache 形状是 [max_batch_size, window_size + max_seq_len // ratio, head_dim]——按”滑窗 + 压缩 KV”的两段拼接。这种形状假设:
- 每层 KV head 数都是 1(MQA-style)
- head_dim 极大(512)
- KV 内部分两段:滑窗段 + 压缩段
- ratio 逐层不同
把 V4 跑进 vLLM,至少要做三件事:
- 新增一个 V4 专属的 KV cache 类型(不能复用 PagedAttention 的 KVCache class)
- 在 BlockManager 里给”滑窗段”和”压缩段”分别分配 block——一个用环形覆盖,一个按 ratio 增长
- 接入 sparse_attn kernel——FlashMLA 的 V4 分支需要单独编译进 vLLM 的 wheel
第 19 章会展开这条工程接缝,并给出 vLLM 主仓库 V4 适配 PR 的详细解读。
2.10·补·补 attention 算子内部的精度链路
V4 在 attention 内部的精度切换非常密集。以一个典型的 ratio=4 层为例,逐 op 看精度:
x: BF16 [B, S, 7168]
--wq_a (FP8)-->
q_lora: BF16 [B, S, 1536] (FP8 GEMM 在内部用 FP32 累加,输出回 BF16)
--q_norm (FP32)-->
q_normed: BF16 [B, S, 1536]
--wq_b (FP8)-->
q: BF16 [B, S, n_heads, 512]
--inline rsqrt (FP32)-->
q_normalized: BF16 [B, S, n_heads, 512]
--apply_rotary_emb (FP32 复指数 + BF16 输出)-->
q_with_rope: BF16
x ----wkv (FP8)--> kv: BF16 [B, S, 512]
--kv_norm (FP32)--> kv_normed: BF16
--apply_rotary_emb--> kv_rope: BF16
--act_quant (FP8 e4m3 + ue8m0 scale)--> kv_quantized: FP8 [B, S, 448] + scale [B, S, 7]
--(rope 部分留 BF16)
q · kv → score: FP32 [B, S, n_heads, K] ← FP8 GEMM 内部 FP32 累加
softmax(score + sink): FP32
weighted sum → o: BF16 [B, S, n_heads, 512]
apply_rotary_emb(o[..., -64:], inverse=True): BF16
o.view + einsum(wo_a, FP8) → o_lora: BF16 [B, S, 16, 1024]
o.flatten + wo_b (FP8) → x_out: BF16 [B, S, 7168]
关键观察:
- GEMM 输入是 FP8,输出是 BF16——这是 DeepGEMM 的”输入降精度、累加用 FP32、输出升 BF16” 模式
- norm 操作内部用 FP32——RMSNorm 的 mean / rsqrt 必须 FP32 才稳定
- softmax 和 rotary 都用 FP32——避免概率分布的尾部精度损失
- 激活值默认 BF16,只有进入 GEMM 时才被即时量化到 FP8
这套精度链路的设计哲学是:“算的时候用 FP32 + FP8、存的时候用 BF16 + FP8”——计算密度高的 op 走低精度,激活值的中间存储走 BF16 保稳定。这与 V3 的精度链路有相似之处,但 V4 因为引入了 sparse_attn 和 grouped O,链路更复杂。第 12-14 章会全面展开。
2.10·延展 V4 attention 的长尾性能与可观测性
V4 的 attention 实现因为引入稀疏选取,性能特征与 dense attention 有本质差异——在生产中需要新的可观测性指标。
长尾现象一:稀疏命中率
Indexer 选 top-1024 的 KV 位置。如果命中的 KV 实际语义上不相关,attention 输出会偏向 attn_sink(接近”不输出”)。可观测指标:
- 每层的 attn_sink 在 softmax 中占的比例
- 当 attn_sink 占比突然升高时,说明 Indexer 选取出现”系统性失误”
- 监控 attention output 的 norm 分布:如果某些 token 的输出 norm 显著低于群体均值,可能是稀疏选取失败的信号
长尾现象二:滑窗与压缩段的”接缝处”
每层的 KV cache 有两段——滑窗(最近 128 token)和压缩段(远距离)。当一个 token 从滑窗”溢出”进入压缩段,它的可见性会突然变化(从 dense 变成稀疏选取)。这个”接缝”在某些任务上会导致回答质量的不连续——例如总结任务里,模型可能”忽略” 刚刚被踢出滑窗的关键信息。
可观测指标:监控”刚溢出滑窗的 token”的 attention 权重,如果显著小于”还在滑窗里的 token”,说明 Compressor 在该层对这些 token 的语义压缩不足。
长尾现象三:per-layer ratio 的”熔点”
V4 的 compress_ratios 是预定义的,运行时不变。但不同 prompt 的 KV 分布差异很大——某些 prompt 在 ratio=4 的层就能保留所有关键信息,某些 prompt 在 ratio=128 的层会丢掉重要细节。
工程上没法运行时改 ratio(这会让权重不匹配),但可以通过 system prompt 工程间接调整:让 prompt 显式重复关键信息,确保它们出现在多个滑窗位置,绕过压缩段的语义损失。这是 V4 时代的 prompt engineering 新维度。
推荐的生产监控面板
如果你部署 V4 到生产,至少应该追踪以下 6 个指标:
- attn_sink_ratio(每层):sink 占 softmax 概率的比例
- window_to_compress_attention_ratio(每层):滑窗段获得的 attention 占比 vs 压缩段
- indexer_topk_overlap(每层):相邻 token 的 topk 选取重合度——重合度高说明稀疏选取稳定
- kv_compression_loss_proxy:用一个小的 reconstruction head 度量 Compressor 输出与原 KV 的差异
- mtp_consistency:MTP head 的预测与主 head 的一致率——两者都偏离表示 attention 出错
- end_to_end_perplexity_per_layer:每层去掉时的 perplexity 增量——找出”关键层”
这些指标在 V3 dense attention 时代不需要,但 V4 时代是生产部署的必备。
2.11 延伸阅读
- DeepSeek-V2 论文(arXiv:2405.04434):MLA 的源头
- DeepSeek-V3 技术报告(arXiv:2412.19437):V3 MLA 的全部细节
- Streaming LLM(arXiv:2309.17453):attention sink 的源头
- GQA 论文(arXiv:2305.13245):grouped query attention 的提出
- MQA 论文(arXiv:1911.02150):multi-query attention 的提出
- 本书《vLLM 推理内核深度解析》第 5 章 KV cache 管理——理解 PagedAttention 与 V4 的 KV cache 对接
2.12 本章小结
- V4 抛弃 latent KV,把 head_dim 拉到 512——每个 head 自带完整的 KV 语义带宽
num_key_value_heads=1让 V4 在结构上接近 MQA,但 head_dim 大到 512,单 token KV 占用与 V3 接近,少了 wkv_b 升维- grouped O 投影(n_groups=16, o_lora_rank=1024) 是 V4 的代数创新——以 1.6 倍参数换更强的局部 + 全局两段式表达力
- attn_sink 是稀疏注意力的”兜底参数”——保证当 Indexer 找不到合适 KV 时数值不崩
- Q 路径的 inline RMSNorm 是稀疏 attention 的精度保证,不可省
- KV 路径的”位置 / 语义分离 → BF16 / FP8 分离”是 V4 整个混合精度策略的缩影
第 3 章我们进入 V4 注意力革命的第二站:Compressor——逐层独立配置的 KV 压缩模块,是怎么把 1M token 的 KV 压到几 GB 的。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。