第18章 FSDP:ZeRO 风格的参数分片
“FSDP changes the question from ‘how do I parallelize this model’ to ‘how much GPU memory can I trade for communication’.”
—— PyTorch FSDP design doc
本章要点
- DDP 在 70B 模型时崩溃:每张卡需要 280GB(params) + 280GB(grad) + 560GB(Adam state)= 1120GB,单 H100 才 80GB
- FSDP 灵感来自 DeepSpeed ZeRO:把 params / grads / optimizer state 都切到 N 张卡,每张卡只持有 1/N
- 核心机制:forward 时 AllGather 把分片重建成完整 param、用完后立即 reshard;backward 时同理 + 最后 ReduceScatter 同步梯度
- 5 种 ShardingStrategy:
NO_SHARD(=DDP)、SHARD_GRAD_OP(ZeRO-2)、FULL_SHARD(ZeRO-3)、HYBRID_SHARD、_HYBRID_SHARD_ZERO2 - FSDP-1 vs FSDP-2:v2.4+ 推出新接口
torch.distributed.fsdp.fully_shard,模块化 + 与 torch.compile 兼容性好得多 - prefetch 是性能命脉:在 layer N forward 时 prefetch layer N+1 的 unshard,隐藏通信延迟
18.1 为什么 DDP 不够用
第 17 章 DDP 假设”每张卡放得下完整模型”。70B 模型:
- params:70B × 4B(fp32) = 280 GB
- grads:280 GB
- Adam optimizer state (m + v):560 GB
- 加上 activations 几十 GB
- 合计 1120+ GB
H100 80GB / GB200 192GB —— 单卡都装不下。这是 DDP 在大模型时代的天花板。
FSDP(Fully Sharded Data Parallel)的解法:把这堆显存按 rank 数 N 均匀切分到 N 张卡。每张卡只持有 1/N 份数据,需要时通过通信临时凑齐完整张量。
8 卡 FSDP:
每张卡 params: 35 GB
每张卡 grads: 35 GB
每张卡 optimizer: 70 GB
合计 ~140 GB / 卡 ← 80GB 仍然装不下
16 卡: ~70 GB / 卡 ← 接近但还要 activation
32 卡: ~35 GB / 卡 ← 舒服了
70B 模型至少 32 张卡才能跑稳。FSDP 的”分片粒度”是这套机制的核心。
18.2 5 种 ShardingStrategy
graph TB
DDP["DDP / NO_SHARD<br/>每卡完整 params + grads + optimizer"]
Z2["ZeRO-2 / SHARD_GRAD_OP<br/>params 完整,但 grads + optimizer 切片"]
Z3["ZeRO-3 / FULL_SHARD<br/>params + grads + optimizer 全切片<br/>显存最省"]
HZ["HYBRID_SHARD<br/>同节点内 ZeRO-3,跨节点 DDP<br/>跨节点带宽不足时的折中"]
DDP -.显存大.-> Z2
Z2 -.通信增多.-> Z3
Z3 -.通信占比可能太高.-> HZ
style Z3 fill:#dcfce7,stroke:#22c55e,stroke-width:2px
style HZ fill:#fef3c7,stroke:#f59e0b
通信代价:
| Strategy | params 通信 | grads 通信 | optimizer 通信 | 显存 / 卡 |
|---|---|---|---|---|
| NO_SHARD (DDP) | 0 | AllReduce | 0 | 完整 |
| SHARD_GRAD_OP | 0 | ReduceScatter | 0 | 1/N grads + opt |
| FULL_SHARD | 2× AllGather (fw + bw) | ReduceScatter | 0(params 已分片) | 全部 1/N |
| HYBRID_SHARD | 节点内 AllGather | 节点内 RS + 跨节点 AR | - | 1/N (节点内) |
FULL_SHARD 显存最省、通信最多。生产代码里 70B+ 通常用 FULL_SHARD 或 HYBRID_SHARD。
18.3 unshard / reshard 的精确时机
FSDP 把参数按 N 切片后,每张卡平时只持有 1/N。但算子需要完整 param 才能算(Linear 需要完整 weight 矩阵做 GEMM)。所以 forward 前要 AllGather 把分片凑成完整 param,算完立即 reshard(释放完整副本,回到 1/N):
sequenceDiagram
autonumber
participant Layer as Layer N forward
participant Comm as AllGather
participant GPU as GPU 显存
Layer->>Comm: 发起 AllGather (异步)
Note over GPU: 此时 GPU 上 layer N 的 param 是分片
Comm->>GPU: AllGather 完成 → 完整 param 在显存
Layer->>GPU: 跑 layer N 的 GEMM
GPU->>GPU: reshard: 立即释放完整 param
Note over GPU: 回到 1/N 状态, 等下次 forward
backward 同理:算每层 grad 前再 AllGather 一次 param,算完 reshard,最后用 ReduceScatter 把这层的 grad 切片同步:
Layer N backward:
1. AllGather param (因为反向也要用 weight 算 grad_input)
2. 计算 grad_input + grad_param
3. ReduceScatter grad_param: 8 卡各自只保留自己的 1/N 切片
4. 释放完整 param
ReduceScatter 是 AllReduce 的”半成品” —— 每个 rank 只拿到 reduce 结果的 1/N 切片,不再 AllGather。这刚好对应 FSDP 的需求:每个 rank 只更新自己持有的 1/N 参数。
18.4 prefetch:隐藏通信延迟
朴素实现里:
fw layer 1 = AllGather + GEMM + reshard (串行)
fw layer 2 = AllGather + GEMM + reshard
...
每层都等 AllGather 完成才开始算 —— 通信延迟全暴露。prefetch 优化:
fw layer 1: AllGather_1 → GEMM_1 (期间 prefetch AllGather_2)
fw layer 2: GEMM_2 (用已经 prefetch 完的 param) → prefetch AllGather_3
...
每层 GEMM 进行时,下一层的 AllGather 在另一个 stream 上跑。理想情况下通信完全 overlap 到计算里,FSDP 性能接近 DDP(如果带宽足够)。
实现上 FSDP 用 dedicated CUDA stream(_unshard_stream、_pre_unshard_stream 在 _runtime_utils.py:263-269)跑 collective,与训练主 stream 并发。
prefetch 深度由 forward_prefetch=True/False 与 backward 的 backward_prefetch=BACKWARD_PRE/POST 控制。BACKWARD_PRE 在前一层 backward 开始前 prefetch;BACKWARD_POST 在前一层 backward 完成后 prefetch(更安全但 overlap 少)。
18.5 FSDP-1 vs FSDP-2
PyTorch v2.4+ 推出了新一代 FSDP:torch.distributed.fsdp.fully_shard(也叫 FSDP-2,源码在 torch/distributed/fsdp/_fully_shard/)。它解决了 FSDP-1 的几个痛点:
1. 模块化:FSDP-1 把整个模型 wrap 成一个 FullyShardedDataParallel(model),所有逻辑放一个大类(2167 行)。FSDP-2 用 fully_shard(submodule) API 给每个 submodule 单独 wrap,更精细控制:
# FSDP-2
from torch.distributed.fsdp import fully_shard
for layer in model.layers:
fully_shard(layer) # 每个 transformer 层单独 shard
fully_shard(model) # root module
2. 与 torch.compile 兼容:FSDP-1 有不少 nn.Module.__setattr__ 黑魔法,让 Dynamo trace 时 graph break 严重。FSDP-2 重新设计了参数管理,能完整被 Inductor 编译。
3. DTensor 后端:FSDP-2 内部用 DTensor(distributed tensor)抽象,让张量分片成为 first-class concept。
生产代码里 v2.4+ 推荐用 FSDP-2。FSDP-1 仍然支持但被标记为”老 API”,新功能(如 fully_shard 与 mesh 接口)只在 FSDP-2 上加。
18.6 mixed precision:FSDP 的另一大优化
FSDP 提供专门的 MixedPrecision 配置:
from torch.distributed.fsdp import MixedPrecision
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16, # AllGather 时用 bf16 通信
reduce_dtype=torch.float32, # ReduceScatter 时用 fp32 (避免梯度精度损失)
buffer_dtype=torch.bfloat16,
)
model = FSDP(model, mixed_precision=mp_policy)
为什么这个比朴素 model.bfloat16() 好?因为 FSDP 控制每种张量的 dtype 独立:
- params 在 bf16(省一半 AllGather 带宽)
- gradients 在 fp32 reduce(保证数值稳定)
- master weights 仍是 fp32(让 Adam 更新精确)
这套”分别控制”是大模型训练 fp32+bf16 混合精度的标配。第 20 章量化与混合精度会展开。
18.6.5 DeviceMesh:拓扑的 first-class 表达
FSDP-2 之前,分布式训练靠 ProcessGroup 表达”哪些 rank 参与同一通信”。混合并行(DP + TP + PP)时要手工管理多个 group,容易出错。
torch.distributed.device_mesh(1553 行)引入 DeviceMesh 抽象:把 N 个 rank 排成 K 维网格,每维有自己的 ProcessGroup。
from torch.distributed.device_mesh import init_device_mesh
# 32 卡 = 4 节点 × 8 卡
# 节点内 8 卡 ZeRO-3, 跨节点 4 个 replica DDP
mesh_2d = init_device_mesh(
device_type="cuda",
mesh_shape=(4, 8),
mesh_dim_names=("replica", "shard"),
)
# 取出某一维的 ProcessGroup
shard_pg = mesh_2d["shard"].get_group() # 节点内 8 卡 group
replica_pg = mesh_2d["replica"].get_group() # 跨节点 4 group
每个 rank 在 mesh 里有 K 个坐标(如 (replica=2, shard=5))。在某一维上的 collective(如 shard.all_gather)只在那一维的 group 内做,不涉及其他维。
DeviceMesh 让 HSDP / 3D parallel(DP + TP + PP)的实现从”手动管 N 个 group”变成”声明一个 mesh 然后取维度”。这是 FSDP-2 / DTensor 的基础。
18.6.6 DTensor:分片张量的 first-class 类型
FSDP-2 内部不直接持有 plain Tensor,而是 DTensor(distributed tensor)。一个 DTensor 由三部分组成:
from torch.distributed._tensor import DTensor, Shard, Replicate
dtensor = DTensor.from_local(
local_tensor, # 本 rank 持有的切片
device_mesh, # 它在哪个 mesh 上
placements=[Shard(0)] # 在 mesh 维度上怎么分布: Shard / Replicate / Partial
)
三种 Placement:
Shard(dim):按某 dim 切到 mesh 上,每 rank 持有 1/NReplicate():每 rank 都有完整副本Partial:每 rank 持有”部分和”,下次访问前需 reduce
DTensor 算子(dtensor + dtensor)会自动选最优通信策略:如果两边都是 Shard(0),相加无需通信;如果一个 Shard(0) 一个 Replicate,自动 AllGather + 加。这套自动 dispatch 让用户不用手写 collective。
FSDP-2 把每个 nn.Parameter 包成 DTensor(local_shard, mesh, [Shard(0)]),forward 时调用 to_replicate() 触发 AllGather 凑成完整 param、用完转回 Shard。FSDP-2 的 unshard / reshard 逻辑就是 DTensor 的 placement 转换,比 FSDP-1 的手写更通用、与 torch.compile 兼容好。
18.6.7 HSDP 的具体拓扑
HYBRID_SHARD 的关键是用 2D mesh:节点内 shard、跨节点 replicate:
mesh = init_device_mesh("cuda", (2, 8), mesh_dim_names=("inter", "intra"))
# 配 FSDP-2: 节点内 (intra) shard, 跨节点 (inter) replicate
fully_shard(model, mesh=mesh)
通信模式:
- forward 的 AllGather 只在 intra 维(节点内),走 NVLink(带宽 ~600 GB/s)
- backward 的 ReduceScatter 也在 intra 维
- backward 完成后额外做一次 AllReduce 在 inter 维(跨节点)同步梯度
这条策略让”高带宽节点内通信”承担分片代价、“低带宽跨节点”只做一次梯度同步。32 卡 H100 集群上 HSDP 比 FULL_SHARD 通常快 20-40%,前提是节点内带宽远大于跨节点。
18.6.8 _unshard 的 stream 编排
打开 _runtime_utils.py:277 的 _unshard:
def _unshard(state, handle, unshard_stream, pre_unshard_stream):
with state._device_handle.stream(pre_unshard_stream):
# 1. 在 pre-unshard stream 上准备 buffer (alloc 完整 param 大小的 tensor)
pad_for_unshard(handle)
with state._device_handle.stream(unshard_stream):
# 2. 在 unshard stream 上发起 AllGather
# event 让 unshard stream 等 pre-unshard 完成
unshard_stream.wait_stream(pre_unshard_stream)
all_gather_into_tensor(...)
两个 stream 的分工:
pre_unshard_stream:跑 alloc / pad 等”准备工作”unshard_stream:跑 AllGather 本身
为什么要两个 stream?因为如果 alloc 与 AllGather 都在主 stream 跑,会阻塞下一层 forward。两个独立 stream 让它们与主 stream 完全 overlap。
wait_stream 是 CUDA 的 stream-to-stream 同步原语:让 unshard stream 等 pre-unshard stream 上之前发的 op 完成。这种”多 stream 协同”是 FSDP 性能的关键 —— 配合 prefetch(§18.4),让通信完全隐藏到计算里。
18.6.9 activation_checkpoint × FSDP 的协作
第 7 章 §7.5.3 讲过 activation_checkpoint:前向不保存中间激活,反向时重新 forward 一遍取回。这套机制配 FSDP 用时多一层复杂度。
考虑 FSDP+checkpoint 的反向流程:
- 正常 forward(FSDP layer N):
- AllGather 凑完整 param → forward 计算(不存 activation)→ reshard 释放完整 param
- 反向走到 layer N:
- 需要重新 forward 取 activation → 但 param 已经被 reshard!
- 必须 再次 AllGather 凑完整 param → 重 forward → 拿到 activation → backward
- 完成后又 reshard
sequenceDiagram
autonumber
participant FW as Forward layer N
participant Mem as 显存
participant BW as Backward layer N
FW->>Mem: AllGather param (1)
FW->>FW: 跑 forward, 不存 activation
FW->>Mem: reshard param (释放完整副本)
Note over BW: 反向到 layer N
BW->>Mem: AllGather param (2) ← 重 forward 还要再 unshard 一次!
BW->>BW: 重 forward 拿到 activation
BW->>BW: backward 算梯度
BW->>Mem: reshard param
BW->>Mem: ReduceScatter grad
每个 checkpointed layer 在反向触发 2 次 AllGather(一次为 backward 本身、一次为重 forward)。这是 FSDP+checkpoint 比纯 FSDP 通信量上升的根本。
torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:112 的 CheckpointWrapper 是 PyTorch 给 FSDP 的官方 checkpoint 接口。它的关键设计:
apply_activation_checkpointing(model, ...)(:239) 递归给每个匹配check_fn的 module 套上CheckpointWrapperuse_reentrant=False是新版默认(v2.0+):用saved_tensors_hooks(第 7 章 §7.5.3)实现,与 FSDP / torch.compile 兼容性好- 老的
use_reentrant=True用torch.autograd.Function(第 7 章 §7.8),与 FSDP 反向 hook 互动有 corner case,新代码避开
实战决策:开 FSDP 后是否再加 activation_checkpoint?看显存。70B 训练即使用 FULL_SHARD,activation 显存(forward 中间激活)仍然几十 GB。activation_checkpoint 能再省 70%+ activation 显存,代价是反向多 2x AllGather。如果跨节点带宽足够,这笔账划算 —— 大模型训练几乎都开。
apply_activation_checkpointing 的 check_fn 通常设成”是 transformer block 就 wrap”:
from functools import partial
apply_activation_checkpointing(
model,
check_fn=lambda m: isinstance(m, TransformerBlock),
)
这种”按 block 粒度 checkpoint”是 LLM 训练的标准做法。block 内部不 checkpoint —— 否则 attention / mlp 各自重 forward 通信量爆炸。
18.6.10 FlatParameter (FSDP-1):把多个 param 平铺成一个大张量
FSDP-1 内部不直接管理用户的 nn.Parameter,而是把同一 wrap unit 内的多个 param 平铺合并成一个 FlatParameter:
# 用户 module 有 3 个 param
linear.weight shape [768, 768] fp32
linear.bias shape [768] fp32
norm.weight shape [768] fp32
# FSDP-1 合并 (flatten + concat)
flat_param = FlatParameter(torch.cat([
linear.weight.flatten(), # 589824 元素
linear.bias.flatten(), # 768
norm.weight.flatten(), # 768
])) # 共 591360 元素
_flat_param.py:202 的 FlatParameter 是 nn.Parameter 子类,元数据里记着每个原始 param 的 offset / shape / dtype。访问 linear.weight 时通过 view 从 FlatParameter 取出对应区段。
为什么要 flatten?
- AllGather 能一次拉所有 param(vs 分别拉每个 param 的 N 次 collective)
- shard 时只切一次(按 1/N 切 FlatParameter,不是切几十个 param)
- memory 连续让 GPU memory bandwidth 利用率高
代价是 view 重建复杂、与 torch.compile 兼容性差。FSDP-2 抛弃 FlatParameter,每个 param 用 DTensor 直接 shard —— 与 compile 兼容性大幅提升,但 collective 数量变多(靠 group AllGather 等优化补偿)。
理解 FlatParameter 让你看到 FSDP-1 的 ckpt 文件里那些”奇怪 key”(如 _fsdp_wrapped_module._flat_param_0)时不困惑 —— 那是 flat 后的合并张量。
18.6.11 auto_wrap_policy:决定 sharding 粒度
fully_shard(model, auto_wrap_policy=...) 让 FSDP 自动决定哪些子 module 各自成为一个 wrap unit。粒度选择:
- 太粗(整个 model 一个 unit):unshard 一次拉全部参数,显存峰值与不分片相同 —— FSDP 退化成 DDP
- 太细(每个 Linear 一个 unit):每层 unshard 一次 collective,通信开销爆炸
- 合适粒度:每个 transformer block 一个 unit,平衡 collective 数量与显存峰值
wrap.py:178 的 ModuleWrapPolicy 是常用 policy:按 module 类型选择 wrap unit。
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
policy = ModuleWrapPolicy({TransformerBlock}) # 每个 TransformerBlock 一个 unit
fully_shard(model, auto_wrap_policy=policy)
其他 policy:
size_based_auto_wrap_policy:按参数量阈值(如 >100M 一个 unit)transformer_auto_wrap_policy:transformer 友好的版本- 自定义 callable policy
LLM 训练几乎都是 ModuleWrapPolicy({TransformerBlock}) —— 与 activation_checkpoint 的 wrap 粒度对齐,让两者协作最优。
18.6.12 lazy init 与 meta device:训练前不分配显存
70B 模型直接 model = LlamaModel(config) 会立即在 GPU 上分配 280GB params —— 单卡装不下、初始化崩。
torch.device("meta") 是个”假 device”,张量只有 shape / dtype 元信息、没有实际数据。FSDP-2 配 meta device 流程:
with torch.device("meta"):
model = LlamaModel(config) # 不分配显存
fully_shard(model, mesh=mesh, ...) # FSDP 用 meta module 构造 sharded param
model.to_empty(device='cuda') # 把 meta param 替换成真实显存 (1/N 大小)
init_model_weights(model) # 用户实现的 init 函数 (按 rank 自己 init 那 1/N)
整个流程全程没有”完整 280GB”在任何 rank 显存里出现。FSDP-2 的 lazy init 让 70B+ 模型能在 80GB 单卡上启动训练 —— 之前是不可能的。
to_empty() 是个特殊 API:把 meta param 替换成同 shape 的真实张量但不初始化。后续用户调 init_model_weights 在每个 rank 上各自 init 自己持有的那 1/N 数据。这种”分片 init”避免了”先 init 完整 model 再切”的中间显存峰值。
18.6.13 CPU Offload:参数 / optimizer state 卸到 CPU
显存极紧张时(如单卡训练超大模型),FSDP 提供 cpu_offload:
from torch.distributed.fsdp import CPUOffload
fully_shard(model, cpu_offload=CPUOffload(offload_params=True))
机制:
- params 平时存在 CPU RAM(host memory)
- forward 时 CPU → GPU 拷贝 + AllGather + GPU 跑 forward + reshard 回 GPU + GPU → CPU 拷贝释放 GPU 副本
- backward 同理
代价是每 forward 多两次 H2D / D2H 拷贝(PCIe 受限,几十毫秒)。整个训练吞吐通常降 50-70%,但能让”放不下的模型放下”。
更激进的 optim_state_offload(FSDP-2):optimizer state(exp_avg / exp_avg_sq)也卸到 CPU。step 时 CPU → GPU 取 + 计算 + 写回。再省一份显存(约参数量 × 2 字节,70B 大约 560GB)。
实战:能用多卡分摊就用多卡(HSDP),cpu_offload 是”硬件不够时的最后兜底”。生产 70B 训练几乎从不用 cpu_offload —— 性能损失太大、不如多租几张卡。
18.6.14 state_dict_type:完整 vs 分片
FSDP 提供两种 state_dict 视图:
StateDictType.FULL_STATE_DICT(api.py:293):rank 0 收集所有 rank 的分片,组装成完整 state_dict(与未 shard 的 model 视图等价)StateDictType.SHARDED_STATE_DICT(api.py:340):每 rank 输出自己持有的 1/N 切片
from torch.distributed.fsdp import StateDictType
# 完整 state_dict (rank 0 持有完整, 其他 rank 空)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
sd = model.state_dict()
if rank == 0:
torch.save(sd, "ckpt.pt")
完整 state_dict 简单但 rank 0 要装下完整模型(70B = 280GB),实战不可行。
# 分片 state_dict (每 rank 写自己那 1/N)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
sd = model.state_dict()
dcp.save(sd, checkpoint_id="ckpt_dir")
分片 state_dict 与第 19 章 §19.6.7 的 DCP 配合,每 rank 并行写自己那份 —— 这是 70B+ 训练 ckpt 的标配。SHARDED_STATE_DICT 写出的不是 ddp model 的视图,而是”DTensor 字典”,每个 entry 是 DTensor(local_shard, mesh, placement)。DCP 知道怎么序列化 DTensor + 加载时按当前 mesh 重建。
18.6.15 summon_full_params:临时凑齐完整参数
某些操作(如 model surgery、debug print 完整 weight)需要完整参数。FSDP.summon_full_params context manager 让 FSDP 临时 unshard 整个 group:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(model, recurse=True):
# 此时所有 param 都是完整的 (每 rank 都有完整副本)
print(model.linear.weight.shape) # [hidden, hidden] 完整尺寸
实现机制:进入 with 块时触发对所有 wrap unit 的 AllGather → 完整副本临时驻留显存 → 退出时立即 reshard。显存峰值临时翻 N 倍,所以 70B 模型 8 卡 unshard 后单卡瞬间需要 280GB —— 不可行。生产 summon_full_params 用于小 unit / 测试。
writeback=True 让退出时把完整 param 写回到分片(用于 model surgery 后保存改动)。offload_to_cpu=True 让完整副本驻留 CPU 而非 GPU(省 GPU 但慢)。
18.6.16 FSDP × clip_grad_norm
第 17 章 §17.8.31 提过 DDP 下 clip_grad_norm_ 直接用就行(grad 已经全局平均)。FSDP 不行:每 rank 持有 1/N 的 grad,本地 norm 不是全局 norm。
FSDP-2 提供专用接口:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
FSDP.clip_grad_norm_(model, max_norm=1.0)
内部实现:
- 每 rank 算 local grad norm²(即 sum of squares)
- AllReduce sum 得全局 norm²
- 算 sqrt 得全局 norm
- 每 rank 用全局 norm 缩放自己的 grad 切片
只多一次 AllReduce 标量(极小开销),数学上与 single-rank clip 等价。FSDP-1 / FSDP-2 都提供这个 API,名字略有差异(FSDP-2 用顶层 torch.nn.utils.clip_grad_norm_ 自动识别 DTensor 走分布式路径)。
18.6.17 FSDP × torch.compile:v2.4+ 的兼容路径
FSDP-1 与 torch.compile 兼容性差(FlatParameter view + 各种 hook 让 Dynamo 频繁 graph break)。FSDP-2 重新设计后与 compile 几乎完美兼容:
fully_shard(model, mesh=mesh)
compiled_model = torch.compile(model)
实现关键:
- DTensor 算子在 dispatcher 层有完整注册(第 5 章),Dynamo 能 trace
- AllGather / ReduceScatter 算子用 functional collectives(第 16 章 §16.7.9),与 functionalize 兼容
- FSDP-2 的 hook 用
register_post_accumulate_grad_hook,与 compiled autograd 兼容
实测 70B FSDP-2 + torch.compile 比 FSDP-1 + compile 快 1.5-2x —— FSDP-1 的兼容代价是巨大的。所以新代码强烈用 FSDP-2,老代码尽快迁移。
18.6.18 BACKWARD_PRE vs BACKWARD_POST prefetch
§18.4 提过 prefetch 有两种模式:
BackwardPrefetch.BACKWARD_PRE(推荐):layer N+1 backward 开始前 prefetch layer N。最大化 overlap、安全性高BackwardPrefetch.BACKWARD_POST:layer N+1 backward 结束后才 prefetch layer N。overlap 少但显存峰值更低
BACKWARD_PRE 是默认,几乎所有场景最优。BACKWARD_POST 仅在显存极紧张时考虑(少 overlap 但少一份 unsharded param 同时驻留)。
forward_prefetch=True 是另一档:forward 时也做 prefetch(默认关)。开启后 forward 速度提升 5-15%,代价是显存峰值轻微上升。生产 LLM 训练通常开启。
18.6.19 ZeRO-2 (SHARD_GRAD_OP) 的具体节省机制
§18.2 表里讲 ZeRO-2 是”params 完整、grads + optimizer 切片”。具体怎么实现?
- forward:因为 params 完整(每 rank 都有),forward 不需要 AllGather params —— 直接跑 → 显存与 DDP 相同
- backward:每 rank 算自己 batch 的本地 grad → ReduceScatter(不是 AllReduce!)让每 rank 拿到全局 grad 的 1/N 切片
- optimizer step:每 rank 只更新自己持有的 1/N 参数(用 1/N grad + 1/N optimizer state)
- step 后 AllGather 更新后的参数:让所有 rank 重新拿到完整 params,准备下一次 forward
这套流程让 grads 与 optimizer state 各只占 1/N(与 ZeRO-3 同),但 params 完整(vs ZeRO-3 的 1/N)。ZeRO-2 显存比 DDP 省 ~67%(grads + optimizer state 各 1/N,2 / 3 显存),但保留完整 params 的灵活性。
适用场景:模型刚好能装下 params 但 optimizer state(Adam 是 params × 2)装不下。比如 13B 模型 fp32 params 52GB(8 卡能装下)但 Adam state 104GB(装不下)—— ZeRO-2 完美。
18.6.20 NO_SHARD:FSDP 退化成 DDP
ShardingStrategy.NO_SHARD 让 FSDP 不分片任何东西,等价于 DDP:
fully_shard(model, sharding_strategy=ShardingStrategy.NO_SHARD)
# 等价于
DDP(model)
这个看似无意义的选项有真实工程价值:
- 统一接口:训练框架可以”用 FSDP 一套接口表达 DDP / ZeRO-2 / ZeRO-3 / HSDP”,配置切换简单
- fallback 路径:HSDP 的跨节点维度本质是 NO_SHARD(跨节点不切,节点内才切)
- debugging:怀疑 FSDP 引入的 bug 时切到 NO_SHARD 看是否消失
NO_SHARD 让 FSDP 的接口涵盖 DDP 全部功能,FSDP-2 在 v3.0+ 可能成为统一的多卡训练 API。
18.6.21 use_orig_params=True:兼容老 optimizer
FSDP-1 用 FlatParameter(§18.6.10)会让用户原本的 nn.Parameter 引用失效 —— optimizer 创建时拿到的是 user param,FSDP wrap 后这些 param 引用指向已经”被合并”的位置。
use_orig_params=True 让 FSDP-1 保留原始 param 引用:用户构造 optimizer 时传 model.parameters(),FSDP 内部建立 orig_param ↔ FlatParameter 切片的映射。optimizer step 时按 orig_param 更新,FSDP 自动写回 FlatParameter。
这条选项让”现有 optimizer 代码不改”就能上 FSDP。代价是内部多一层映射开销(小)。生产代码强烈建议开启。FSDP-2 默认行为就这样(每个 param 独立 DTensor,根本不需要这个选项)。
18.6.22 FSDP × 3D parallel (DP + TP + PP)
70B+ 训练经常需要 3D parallel:
- Data Parallel (DP):FSDP 在 DP 维度上分片
- Tensor Parallel (TP):把单层 weight 切到多卡(如 attention QKV 切 4 路)
- Pipeline Parallel (PP):把多层切到多卡,流水线执行
DeviceMesh(§18.6.5)让 3D 配置变简单:
mesh = init_device_mesh("cuda", (PP, DP, TP), mesh_dim_names=("pp", "dp", "tp"))
# FSDP 在 dp 维度上 shard
fully_shard(model, mesh=mesh["dp"])
# TP 在 tp 维度上 shard
parallelize_module(model, mesh["tp"], TPParallelStyle())
# PP 用 mesh["pp"] 配 PipelineStage
3D parallel 让 1024 卡训练 405B 模型(如 Llama-3 405B)成为可能。每个维度的 group 用 functional collectives(第 16 章 §16.7.9)通信,互不干扰。
实战参数选择:8 卡节点内 TP=8(NVLink 高带宽)+ 4 节点 DP(FSDP)+ 4 PP(跨更多节点)= 128 卡训练。具体取决于模型与硬件拓扑。
18.6.23 FSDP × LoRA:高效微调
LoRA(Low-Rank Adaptation)只训练注入的低秩矩阵 A、B,冻结原始 weight。FSDP 与 LoRA 的协作:
# 冻结 base model 参数
for p in base_model.parameters():
p.requires_grad = False
# 注入 LoRA adapter (有 grad)
inject_lora(base_model, rank=8)
# FSDP wrap (只 shard 有 grad 的部分?)
fully_shard(base_model, ...)
关键问题:FSDP 默认 shard 所有 params,包括 frozen 的 base weights。但 frozen weights 不需要 grad / optimizer state,shard 浪费空间吗?
实际不浪费 —— frozen weights 只是 inference 时需要完整 unshard、不进 backward / optimizer,显存账上反而占更少(没有 grad / optimizer state copy)。FSDP 在内部识别 requires_grad=False 的 param、跳过它们的 grad 处理。
LoRA + FSDP 是 fine-tuning 70B 模型的标配 —— base model 冻结后只训几十兆 LoRA 参数,FSDP 让 base model 分片到多卡能装下。
18.6.24 FSDP × gradient accumulation
显存不够调 batch size 时常用 gradient accumulation:跑 N 个 micro batch 累积梯度、再一次 step。
FSDP 下的微妙点:每 micro batch 默认会触发 ReduceScatter 同步梯度。N 次 micro batch = N 次 collective —— 浪费。
fully_shard(..., reshard_after_forward=False) + with model.no_sync(): 上下文:
with model.no_sync():
for i in range(accumulation_steps - 1):
loss = model(batch[i])
loss.backward() # 不触发 ReduceScatter, 只本地累积
# 最后一个 micro batch 正常触发 ReduceScatter
loss = model(batch[-1])
loss.backward()
optimizer.step()
no_sync 让 N-1 次 micro batch 跳过 collective,只在最后一次同步累积总梯度。通信量从 N 次降到 1 次,gradient accumulation 几乎免费。
DDP 也有同名 no_sync API,思想一致。这是 PyTorch 多卡训练的标准 gradient accumulation 模式。
18.6.25 FSDP × 量化训练 / 推理
第 20 章讲过量化。FSDP 与量化协作有几个工程点:
训练时 fp8:FSDP MixedPrecision(param_dtype=torch.float8_e4m3fn) 让 AllGather 通信用 fp8,比 bf16 再省一半带宽。但 fp8 数值范围窄,需要 per-tensor scale,FSDP-2 与 TransformerEngine(NVIDIA 库)配合才能正确处理 scale。
推理时 INT8:FSDP shard 的 model 直接量化遇到障碍 —— 量化要看完整张量做 calibration。一般做法:训练用 FSDP / 量化前把 model 收到单卡 / 量化后再加载(不分片或者用 PT2E 静态图)。
LLM 推理几乎不用 FSDP(vLLM / SGLang 用 TP + KV cache 管理替代)。FSDP 主要是训练时的工具,推理时通常切换到其他并行策略。
18.6.26 FSDP root module 的特殊性
FSDP 中的”root module”(最外层 fully_shard 的 module)有几个特殊职责:
- 管理整个 FSDP unit 树:root 持有所有子 unit 的引用,调度它们的 unshard / reshard 时机
- 管理 stream:root 创建
_unshard_stream/_pre_unshard_stream等,子 unit 共用 - 触发 root_pre_forward:第一次 forward 时 root 做整体初始化(首次 AllGather、CUDA stream sync 等)
- forward 完成后调 reshard:root 决定 root unit 自身的 reshard 时机(最后一个 reshard 在整 forward 完成后)
fully_shard(model)(不指定 mesh)默认让最外层 model 成为 root。如果用 fully_shard(layer) 给每层都 wrap,最外层依然是 root,每层是 sub-unit。root 的 lifecycle 决定整体训练流程,理解这点能解释为什么”在 root 之外的代码看到的 param 是分片的、root forward 内部看到完整的”。
18.6.27 FSDP × torch.export 与部署
FSDP 是训练工具,部署时通常不带 FSDP 直接 export。部署流程:
- 训练完成后用
FSDP.summon_full_params(model)或者state_dict_type=FULL_STATE_DICT收集完整权重 - 在单卡(或推理用的并行配置)重建 model(不带 FSDP wrap)
- 加载完整权重
torch.export(model, ...)+ AOTI(第 15 章 §15.6.7)
这条流程让训练时的 FSDP 与部署时的 AOTI 完全解耦。FSDP 不存在于部署 binary 里 —— 它是纯粹的训练时工具。
实战:HuggingFace Transformers 的 FSDP 训练流程结尾通常调 unwrap_model(model) 取出原始 nn.Module、再保存 state_dict。这是与 FSDP 解耦的标准做法。
18.6.28 FSDP-2 与 DTensor 的层次关系
FSDP-2 内部把每个 nn.Parameter 替换成 DTensor(§18.6.6)。这意味着:
- 用户视角:
model.linear.weight是 DTensor(local_shard 是 1/N) - forward 视角:DTensor 的算子自动触发必要的 placement 转换(如从 Shard(0) 转 Replicate 触发 AllGather)
- backward 视角:DTensor 的反向规则自动产生对应的反向 collective
这套机制让 FSDP-2 不需要写很多手动的 unshard / reshard 代码 —— DTensor 自己处理 placement 转换。FSDP-2 的核心代码(_fully_shard 目录)只有几千行,远比 FSDP-1 的 fully_sharded_data_parallel.py + _runtime_utils.py 共 4000+ 行少。
DTensor 思想的胜利:把分片表示为类型而非协议。Tensor 的 placement 是 Tensor 类型的一部分,编译器 / 运行时都能利用这个信息做优化。这是 PyTorch 分布式训练演进的下一代方向。
18.6.29 完整训练 step 时间分解
70B Llama 训练单 step 时间(H100,32 卡 4 节点 HSDP,bf16):
| 阶段 | 占比 | 时长 |
| forward (compute) | 30% | 1500ms |
| ├─ AllGather params | 15% | (overlap 在 forward 计算里)
| └─ forward 实际计算 | 15% |
| backward (compute) | 45% | 2250ms |
| ├─ AllGather params | 12% | (overlap)
| ├─ backward 实际计算 | 25% |
| └─ ReduceScatter grads | 8% | (overlap)
| 跨节点 AllReduce (HSDP only) | 10% | 500ms |
| optimizer.step | 5% | 250ms |
| 单 step 总时间 | 100% | ~5000ms|
对比 DDP 70B 单 step ~2000ms(§17.10.6),FSDP 慢 ~2.5x —— 代价是显存从需要 1120GB 降到每卡 35-40GB。这就是 FSDP 的工程哲学:用通信换显存。
如果显存够(中等模型),用 DDP 性能更好;如果装不下,FSDP 是唯一选择。混合 HSDP 是中间值 —— 跨节点不分(DDP 风格)、节点内分(FSDP 风格),平衡通信与显存。
18.6.29.5 内存账:FSDP vs DDP 的精确对比
7B 模型 fp32 训练详细内存账(每卡):
| 项 | DDP | FSDP-2 ZeRO-3 (8 卡) |
| params | 28 GB | 3.5 GB |
| grads | 28 GB | 3.5 GB |
| optimizer (Adam m+v) | 56 GB | 7 GB |
| activations | 8 GB | 8 GB (相同) |
| 临时 buffer | 4 GB | 28 GB (unsharded peak)|
| 单卡总 |124 GB | 50 GB |
注意 FSDP-2 的”临时 buffer”列:unsharded params 在 forward 那一瞬间需要完整 28GB(虽然只 1 ms 后就 reshard)。这就是为什么 FSDP-2 的实际显存峰值不是简单的 1/N。
70B 模型同样表(32 卡 ZeRO-3):
DDP: params 280 + grads 280 + opt 560 + activ 几十 + buffer = 1100+ GB → 装不下
FSDP-2: params 8.75 + grads 8.75 + opt 17.5 + activ 几十 + buffer 280 (unsharded) = ~330 GB
FSDP-2 在 70B 上让”单卡装下成为可能”。但 buffer 的 unsharded peak 仍是 280GB —— 这告诉我们 unit 粒度不能太大(每 unit unshard 后的临时显存是关键约束)。
18.6.29.7 实测带宽 vs 计算的临界点
FSDP 性能取决于”通信能不能 overlap 进计算”。临界点:
- 节点内 NVLink (~600 GB/s):fp32 5GB AllGather 约 8 ms,足够 overlap 进十几 layer 的 backward 计算
- 节点间 IB (~400 GB/s 双向):5GB AllGather 约 12 ms,仍能 overlap
- 节点间 100 Gbps ethernet (12 GB/s):5GB AllGather 约 400 ms,完全无法 overlap —— 这是为什么 ethernet 集群用 HSDP 而非纯 FSDP
判断你的硬件是否适合纯 FULL_SHARD:
- 算每 backward step 的 GPU 计算时间(用 profiler 看)
- 算每个 wrap unit 的 AllGather 通信时间(unit_param_size / 带宽)
- 比较两者,AllGather > 计算时间就要切 HSDP 或减小 unit 粒度
实战:H100 + NVLink,70B 训练每 unit 约 2GB params、forward 50ms / unit。AllGather 时间 ~3ms 远小于 50ms compute → 完美 overlap。
18.6.30 FSDP × torch.compile 的协作细节
FSDP-2 + compile 的具体内部协作:
fully_shard(model, mesh=mesh)
compiled = torch.compile(model)
out = compiled(x)
发生的事:
- 第一次调用:Dynamo trace 看到
model(x),遇到 DTensor 张量 - DTensor 算子 dispatch:每个 op 在 dispatcher 层有 DTensor 的特殊实现,知道如何处理 placement
- AOTAutograd functionalize:DTensor 的 mutation(unshard / reshard)被识别 + functionalize
- min-cut partition:考虑 collective 的 cost(AllGather 比普通 op 贵)做 fusion 决策
- Inductor codegen:生成的 Triton kernel 直接调用 NCCL collective 函数,与 compute fuse
最终编译产物里 collective 与 compute 在同一个 kernel 链路,CPU 几乎不参与。这与 FSDP-1 + compile(频繁 graph break)相比有本质提升。
实测 70B FSDP-2 + compile 比纯 FSDP-2 快 1.3-1.5x。第 14 章 §14.9.5 的 transformer 加速比就建立在这条路径上。
18.6.31 FSDP × Pipeline Parallel
Pipeline Parallel (PP) 把模型按层切到多卡,与 FSDP 在不同维度上分。3D parallel 中两者协同:
mesh = init_device_mesh("cuda", (PP, DP), mesh_dim_names=("pp", "dp"))
# 每 pipeline stage 内用 FSDP
for stage_id in range(PP):
if local_pp_rank == stage_id:
layers = model.layers[stage_id*L:(stage_id+1)*L]
fully_shard(layers, mesh=mesh["dp"])
# Pipeline schedule 处理跨 stage 通信
pipe = PipelineStage(layers, ...)
每个 PP stage 内部用 FSDP shard 自己那部分 layers。stage 之间用 P2P send/recv 传 activation。这种”stage 内 FSDP、stage 间 PP”是 405B 训练的标准配置。
实战调优:PP 的”bubble”(pipeline 启动 / 结束的空闲时间)与 FSDP 的 collective 时间互动复杂,需要 profiler(第 21 章)实测每个 stage 的 timeline 优化。
18.6.31.5 FSDP × evaluation:临时切到完整模型
训练循环里偶尔要做 eval(计算 val loss / metrics)。FSDP 模型默认是分片状态,eval 时需要完整 forward。两条路:
A. 保持 FSDP 状态做 eval(推荐):
model.eval()
with torch.no_grad():
for batch in val_loader:
out = model(batch) # FSDP 仍 unshard / reshard, 与训练相同路径
eval 期间 FSDP 仍触发 AllGather + reshard,但因为 no_grad、不做 reshard backward、不需要 ReduceScatter grads。开销是单纯的 forward unshard。
B. summon_full_params 一次性 unshard:
with FSDP.summon_full_params(model, recurse=True):
model.eval()
with torch.no_grad():
for batch in val_loader:
out = model(batch)
完整模型驻留显存,eval 期间不触发 collective —— 但单卡显存压力翻 N 倍。只在 eval batch 多 + 模型小时用,70B 训练绝对不能用(OOM)。
实战推荐 A 路线:FSDP eval 速度足够,省心。
18.6.32 FSDP 错误诊断速查表
| 症状 | 可能原因 | 诊断 |
|---|---|---|
| 反向卡在 AllGather | 某 rank 进度不一致 / NCCL hang | 开 TORCH_NCCL_ASYNC_ERROR_HANDLING=1 看哪个 rank 先报错 |
| OOM 在 forward 第一层 | wrap policy 太粗(整 model 一个 unit) | 改用 ModuleWrapPolicy({TransformerBlock}) |
| 训练慢比 DDP 慢一倍 | HSDP 跨节点带宽不够 | 改 HYBRID_SHARD_ZERO2 或者升级网络 |
RuntimeError: Tensors of type DTensor must have device_mesh attribute | model surgery 后 wrap 出错 | 检查是否在 fully_shard 之后改了 model |
| ckpt save 时 rank 0 OOM | 用了 FULL_STATE_DICT | 切到 SHARDED_STATE_DICT + DCP |
| 加载 HF safetensors 后训练异常 | weight 没正确切片到 DTensor | 用 distributed_state_dict_helper 帮助加载 |
| forward 与 eager 数值不一致 | precision 配置错 | 检查 MixedPrecision policy 的 reduce_dtype |
把这套表打印贴工位,FSDP 调试效率能提升 5x 以上。
18.6.33 FSDP 的 unshard 时机决策
FSDP-2 内部的 unshard 调度算法:
flowchart TB
Start[forward 开始]
Start --> Stage{当前 unit?}
Stage -->|root unit| RootUnshard[unshard root + 立即 prefetch unit 1]
Stage -->|sub unit| SubCheck{已被 prefetch?}
SubCheck -->|是| Use[直接用]
SubCheck -->|否| LazyUnshard[紧急 unshard]
Use --> Compute[执行 forward]
Compute --> Reshard[reshard 当前 unit]
Reshard --> Prefetch[启动 prefetch 下下个 unit]
Prefetch --> Stage
style RootUnshard fill:#fef3c7
style LazyUnshard fill:#fee2e2
核心规则:当前 unit 计算时,下一个 unit 的 AllGather 应该已经在飞。forward_prefetch=True 让这套机制自动工作。
如果实际跑出来 prefetch 没赶上(latency-bound 场景),会触发 lazy unshard —— 当前 unit 等 AllGather 完成,整个 timeline 出现 gap。这时 profiler 的 distributed view 能直接看到。优化:调整 wrap policy 让 unit 更小(每个 unit AllGather 时间更短、prefetch 更容易及时)。
18.6.33.5 grad accumulation 的 FSDP 显存账
§18.6.24 提了 no_sync 跳过中间 ReduceScatter。但还有显存账值得展开:
正常 FSDP grad accumulation 流程:
micro_step 1: forward (unshard / reshard) + backward (有 ReduceScatter)
→ grad 保留在 1/N 切片
micro_step 2: 同上, grad 累积到 1/N 切片
...
micro_step N: 同上 + optimizer.step
每 micro step 都触发 ReduceScatter,累积 N 次通信。no_sync 模式:
with model.no_sync(): # 内部不触发 ReduceScatter
micro_step 1: forward + backward (grad 保留为完整, N 倍显存)
micro_step 2: grad 累积, 仍是完整
...
micro_step N-1: 同上
# 退出 no_sync
micro_step N: forward + backward (这次触发 ReduceScatter, grad 切到 1/N)
optimizer.step()
关键差异:no_sync 期间 grad 是完整的(不分片),显存 N 倍上升。如果 N=8 + grad 大小 280GB → 单 rank 持有 2240GB(不可能装下)。
所以 FSDP 下用 no_sync 必须 配 cpu_offload 或者 ckpt 卸载 grad 才能避免显存爆炸。生产代码里 FSDP grad accumulation 通常不用 no_sync,接受 N 次通信开销 —— 这是 FSDP 与 DDP 工程取舍的不同。
18.6.34 FSDP 与其他并行框架对比
| 框架 | 核心思想 | 对比 |
|---|---|---|
| PyTorch FSDP-2 | DTensor + DeviceMesh | 与 PyTorch 生态深度集成、与 compile 兼容 |
| DeepSpeed ZeRO | Shard params/grads/state | 更早实现,是 FSDP 的灵感来源;与 PyTorch 集成靠 wrapper |
| Megatron-LM | TP + PP(无 DP shard) | NVIDIA 主推,专攻 TP/PP,DP 用 DDP |
| ColossalAI | 多策略统一抽象 | 国产框架,支持多种并行策略 |
| OneFlow | 全局视角自动并行 | 算子级自动决定并行策略 |
PyTorch FSDP-2 的工程优势是与 torch.compile / torch.export / DTensor 共生 —— 不是独立框架而是 PyTorch 生态的一等公民。其他框架要么是 wrapper(依赖 PyTorch 但侵入式扩展),要么是独立运行时(迁移成本高)。
实战选择:
- 用 PyTorch 训练 → 用 FSDP-2
- 已有 DeepSpeed 代码 → 继续用 DeepSpeed(迁移成本高、收益不一定值得)
- 极致 TP/PP 优化 → Megatron-LM
- 国产芯片支持 → ColossalAI / 厂商定制
18.6.35 FSDP × HF Transformers / Lightning 实战
HuggingFace Trainer 与 PyTorch Lightning 都支持 FSDP,但配置方式不同:
HuggingFace TrainingArguments:
training_args = TrainingArguments(
fsdp="full_shard auto_wrap",
fsdp_config={
"transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
"min_num_params": 0,
},
)
PyTorch Lightning Strategy:
trainer = pl.Trainer(strategy="fsdp", devices=8, ...)
# 或者更细粒度
trainer = pl.Trainer(strategy=FSDPStrategy(
auto_wrap_policy={LlamaDecoderLayer},
sharding_strategy="FULL_SHARD",
))
两个框架内部都调 fully_shard —— 只是包装好让用户不用手写 wrap_policy。但理解原始 API 让你能在 Trainer / Lightning 配置出错时定位是哪一层包装的问题。
国内 Llama-Factory、Firefly 等微调框架也内置 FSDP 支持,配置类似 HF Trainer。
18.6.35.5 FSDP 与 elastic 训练的协作
第 17 章 §17.8.25 / §17.9.7 讲过 elastic(torchrun + max_restarts)。FSDP 与 elastic 协作时有几个细节:
- 重启时 mesh 配置必须一致:如果重启时用了不同 world_size,必须用 DCP 加载(自动 reshard 到新 mesh)
- ckpt 频率要够高:失败重启等于丢失从最近 ckpt 到现在的所有计算。70B 训练通常每 1000 步存一次(约 30 分钟)
- DCP 的写入要 robust:跨节点写文件可能因为某 rank 网络抖动失败,DCP 要支持”部分 rank 重写”
实战:torchrun + DCP + 每 1000 步 ckpt 是 LLM 训练的标配。失败时整 job 重启、加载最近 ckpt 继续训。一周训练 70B 模型期间故障 5-10 次是常态,没有 elastic 自动恢复就要全人工介入。
18.6.36 实战训练流程的完整推荐
70B 模型从零训练的标准 FSDP-2 流程:
# 1. 设置环境
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
# 2. 初始化 process group
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
# 3. 创建 DeviceMesh (32 卡 4 节点 HSDP)
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("inter", "intra"))
# 4. meta device 构造模型
with torch.device("meta"):
model = LlamaModel(config)
# 5. activation_checkpoint 每层
apply_activation_checkpointing(
model,
check_fn=lambda m: isinstance(m, LlamaDecoderLayer),
)
# 6. FSDP-2 wrap (HSDP)
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
for layer in model.layers:
fully_shard(layer, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)
# 7. lazy init weights
model.to_empty(device='cuda')
init_model_weights(model)
# 8. compile
model = torch.compile(model)
# 9. optimizer 与 scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)
scheduler = ...
# 10. 训练循环
for batch in loader:
optimizer.zero_grad()
loss = model(batch).loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # FSDP 自动处理 DTensor
optimizer.step()
scheduler.step()
整套流程把全书前面章节的内容串起来:DeviceMesh(§18.6.5)+ FSDP-2(§18.5)+ MixedPrecision(§18.6 + 第 20 章)+ activation_checkpoint(§18.6.9)+ HSDP(§18.6.7)+ torch.compile(第 12-15 章)+ fused optimizer(第 10 章)+ DTensor grad clip(§18.6.16)。
这条流程是当今 70B+ LLM 训练的事实标准。理解整章后回到这段代码,每行配置都对应某个具体的工程决策。
18.6.37 FSDP 调试:常用日志与工具
调 FSDP 训练的几个常用工具:
TORCH_DISTRIBUTED_DEBUG=DETAIL:开启后所有 collective 在 NCCL kernel 前后打印 tag + tensor 信息,能精确定位”哪个 rank 卡在哪个 collective”。
TORCH_LOGS=fsdp:打印 FSDP 的 unshard / reshard / prefetch 决策,看是否符合预期。
torch.distributed.fsdp.FullyShardedDataParallel.print_runtime_summary(model):训练几个 step 后打印每个 unit 的统计(unshard 次数、平均时间、prefetch 命中率等)。是 FSDP 自带的”轻量 profiler”。
chrome trace + distributed view(第 21 章 §21.9):终极调试工具。能看到每 rank 的 collective timeline + 各 rank 等待时间。
实战调试流程:先开 TORCH_LOGS=fsdp 看决策对不对、再用 chrome trace 看实际 timeline。多数 FSDP 性能问题(如 unshard 没及时 prefetch)能在这两层定位。
18.6.38 FSDP 设计上的几个隐形约束
FSDP 不是万能的。几个不能突破的工程约束:
1. 每个 unit 的 unshard 必须能装下完整 params:FSDP 的最大单 unit param 大小不能超过单 rank GPU 显存。70B 单层 1B 参数 → 4GB(fp32)单 rank 装得下。但 1T 模型单层 100B → 400GB 装不下,FSDP 必须配合 TP 才能跑。
2. 同一 unit 的所有 param 必须同 dtype:FlatParameter(FSDP-1)合并要求 dtype 一致;FSDP-2 用 DTensor 没这个限制但一个 unit 同 dtype 仍是最优。
3. unit 边界不能跨越复杂的 module 控制流:if-branch / loop 内的 module 不适合自己成为 wrap unit,因为不是每次 forward 都被调到,prefetch 决策困难。
4. FSDP-2 的 mesh 一旦决定就不能改:训练中途换 mesh 配置(如从 8 卡变 16 卡)需要先 ckpt 落盘 → 用新 mesh 重 init → DCP load 时自动 reshard。不能”在线变”。
理解这些约束让你设计训练架构时心里有数。“一切都用 FSDP” 不是答案 —— 极限场景下还要 TP / PP / 自定义并行。
18.7 几条工程经验
1. ShardingStrategy 选择:30B 以下用 SHARD_GRAD_OP(ZeRO-2,省 1/2 显存);70B 以上用 FULL_SHARD;跨节点带宽不足用 HYBRID_SHARD
2. 用 v2.4+ 的 FSDP-2 (fully_shard):除非你已经有大量 FSDP-1 代码,否则新代码直接用 FSDP-2
3. forward_prefetch=True + backward_prefetch=BACKWARD_PRE 默认开启:是 overlap 的核心
4. wrapping policy 设到合适粒度:每个 transformer block 单独 shard 通常最优。粒度太细(每个 Linear 都 shard)通信开销过大;太粗(整个模型一个 shard)overlap 不好
5. use_orig_params=True:让 FSDP 不破坏 param 引用,原 optimizer 能直接复用。FSDP-1 的兼容选项,FSDP-2 默认就这样
6. cpu_offload=CPUOffload(offload_params=True):参数卸到 CPU,省更多显存,代价是 H2D 拷贝拖慢训练。仅在显存极紧张时用
7. checkpoint 与 FSDP:FSDP 的 ckpt 用 torch.distributed.checkpoint(DCP)存分布式格式。每 rank 只存自己那 1/N,避免 rank 0 写 1TB 文件的瓶颈
8. activation_checkpoint 与 FSDP 叠加用:FSDP 省 params/grads/optimizer 显存,checkpoint 省 activation 显存。两者正交、可叠加
18.8 实战决策路径
flowchart TD
Start[要训练大模型]
Start --> Size{模型 / GPU?}
Size -->|装得下| DDP[用 DDP]
Size -->|装不下| Bandwidth{跨节点带宽?}
Bandwidth -->|InfiniBand 高带宽| Full[FULL_SHARD]
Bandwidth -->|普通 ethernet| Hybrid[HYBRID_SHARD<br/>节点内 shard 跨节点 DDP]
Full --> Mp[配 MixedPrecision]
Hybrid --> Mp
Mp --> Ckpt[启用 activation_checkpoint]
Ckpt --> Compile[用 torch.compile 提速]
style Full fill:#dcfce7,stroke:#22c55e
style Hybrid fill:#fef3c7,stroke:#f59e0b
70B Llama 训练的典型配置:32 张 H100、HYBRID_SHARD(同 8 卡节点 ZeRO-3、跨 4 节点 DDP)、bf16 通信 + fp32 reduce、每个 transformer block 单独 fully_shard、activation_checkpoint 每 4 层一次。
18.8.5 FSDP 与 PyTorch 演进路线
FSDP 在 PyTorch 演进路线上的位置:
- v1.10 (2021):FSDP-1 实验性引入,灵感来自 DeepSpeed ZeRO
- v1.12 (2022):FSDP-1 成为 stable,开始被生产使用
- v2.0 (2023):FSDP-1 与 torch.compile 集成尝试,但兼容性差
- v2.2 (2024 初):FSDP-2 (
fully_shard) prototype 推出 - v2.4 (2024 中):FSDP-2 成为推荐 API,FSDP-1 进入维护模式
- v2.11 (2026):FSDP-2 与 DTensor / DeviceMesh 深度整合,成为新一代分布式训练标准
FSDP-1 大概率会在 v3.x 时代被完全弃用 —— FSDP-2 在性能、灵活性、工具兼容性上全面胜出。但因为 FSDP-1 在生产训练里使用广泛,PyTorch 团队承诺至少维护到 v3.0+。
国内训练框架(如华为 MindSpore 的并行模式、字节 ByteCheckpoint 等)也大量借鉴 FSDP / ZeRO 思想。理解 FSDP 不只是理解 PyTorch 一个工具,是理解整个大模型训练时代的工程基础。
18.8.6 完整决策树:用 DDP 还是 FSDP-2?
flowchart TB
Start[要训练大模型]
Start --> SizeCheck{单 rank 装得下完整 params + grads + optimizer state?}
SizeCheck -->|是| DDP[DDP - 性能最优]
SizeCheck -->|否, 装不下 optimizer state| ZERO2[FSDP SHARD_GRAD_OP - ZeRO-2]
SizeCheck -->|否, 连 params 也装不下| ZERO3[FSDP FULL_SHARD - ZeRO-3]
ZERO2 --> Bandwidth{跨节点带宽足够?}
ZERO3 --> Bandwidth
Bandwidth -->|是, NVLink/IB| Flat[平面 mesh: 单维度 shard]
Bandwidth -->|否, ethernet| Hybrid[HSDP 2D mesh: 节点内 shard 跨节点 replicate]
Flat --> Compile{要 torch.compile?}
Hybrid --> Compile
Compile -->|是| FSDP2[FSDP-2 fully_shard API]
Compile -->|否, 用老代码| FSDP1["FSDP-1 (维护模式)"]
style DDP fill:#dcfce7
style FSDP2 fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
style FSDP1 fill:#fee2e2
这条决策树覆盖 95% 的大模型训练场景。剩下 5% 是极端场景(如 1T+ 模型必须配 TP/PP)需要专门设计。
18.9 跨书关联
- 第 16 章 ProcessGroup:FSDP 的 AllGather / ReduceScatter 都通过 ProcessGroupNCCL
- 第 4 章 Caching Allocator:FSDP 的 unshard 触发临时分配大量显存,对 allocator 压力大,常需
expandable_segments=True - 第 7 章 §7.5.3 saved_tensors_hooks:FSDP 的 activation checkpoint 与这套 hook 配合
- 第 13 章 AOTAutograd:FSDP-2 的 collective 也走 AOTAutograd,让通信 op 被 Inductor 编译
18.9.5 FSDP × Engine:与第 8 章的协作
第 17 章 §17.10.7 讲过 DDP × Engine。FSDP 的协作更复杂:
- forward 时插入 unshard hook:FSDP 通过
register_forward_pre_hook(第 9 章 §9.8)在 forward 前触发 unshard - forward 后插入 reshard hook:
register_forward_hook触发 reshard - backward 时通过 autograd.Function 重新触发 unshard:与 §13 AOTAutograd 类似的机制
- backward 完成的 ReduceScatter 通过 grad accumulator post hook 触发(与 DDP §17.8.10 类似但用 ReduceScatter 而非 AllReduce)
整套机制让 FSDP 把分片调度寄生在 PyTorch 的 hook 体系上。Engine 完全不知道 FSDP 存在,只是按 DAG 调度反向 —— hook 触发的 collective 是 Engine 视角的”普通副作用”。
这种”分布式策略寄生在框架 hook 之上”的设计是 PyTorch 分布式训练能持续演进的根本(DDP / FSDP / 用户自定义并行都共享 hook 接口)。第 23 章设计哲学会再回到这条线索。
18.9.6 整章信息密度的小结
读完 Ch 18 你应该能:
- 决策:DDP / ZeRO-2 / ZeRO-3 / HSDP 各自适用场景(§18.8.6 决策树)
- 配置:选 wrap_policy 粒度、配 MixedPrecision、决定是否开 cpu_offload
- 理解:unshard / reshard 时机、prefetch 怎么 overlap、min-cut activation_checkpoint 怎么决定
- 调试:错误诊断速查表 + chrome trace 看 timeline
- 接生态:HF Trainer / Lightning / DCP / torch.compile 怎么协作
- 预判演进:FSDP-1 → FSDP-2 → 未来与 DTensor 进一步融合
70B Llama 训练的实战配置(§18.6.36)把整章串起来 —— 每行配置都对应某节的工程决策。这是当今大模型训练的事实标准,理解它就理解了”为什么 LLM 训练长这样、不能简化成更朴素的形式”。
18.10 设计启示
FSDP 的核心思想:
第一:显存与计算 / 通信可以互换:FSDP 把显存压力转成通信压力,让”训练大模型”从”硬件极限”变成”调优问题”
第二:分片 + 临时凑齐是分布式数据结构的通用模式:DTensor / 分布式哈希表 / 分片数据库都用这套思路。每节点常态只持有部分数据,需要时凑齐
第三:prefetch 是 overlap 通信的标配:任何”计算 - 通信 - 计算”链路都要考虑 prefetch,让通信延迟隐藏到计算时间里
第四:模块级 wrap 比模型级 wrap 灵活:FSDP-2 的 fully_shard(submodule) 思想可以借鉴到所有”框架增强 module”的场景
下一章拆序列化:torch.save / torch.load + safetensors + Distributed Checkpoint,看大模型训练的 ckpt 怎么管理。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。