第15章 Tensor / Expert 并行:ColumnParallel 与 RowParallel
“A 1.6T model lives across machines. Every forward is a coordinated dance.” —— V3 工程师内部分享
V4 的并行策略不是事后补丁,而是写在每个 Linear 类里的”出生属性”。
15.1 引子:1.6T 模型的部署算式
V4 Pro 总参 1.6T。即便用 FP4 + FP8 混合精度(平均约 0.7 字节 / 参数),权重大小 ~1.1 TB。
单卡 H100 80GB 显然装不下。即使是 H200 141GB 也不行。V4 必须用分布式部署——典型方案:
- 8 卡 NVLink 节点(H100 / H200):模型权重切到 8 卡,每卡约 140 GB
- 16 卡多节点(2 × 8 卡 IB 互联):权重切到 16 卡,每卡约 70 GB
- 32 卡 / 64 卡:通常给极大 batch 或低延迟需求
无论哪种部署,模型必须在张量层面切分——不能简单复制。V4 用两个核心机制:
- Tensor Parallel (TP):把每层 Linear 的权重沿 row / column 切到不同卡
- Expert Parallel (EP):把 384 个 routed expert 分配到不同卡,每卡持有一部分
V4 源码 (inference/model.py) 通过 world_size、rank 这对全局变量,把切分逻辑直接写在每个并行类里。本章拆这 4 个类。
15.2 全局并行状态
V4 在文件顶部声明:
world_size = 1
rank = 0
block_size = 128
这三个变量在 Transformer.__init__ 被覆盖:
def __init__(self, args: ModelArgs):
global world_size, rank, default_dtype, scale_fmt, scale_dtype
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
...
V4 用 PyTorch 的 torch.distributed 做通信——world_size = 总 GPU 数、rank = 当前 GPU 编号。这两个值在每个进程里是一致的——通过 torchrun 或 mpirun 启动时由 launcher 设置。
V4 让这两个变量是全局可变,而不是每个 module 持有自己的引用——这种”全局可变变量”的设计违反传统软件工程纪律,但在 V4 这种”模块嵌套深 + 配置一次贯穿全程”的场景下能大幅简化代码。
15.3 ColumnParallelLinear:输出维度切分
class ColumnParallelLinear(Linear):
"""Shards output dim across TP ranks. No all-reduce needed on output."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return linear(x, self.weight, self.bias)
ColumnParallelLinear 的语义:把 [in, out] 矩阵沿 out 维度切成 world_size 份。每个 rank 持有 [in, out / world_size] 的局部矩阵。
forward 时:每个 rank 独立计算局部输出 [B, S, out / world_size]——不需要通信。
输出维度被切了,怎么办?要么:
- 后续操作能直接处理切分后的输出(如紧接 RowParallelLinear,正好把切分维度作为输入维度)
- 或者外部显式 all_gather 重组
V4 的典型用法:q、k、v、wq_b 等都是 ColumnParallelLinear——输出的 head 维度被天然切分到不同 rank,后续 attention 计算完后通过 RowParallelLinear 自然合并。
15.4 RowParallelLinear:输入维度切分
class RowParallelLinear(Linear):
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert in_features % world_size == 0
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = linear(x, self.weight, None)
if world_size > 1:
y = y.float()
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y.type_as(x)
RowParallelLinear 的语义:把 [in, out] 矩阵沿 in 维度切。每个 rank 持有 [in / world_size, out] 的局部矩阵。
forward 时:每个 rank 用局部 in 输入算局部输出,输出形状 [B, S, out](完整 out 维度),但是数值上只是”部分和”。用 all_reduce 把所有 rank 的部分和加起来,得到完整输出。
注意 V4 的实现:先转 float32 再 all_reduce、加 bias、最后转回原 dtype。这是为了:
- all_reduce 的精度敏感——不同 rank 的部分和量级可能不同,FP32 累加更稳
- bias 在 reduce 之后加,避免被 reduce 重复 N 倍
ColumnParallel + RowParallel 的经典组合:
ColumnParallelLinear(in=D, out=H × heads) # 输出按 head 维度切
↓ (输出: [B, S, head/world × D_h], 每个 rank 各持一部分 head)
ColumnParallel attention 内部计算
↓
RowParallelLinear(in=H × heads, out=D) # 输入按 head 切,输出 reduce
↓ all_reduce 后得到完整 [B, S, D]
这套组合在 V4 的 Attention 类里直接可见——wq_b 是 ColumnParallel,wo_a 是 ColumnParallel,wo_b 是 RowParallel。
15.4·补 ColumnParallel + RowParallel 经典组合的张量流
V4 attention 内部的并行流转最典型地体现了这套组合:
flowchart LR
X["x: [B,S,7168]<br/>每 rank 完整副本"] --> ColLin["ColumnParallel<br/>wq_b: 切 out_features"]
ColLin --> Q["q: [B,S, n_heads/world × head_dim]<br/>每 rank 不同 head 切片"]
Q --> Compute["attention 计算<br/>(rank 内独立)"]
Compute --> O["o: [B,S, n_heads/world × head_dim]<br/>仍是 head 切片"]
O --> RowLin["RowParallel<br/>wo_b: 切 in_features"]
RowLin --> Partial["partial_output: [B,S,7168]<br/>每 rank 仅算了部分和"]
Partial --> AllReduce[("all_reduce<br/>跨 rank 累加")]
AllReduce --> Final["完整 x: [B,S,7168]<br/>每 rank 都有相同结果"]
classDef parallel fill:#312e81,stroke:#a78bfa,color:#ede9fe
classDef sync fill:#7c2d12,stroke:#fb923c,color:#ffedd5
class ColLin,RowLin parallel
class AllReduce sync
整段 attention 只有最后一次 all_reduce 通信——通信量 O(B×S×D),远小于 attention 内部的 O(B×S×n_heads×head_dim) 计算量。这就是”TP 在 attention 上几乎免费”的工程账。
15.5 ParallelEmbedding:vocab 维度切分
class ParallelEmbedding(nn.Module):
def __init__(self, vocab_size: int, dim: int):
super().__init__()
...
assert vocab_size % world_size == 0
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
ParallelEmbedding 的语义:把 vocab 切到不同 rank——rank 0 持有 token id [0, V/8),rank 1 持有 [V/8, 2V/8) …
forward 的关键技巧:
- 每个 rank 算”自己负责的 token id”的 embedding——不属于本 rank 的 token id 被 mask 成 0
dist.all_reduce把所有 rank 的部分 embedding 累加——因为每个 token id 只有一个 rank 真正算了,all_reduce 实际等于”取那个 rank 的输出”
这种”mask + all_reduce”的写法看起来浪费——为什么不直接用 dist.all_to_all 或者点对点通信?答案是 all_reduce 的硬件支持最成熟、延迟最低——比 all_to_all 快得多。在 V4 这种 vocab=129280、单 token embedding 只有 7168 维的场景下,all_reduce 几乎是零开销。
15.6 ParallelHead:lm_head 的 vocab gather
class ParallelHead(nn.Module):
def __init__(self, vocab_size: int, dim: int, ...):
super().__init__()
...
self.part_vocab_size = (vocab_size // world_size)
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
def get_logits(self, x):
return F.linear(x[:, -1].float(), self.weight)
def forward(self, x, hc_fn, hc_scale, hc_base, norm):
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
logits = self.get_logits(norm(x))
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
ParallelHead 与 ParallelEmbedding 对偶:vocab 切到不同 rank、每个 rank 算”自己负责的部分 logits”、最后 all_gather 把所有 rank 的 logits 拼起来。
注意几个细节:
细节 1:只算最后一个 token 的 logits
return F.linear(x[:, -1].float(), self.weight)
x[:, -1] 取每条序列的最后一个 token——LLM 推理通常只关心下一个 token 的 logits。这避免了为序列中间 token 算 logits 的浪费。
细节 2:weight 是 float32
V4 的 lm_head 权重保 FP32(不像 attention / FFN 走 FP4 / FP8)。这是因为 logits 直接决定 next-token 概率,精度损失会被反复采样放大——必须保 FP32。
细节 3:all_gather 而非 all_reduce
vocab 在不同 rank 上是互不相交的——每个 rank 算的 logits 是 vocab 不同切片的真实结果,不是部分和。all_gather 把这些”真实切片”拼起来;如果用 all_reduce,会把不同切片错误相加。
15.7 Expert Parallel:384 expert 怎么切到 8 卡
V4 的 MoE 类里:
class MoE(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
...
assert args.n_routed_experts % world_size == 0
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
...
self.experts = nn.ModuleList([
Expert(...) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)
])
self.shared_experts = Expert(...) # 每 rank 都有完整 shared expert
8 卡 TP 部署 V4 Pro:384 / 8 = 48 个 expert / rank。每 rank 的 nn.ModuleList 里:
- 自己持有的 48 个位置是真实 Expert 实例
- 其他 336 个位置是 None
forward 时:
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx], weights[idx, top, None])
if world_size > 1:
dist.all_reduce(y)
y += self.shared_experts(x)
每 rank 只跑自己持有的 48 个 expert,输出贡献到 y。最后 all_reduce 把所有 rank 的部分输出累加。
通信成本:每层 MoE 一次 all_reduce on [B*S, D]。对于 V4 Pro 在 1M context + batch=8 下:
- y 大小:8 × 1048576 × 7168 × 2 bytes (BF16) = 120 GB / step
- 显然这不可能——all_reduce 不可能在每层 layer 上传 120 GB
实际上 V4 在生产部署时不会让 1M 序列经过所有 61 层 MoE——会用 prefill / decode 分阶段,且 MoE 的 all_reduce 在 decode 时只走 1 个 token(D / 8 卡 ≈ 数 KB)。
第 16 章会深入这部分通信优化——DeepEP 给 MoE 提供了比 NCCL all_reduce 更快的”专用通信库”。
15.8 一段 Attention 的并行流转
把 V4 的 Attention 在 8 卡 TP 下的并行流转走一遍:
flowchart TB X["x: [B, S, 7168]<br/>每 rank 都有完整副本"] --> ColLin1["ColumnParallelLinear (wq_b)<br/>每 rank 持 [1536, 16 heads × 512]"] ColLin1 --> Q["q: [B, S, 16 local heads, 512]<br/>每 rank 不同 head"] X --> Lin["Linear (wkv)<br/>每 rank 完整副本"] Lin --> KV["kv: [B, S, 512]<br/>每 rank 完整副本 (MQA)"] Q --> SparseAttn["sparse_attn<br/>每 rank 只算自己的 head"] KV --> SparseAttn SparseAttn --> O["o: [B, S, 16 local heads, 512]"] O --> ColLin2["wo_a (ColumnParallel)<br/>每 rank 持 [n_heads × 512 / 16, 16 × 1024]"] ColLin2 --> OLora["o_lora: [B, S, 16 / 8 groups, 1024]"] OLora --> RowLin["wo_b (RowParallel)<br/>每 rank 持 [16 × 1024 / 8, 7168]"] RowLin --> Reduce["all_reduce"] Reduce --> XOut["x_out: [B, S, 7168]<br/>每 rank 都有完整结果"]
整个 attention 的通信只有最后一次 all_reduce on [B, S, 7168]——其他都是 rank 内独立计算。这是 TP 在 attention 上的高效之处:通信量 O(B × S × D),远小于参数量 O(D²)。
15.9 与 Megatron-LM 的对比
V4 的并行类与 Megatron-LM 的 ColumnParallelLinear / RowParallelLinear 有同源思路(Megatron-LM 是这套方案的工业化先驱):
| 维度 | Megatron-LM | V4 inference/model.py |
|---|---|---|
| ColumnParallel | 完整支持 | 完整支持 |
| RowParallel | 完整支持 | 完整支持 |
| Embedding | VocabParallelEmbedding | ParallelEmbedding (同名不同实现) |
| LM Head | VocabParallelOutputLayer | ParallelHead (含 hc_head 处理) |
| Sequence Parallel | 支持 | 不支持(V4 用稀疏 attention 不需要) |
| Pipeline Parallel | 支持 | 不支持(推理代码无 PP) |
| 代码量 | 数千行 | ~50 行 |
V4 的并行实现非常简洁——只覆盖推理需要的部分。训练时的 Pipeline Parallel、Sequence Parallel 等更复杂机制由训练框架(不公开的内部代码)处理,与 inference/model.py 解耦。
这种”训练 / 推理代码分开”的设计让公开的推理代码极其简洁——读者不需要被训练栈的复杂性淹没。
15.10 通信开销估算
把 V4 Pro 在 8 卡 TP / 16 卡 TP+EP 下的通信开销估算一下(每 token decode):
| 通信操作 | 张量大小 | 频率 | 总通信量 / token |
|---|---|---|---|
| Attention all_reduce | [1, 7168] BF16 = 14 KB | 每层 1 次 × 61 层 | 854 KB |
| MoE expert all_reduce | [1, 7168] BF16 = 14 KB | 每层 1 次 × 61 层 | 854 KB |
| LM Head all_gather | [vocab/8] FP32 = 16 KB × 8 | 1 次 | 130 KB |
| Embedding all_reduce | [1, 7168] BF16 = 14 KB | 1 次 | 14 KB |
| 总计 | - | - | ~1.85 MB / token |
NVLink 带宽约 600 GB/s(H100 之间),所以单 token 通信耗时约 3 μs——对 50 ms / token 的 decode 几乎可以忽略。
但在 prefill 阶段,序列长度 S 进入公式——通信量乘以 S。1M context 的 prefill 一次通信量约 2 GB——这时通信成本变得显著。
第 16 章 DeepEP 主要针对 prefill 的 all-to-all 优化。
15.11 动手实验:跑通最小 TP 推理
# 启动 2 卡 TP(在单机上模拟)
# torchrun --nproc-per-node=2 --master-port=29500 minimal_tp.py
import os
import torch
import torch.distributed as dist
import torch.nn as nn
dist.init_process_group(backend='nccl')
world_size = dist.get_world_size()
rank = dist.get_rank()
device = torch.device(f'cuda:{rank}')
class ColParallel(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
assert out_dim % world_size == 0
self.weight = nn.Parameter(torch.randn(out_dim // world_size, in_dim, device=device))
def forward(self, x):
return x @ self.weight.T
class RowParallel(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
assert in_dim % world_size == 0
self.weight = nn.Parameter(torch.randn(out_dim, in_dim // world_size, device=device))
def forward(self, x):
y = x @ self.weight.T
dist.all_reduce(y)
return y
# 2 卡上测试
m1 = ColParallel(128, 256)
m2 = RowParallel(256, 128)
x = torch.randn(4, 128, device=device)
y = m2(m1(x))
print(f"Rank {rank}: y shape = {y.shape}, mean = {y.mean().item():.4f}")
dist.destroy_process_group()
跑这段代码,会在 2 卡上看到一致的 y mean——证明 ColumnParallel + RowParallel 的通信正确组合得到了与单卡等价的结果。
15.11·补 V4 并行策略的”代码极简” 哲学
把 V4 的并行实现与 Megatron-LM / DeepSpeed 等成熟训练框架对比,最显著的差异是代码量——V4 的 4 个并行类合计约 50 行核心代码,Megatron-LM 同等功能要数千行。这种极简来自几个设计选择:
选择 1:把训练复杂度外置
Megatron-LM 是训练框架,必须处理优化器状态分片、梯度累积、Pipeline Parallel、混合精度训练等复杂场景。V4 的 inference/model.py 是推理代码——所有训练复杂度被外置到不公开的训练栈,只保留推理需要的最小并行抽象。
选择 2:放弃通用 API
Megatron-LM 的 ColumnParallelLinear 支持丰富配置——gather_output / async_tensor_model_parallel_allreduce / sequence_parallel 等十几个开关。V4 的版本只有一种行为——按 out_features 切分、forward 不通信。这种”放弃通用性”让代码极简、性能最优。
选择 3:全局变量代替依赖注入
world_size / rank 是文件级全局变量——不是每个 module 持有自己的 ProcessGroup 引用。这违反了”显式依赖” 的传统软件工程纪律,但避免了在 200 处方法签名里都加 group 参数。
选择 4:MQA-style 让 KV 不切
V4 的 num_key_value_heads=1 让 KV 在 TP 多卡间不切——每 rank 持有完整 KV cache。这绕过了”KV 切分时的复杂索引数学”——代码大幅简化。代价是显存冗余 N 倍,但 V4 的 KV cache 已经被 Compressor 压到很小(每序列 ~8 GB),8 卡冗余只占总显存的小份额。
这些选择让 V4 的并行代码”看一眼就懂”——这种可读性对开源项目极重要。读者不需要花一周读代码才能理解 V4 怎么工作,几小时就够了。
15.11·补·补 部署 V4 时的并行策略选择
把 V4 部署到生产时,TP / EP / DP 的配比是关键工程决策。给一个决策树:
决策点 1:序列长度
- < 32K:可以用更激进的 TP(16+)让单序列延迟最低
- 32K - 256K:8 卡 TP 是甜区——平衡延迟与吞吐
- 256K - 1M:必须 8 卡 TP + 充足 KV 显存——可能需要把 KV cache 切到主机内存
决策点 2:并发量(batch size)
- 低并发(batch < 4):TP 比 EP 重要——让单个 sequence 跑得最快
- 中并发(batch 4-16):TP=8 + DP(多副本)——每个副本服务一部分用户
- 高并发(batch > 16):TP=8 + EP=16/32(跨节点)—— 单一大模型实例服务大量并发
决策点 3:硬件拓扑
- 单节点 8 卡:TP=8、不需要 EP(expert 全部在 NVLink 内)
- 双节点 16 卡:TP=8 + EP=2(每节点 192 expert,跨节点 EP=2)
- 多节点 32+ 卡:TP=8 + EP 跨节点 + DP 多副本
决策点 4:延迟 vs 吞吐 trade-off
- 延迟优先:增加 TP(更多并行算 single token)
- 吞吐优先:增加 batch + 保持 TP 适中
实际部署中,最常见的配置是 单节点 8 卡 TP=8——这是 V4 在 H100 上的最佳”性价比”配置。多节点部署仅在”高并发 + 延迟可接受”的场景下才有 ROI。
15.11·延展 V4 的并行抽象与 vLLM PagedAttention 的协同
V4 的并行类(ColumnParallel / RowParallel / ParallelEmbedding / ParallelHead)与 vLLM 的 PagedAttention 在工程层有微妙的协同关系。
协同点 1:KV cache 不切
V4 的 num_key_value_heads=1 让 KV 在所有 TP rank 上是完整副本。这与 vLLM PagedAttention 的”每 rank 持有全部 block 的本地副本”语义一致——每 rank 都能直接访问完整 KV,不需要跨 rank 通信。
协同点 2:Q 切分到 head
V4 的 wq_b 是 ColumnParallel——Q head 被切分到不同 rank。vLLM 的 PagedAttention 接收”本 rank 的 Q heads”,与”全副本的 KV” 做 attention 计算,输出 ColumnParallel 形式的 attention 输出。
协同点 3:O 投影做 reduce
V4 的 wo_b 是 RowParallel——attention 输出经过它后做 all_reduce 得到完整 hidden。vLLM 的 attention backend 在这一步可能直接调用 RowParallelLinear 的 forward——无缝衔接。
协同点 4:MoE 的 expert parallel
V4 的 384 expert 切到不同 rank。vLLM 适配时需要让 ModelRunner 知道每 rank 持有哪些 expert——这与 V4 的 experts_start_idx / experts_end_idx 一致。
这些协同点意味着:vLLM 的 V4 适配 PR 的”模型并行”部分非常薄——主要是把 V4 的并行类注册到 vLLM 的 ParallelState 系统。绝大多数代码可以从 V4 的 inference/model.py 直接复制——这是 V4 设计的工程红利。
具体的协同细节会在《vLLM 推理内核深度解析》第 14 章 “Tensor 并行” 后续更新中展开。
15.11·拓展 V4 并行实现的”边界条件”清单
V4 并行实现里有几个重要的边界条件——必须满足才能正确工作。把它们整理成一份”启动前检查清单”:
边界 1:out_features % world_size == 0
ColumnParallelLinear 要求输出维度能被 world_size 整除。V4 的 out_features 都是 128 倍数(如 128 head × 512 head_dim = 65536),可以整除 8 / 16 / 32 等常见 world_size。但如果你 fine-tune 时改了 head 数(比如减到 96),可能违反这个约束——必须挑能整除的 world_size。
边界 2:in_features % world_size == 0
RowParallelLinear 同理。V4 的 in_features 也都是 128 倍数,对常见 world_size 都满足。
边界 3:vocab_size % world_size == 0
ParallelEmbedding / ParallelHead 要求 vocab 能被 world_size 整除。V4 vocab=129280,可以整除 8(=16160 per rank)但不能整除 16(129280 / 16 = 8080)——意味着 16 卡 TP 部署需要扩展 vocab 或换其他切分方式。
边界 4:n_routed_experts % world_size == 0
MoE 要求 expert 数能被 world_size 整除。V4 的 384 expert 整除 8(48 per rank)、16(24 per rank)、32(12 per rank)——常见配置都满足。
边界 5:world_size == 1 时不通信
V4 的并行类有 if world_size > 1: 守卫——单卡模式下完全跳过通信。这让你可以单卡跑 V4 (small variant) 做调试 + 多卡跑 V4 Pro 做生产,同一份代码兼容两种模式。
边界 6:dist 必须先 init
world_size = dist.get_world_size() if dist.is_initialized() else 1——V4 的 Transformer.init 检查 dist 是否初始化。如果你忘了 dist.init_process_group(...),V4 会默默退化到 single-rank 模式,可能导致部署不正确。
把这 6 个边界做成 deployment 的 pre-check,可以在 V4 部署的第一天避免大部分配置错误。
15.12 延伸阅读
- Megatron-LM 论文(arXiv:1909.08053):TP 的工业化先驱
- DeepSpeed Ulysses(arXiv:2309.14509):sequence parallelism
- Tensor Parallelism in Distributed Training(NVIDIA 文档):TP 的硬件视角
- 本书第 16 章:DeepEP——MoE 的专用 all-to-all 通信库
- 本书《vLLM 推理内核深度解析》第 14 章:vLLM 中的 TP 实现
15.12·补 V4 与 ZeRO 优化器分片的协同
V4 的并行类是”Tensor Parallel + Expert Parallel” 的组合。生产训练通常还会叠加 ZeRO 优化器分片——把优化器状态切到不同 rank。把这套叠加的工程细节梳理一下。
ZeRO-1:仅切分优化器状态(如 Muon 的 momentum buffer)。每 rank 持有所有 weight 但只持有部分优化器状态。
ZeRO-2:切分优化器状态 + 梯度。反向传播时梯度先在本 rank 计算,再 reduce_scatter 到对应 rank。
ZeRO-3:切分优化器状态 + 梯度 + weight。每 rank 只持有部分 weight—— forward 时需要 all_gather 把 weight 拼起来。
V4 与 ZeRO 的叠加规则:
- TP 切分维度(如 head 维度)与 ZeRO 切分维度(如 optimizer state 的 layer 维度)必须正交——避免冲突
- TP rank 之间共享 ZeRO rank ID—— 比如 8 卡 TP + 4 节点,每节点 8 卡是同一个 ZeRO rank
- Expert Parallel 的 expert 不参与 ZeRO 切分——expert weight 已经分布到不同 rank,再切就乱了
V4 的训练大概率用 ZeRO-1(最保守)或 ZeRO-2——ZeRO-3 在 1.6T 模型上 weight all_gather 的通信开销过大,不划算。
这部分配置大多数公司不会自己实现——直接用 DeepSpeed 或 FSDP 的现成支持。但理解原理让你能 debug 配置错误。
15.12·补·补 V4 与 Sequence Parallel / Pipeline Parallel 的关系
V4 的并行类只覆盖 Tensor Parallel + Expert Parallel。但生产训练通常还有 Sequence Parallel 和 Pipeline Parallel——把它们与 V4 的关系说清楚。
Sequence Parallel (SP):
把 sequence 维度切到不同 rank。每 rank 处理 sequence 的一段——主要在 activation memory 上节省(不需要每 rank 存完整 sequence)。
V4 的并行类没有原生 SP 支持——inference/model.py 不切 sequence。但 V4 训练时大概率有 SP(不公开训练栈)——只是 inference 用不到。
为什么 inference 不用 SP:
inference 的瓶颈是 KV cache + GEMM,不是 activation。SP 节省 activation 的好处在 inference 上没用——而 SP 的通信开销反而拖慢推理。所以 V4 inference 路径不带 SP。
Pipeline Parallel (PP):
把不同 layer 切到不同 rank——前半模型在 rank 0、后半在 rank 1。每 token 顺序经过 rank。
V4 的 inference/model.py 没有 PP——所有 layer 都在同一 rank(每 rank 持有全部 layer 的 TP 部分)。这是 inference 路径的设计选择。
为什么 inference 不用 PP:
PP 引入”流水线泡沫”(pipeline bubble)——某些 rank 在等其他 rank。短上下文 + 大 batch 下泡沫小,PP 可行;但 V4 的目标是长上下文 + 适中 batch,泡沫会成为延迟瓶颈。
训练 vs 推理的并行差异:
| 维度 | 训练 | 推理(V4 inference) |
|---|---|---|
| TP | ✅ | ✅ |
| EP | ✅ | ✅ |
| SP | ✅ | ❌(不需要) |
| PP | ✅(大模型必须) | ❌(延迟敏感) |
| ZeRO | ✅ | ❌(推理无优化器状态) |
| FSDP | ✅(部分) | ❌ |
理解这种差异让你正确选并行配置——训练时该用 PP 就用,推理时不要照搬训练配置。
15.13 本章小结
- V4 用 4 个并行类:ColumnParallel / RowParallel / ParallelEmbedding / ParallelHead
- ColumnParallel 切 out 维度,无 reduce;RowParallel 切 in 维度,需 all_reduce
- ParallelEmbedding / ParallelHead 切 vocab 维度,分别用 mask+all_reduce / all_gather
- Expert Parallel 用稀疏 ModuleList + None 占位 + per-rank 循环的”代码简洁”实现
- Attention 的 ColumnParallel + RowParallel 经典组合让通信量降到 O(B × S × D)
- 与 Megatron-LM 同源,但 V4 推理代码极简(~50 行)——训练复杂度解耦到不公开训练栈
第 16 章:DeepEP——V4 给 MoE all-to-all 量身定制的通信库。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。