第5章 sparse_attn 与 FlashMLA:V4 路径下的 CUDA 内核
“Architecture without efficient kernels is poetry without rhyme.” —— 引自 NVIDIA 一位资深 CUDA 工程师
V4 的全部 attention 革命(MLA + Compressor + Indexer),最后必须落到一个 CUDA kernel 上才能产生工程价值——这个 kernel 就是 FlashMLA。
5.1 引子:从一行 PyTorch 到 GPU 上的真实计算
V4 的 Attention.forward 里 attention 的”实际计算”被压缩成一行:
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
参数:
q:query,形状[B, S, n_heads=128, head_dim=512],dtype 取决于 prefill / decodekv:KV cache(滑窗 + 压缩段拼起来的完整 KV),形状[B, kv_cache_size, head_dim=512]attn_sink:每 head 一个 float32 标量,形状[n_heads]topk_idxs:稀疏选取的 KV 位置索引,形状[B, S, window_size + index_topk] = [B, S, 1152]softmax_scale:缩放系数head_dim ** -0.5
sparse_attn 是一个从 kernel 模块导入的函数:
from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
而 kernel 模块本身是一个C++/CUDA 扩展——sparse_attn 的 PyTorch 入口只是个绑定,真正的计算发生在 GPU kernel 上。这个 kernel 被托管在 FlashMLA 仓库里,与 V4 同一周期开源。
flowchart TB
subgraph Python层
P1["Attention.forward 中:<br/>sparse_attn(q, kv, sink, idxs, scale)"]
P2["kernel.py 中的 sparse_attn binding"]
end
subgraph C++层
C1["torch::Tensor sparse_attn(...)"]
C2["dispatch by GPU arch (SM90/SM100)"]
end
subgraph CUDA层
K1["sparse_attn_kernel_sm100<<<...>>>"]
K2["针对 H100 / B200 的 ldsm + WGMMA 实现"]
end
P1 --> P2 --> C1 --> C2 --> K1 --> K2
本章拆这条调用链——从 Python 的 sparse_attn 到 CUDA kernel 的全部工程缝合。
5.2 sparse_attn 的接口语义
sparse_attn(q, kv, sink, idxs, scale) 的语义是:
对每个 query token 的每个 head,仅计算 q · kv[idxs] 的内积,做 softmax + 加权求和,然后跨 head 输出。
伪代码形式:
def sparse_attn(q, kv, attn_sink, topk_idxs, softmax_scale):
# q: [B, S, H, D]
# kv: [B, T, D] (T = window_size + max_seq_len // ratio)
# attn_sink: [H]
# topk_idxs: [B, S, K] (K = window_size + index_topk = 1152)
# softmax_scale: scalar
# 返回 o: [B, S, H, D]
B, S, H, D = q.shape
K = topk_idxs.shape[-1]
o = torch.zeros_like(q)
for b in range(B):
for s in range(S):
for h in range(H):
# gather K 个 KV 位置
gathered_kv = kv[b, topk_idxs[b, s]] # [K, D]
# mask 掉 -1 位置
valid = topk_idxs[b, s] >= 0
# logits = q · k / sqrt(D)
logits = (q[b, s, h] @ gathered_kv.T) * softmax_scale
logits = torch.where(valid, logits, float("-inf"))
# 拼接 attn_sink
logits_with_sink = torch.cat([logits, attn_sink[h:h+1]])
weights = F.softmax(logits_with_sink, dim=0)
# 加权求和(不包括 sink)
o[b, s, h] = (weights[:K, None] * gathered_kv).sum(dim=0)
return o
这只是语义模型——真实 CUDA 实现要快几个数量级。但理解这个伪代码后,CUDA 实现就是”如何把这段循环并行化、向量化、利用 TensorCore”的工程问题。
5.3 FlashMLA 仓库的代码组织
FlashMLA 仓库(github.com/deepseek-ai/FlashMLA)的代码组织:
FlashMLA/
├── csrc/ # C++ / CUDA 源码
│ ├── flash_mla/
│ │ ├── sparse_attn_sm100.cu # SM100 (B200) 路径
│ │ ├── sparse_attn_sm90.cu # SM90 (H100/H800) 路径
│ │ ├── dense_attn_sm90.cu # 兼容 V3.2 的 dense MLA
│ │ └── ...
│ ├── flash_mla_extension.cpp # PyTorch binding
│ └── ...
├── flash_mla/ # Python 包装
│ ├── __init__.py
│ └── kernel.py # 暴露 sparse_attn / fp8_gemm 等接口
├── tests/ # 测试
└── benchmark/ # 性能基准
V4 的 inference/model.py 直接 from kernel import sparse_attn——这个 kernel 模块就是 FlashMLA 编译安装后暴露的 Python 包。
FlashMLA 的关键设计:
- 每个 GPU 架构一个独立的 .cu 文件:SM90 (H100/H800) 与 SM100 (B200) 的 kernel 实现差异巨大——TMA、WGMMA、共享内存大小都不同,必须分开优化
- dense 与 sparse 路径并存:dense 路径服务 V3 / V3.2-Exp(dense MLA),sparse 路径服务 V4。一个仓库支持两代模型
- PyTorch binding 极薄:C++ 入口只做参数检查 + dispatch by arch + launch kernel,不含业务逻辑
5.4 稀疏 attention 的 GPU 优化挑战
稀疏 attention 比 dense attention 在 GPU 上更难优化。dense attention 的内存访问模式是连续的——每个 query 顺序读 KV cache 从 0 到 T。稀疏 attention 的内存访问模式是索引跳跃的——每个 query 按 topk_idxs 跳着读 KV。
这带来三个挑战:
挑战一:合并访问失效
GPU 的全局内存读取按 32-byte / 128-byte cache line 合并。dense attention 的 KV 读取连续,多个线程一次读取一条 cache line。稀疏 attention 的 KV 读取分散,多个线程可能读到不同的 cache line——内存带宽利用率下降。
FlashMLA 的解决方案:用 ldmatrix / TMA 指令做”非连续 gather”。SM90 / SM100 的 TMA(Tensor Memory Accelerator)支持基于索引的非连续读取,硬件层面解决合并访问问题。
挑战二:索引去重
如果两个相邻 query token 的 topk_idxs 有重叠(实际上重叠率往往很高),naive 实现会多次读同一个 KV。FlashMLA 的优化:先把所有 query 的 topk_idxs 做 union,按 union 后的位置读取 KV,再 gather 到每个 query。这把多次读取变成一次。
挑战三:tile 大小与 K 的关系
FlashAttention 的标准做法是把 KV 分成 tiles(典型 64 / 128 token 一个 tile),逐 tile 处理。稀疏 attention 的 K(每个 query 看 1152 个位置)与 tile 大小的关系决定了 kernel 效率:
- K << tile_size:每个 tile 内能处理的 query 多,但有空闲的 tile 槽位
- K >> tile_size:每个 query 要跨多个 tile,tile 边界的 softmax 归一化要做”online softmax”
FlashMLA 的 V4 路径选 K=1152、tile_size=128——意味着每个 query 跨 9 个 tile。kernel 用 online softmax 在 tile 间累积分子分母,最后归一化。
5.4·补 online softmax:稀疏 attention 的核心算法挑战
V4 的 sparse_attn kernel 必须解决一个问题:当一个 query 跨多个 tile 计算时,怎么在 tile 之间正确累积 softmax 的分子分母?
这个问题在 dense FlashAttention 里就已经存在——每个 query 看完整 KV,KV 被分 tile 处理。FlashAttention 的解法是 online softmax 算法(Milakov & Gimelshein 2018):
对于每个 tile, 维护两个 running quantity:
m: 当前已见 logits 的最大值
l: 当前的 sum of exp(logits - m)
收到新 tile 的 logits 时:
m_new = max(m, max(new_logits))
l_new = exp(m - m_new) * l + sum(exp(new_logits - m_new))
o_new = exp(m - m_new) * o + sum(exp(new_logits - m_new) * new_v)
最后归一化:
o = o / l
这个算法的妙处在于 m 和 l 的更新可以增量进行,且数学上等价于 dense softmax。
V4 的 sparse_attn 的 online softmax 多了一个复杂性:attn_sink。sink 不是某个 tile 里的位置,而是一个 head-level 的常数 logit——必须把它”虚拟”地参与到 softmax 归一化里。
具体处理:
- 初始化 m = attn_sink[h],l = 1(exp(0) = 1,因为 sink 的 logit 已经在 m 里)
- 处理每个 KV tile 时,用上述公式更新 m / l / o
- 注意 o 的累积不包括 sink 项(sink 不贡献输出向量,只参与归一化)
这种”sink-aware online softmax” 是 V4 sparse_attn 的特化部分,FlashAttention v3 的标准 online softmax 没有这个能力。
flowchart LR
subgraph TileFlow["online softmax 的 tile 流"]
T0["tile 0: 64 个 KV"] -->|更新 m,l,o| T1["tile 1: 64 个 KV"]
T1 -->|更新 m,l,o| T2["tile 2: ..."]
T2 -->|...| TN["tile N"]
TN -->|最后除以 l| Out["o 输出"]
end
Sink["attn_sink 作为<br/>初始 m"] -.参与归一化.-> T0
V4 的 sparse_attn kernel 在每个 query 上跑这套 sink-aware online softmax,跨 9 个 KV tiles(K=1152 / tile_size=128)累积,最后归一化输出。
5.4·补·补 索引 gather 的硬件路径
稀疏 attention 在 SM90 / SM100 上的”索引 gather”是 V4 工程化的一个具体硬件挑战。让我们看一段简化的 CUDA 伪代码理解它在硬件层做了什么:
// 简化的稀疏 KV gather kernel(C++/CUDA)
__global__ void gather_kv_for_sparse_attn(
const __nv_fp8_e4m3* kv_global, // [B, T, D] 全部 KV
const int* topk_idxs_global, // [B, S, K] topk 索引
__nv_fp8_e4m3* kv_gathered_smem, // shared memory 输出
int B, int S, int T, int D, int K)
{
int b = blockIdx.x;
int s = blockIdx.y;
int tid = threadIdx.x;
// 每个 thread 负责 K/blockDim.x 个 KV 位置
for (int k_local = tid; k_local < K; k_local += blockDim.x) {
int kv_idx = topk_idxs_global[b * S * K + s * K + k_local];
if (kv_idx < 0) continue; // mask -1 invalid
// gather D 维 (D=512)
// SM90: 用 cp.async.bulk + ldmatrix 做异步 gather
// SM100: 用 TMA 的 indexed mode
int dst_offset = k_local * D;
int src_offset = b * T * D + kv_idx * D;
// ... TMA / cp.async copy ...
}
}
实际 FlashMLA 的 kernel 比这复杂得多——
- 用
cp.async.bulk异步从 global 拷到 shared,重叠拷贝与计算 - 用
ldmatrix.x4一次加载 4 个 8x16 矩阵到寄存器 - TensorCore 用 WGMMA 指令做 q @ k^T 的 GEMM
但核心思想是:用硬件原生的”非连续 gather”指令把 topk 选取的 KV 拷到 SMEM,然后正常做 attention。FlashMLA 的工程价值就在于”把这个 gather + GEMM + softmax 流水线优化到接近硬件极限”。
5.5 SM90 vs SM100:两套 kernel 的差异
V4 同时支持 H100/H800(SM90)和 B200(SM100)。FlashMLA 给两套架构写了完全独立的 kernel:
| 维度 | SM90 (H100/H800) | SM100 (B200) |
|---|---|---|
| FP8 GEMM 指令 | WGMMA (Warp Group MMA) | WGMMA (改进版) + UMMA |
| FP4 GEMM 指令 | 模拟(FP8 → FP4 算子分解) | 原生 FP4 MMA |
| TMA | 1D / 2D | 1D / 2D / 3D + tensor map |
| 共享内存 | 228 KB / SM | 256 KB / SM |
| L2 cache | 50 MB | 60 MB |
| Threadblock size | 4 warps / 8 warps | 8 warps / 16 warps |
| 关键差异 | softmax 在 SMEM 内做 | softmax 跨 SMEM + L2 |
V4 在 H100 与 B200 上的 throughput 差异(README 公开数字):
- H100 上 V4 Pro decode 吞吐约 410 TFlops(FP8)
- B200 上 V4 Pro decode 吞吐约 640 TFlops(FP8)
差异主要来自 B200 的原生 FP4 MMA + 更大的 L2。SM90 路径的 FP4 因为是模拟(先反量化到 FP8 再 GEMM),效率比原生差不少——这是 V4 在 B200 上有”更高占有率”的硬件原因。
5.6 V4 sparse_attn 与 FlashAttention v3 的对比
FlashAttention v3 是 dense attention 在 H100 / B200 上的事实标准。把 V4 的 sparse_attn 与 FA3 横向对比:
| 维度 | FlashAttention v3 | V4 sparse_attn |
|---|---|---|
| KV 访问模式 | 连续 | 索引 gather |
| 支持的 head_dim | 32 / 64 / 128 / 256 | 必须 512 |
| 支持的稀疏度 | 不支持 | top-K |
| 是否支持 sink | 不支持 | 支持 |
| 是否支持 grouped O | 不支持(O 投影外置) | 不支持(O 投影也外置) |
| 主要应用场景 | dense attention 全部模型 | V3 (dense MLA) + V4 |
| 长上下文成本 | O(n²) FLOPs | O(n) FLOPs (K 固定) |
V4 的 sparse_attn 不能替代 FA3——它是为 V4 这种”超大 head_dim + 稀疏选取 + sink”的特定形态量身定做的。反过来 FA3 也不能跑 V4——FA3 的 dense KV 假设与 V4 的 topk_idxs 接口完全不兼容。
这种”特化 kernel”的代价是 V4 必须自带 FlashMLA。但带来的红利是:V4 的稀疏注意力在 H100 上能跑 410 TFlops——这个数字是 FA3 跑 dense attention 的 75% 左右,意味着稀疏 attention 在工程上已经追上 dense attention 的效率。
5.7 vLLM / SGLang 集成 sparse_attn 的工程接缝
把 sparse_attn 集成进 vLLM / SGLang 这类推理引擎,至少要做四件事:
事项一:编译 FlashMLA 为可链接库
FlashMLA 的 C++/CUDA 必须用 -arch=sm_90 或 sm_100 编译,且需要 CUDA 12.8+。集成时需要:
- 把 FlashMLA 加到引擎的 wheel 构建脚本
- 处理”用户的硬件不是 H100/B200” 的回退(一般回退到 PyTorch + Triton 实现)
事项二:传 topk_idxs 给 kernel
引擎需要在每次 forward 时为每个 attention layer 计算 topk_idxs——这意味着引擎要 invoke Indexer,传 query / 中间表示给 Indexer,再把 Indexer 的输出送给 sparse_attn。这个 dataflow 在 vLLM / SGLang 之前的代码里完全不存在——必须新增。
事项三:KV cache 形状改造
vLLM 默认的 PagedAttention KV cache 形状是 [block, block_size, head, dim]。V4 的 KV cache 是 [B, window + n/ratio, dim]——不分 head(MQA-style),不按 block 切。集成时要么扩展 PagedAttention 的形状抽象、要么新增一个”V4-style KV cache” 类型。第 19 章会展开 vLLM PR 的具体改动。
事项四:与调度器的协调
V4 的 prefill / decode 走不同 codepath,且 prefill 有”压缩 KV 一次性算”的批量步骤。vLLM 的调度器需要识别”这是 V4 模型”,分别调度 prefill 和 decode 阶段,避免把它们错误合并到同一个 CUDA stream。
第 19 章会针对 vLLM 主仓库的 V4 适配 PR 逐改动展开。本章只到这里——给读者建立”sparse_attn 需要外部生态怎么配合”的全局认知。
5.8 动手实验:跑通 FlashMLA 的 V4 测试
# 1. 拉取 FlashMLA
git clone https://github.com/deepseek-ai/FlashMLA.git
cd FlashMLA
# 2. 编译(需要 CUDA 12.8+ 和 H100/H800 或 B200)
pip install -e .
# 3. 跑 V4 路径的单元测试
python -m pytest tests/test_sparse_attn.py -v -k v4
# 4. 跑性能基准(输出 TFlops 数字)
python benchmark/bench_sparse_attn.py --arch sm90 --seq-len 1048576 --topk 1024
如果你没有 H100/H800,可以用 nVIDIA 的 PTX 模拟器或者改 arch=sm_80(A100,但需要回退到 Triton 实现,性能差得多)。
测试通过后,会得到一个对照表:dense attention vs sparse_attn 的 TFlops、显存占用、首 token 延迟。把这个表对照本书第 1 章 §1.9·补·补 的”README 三组数字”,能看到工程数字与营销数字的吻合度。
5.8·补 sparse_attn 在不同 batch / context 下的性能特征
V4 的 sparse_attn 性能不是恒定的——它随 batch size、context 长度、稀疏比例呈非线性变化。把这条性能曲线的几个关键拐点标出来:
拐点 1:context = 32K
context ≤ 32K 时,sparse_attn 的优势相对小——dense attention 在 32K 下 KV 也才几 GB,FLOPs 也可承受。sparse_attn 的”选 top-1024”反而引入了 Indexer 的额外成本。这个范围内 sparse_attn 比 dense 快约 1.2-1.5 倍。
拐点 2:context = 128K
context = 128K 时差异显著拉开——dense 的 attention FLOPs 是 O(n²) 二次增长,sparse_attn 的 FLOPs 是 O(n × 1152) 几乎线性。这个范围内 sparse_attn 比 dense 快约 5-8 倍。
拐点 3:context = 1M
context = 1M 是 V4 的设计目标。dense attention 在这里完全不可用(FLOPs 爆、KV 爆),sparse_attn 仍能维持大约 80-90% 的 32K 单 token 吞吐。这个范围内 sparse_attn 是”唯一可行方案”——速度对比变得无意义。
拐点 4:batch = 32+
随 batch 增大,sparse_attn 的优势从”FLOPs”转向”显存”——dense 的 KV cache 在大 batch 下爆掉,sparse_attn 的 KV cache 仍在可控范围。这个范围内 sparse_attn 让”原本只能跑 batch=4 的硬件能跑 batch=32”——并发能力提升 8 倍。
拐点 5:稀疏比例(topk / total_kv)= 1%
V4 默认 topk=1024,1M context 下稀疏比例约 0.4%。如果应用场景把 topk 调到 4096(稀疏比 1.6%),sparse_attn 的吞吐下降约 40%——但精度提升约 5%。这种”精度 vs 速度” 的 trade-off 是部署时可以调的。
理解这条性能曲线对容量规划 极重要——你要根据自己的典型 context 长度选择”V4 是否真的合适”——short context + 大 batch 用户其实可以用更小的模型。
5.8·补·补 sparse_attn 与 GPU SM 占用率
sparse_attn 在 GPU 上的”占用率”(SM utilization)是衡量 kernel 优化质量的关键指标。
理论上限:H100 有 132 个 SM。每个 SM 同时跑多个 warp(线程块)。sparse_attn 的理论上限是”所有 SM 都满载在跑 attention 计算”——约 95% 占用率。
实际占用率:FlashMLA 的 sparse_attn 在 H100 上典型占用率 80-85%。差距来自:
- 索引 gather 的等待(部分 SM 在等 cp.async 完成)
- tile 边界的 softmax 同步开销
- Indexer 的输出尚未到达时 sparse_attn 必须等待
B200 上的占用率:B200 有 192 个 SM + 原生 FP4。FlashMLA 在 B200 上占用率 90%+——更接近理论上限。这是 V4 在 B200 上比 H100 快 1.6 倍的内在原因。
GPU 时钟与温度:实际占用率还受 GPU 物理状态影响。H100 在 boost clock + 良好散热下占用率最高;如果 thermal throttle,占用率会降到 70%。生产部署时必须监控 nvidia-smi 的 power / temp / clock。
与其他 kernel 协同:sparse_attn 不是孤立运行——它与 Linear(DeepGEMM)、与 RMSNorm、与 Compressor 共享 SM 资源。如果其他 kernel 占太多 SM,sparse_attn 会被”挤压”。V4 的解决:用 CUDA Graph 把多个 kernel 编排成”流水线”,让 SM 利用率持续高位。
理解 SM 占用率让你能判断”sparse_attn 是否被瓶颈”——如果 nvidia-smi 显示 GPU 利用率只有 60%,sparse_attn 一定不是瓶颈,问题在其他地方(数据加载 / 网络 / Python 调度)。
5.9 延伸阅读
- FlashAttention v3(arXiv:2407.08608):dense attention 的 SM90 实现
- Native Sparse Attention(arXiv:2502.11089):稀疏 attention 训练的理论
- DeepSeek FlashMLA 仓库 README:本章主要参考
- 本书《vLLM 推理内核深度解析》第 4-5 章:PagedAttention 与 V4 的 KV cache 对接
- 本书第 19 章:vLLM 主仓库 V4 适配 PR 的全部改动
5.9·补 sparse_attn 的”工程债务”清单
任何工业级 kernel 都有它的工程债务——为快速发布而留下的”以后再优化” 项。FlashMLA 的 sparse_attn 也有。把可观察到的工程债务列出来:
债务 1:dense 路径与 sparse 路径并存
FlashMLA 仓库同时维护 dense 路径(给 V3 / V3.2-Exp)和 sparse 路径(给 V4)。两套代码在某些功能上重叠(如 KV cache 管理),但实现独立。未来某个版本可能把它们统一——但短期内为了不引入回归风险,保留并存。
债务 2:SM80(A100)路径缺失
FlashMLA 主要支持 SM90 / SM100。SM80(A100)没有原生 FP8 / FP4,需要软件模拟——FlashMLA 没有为此专门优化。如果你想在 A100 上跑 V4,sparse_attn 会有显著性能损失。
债务 3:动态稀疏度不灵活
V4 的 topk=1024 是 config 写死的。FlashMLA 内部 tile_size 与 K=1152 紧密绑定——如果 fine-tune 模型把 topk 改成 512 或 2048,需要重新调 kernel 的 tile size。短期内只能用预定的几个固定值。
债务 4:与 cuBLAS / cutlass 的协同
vLLM 部署时 sparse_attn 与 cuBLAS(其他模型的 GEMM)共享 GPU。但 FlashMLA 的 sparse_attn 占用的 SM 数固定(通过 set_num_sms),不会根据 cuBLAS 的负载动态调整。理想情况下应该有 SM 调度协同——目前是工程债务。
债务 5:错误诊断信息有限
如果 sparse_attn 跑出错(如 topk_idxs 越界),错误信息通常是 CUDA error,不直接指向问题位置。FlashMLA 缺一套 debug build——开启后可以打印每个 tile 的状态。这是开源项目的常见短板。
理解这些债务让你在使用 FlashMLA 时心里有底——遇到问题时知道去哪里查、知道哪些限制是”暂时的”,避免被”看似不一致的现象” 困惑。
5.9·补·补 sparse_attn 工程师速记
部署或调试 V4 sparse_attn 时最常用的几条速记规则。打印一张贴在工位上:
速记 1:版本要求
- CUDA 12.8+
- PyTorch 2.4+
- H100 / H800(SM90)或 B200(SM100)
- A100 / 老 GPU 不支持
速记 2:性能数字(H100 上)
- FP8 GEMM 峰值:~1300 TFlops
- sparse_attn 峰值:~410 TFlops
- B200 大约提升 1.6x
速记 3:典型形状参数
- topk_idxs 大小:[B, S, 1152](=window_size 128 + index_topk 1024)
- KV cache 大小:每层 [B, 滑窗 + n/ratio, 512]
- 单序列 KV cache 总和:~8 GB(1M context)
速记 4:常见错误信号
- “RuntimeError: invalid argument”:GPU 不支持 SM 架构
- “CUDA error: invalid configuration”:tile size 不匹配(通常 SMEM 不够)
- 输出 NaN:q 没正确量化或 attn_sink 没初始化
速记 5:优化优先级
- 检查 GPU 利用率(nvidia-smi)—— 低于 80% 说明是 IO/通信瓶颈,sparse_attn 不是元凶
- 检查 tile size 与 SMEM 余量
- 检查 stream 配置(DeepGEMM 与 sparse_attn 是否在同一 stream 串行)
速记 6:与 vLLM 集成的关键文件
vllm/attention/backends/flash_mla.py(V4 的 attention backend)vllm/model_executor/models/deepseek_v4.py(V4 model class)vllm/distributed/device_communicators/(DeepEP 集成)
这些速记是”工业实战经验” 的浓缩——不需要每次都翻文档,速记能解决 80% 的日常问题。
5.9·延展 sparse_attn 与 dense FlashAttention 的”代码量对比”
把 V4 的 sparse_attn 与 FlashAttention v3(dense)的代码量、复杂度、可读性对比一下——这能让你直观感受 sparse 路径的工程额外开销。
FlashAttention v3 (dense):
- 主要 .cu 文件:3-5 个,每个数百行
- 核心算法:online softmax + 分块 KV
- 接口:
flash_attn_func(q, k, v, ...) - 调用方式:直接传完整 KV,kernel 内部分 tile 处理
- 数学复杂度:低(标准 attention)
FlashMLA sparse_attn (V4 路径):
- 主要 .cu 文件:6-10 个(SM90 + SM100 各一套)
- 核心算法:sink-aware online softmax + 索引 gather + tile 内 sparse mask
- 接口:
sparse_attn(q, kv, sink, idxs, scale)—— 多 2 个参数 - 调用方式:传 KV cache + topk 索引,kernel gather 后再算
- 数学复杂度:中(sink + sparse 选取)
代码量差异:
sparse_attn 比 dense 多约 2-3x 代码量——主要在索引 gather、sink 处理、tile 边界对齐上。
可读性差异:
dense FlashAttention v3 已经是 CUDA kernel 中的”复杂代码”——理解它需要懂 WGMMA、TMA、async copy。sparse_attn 在此基础上再加一层”非连续访问”的复杂度,理解成本约高 50%。
性能差异:
dense FlashAttention v3 在 H100 上达到 ~600 TFlops(接近 FP8 GEMM 峰值)。sparse_attn 在 H100 上达到 ~410 TFlops——比 dense 低 30%,但因为只算 K=1152 个位置(dense 算所有),在长 context 下整体仍然快 5-10x。
这种”单位 FLOPs 慢、但总 FLOPs 少” 的工程权衡正是 sparse 路径的本质——用更少的 FLOPs 做出与 dense 相当的输出。
5.9·拓展 sparse_attn 与”动态批处理(continuous batching)“的协同
vLLM 的核心调度优化是”continuous batching”——多个请求的 prompt / decode 共享 GPU 跑在一个 batch 里。V4 的 sparse_attn 与这套调度有几个微妙的协同点。
协同点 1:每请求独立的 topk_idxs
continuous batching 把多个请求的 query 拼到一起跑 attention。V4 的 sparse_attn 接受 [B, S, K] 的 topk_idxs——每条 sequence 有独立的稀疏选择。这与 dense attention 兼容(dense 用 padding mask)。
协同点 2:长短 prompt 共存
某些请求是 64K context、某些是 1K——一起 batch 时长 prompt 的 KV cache 占大头,但 sparse_attn 的计算仍可控(K=1152 固定)。这是 V4 在 mixed workload 下的红利。
协同点 3:prefill / decode 的混合 batching
vLLM 的 chunked prefill 让 prefill 与 decode 在同一 batch 内跑。V4 的 sparse_attn 需要分别处理两种 phase——prefill 走批量 codepath、decode 走增量 codepath。kernel 需要支持”同一 batch 内同时跑两种 codepath”。
协同点 4:prefix caching 与稀疏 KV
vLLM 的 prefix caching 复用相同前缀的 KV——多个请求共享 KV cache 的某部分。V4 的稀疏 KV 让这种复用更复杂——压缩 KV 段可以共享,滑窗段每请求独立。详见第 19 章关于 SGLang RadixAttention 的讨论。
协同点 5:抢占(preemption)
vLLM 的调度器会抢占 / 恢复请求——把”被抢占请求的 KV cache” 暂存到 CPU 后再恢复。V4 的稀疏 KV 也支持这种 swap——但 swap 单位变成”滑窗段 + 压缩段”两块,不再是单一 KV block。
理解这些协同点让你在 vLLM 中正确部署 V4——不会因为”V4 行为与 dense 模型不同”而踩坑。
5.10 本章小结
- V4 的
sparse_attn在 PyTorch 是一行调用,背后是 FlashMLA 仓库里数千行 CUDA 代码 - 稀疏 attention 在 GPU 上比 dense attention 难优化——索引跳跃、tile 跨界、softmax 跨 SMEM 都是工程挑战
- FlashMLA 的 V4 路径用 ldmatrix / TMA + online softmax 解决了主要挑战,在 H100 上能跑 410 TFlops
- B200 (SM100) 因为有原生 FP4 MMA,比 H100 快约 1.6 倍——V4 在 B200 上的红利明显
- 集成 sparse_attn 进 vLLM / SGLang 等引擎需要四件事:编 FlashMLA、传 topk_idxs、改 KV cache、调度器协调
第 6 章我们离开 attention 内部,来到 V4 长上下文工程的另一支柱:YaRN RoPE——它怎么把 65K 训练上下文外推到 1M。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。