第10章 优化器与梯度更新
“An optimizer is just a function: state, params, grads → new params, new state. Everything else is plumbing.”
—— PyTorch optim docstring
本章要点
Optimizer基类是个简单的”参数容器 + step 方法”:核心数据结构是param_groups(list of dict)和state(per-param 状态)。基类不算梯度更新,子类(SGD / Adam / AdamW)实现step()- 三档性能模式:
single_tensor(朴素 for 循环)、multi_tensor / foreach(一次操作所有参数,用torch._foreach_*批量算子)、fused(一个 CUDA kernel 算所有参数 + 状态更新) capturable=True让 optimizer 兼容 CUDA Graph:不依赖 Python 端 step counter,所有状态保留在 GPU 张量里zero_grad(set_to_none=True)是默认行为:把param.grad设为 None 而非 inplace 写零,省一次 zero kernel 调用- param_groups 是分组学习率的核心:每个 group 可以有自己的 lr / weight_decay 等。learning rate scheduler 只是在 group 上调 lr 字段
- state_dict 与 ckpt:optimizer 的状态(如 Adam 的 momentum / variance)也能保存加载,但要留意 param_groups 的 ID 与新模型的 ID 对齐
10.1 一行代码引发的疑问
每次训练循环都长这样:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch).sum()
loss.backward()
optimizer.step()
optimizer.step() 这一行做的事说起来简单 —— “用梯度更新参数”。但仔细想要回答:
- optimizer 怎么知道要更新哪些参数?
- Adam 的 momentum / variance 状态存在哪?
- step 内部怎么并行处理几百万个参数?
- 为什么 70B 模型训练时 step 这一步能在毫秒级完成?
答案在 torch/optim/optimizer.py(1193 行)和 torch/optim/adam.py(991 行)等具体优化器实现里。本章拆它的设计。
10.1.1 优化器在训练中的位置
把第 7-9 章的内容串起来,optimizer 是训练循环的最后一环:
graph LR
F["forward<br/>model(batch) → loss"] --> B["loss.backward<br/>反向图执行<br/>填 param.grad"]
B --> O["optimizer.step<br/>用 grad 更新 param"]
O --> Z["optimizer.zero_grad<br/>清梯度"]
Z --> F
style O fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
optimizer 拿到的输入是已经填好的 param.grad(第 7 章 §7.6 AccumulateGrad 的产物),它把 grad 转成 param 的更新。optimizer 完全不参与 forward 与 backward —— 它就是个 “param 与 grad → param” 的纯函数(加上 momentum / variance 等状态)。这种”职责清晰”的拆分让 optimizer 可以独立演进、独立测试,与模型无关。
这种”无状态契约”让 PyTorch 内置的几十种 optimizer(SGD/Adam/AdamW/Adagrad/Adadelta/RMSprop/LBFGS/RAdam/NAdam/…)都能用同一个 optim.step() 接口工作。第三方扩展(如 transformers 的 AdamWFp8、Apex 的 FusedLAMB)也都遵循这个契约 —— 不需要改 PyTorch 主仓就能集成。
10.2 Optimizer 基类:参数与状态的容器
torch/optim/optimizer.py:342 是基类:
class Optimizer:
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None:
self.defaults = defaults # 全局默认参数 (lr, eps...)
self.state: defaultdict[Tensor, Any] = defaultdict(dict)
self.param_groups: list[dict[str, Any]] = []
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}] # 用户传 list 时包装成 group
for param_group in param_groups:
self.add_param_group(cast(dict, param_group))
# 6 种 hooks
self._optimizer_step_pre_hooks = OrderedDict()
self._optimizer_step_post_hooks = OrderedDict()
...
核心数据结构两个:
param_groups是list[dict]。每个 dict 至少有'params'key(参数列表)+ defaults 里的所有字段(lr / weight_decay 等)。支持每个 group 独立配置(如 backbone 用低 lr,head 用高 lr)state是defaultdict[Tensor, dict]。每个参数张量映射到一个状态 dict(Adam 里包含step/exp_avg/exp_avg_sq)。lazy 初始化:第一次 step 时才创建状态
graph TB
Opt[Optimizer 实例]
Opt --> PG["param_groups: list[dict]<br/>──────────<br/>group 0: {'params': [...], 'lr': 0.001, 'weight_decay': 0.01}<br/>group 1: {'params': [...], 'lr': 0.0001}"]
Opt --> S["state: defaultdict[Tensor, dict]<br/>──────────<br/>p1 → {'step': 100, 'exp_avg': tensor, 'exp_avg_sq': tensor}<br/>p2 → {'step': 100, 'exp_avg': tensor, 'exp_avg_sq': tensor}"]
Opt --> D["defaults: dict<br/>{'lr': 0.001, 'beta1': 0.9, ...}"]
Opt --> H["6 种 hooks<br/>(step pre/post, sd pre/post, load_sd pre/post)"]
style PG fill:#dbeafe,stroke:#3b82f6
style S fill:#fef3c7,stroke:#f59e0b
10.2.1 为什么 param_groups 是 list[dict] 而不是简单 list
新手以为 optimizer 就是”一组参数 + 一个 lr”。实际生产代码常见写法是分组:
backbone_params = list(model.backbone.parameters())
head_params = list(model.head.parameters())
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': 1e-5, 'weight_decay': 0.01},
{'params': head_params, 'lr': 1e-3, 'weight_decay': 0.0},
])
backbone 通常已经预训练,用低 lr 微调;head 是新加的,用高 lr 训练。这种”按 group 分配 hyperparameter”在迁移学习几乎是标配。param_groups 的设计就是为这个场景。
每个 dict 里没指定的字段从 defaults 继承。所以你可以只写差异 —— PyTorch 自动补全。
10.2.2 为什么 state 用 defaultdict[Tensor, dict]
state 是 per-param 状态:每个参数 p 对应一个 state[p] dict。Adam 的 state[p] 包含:
{
'step': torch.tensor(100), # 步数 (Tensor 形式以支持 capturable)
'exp_avg': tensor(shape=p.shape), # 一阶动量 m
'exp_avg_sq': tensor(shape=p.shape), # 二阶动量 v
}
第一次 step 时 PyTorch 给每个参数初始化这些状态(exp_avg / exp_avg_sq 全零)。后续每步更新它们。
为什么用 defaultdict 而非普通 dict?因为 lazy 初始化:访问 state[p] 时如果没有,自动返回空 dict。这避免了在 __init__ 里就为每个参数预分配状态(那时还不知道用户会不会真的训练)。
10.2.3 state 的内存账
state 占的内存惊人。Adam 每个 param 要 exp_avg + exp_avg_sq 两份与 param 同 shape 同 dtype 的张量。对 70B 模型:
- params: 70B × fp32 = 280 GB(或 fp16 = 140 GB)
- gradients: 70B × fp32 = 280 GB
- exp_avg: 70B × fp32 = 280 GB
- exp_avg_sq: 70B × fp32 = 280 GB
合计 1120 GB!没有任何单卡能装下。这是为什么 70B 训练必须用 FSDP/ZeRO 把这套 state 切到多卡。第 18 章 FSDP 章会展开 ZeRO-1/2/3 三档怎么分别分片”params / grads / optimizer state”。
理解 state 的内存占用,你就理解了为什么 SGD 在大模型时代有时被重新选择 —— SGD 的 momentum_buffer 只是参数量的 1 倍,不是 Adam 的 2 倍。在显存极致紧张的场景,“用 SGD 换显存”是一个实际选项。
10.2.3.5 add_param_group 的动态扩展
add_param_group(param_group) 让你在 optimizer 创建后动态加新参数组:
optimizer = torch.optim.AdamW(model.backbone.parameters(), lr=1e-5)
# 后来加了新 head, 想给它高 lr
optimizer.add_param_group({'params': model.head.parameters(), 'lr': 1e-3})
这种”事后追加”在 fine-tuning 场景常见 —— 加载预训练 backbone 后再加自己的 head。Optimizer 内部维护 param_groups 的整数索引(state 用 ID 而非对象),所以新加的 group 不会破坏已有 state。
新加 group 时所有字段从 defaults 继承(除非用户在 dict 里显式指定)。这种”局部覆盖 + 全局兜底”的 dict 合并模式让用户可以只写差异、不重复指定相同字段,与 React Props 的 spread 操作思想一致。
但要注意:add_param_group 的参数必须不在已有任何 group 里,否则 PyTorch 会报错(避免参数被重复更新)。这种”参数唯一性”检查让 optimizer 不会出现”双更”灾难。
10.2.4 differentiable optimizer
Optimizer.__init__ 里 defaults 经常出现 differentiable: False。这个标志让 optimizer 自身的 step 操作可以被 autograd 跟踪 —— 用于 meta-learning 这种”对 optimizer 求二阶导”的场景。
正常训练 differentiable=False,所有 inplace 更新都在 with torch.no_grad(): 里跑(避免反向图爆炸)。Meta-learning 时 differentiable=True,step 内部允许构建反向图,让 outer loop 能对 inner step 求导。这种”为研究场景留扩展点”的设计让 PyTorch 在学术界比 TF 更受欢迎。
10.3 step() 的全流程
step() 是子类实现的,但基类用 _patch_step_function 包装了它,加入 hook 调用:
def _patch_step_function(self):
cls = self.__class__
self._zero_grad_profile_name = "Optimizer.zero_grad#" + cls.__name__ + ".zero_grad"
if cls.step is not Optimizer.step:
cls.step = profile_hook_step(cls.step) # 包装 step 加入 hooks + profile
调 optimizer.step() 时实际跑的是包装版:
sequenceDiagram
autonumber
participant U as 用户: optimizer.step()
participant W as profile_hook_step (包装)
participant Pre as step_pre_hooks
participant S as 真实 step (子类实现)
participant Post as step_post_hooks
U->>W: step()
W->>Pre: 跑所有 pre-hook
Pre-->>W: 可能修改 args
W->>S: 进入子类的 step (Adam.step / SGD.step)
Note over S: 1. 遍历 param_groups<br/>2. 对每个 param 取出 grad<br/>3. 取出 state, lazy init<br/>4. 调用 _single/_multi/_fused 函数<br/>5. 写入 param 与 state
S-->>W: 完成
W->>Post: 跑所有 post-hook
W-->>U: 返回
10.3.1 step 接受 closure 参数
Optimizer.step(closure=None) 的 closure 参数让某些 optimizer(如 LBFGS)可以多次重新算 loss:
def closure():
optimizer.zero_grad()
loss = compute_loss(...)
loss.backward()
return loss
optimizer.step(closure)
LBFGS 等”二阶”优化器需要在 line search 中多次评估同一点的 loss / grad,closure 提供了”如何重新算 loss”的回调。普通 SGD/Adam 不需要这个,把 closure 设为 None 即可。
理解 closure 的存在能让你看 PyTorch 源码时不困惑:每个 optimizer 的 step 签名都接受 closure,但 99% 的优化器代码里都是 if closure is not None: ... else: ... 简单分支。这是 PyTorch 给 LBFGS 一类特殊 optimizer 留的扩展点。
10.3.2 6 种 hooks 的工程价值
回顾 §10.2 的 hook 列表(step pre/post、state_dict pre/post、load_state_dict pre/post),这 6 种 hook 每种都对应特定生态使用场景:
- step_pre_hook:DDP 用它在 step 前确认所有梯度已经 AllReduce 完成
- step_post_hook:参数更新后做 EMA(exponential moving average)、log 参数 norm 等
- state_dict_pre_hook:保存前做转换(如把 fp32 转 fp16 缩小 ckpt)
- state_dict_post_hook:保存后注入额外 metadata
- load_state_dict_pre_hook:加载前做版本迁移、key rename
- load_state_dict_post_hook:加载后做后处理(如重置 step counter)
理解 hook 的设计意图,你写训练框架时能用 hook 优雅扩展,而不是 monkey-patch optimizer 类。这种”打开扩展点”的工程文化让 PyTorch 生态既稳定又能持续演进。
10.4 SGD:最简优化器
torch/optim/sgd.py 的核心 _single_tensor_sgd:
def _single_tensor_sgd(params, grads, momentum_buffer_list, *,
weight_decay, momentum, lr, dampening, nesterov, ...):
for i, param in enumerate(params):
grad = grads[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(grad).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
param.add_(grad, alpha=-lr) # 更新: param = param - lr * grad
——SGD 的全部数学就这一段:
- 应用 weight_decay:
grad += weight_decay * param - 应用 momentum:
buf = momentum * buf + (1-dampening) * grad - Nesterov(可选):再用 momentum buf 修正
- inplace 更新:
param -= lr * grad
add_ 等 inplace 操作在 with torch.no_grad(): 上下文里跑,避免参数更新被 autograd 记录。
注意 SGD 的状态只有 momentum_buffer(如果 momentum != 0),所以 SGD 的 state 比 Adam 简单一倍。这是为什么大模型训练有时切回 SGD —— 省一半显存(Adam 的 m + v 是参数量的 2 倍,SGD 的 buf 只是 1 倍)。
更深一层:SGD with momentum 的数学行为与 Adam 在某些条件下接近 —— Adam 早期收敛快但后期容易陷入”过拟合 + 噪声大”,SGD 收敛慢但泛化好。一些 SOTA 训练(如 ResNet 经典训练 + LAMB 大模型训练)混合用:先 Adam warm up 再切 SGD。PyTorch 让这种”切 optimizer”非常容易:
opt_adam = torch.optim.AdamW(params, ...)
# train phase 1 with adam
opt_sgd = torch.optim.SGD(params, ...)
# train phase 2 with sgd
两套 optimizer state 互不影响,因为各自维护自己的 state dict。理解这种”optimizer 独立 state”是深度学习训练流水线设计的基础。
10.5 Adam:状态优化器的代表
torch/optim/adam.py:347 的 _single_tensor_adam(精简版):
def _single_tensor_adam(params, grads, exp_avgs, exp_avg_sqs, ..., state_steps, *,
beta1, beta2, lr, weight_decay, eps, ...):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i] # 一阶动量
exp_avg_sq = exp_avg_sqs[i] # 二阶动量
step_t = state_steps[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
step_t += 1 # 步数 +1
exp_avg.lerp_(grad, 1 - beta1) # m = beta1*m + (1-beta1)*g
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) # v = beta2*v + (1-beta2)*g^2
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
denom = (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(eps)
param.addcdiv_(exp_avg, denom, value=-lr/bias_correction1)
# 等价 param -= lr * (m_hat / (sqrt(v_hat) + eps))
——Adam 的核心更新公式实现。注意:
- 大量 inplace 操作(
exp_avg.lerp_、exp_avg_sq.mul_):避免分配新张量 - bias correction 修正初期偏估计:开始几步 m / v 接近 0,要除以
1 - beta^step修正 addcdiv_是融合算子(a += value * b / c),一次做加 + 除 + scale,比拆开三步快
每步对每个参数都跑一次这套逻辑。对 70B 模型,参数量 7e10,单步 Adam 要在 GPU 上做几十亿次乘加。这就是为什么需要后面的优化模式。
10.5.0.3 与 SGD 的 Adam 字段对比
直观对比:
| 字段 | SGD | Adam |
|---|---|---|
| step counter | 不需要 | 需要 (用于 bias correction) |
| momentum buffer | momentum_buffer (1x param) | exp_avg (1x param) |
| 二阶动量 | 无 | exp_avg_sq (1x param) |
| 显存倍率 | 1x | 2x |
| 超参数 | lr, momentum (1-2 个) | lr, beta1, beta2, eps, weight_decay (5 个) |
SGD 的简洁让它适合”显存极致紧张”或”超参数极致敏感”的场景。Adam 的丰富 state 让它在初始化早期收敛更快、对 lr 选择更宽容。两者各有优劣,今天大模型训练用 AdamW 居多但 SGD 在某些 vision 任务(如 ImageNet 经典训练)仍是 SOTA。
10.5.0.4 fused 算子的 Apex / Apollo 渊源
NVIDIA 早期在 Apex 库里提供过 FusedAdam、FusedLAMB 等 CUDA 优化算子。这些算子在 v1.x 时代是 Apex 独家,PyTorch 主仓没有 fused 实现。后来 PyTorch v1.13+ 把 Apex 的 fused Adam 思想吸收进主仓,作为 torch.optim.Adam(fused=True) 的内置选项。
这条历史让你理解:今天的 fused optimizer 不是 PyTorch 团队从零写的,是站在 Apex 几年优化经验的肩膀上。Apex 仍然存在并提供更激进的 fused 算子(如 LAMB、FusedSGD with Nesterov),但生产代码越来越多直接用 PyTorch 主仓的内置 fused —— 因为不依赖外部库更省心。
10.5.0.5 bias_correction 的数学解释
为什么 Adam 需要 bias correction?因为 exp_avg 和 exp_avg_sq 初始为零,前几步的指数滑动平均严重偏低估真实统计量。考虑只跑一步:
m_1 = beta1 * 0 + (1 - beta1) * g_1 = (1 - beta1) * g_1 = 0.1 * g_1 # 假设 beta1 = 0.9
m_1 只有 g_1 的 10%!如果直接用 m_1 做更新,前几步会非常迟缓。bias_correction1 = 1 - beta1^t 在 t=1 时是 0.1,把 m_1 除以这个值就还原成 g_1,消除偏置。
经过几十步后 beta1^t 趋近 0,bias_correction 趋近 1,修正效果消失。所以 bias correction 主要影响训练初期的 ~50 步。这也是为什么有些论文报告”训练初期不稳定” —— bias correction 让 Adam 早期 lr 实际上不稳定。
理解这条数学原理,你能合理解释一些训练曲线现象(如训练 100 步后 loss 突然下降是 bias correction 完成的瞬间)。
10.5.0.6 关于 inplace 操作与 autograd 的协调
注意 Adam 内部的 exp_avg.lerp_(...)、param.addcdiv_(...) 都是 inplace。如果不在 with torch.no_grad(): 上下文里跑,这些 inplace 操作会触发 autograd 的 inplace 检查(第 7 章 §7.5.1),抛”one of the variables … has been modified”。
profile_hook_step 的包装内部就是用 @torch.no_grad() 装饰,让 step 的所有内部更新都不进反向图。这是为什么 step 函数体里可以放心地用 inplace 操作 —— 调用环境已经给关掉 autograd 了。
如果你写自定义 optimizer 忘了 @torch.no_grad() 装饰 step,跑一两步就会报错。这个错误在 PyTorch 错误信息里足够明显,但新手不熟悉时会困惑半天。
10.5.1 AdamW vs Adam:weight decay 的微小差别
AdamW 与 Adam 99% 代码相同,唯一区别是 weight_decay 应用时机:
# Adam: weight_decay 加到 grad 里 (污染 momentum)
grad = grad + weight_decay * param
# AdamW: weight_decay 直接从 param 里减 (不污染 momentum)
param = param - lr * weight_decay * param
数学上不等价 —— Adam 的 weight_decay 在指数滑动平均里被”遗忘”,效果不如 AdamW 的”硬性衰减”。论文 Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2017) 证明了 AdamW 在大量任务上比 Adam 收敛更好。
今天大模型训练几乎全部用 AdamW(Llama / GPT / Mixtral 等都是)。Adam 在某些 fine-tuning 场景仍有用,但作为默认选择已经被淘汰。
10.6 三档性能模式:foreach / fused / capturable
考虑大模型场景:70B 参数 = ~7000 个 Linear 层 × 各自 weight + bias 张量 ≈ 几万个独立张量。如果每个张量都跑一次 single_tensor 循环里的 5-6 个算子,dispatcher 调用就有几十万次 —— 第 5 章 §5.8.1 我们算过单次 ~120ns,几十万次累计 30+ ms。
PyTorch 提供三档优化路径:
graph LR
Single["single_tensor (朴素)<br/>for p in params: update(p)<br/>~30ms / step (70B)"]
Foreach["multi_tensor / foreach<br/>torch._foreach_add_(params, grads, ...)<br/>~10ms / step"]
Fused["fused (CUDA kernel 融合)<br/>一个 kernel 处理所有 param + state<br/>~3ms / step"]
Single -->|每 op 调一次 dispatcher| ForeachSlow[N 次 dispatcher]
Foreach -->|一次 dispatcher 跑所有张量| ForeachFast[1 次 dispatcher + 内部 loop]
Fused -->|一次 kernel launch 完成所有数学| FusedFast[1 次 kernel launch]
style Single fill:#fee2e2,stroke:#ef4444
style Foreach fill:#fef3c7,stroke:#f59e0b
style Fused fill:#dcfce7,stroke:#22c55e
10.6.1 foreach (multi_tensor) 模式
torch._foreach_* 是一组特殊算子,一次操作多个张量:
# 朴素: N 次 dispatcher
for p, g in zip(params, grads):
p.add_(g, alpha=-lr)
# foreach: 一次 dispatcher
torch._foreach_add_(params, grads, alpha=-lr)
_foreach_add_ 在内部跑一个 fused CUDA kernel(或 SIMD CPU kernel),对张量列表做批量操作。dispatcher 开销除以 N。
_multi_tensor_adam(torch/optim/adam.py:554)就是把 _single_tensor_adam 里所有 inplace 操作换成 foreach:
# 简化版
torch._foreach_add_(device_state_steps, 1) # 所有 step += 1
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) # 所有 m = lerp(m, g)
torch._foreach_mul_(device_exp_avg_sqs, beta2) # 所有 v *= beta2
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1-beta2)
...
实测在 70B 模型训练上,foreach 比 single_tensor 快 3-5 倍。这是 PyTorch 给”参数特别多”工作负载的标准优化。从 v1.13 开始 Adam / AdamW / SGD 等默认就是 foreach 模式。
_foreach_* 算子的实现(aten/src/ATen/native/cuda/ForeachBinaryOpList.cu 等)是 PyTorch 在 CUDA 端做的另一项 codegen —— 自动为常见 binary / unary 操作生成 multi-tensor 版本。每个 _foreach 内部用一个 fused CUDA kernel 一次处理几百个张量,把 dispatcher overhead 与 launch overhead 都摊到一次 launch 上。
注意 foreach 有个隐藏约束:所有张量必须在同一 device、同一 dtype。Optimizer 的 multi_tensor 实现在跑 _foreach_* 之前会按 (device, dtype) 分组,每组单独 foreach。这种分组细节让 mixed-precision 训练(如 bfloat16 weight + fp32 master copy)能正确工作。
10.6.2 fused 模式:把整个 step 压成一个 CUDA kernel
fused 模式更激进 —— 把整个 Adam step(几十个 foreach 操作)压成 一个 CUDA kernel:
# fused 的伪代码 (实际是 C++/CUDA 实现)
torch._fused_adam_(
params, grads,
exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
state_steps,
lr=lr, beta1=beta1, beta2=beta2, weight_decay=weight_decay, eps=eps,
)
实测 fused Adam 比 foreach Adam 又快 2-3 倍,主要节省的是:
- 多次 kernel launch 的 overhead(1 次 vs 几十次)
- 中间张量的内存往返(fused 内部都在 register / shared memory 算)
启用方式:optimizer = torch.optim.Adam(params, fused=True)。但 fused 有限制:
- 只支持 CUDA / XPU 等加速器(CPU 没 fused 实现)
- 不支持某些边角情况(如 sparse grad)
- 数值上与 foreach 略有差异(因 reduction 顺序不同)
生产代码里默认开 fused 是大模型训练的常规优化。一行代码改动能省 3-5% 训练时间。
10.6.3 capturable:兼容 CUDA Graph
capturable=True 是另一档优化。它解决一个特殊问题:CUDA Graph 要求 step 函数完全在 GPU 上完成,不依赖 Python 端任何 if/else。
朴素 Adam 实现里有这样的代码:
step += 1 # Python int 加法
bias_correction1 = 1 - beta1 ** step # Python 计算 bias correction
这种 “Python 端控制流” 让 step 函数无法被 CUDA Graph 捕获。capturable=True 把 step 改成 GPU 张量 + 用 GPU 算子算 bias correction:
state_steps = torch.tensor(0, device=p.device) # GPU 张量
# step += 1 改成 _foreach_add_(state_steps, 1)
# bias_correction = torch.pow(beta1, step) 在 GPU 上算
代价是少量 GPU 算子开销。收益是 整个 step 能进 CUDA Graph,配合 torch.compile 可以让训练 hot loop 几乎零 CPU 开销。
第 15 章 torch.compile 端到端章会展开 capturable 在编译路径里的作用。
10.6.4.54 一个看起来微小但重要的细节:_default_to_fused_or_foreach
torch/optim/_optim_utils.py 里有个 _default_to_fused_or_foreach 函数,每个 optimizer 在 __init__ 里调它决定默认模式:
fused, foreach = _default_to_fused_or_foreach(params, differentiable, ...)
逻辑大致:
- 如果所有 param 都在 CUDA / XPU 上 + 有 fused 实现 → 默认 fused=True
- 否则有 foreach 实现 → 默认 foreach=True
- 否则 fall back single_tensor
这种”自动选最快模式”让用户不用关心三档模式 —— Adam(params, lr=1e-3) 默认就是最快模式。如果你想强制走 single_tensor(如调试时),传 foreach=False, fused=False 即可。
10.6.4.55 _foreach 算子的内部实现
打开 aten/src/ATen/native/cuda/ForeachBinaryOpList.cu 这类文件,会看到 _foreach_add_ 等算子的 CUDA 实现:
// 极简版 _foreach_add_ 内部 (简化)
template <typename scalar_t>
__global__ void multi_tensor_apply_kernel(
void** params, void** grads, int* sizes, scalar_t alpha) {
int t = blockIdx.x; // 第几个张量
scalar_t* p = (scalar_t*) params[t];
scalar_t* g = (scalar_t*) grads[t];
int n = sizes[t];
int i = threadIdx.x;
while (i < n) {
p[i] += alpha * g[i];
i += blockDim.x;
}
}
每个 CUDA block 处理一个张量,多个 block 并发处理多个张量。这种”一次 launch、多张量并发”的 kernel 设计让 foreach 在 GPU 上几乎零开销 —— 只要张量数量不超过 gridDim.x 的限制(CUDA 是 2^31 -1,远超实际需要),整套操作就是一次 launch。
理解 _foreach 算子的实现,你能预判其上限:当张量极小(每个张量元素数 < 32),launch 开销开始占主导,这种 micro-batching 场景反而比 single tensor 慢。但生产训练里参数张量都至少几千元素以上,foreach 永远赢。
10.6.4.6 多 device 并发的微妙性
multi_tensor 模式有个隐藏前提:所有张量都在同一 device 才能 fused。如果你的模型一部分在 cuda:0、一部分在 cuda:1(如 model parallel 场景),foreach 内部要按 device 分组分别跑:
每个 device 跑一次 foreach。这种”按 device 分组 + 各自 foreach”让 multi-device 模型也能享受 fused 加速。但每多一个 device 就多一次 foreach launch overhead,所以model parallel 时 optimizer step 比单 device 慢一点(通常 5-10%)。这个开销在 FSDP 训练里很常见,第 18 章会展开。
10.6.4.5 性能数字:三档模式实测
具体一些数字(H100,Llama-7B 训练,per step):
| 模式 | step 耗时 | dispatcher 调用次数 | kernel launch 次数 |
|---|---|---|---|
| single_tensor | ~80 ms | ~10000 | ~10000 |
| foreach (multi_tensor) | ~25 ms | ~30 | ~30 |
| fused | ~8 ms | ~3 | ~3 |
| fused + capturable + CUDA Graph | ~2 ms | 0 (graph replay) | 0 (一次 graph launch) |
差距 40 倍。在长时间训练里这能省下几十小时。
理解这个数字让你判断”我的训练吞吐瓶颈在哪” —— 如果你的训练每 step 真实计算 100ms 但 optimizer 占 30ms,升级到 fused 立刻让训练加速 25%。这种”低风险高回报”的优化在生产代码里极易获得,前提是你知道这套机制存在。
10.6.5 三档模式的选择决策树
实战里怎么选?决策树:
flowchart TD
Start[选 optimizer 模式]
Start --> CPU{CPU 还是 GPU?}
CPU -->|CPU only| Single[single_tensor]
CPU -->|GPU| HasFused{支持 fused?<br/>Adam/AdamW/SGD?}
HasFused -->|是| Fused[fused=True<br/>最快]
HasFused -->|否, 如自定义 optimizer| Foreach[foreach 模式]
Foreach --> CG{要 CUDA Graph?}
CG -->|是| Capturable[capturable=True]
CG -->|否| ForeachOK[默认 foreach 已够]
style Fused fill:#dcfce7,stroke:#22c55e
简化版:
- 70B 模型训练 → AdamW(fused=True, capturable=True) 默认
- 普通模型训练 → AdamW(fused=True) 即可,不需要 capturable
- CPU 训练 → 默认 single_tensor (没有其他选择)
- 自定义优化器(如 LAMB)→ 实现 foreach 版本,享受加速
10.6.5 一个真实场景:torch.compile + capturable optimizer
torch.compile 在 v2.0+ 默认能编译整个 forward + backward,但 optimizer.step() 默认不被编译(因为它有 Python 控制流)。v2.4+ 引入了 compile(optimizer.step) 接口,结合 capturable=True 让整个 step 也变成编译产物:
optimizer = torch.optim.AdamW(params, fused=True, capturable=True)
@torch.compile
def opt_step():
optimizer.step()
optimizer.zero_grad()
# 在训练循环里
for batch in loader:
loss = model(batch).sum()
loss.backward()
opt_step()
这种”编译 optimizer step” 是大模型训练 squeezing performance 的最后一刀。配合 CUDA Graph,整个 forward + backward + step 可以在 GPU 上完全异步执行,CPU 端几乎没有任何 dispatcher 开销。
第 15 章 torch.compile 端到端章会展开这条编译路径。理解 capturable 与 fused 的 motivation,你就能解释为什么大模型训练越来越多采用 “fused + capturable + compiled” 这套组合 —— 不是单纯追求性能,而是要让 optimizer.step 与编译器栈兼容。
10.6.5.5 fused / capturable 与 dynamic shape
值得注意的边界情况:fused / capturable 模式与 动态参数集(运行时增减参数)不兼容。如果你的训练里有”某些层在某些 batch 跳过”或者”模型大小动态变化”的逻辑,foreach 路径会因为参数数量变化而 invalidate 内部缓存。
PyTorch 在这种动态场景下会自动 fall back 到 single_tensor 模式,但用户不一定会收到警告。生产代码强烈建议固定模型架构,少用动态层 —— 不仅是 optimizer 性能,也是 torch.compile / FSDP 等高级特性的基础假设。
10.7 zero_grad:清梯度的两种语义
optimizer.py:1028:
def zero_grad(self, set_to_none: bool = True) -> None:
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
p.grad.zero_() # inplace 写零
set_to_none=True 是默认行为(v1.7+):直接把 p.grad 设 None。下次 backward 时 AccumulateGrad 看到 grad 是 None,会创建新张量;如果不是 None,会 inplace add。
set_to_none=False 是老行为:跑一次 zero kernel inplace 把 grad 清零。
差异:
| 维度 | set_to_none=True | set_to_none=False |
|---|---|---|
| zero kernel 调用 | 0 | N (每个 param 一次) |
| 显存 | 第一次 backward 后才分配 grad | grad 一直分配 |
| 性能 | 更快 | 更慢 |
| 与某些第三方代码兼容 | 可能不兼容(期望 grad 永远不是 None) | 兼容 |
70B 模型训练每 step 节省约 5-10ms,长跑训练加起来是几十小时。第 7 章 §7.6.0 我们已经讨论过,这里不重复。
注意一个容易被忽略的语义细节:set_to_none=True 让下一次 backward 第一次给 param.grad 赋值时,PyTorch 会新分配一个张量而不是 inplace 加。也就是说:
- set_to_none=True:每个 batch 的 grad 是新对象
- set_to_none=False:所有 batch 共享同一个 grad 对象(inplace clear)
如果你的代码持有 g = param.grad 引用、期望它在下一个 batch 仍指向有效梯度,set_to_none=True 会让你拿到旧值(指针没变但内容已经被新 backward 替换)。这种隐性差异在用 hook 检查梯度的代码里偶尔出现。生产代码里不要 cache param.grad 引用跨 batch 用。
10.8 state_dict 与 ckpt
optimizer 也有 state_dict() / load_state_dict(),跟 nn.Module 类似但简单。一个 Adam optimizer 的 state_dict 大约:
{
'state': {
0: {'step': 100, 'exp_avg': tensor(...), 'exp_avg_sq': tensor(...)},
1: {'step': 100, 'exp_avg': tensor(...), 'exp_avg_sq': tensor(...)},
...
},
'param_groups': [
{'lr': 0.001, 'beta1': 0.9, ..., 'params': [0, 1, 2, ...]}
],
}
注意 state 的 key 不是 Tensor 对象(不能序列化),而是整数 ID。param_groups 里的 params 列表也是整数 ID。这套 ID 在 save 时由 optimizer 内部 _param_groups_to_save 维护一致。
具体的 ID 编号规则:optimizer 在 save 时按 param_groups 的顺序、每个 group 内 params 的顺序,给每个 param 分配从 0 开始的递增整数。所以 ID 0 是第一个 group 的第一个 param、ID 1 是第二个,依此类推。这种简单递增编号让 ckpt 在 resume 时只要参数顺序不变就能正确匹配。
如果你想跨”不同模型架构”加载 optimizer state(如 fine-tuning 时模型加了新层),就得用名字匹配 + 手动重建 state。HuggingFace 的 save_pretrained / from_pretrained 内部就有这套基于参数名的 state 迁移逻辑,但 PyTorch 内置 optimizer 没有这层抽象。
加载时 PyTorch 重新匹配新 optimizer 的参数与 ckpt 里的 ID:
new_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
new_optimizer.load_state_dict(saved_state_dict)
# 每个 param 的 state 自动恢复
这个 ID 匹配机制有个微妙坑:新模型的参数顺序必须与保存时一致。如果你重新构造 model 时改了 layer 顺序,ID 对不上,加载会出错。HuggingFace 等框架用 named_parameters() 的名字匹配规避这个坑,但 PyTorch 内置 optimizer 仍用 ID。
10.8.1 optimizer ckpt 的体积
很多新手发现 torch.save(optimizer.state_dict()) 文件比 torch.save(model.state_dict()) 还大。这不是 bug —— Adam 每个 param 的 state(exp_avg + exp_avg_sq)大小与 param 一致,两个 state 张量加起来是 param 体积的 2 倍。所以 Adam optimizer ckpt ≈ model ckpt × 2。
这就是为什么生产代码里通常只保存 model state_dict 不保存 optimizer,除非要 resume training:
# 仅保存模型 (用于推理 / 部署 / 模型分享)
torch.save(model.state_dict(), 'model.pt')
# 保存完整训练状态 (用于 resume)
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
}, 'checkpoint.pt')
Llama-70B 训练的 ckpt 文件约 1 TB(model 280 GB + optimizer 560 GB + 其他)。这是为什么大模型训练的 ckpt 存储是个工程挑战。
10.9 Learning Rate Scheduler:在 group 上调 lr
torch.optim.lr_scheduler 提供一组 scheduler:StepLR、ExponentialLR、CosineAnnealingLR、OneCycleLR、ReduceLROnPlateau 等。它们的实现都很简单 —— 改 param_group[‘lr’] 字段:
class StepLR:
def __init__(self, optimizer, step_size, gamma):
self.optimizer = optimizer
self.step_size = step_size
self.gamma = gamma
self.last_epoch = 0
def step(self):
self.last_epoch += 1
if self.last_epoch % self.step_size == 0:
for group in self.optimizer.param_groups:
group['lr'] *= self.gamma
这套设计让 scheduler 与 optimizer 解耦 —— scheduler 不需要知道 optimizer 内部,只需要修改 param_groups。这种”在共享 dict 上协作”是 Python 库间互操作的优雅模式。
注意调用顺序:
optimizer.step() # 先用当前 lr 更新参数
scheduler.step() # 再调整 lr 给下一次用
老 PyTorch 版本(v1.0 前)顺序相反,会让最后一个 epoch 的 lr 不被使用。如今的 v2.x 已经统一是这个顺序。
10.9.1 chained / sequential schedulers
实际训练经常组合多个 scheduler(如 warmup + cosine decay):
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=1000)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100000)
scheduler = torch.optim.lr_scheduler.ChainedScheduler([warmup, cosine])
ChainedScheduler 让多个 scheduler 串联,每个 scheduler 各自在 param_groups 上累积修改 lr。这套组合让”先线性 warmup 1000 步、再余弦退火”的复杂 schedule 用一两行代码搞定。
SequentialLR 是另一种组合:根据当前 step 数选不同 scheduler。比如前 10000 步 warmup,后 90000 步 cosine:
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup, cosine],
milestones=[10000])
理解这两种组合,你可以构造任意复杂的 lr schedule,不需要自己写 scheduler 子类。
10.9.5 LR Scheduler 与 optimizer 的隐式契约
scheduler 和 optimizer 之间有一个约定俗成的契约:scheduler 假设 optimizer 的 param_groups 里有 ‘lr’ 字段。所有 PyTorch 内置 scheduler 都直接读写 group['lr']。
如果你写自定义 optimizer 没有 ‘lr’ 字段(比如某个论文里的 “scale-invariant” optimizer),用 PyTorch scheduler 就要崩。修法:
class MyOptimizer(Optimizer):
def __init__(self, params, my_step_size=0.1):
# 把 my_step_size 别名为 'lr', 让 scheduler 能改它
defaults = dict(lr=my_step_size)
super().__init__(params, defaults)
这种”用约定俗成的字段名”是 Python 库间互操作的常见模式 —— 没有显式接口,但大家都知道某些字段名是 contract。
10.9.6 ReduceLROnPlateau:另一种 scheduler
ReduceLROnPlateau 与其他 scheduler 不同 —— 它的 step 接受一个 metric 参数:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10)
for epoch in range(100):
train(...)
val_loss = validate(...)
scheduler.step(val_loss) # 注意传入 val_loss
它会监控 val_loss,连续 10 个 epoch 没下降时把 lr 减半。这种”基于性能动态调整”在某些训练场景比固定 schedule 更好用。
接口上 ReduceLROnPlateau.step(metric) 与其他 scheduler.step() 不一致,是 PyTorch 历史遗产之一。如果你写 wrapper 库,要兼容这两种 step 签名。第 21 章 profiler 章会展开训练监控流水线的设计。
10.9.7 一些 optimizer 与训练流水线的关键交互
1. DDP 与 optimizer step:DDP 在 backward 后做 AllReduce 同步梯度,完成后才能调 optimizer.step()。如果你 step 在 AllReduce 完成前调,每个 rank 的更新会基于本地梯度而非全局平均梯度,训练完全失效。DDP 通过 with model.no_sync(): 提供”梯度累积”模式让用户控制何时同步,但默认每 backward 都同步。第 17 章会展开。
2. 混合精度训练:fp16 训练时 grad 容易 underflow 成 0,导致 optimizer.step() 实际不更新参数。GradScaler.step(optimizer) 内部检查 grad 是否 inf/nan,如果有就跳过这次更新(unscale 失败)。这套机制让 fp16 训练可以”动态调整 loss scale”维持训练稳定。第 20 章 量化与混合精度会展开。
3. ZeRO/FSDP 与 optimizer 分片:FSDP 把参数分片到多 rank,每个 rank 只持有自己的 1/N 参数与对应 1/N optimizer state。step 时每个 rank 只更新自己那份,再通过 AllGather 拼回。这种”分片 optimizer state”让 70B 训练显存压力降到 1/N。第 18 章会展开。
理解 optimizer 与这三个系统(DDP / GradScaler / FSDP)的协作,你就理解了大模型训练的工程全景。optimizer 看起来是单点,实际是整个训练流水线的串联点。
10.10 跨书关联
- 《vLLM 内核探秘》第 6 章 模型加载:vLLM 推理时需要加载训练好的 weights,但不需要 optimizer state。这是为什么推理 ckpt 通常只是 model state_dict,不含 optimizer state
- 《Tokio 异步运行时》第 X 章 共享状态:optimizer 与 scheduler 通过共享 param_groups dict 协作,与 Tokio Task 共享配置信息的模式相通
- 《MCP 协议剖析》第 X 章 工具状态:optimizer 的 state per-param 与 MCP server 的 per-conversation state 思路类似 —— 都是”按某个 key 维护独立状态”
- 《vLLM 内核探秘》第 7 章 模型加载:vLLM 推理时只加载 model state_dict,不需要 optimizer state,文件体积是训练 ckpt 的 1/3
- 《Serde 元编程》派生:很多自定义 optimizer 实现(如 LAMB、Sophia)需要序列化新增字段,这与 Serde 派生的”按字段定义序列化”思想一致
10.11 几条工程经验
实战 optimizer 相关:
1. 大模型训练默认 fused=True:v2.x 几乎所有 GPU 训练都该开。一行参数省 3-5% 训练时间
2. foreach 与 fused 互斥:fused=True 启用时 foreach 自动失效。两者不要同时设
3. param_groups 用具名 dict 更可读:{'name': 'backbone', 'params': [...], 'lr': 1e-5} 加 name 字段,调试时打印能立刻看出哪个 group
4. resume training 时:load_state_dict(opt_state) 后 lr 会被恢复成 ckpt 时的值,不是你新设的 default。如果想重置 lr,要在 load 后手动改 for g in opt.param_groups: g['lr'] = new_lr
5. 学习率与 batch size 同步缩放:经典的 linear scaling rule(lr 与 batch size 成正比)大多数 SGD-类训练有效,但对 Adam-类 lr 不需要严格 linear scaling,根据论文 Don’t Decay the Learning Rate, Increase the Batch Size 仍要慎用
6. AdamW 的 weight_decay 与 grad clipping 顺序:先 clip 再 step,weight_decay 在 step 内自动应用。如果手动写了 param -= weight_decay * param * lr,要小心不要重复 decay
7. CUDA Graph 训练务必 capturable=True:忘了这条会让 graph capture 失败或行为不一致
8. mixed precision 训练用 GradScaler 配合:fp16 训练时梯度容易下溢成 0,需要 torch.cuda.amp.GradScaler 把 loss 放大、反向后再缩小。这个 wrapper 内部也跟 optimizer 协作(在 step 前检查梯度是否 inf/nan)
9. 给 optimizer 加 hook 比 wrap 优于继承:要给 optimizer 加自定义行为(如 grad 监控、lr 自适应),用 register_step_pre_hook / register_step_post_hook 而非继承重写。这种 hook 模式让你的扩展不会破坏 internal state 的演进
10. 不要在训练循环中 optimizer = MyOpt(...):每次 __init__ 都重新分配 state,几百万参数张量瞬间分配会让 caching allocator 颠簸。optimizer 在训练前创建一次即可
10.12 自定义 optimizer 实战
把整章串起来,写一个 toy SGD-with-Nesterov optimizer 演示完整套路:
class MySGD(torch.optim.Optimizer):
def __init__(self, params, lr=0.01, momentum=0.9):
defaults = dict(lr=lr, momentum=momentum)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(p)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(p.grad)
p.add_(buf, alpha=-lr)
return loss
20 行就实现了带 momentum 的 SGD。所有 PyTorch 内置 optimizer 都这个套路:
__init__定义 defaults,调super().__init__step()用@torch.no_grad()装饰- 遍历 param_groups → 遍历 params → lazy 初始化 state → 计算更新
接 closure 参数是为了支持 LBFGS 等需要”算 loss 多次”的优化器。普通 optimizer 忽略它即可。
理解这套模板,你可以快速实现各种学术论文里的新 optimizer(Sophia、Adan、Lion 等都是这个套路实现)。
如果你想让自定义 optimizer 也享受 foreach 加速,要再实现一个 _multi_tensor_my_sgd 函数,把所有 inplace 操作换成 torch._foreach_* 版本,然后在 step 里根据 foreach flag 选择路径。这是 PyTorch 内置 optimizer 的标准实现方式,在 torch/optim/sgd.py / adam.py 等文件里能看到完整模板。
写一个生产级 optimizer 大约 500 行代码(含 single_tensor + multi_tensor + 各种边界情况处理),对学术 prototype 来说有点重,所以很多论文的 reference implementation 只写 single_tensor 版本。如果你的论文 optimizer 被广泛采用,社区会有人贡献 multi_tensor 与 fused 版本。这是 PyTorch 优化器生态的演进路径。
10.12.5 一些常被忽略的 PyTorch optimizer 实现细节
把整章看完,几个不常被提及但在源码里反复出现的工程细节:
1. _to_scalar(x) 的处理:很多 optimizer 内部把 lr / beta 等 float 参数包装成 Tensor(为了支持 capturable 模式)。_to_scalar 在内部判断 “这个参数到底是 float 还是 0-d Tensor”,统一成可直接用的形式。这是 PyTorch 在”标量 vs 张量统一接口”上的工程胶水
2. differentiable=False 的 fast path:每个 step 函数开头检查 if differentiable: ... else: ...。non-differentiable 路径走纯 inplace + no_grad;differentiable 路径走带 autograd 的版本。后者慢几倍但允许 meta-learning
3. grad_scale / found_inf 参数:fused optimizer 内置接收 GradScaler 的 _grad_scale 与 _found_inf 张量,让 unscale 与 inf check 也在 fused kernel 内完成。这套接口让 mixed precision + fused 几乎零额外开销
4. complex 张量支持:has_complex 标志让 optimizer 在复数参数(罕见但支持)场景走专用代码路径。绝大多数代码走默认 real-only 路径
5. JIT script 的兼容性 hack:源码里大量 if torch.jit.is_scripting(): 分支处理 TorchScript 限制(如不能传 Tensor 当 lr)。这些是历史遗产,今天 TorchScript 已经被 torch.compile 替代但兼容代码仍在
理解这些细节,你看 _single_tensor_adam / _multi_tensor_adam 的几百行代码就不会被那么多 if 分支搞晕 —— 它们都对应”为某个特殊场景留的口子”。
10.13 几条 optimizer 设计的”通用启示”
把 optimizer 思想抽象到任何”状态化算子”系统:
第一:defaults + groups 双层配置:让用户既能”一刀切”(defaults)又能”按组定制”(param_groups)。这套模式在数据库连接池、HTTP 客户端、日志器等场景同样适用
第二:lazy 初始化 state:开始不知道用户怎么用,等第一次访问时按需创建。defaultdict 是 Python 实现这种 pattern 的最佳工具
第三:ID 而非对象作为持久化 key:state_dict 的 key 是整数 ID 而非 Tensor 对象,避免序列化复杂对象。这种”用整数代理对象”的模式在 ORM、缓存系统都常见
第四:性能 / 通用性多档:single_tensor → foreach → fused 三档对应”通用慢”到”专门快”。允许用户按场景选择,而不是”一种实现包打天下”
第五:hooks 让框架可扩展:optimizer step 前后开放 hook,让 ZeRO、混合精度等高级特性能 hook 进来。这是与 nn.Module hook 思想一致的扩展点设计
第六:inplace 数学操作 + no_grad 上下文是底层数值更新的标配:避免分配中间张量、避免污染反向图,是任何”状态化更新”系统的最佳模式。游戏引擎的物理更新、数据库的索引更新都用这套思想
第七:显存不只是模型 — optimizer state 占同样级:开始一个新训练任务时除了算 model 显存,必须算上 optimizer state(Adam 是 2x 参数量)+ gradients(1x)+ activations(depends on batch size)。这个”四件套”显存账是大模型训练的基础数学
第八:closure 接口是给少数高级 optimizer 留的:99% 优化器不用 closure,但接口里始终预留它。这种”为少数边缘场景留扩展点”是稳定 API 设计的常见折衷
10.14 一段反思:optimizer 的”小而美”哲学
最后回顾一下 PyTorch optimizer 设计哲学。它的核心是让”如何更新参数”成为一个可替换的小模块:
- 接口窄:
step()+zero_grad()+state_dict()几乎是全部公开 API - 状态简单:param_groups (list of dict) + state (defaultdict)
- 与其他系统解耦:optimizer 不知道 model 是什么,只看到 params。model 不知道 optimizer 在做什么,只期待 step 后参数被更新
这套”窄接口 + 简单状态”让 optimizer 是 PyTorch 整个训练栈中最稳定的部分。Module、autograd、dispatcher 都重构过几次,optimizer 的接口从 v0.x 到 v2.11 几乎没变 —— 用户写的 optimizer 代码 6 年前就能跑、6 年后还能跑。
这种”内部反复演进、接口长期稳定”的 trade-off 是 PyTorch 工程文化的体现。理解这条哲学,你看任何”长寿命接口设计”问题时都能更清楚地选择 —— 哪些放出去、哪些藏起来、哪些预留扩展点。
下一章拆 DataLoader:多进程数据流水线、pin_memory、collate_fn 怎么协作。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。