第22章 自定义算子与 C++ 扩展
“Writing a custom op in 2024 is
@torch.library.custom_op. Forget everything you knew about Variable / autograd.Function / TORCH_LIBRARY in 2018.”—— PyTorch dev podcast,custom ops 现代教程
本章要点
- v2.4+ 推荐
torch.library.custom_op装饰器:一个 API 注册算子 + 自动接入 dispatcher / autograd / torch.compile register_fake给 FakeTensor 路径:torch.compile / FSDP 在 trace 时需要”shape 推导而不真算”register_autograd加反向规则:用类似autograd.Function.backward的写法- C++ 扩展走
TORCH_LIBRARY+ pybind11:性能敏感时手写 C++ / CUDA kernel - 完整生态接入:自定义算子能与 dispatcher / autograd / FX / Inductor / DDP / FSDP 全部协作
- 替代老 API:
autograd.Function还能用,但 torch.compile 兼容性差,新代码用 custom_op
22.1 何时需要自定义算子
PyTorch 内置 3000+ 算子,但仍有缺口:
- 新硬件指令:自家芯片有特殊指令(如 NPU 的 fused attention),想用就要写 kernel 包成 PyTorch op
- 新算子:论文里某个新激活函数、特殊归一化,PyTorch 还没收
- 性能极致:某段热路径手写 CUDA 比组合 ATen 算子快 30%+
- 第三方库集成:FlashAttention、xformers、Triton kernel 想暴露成 torch op
如何让自家 kernel 像内置算子一样工作 —— autograd 自动反向、torch.compile 能编译、profiler 能看到、FSDP 能正确处理 —— 是本章主题。
22.2 现代标配:torch.library.custom_op
v2.4+ 推荐写法:
import torch
@torch.library.custom_op("mylib::mymul", mutates_args=())
def my_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y
@my_mul.register_fake
def _(x, y):
# FakeTensor 路径: 只返回正确 shape/dtype 的 empty tensor
return torch.empty_like(x)
def my_mul_backward(ctx, grad):
x, y = ctx.saved_tensors
return grad * y, grad * x
def my_mul_setup_context(ctx, inputs, output):
x, y = inputs
ctx.save_for_backward(x, y)
my_mul.register_autograd(my_mul_backward, setup_context=my_mul_setup_context)
——这一段做了三件事:
@custom_op把my_mul注册到 dispatcher 的mylib::mymulschemaregister_fake给 FakeTensor 路径提供 shape 推导register_autograd给反向规则
之后 my_mul(x, y) 就像内置算子一样工作。
22.2.1 schema 字符串
"mylib::mymul" 是命名空间 + 算子名。mutates_args=() 表示”不修改任何输入”(如果修改了 x,要写 mutates_args=("x",))。完整 schema 由 PyTorch 从函数 type hint 自动推导:
mylib::mymul(Tensor x, Tensor y) -> Tensor
如果你的 op 改了输入张量,schema 用 Tensor(a!) x 标记 alias。这套语法第 6 章 §6.2 讲过。
22.2.2 register_fake 的角色
FakeTensor 在第 5 章 §5.7 与第 13 章 AOTAutograd 出现过。几乎所有现代 PyTorch 高级特性都依赖 fake 路径:
torch.compile用它做 graph capture- FSDP 用它做 lazy init / shape 推导
- export 用它做
torch.export(model) - meta tensor (无数据张量) 也走 fake
所以没注册 fake 函数的自定义算子在 torch.compile 下会 graph break。register_fake 不是可选 —— 现代代码必须有。
fake 函数只允许调用 shape 操作(empty_like / zeros / view / 算 shape),不能做实际数值计算。第 6 章 §6.4.2.5 警告过这条。
22.2.3 register_autograd:反向规则
register_autograd 接受两个函数:backward 和 setup_context。语义与 autograd.Function 类似,但分开成两步:
setup_context(ctx, inputs, output):保存反向需要的张量(在 forward 完成后调用)backward(ctx, *grads):算反向
PyTorch 内部把这套包成 autograd Node,与第 7 章讲的 XxxBackward0 完全等价。自定义算子的反向图与内置算子的反向图无差别,能被 autograd Engine(第 8 章)调度、被 AOTAutograd(第 13 章)capture。
22.3 Triton kernel 作为 custom_op 的实现
如果你想用 Triton 写 kernel(性能比纯 Python 高 10x+),可以让 custom_op 内部调 Triton:
import triton
import triton.language as tl
@triton.jit
def my_kernel(x_ptr, y_ptr, out_ptr, n: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * 128 + tl.arange(0, 128)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask)
y = tl.load(y_ptr + offsets, mask)
tl.store(out_ptr + offsets, x * y, mask)
@torch.library.custom_op("mylib::triton_mul", mutates_args=())
def triton_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.numel()
grid = lambda meta: (triton.cdiv(n, 128),)
my_kernel[grid](x, y, out, n)
return out
注意:torch.compile 看到 triton_mul 时会 inline 调用进生成的 fused kernel,不会再拆开它。这种”自定义 Triton kernel + custom_op”是 FlashAttention 等 SOTA 算子的标准接入方式。
22.4 C++ / CUDA 扩展
性能极敏感时手写 C++(含 CUDA)。流程:
- 写
.cpp文件,用TORCH_LIBRARY注册算子 - 写
setup.py用torch.utils.cpp_extension.CUDAExtension python setup.py install编译成 .so- Python 端
import即可
C++ 端:
// my_ops.cpp
#include <torch/extension.h>
#include <torch/library.h>
at::Tensor my_mul_cpu(const at::Tensor& x, const at::Tensor& y) {
return x * y;
}
at::Tensor my_mul_cuda(const at::Tensor& x, const at::Tensor& y) {
// 实际 CUDA kernel launch
auto out = at::empty_like(x);
my_mul_cuda_kernel<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(),
out.data_ptr<float>(), x.numel());
return out;
}
TORCH_LIBRARY(mylib, m) {
m.def("mymul(Tensor x, Tensor y) -> Tensor");
}
TORCH_LIBRARY_IMPL(mylib, CPU, m) {
m.impl("mymul", my_mul_cpu);
}
TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
m.impl("mymul", my_mul_cuda);
}
setup.py:
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
setup(
name='mylib',
ext_modules=[CUDAExtension('mylib', ['my_ops.cpp', 'my_kernel.cu'])],
cmdclass={'build_ext': BuildExtension},
)
加载后 Python 端:
import torch.ops.mylib
out = torch.ops.mylib.mymul(x, y)
C++ 扩展是国内 AI 芯片厂商接 PyTorch 的标准路径 —— 在 cpp 端用自家 SDK 写 kernel,注册到 dispatcher 的 PrivateUse1 key。
22.5 老 API:autograd.Function
老的 v1.x 写法仍然支持:
class MyMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x * y
@staticmethod
def backward(ctx, grad):
x, y = ctx.saved_tensors
return grad * y, grad * x
out = MyMul.apply(x, y)
简单直接,调试方便。但缺点:
- torch.compile 看到
apply通常 graph break:Inductor 不知道怎么编译 Python autograd.Function - 没有 schema:没法走 dispatcher,对 FSDP / FX 不友好
- 没有 fake 实现:torch.compile / export 走不通
如果你只是研究阶段快速写个 op、不上 compile:autograd.Function 够用。如果生产代码 + 想 torch.compile 加速:必须用 torch.library.custom_op。
第 7 章 §7.8.1 我们对比过两套接口,结论一致。
22.6 完整集成检查清单
写一个生产级自定义算子,要做的事:
flowchart TB
Op[custom_op 装饰器]
Op --> Fake[register_fake<br/>shape 推导]
Op --> Auto[register_autograd<br/>反向规则]
Op --> Cpu[CPU kernel<br/>register_kernel device='cpu']
Op --> Cuda[CUDA kernel<br/>register_kernel device='cuda']
Fake --> Compile[✓ torch.compile 兼容]
Cpu --> Eager[✓ eager 路径]
Cuda --> Eager
Auto --> Eng[✓ autograd Engine]
Style0[注册到 dispatcher<br/>自动获得]
Op --> Style0
Style0 --> Disp[dispatch 调度]
Style0 --> Prof[profiler 自动看到]
Style0 --> Fsdp[FSDP / DDP 兼容]
style Op fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
清单:
- ✅
@custom_op装饰器声明 - ✅
register_fake给每个 op - ✅
register_autograd如果可微 - ✅
register_kernel(..., 'cpu')+register_kernel(..., 'cuda')各自实现 - ✅ 写单元测试用
torch.library.opcheck自动验证(PyTorch 提供的算子合规性检查)
22.6.5 opcheck:自定义算子的合规性测试矩阵
torch/library.py:1632 的 torch.library.opcheck 是自定义算子的”质保检查”。它跑 5 项测试,确认 op 与 PyTorch 各子系统兼容:
import torch
from torch.library import opcheck
opcheck(my_mul, args=(x, y), test_utils=("test_schema", "test_autograd_registration",
"test_faketensor", "test_aot_dispatch_static",
"test_aot_dispatch_dynamic"))
5 项测试的具体职责:
| 测试 | 检查什么 |
|---|---|
test_schema | schema 字符串与实际实现的输入输出 dtype / shape 是否一致 |
test_autograd_registration | 注册了 autograd 后反向规则是否数值正确(用 gradcheck 比对数值梯度) |
test_faketensor | fake 函数返回的 shape / dtype 是否与真实 kernel 输出一致 |
test_aot_dispatch_static | 在 AOTAutograd(静态 shape 模式)下能否正确 trace 与编译 |
test_aot_dispatch_dynamic | 同上但 dynamic shape 模式(更严格,要求 fake 函数能处理 SymInt) |
生产级自定义算子必须 opcheck 通过。社区贡献到 PyTorch 主仓的 op PR 都被要求附 opcheck 测试。这套自动化检查避免了”自定义 op 在 eager 跑得对、torch.compile 编译错”等隐蔽 bug。
opcheck 内部用 torch._library.fake_class_registry 验证 fake 实现、用 torch.autograd.gradcheck 验证反向、用 torch._dynamo 跑 trace 验证 compile 路径。一次调用覆盖整个生态的兼容性。
22.6.6 Library 低级 API
@custom_op 是高级糖,底层是 torch.library.Library(library.py:68)。它提供更细粒度的算子注册:
from torch.library import Library
# 创建一个 library (类似 C++ 端的 TORCH_LIBRARY)
lib = Library("mylib", "DEF")
# 注册 schema (没有实现, 等下注册)
lib.define("mymul(Tensor x, Tensor y) -> Tensor")
# 给特定 dispatch key 注册实现
lib.impl("mymul", lambda x, y: x * y, "CPU")
lib.impl("mymul", my_cuda_kernel, "CUDA")
lib.impl("mymul", my_meta_kernel, "Meta") # FakeTensor 也是 Meta key
第二个参数 "DEF" 是 library 的 kind:
DEF:定义新算子(创建 schema)IMPL:给已有算子加新 dispatch key 实现FRAGMENT:在已有 library 里追加新 op(可多次)
@custom_op 装饰器内部就是构造 Library 然后调 define / impl。直接用 Library 时你能精确控制每个 dispatch key 的实现 —— 适合需要”给 PrivateUse1 注册新 backend”等高级场景。
22.6.7 PrivateUse1:国产芯片接入完整路径
PyTorch 给厂商扩展自家硬件留了 3 个 dispatch key:PrivateUse1 / PrivateUse2 / PrivateUse3(第 3 章 §3.5)。完整接入流程:
# 1. 给 PrivateUse1 起个有意义的名字
torch.utils.rename_privateuse1_backend("npu")
# 之后用户可以写 tensor.to('npu') 而非 'privateuseone'
# 2. 给 PrivateUse1 注册所有 ATen 算子的实现
@torch.library.impl("aten::add.Tensor", "PrivateUse1")
def npu_add(self, other, alpha=1):
# 调你家硬件 SDK 的 add kernel
return _npu_runtime.add(self, other, alpha)
# ... 给几百个常用算子各注册一个 impl ...
# 3. 提供 generate_methods_for_privateuse1_backend 让 tensor.npu() 等方法可用
torch.utils.generate_methods_for_privateuse1_backend()
torch/utils/backend_registration.py:20 的 rename_privateuse1_backend 把 PrivateUse1 重命名 + :362 的 generator 自动给 Tensor 添加 .npu() / .is_npu / .npu() 等方法。这套 API 让国产芯片厂商可以做出完整 PyTorch 体验而不修改主仓代码。
实际工作量:给 PyTorch 全部 3000+ 算子各写一个 backend impl 是几十人月的工程,但**torchgen/gen_backend_stubs.py(第 6 章 §6.10.5)能从一份”目标算子列表 YAML”自动生成 stub 代码**,厂商只需要填实现细节 —— 工作量降到几百算子级。
torch_npu(华为)、torch_mlu(寒武纪)、torch_xpu 等都走这条路。开源在 GitHub 能看到完整模板。
22.6.8 allow_in_graph 与 disable:torch.compile 的两个逃生口
写自定义算子时常遇到 Dynamo 不会 trace 的代码(如调了第三方 C 扩展、动态行为太复杂)。PyTorch 提供两个装饰器作为逃生口:
@torch.compiler.allow_in_graph(torch/compiler/__init__.py:72):
@torch.compiler.allow_in_graph
def my_special_function(x, y):
# Dynamo 不 trace 这个函数体
# 把整个调用当作"一个不透明 op"加入 graph
return some_external_lib.do_magic(x, y)
效果:Dynamo 看到调用 my_special_function(x, y) 时,把它当作单个不透明算子放进 FX Graph(不展开内部)。Inductor 等后端会调用原始函数,跳过编译。
@torch._dynamo.disable:
@torch._dynamo.disable
def my_complex_logic(x):
# Dynamo 看到这个调用直接 graph break, 退回 eager
if x.sum() > 0:
return some_python_heavy_logic(x)
else:
return another_branch(x)
效果:Dynamo 在调用处触发 graph break,整段函数用 eager 跑,break 之后再开始新 trace。
两者关键区别:
| 装饰器 | Dynamo 行为 | 适合场景 |
|---|---|---|
allow_in_graph | 当作不透明 op 留在 graph 里 | 函数行为是确定的 tensor 计算,但 Dynamo trace 不动(如调了某 C 扩展) |
disable | 触发 graph break,退回 eager | 函数有复杂 Python 逻辑(动态控制流 / 大量 dict 操作 / print),不希望 Dynamo 浪费时间分析 |
实际工程里:
- 写自定义 Triton kernel + register_fake:用
custom_op(§22.2),不需要这两个装饰器 - 集成第三方 C 扩展(如 FlashAttention v1 的私有 wrapper):用
allow_in_graph把它当黑盒 - 训练循环里的 logging / metric reporting 函数:用
disable让 Dynamo 不要试图分析
torch/_dynamo/decorators.py 还提供更细的开关:disallow_in_graph(强制某 op 触发 graph break)、mark_static_address(声明 tensor 地址不会变,让 CUDA Graph 能复用)等。生产代码里写自定义算子的 escape hatch,理解这套装饰器家族能让你优雅处理”compile 不动”的边角情况。
22.6.9 inplace 与多输出算子的注册
@custom_op 默认假设 op 是”纯函数”(无副作用、单输出)。两种特殊形态需要额外配置:
inplace 算子(mutate input):
@torch.library.custom_op("mylib::add_inplace_", mutates_args=("x",))
def add_inplace_(x: torch.Tensor, y: torch.Tensor) -> None:
x.add_(y)
# 没有 return: schema 是 (Tensor(a!) x, Tensor y) -> ()
mutates_args=("x",) 让 schema 里 x 标 alias Tensor(a!)。functionalize(§13.4)看到这个标记后会重写代码:把 add_inplace_(x, y) 变成 x_new = x + y; x = x_new 这种纯函数版本。这是 v2.x 让 inplace op 与 compile 共存的关键。
不写 mutates_args 但实际 mutate 输入 → 隐蔽 bug:torch.compile 假设无副作用、生成的 kernel 不会复制 x,运行时 x 被修改但 graph 看不到 → 后续算子拿到错的 x。
多输出算子:
@torch.library.custom_op("mylib::topk_with_idx", mutates_args=())
def topk_with_idx(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
values, indices = torch.topk(x, k)
return values, indices
@topk_with_idx.register_fake
def _(x, k):
new_shape = list(x.shape)
new_shape[-1] = k
return torch.empty(new_shape, dtype=x.dtype), torch.empty(new_shape, dtype=torch.int64)
返回 Tuple[Tensor, ...] 时 schema 自动是 -> (Tensor, Tensor)。fake 函数也返回 tuple。
inplace + 多输出组合:
@torch.library.custom_op("mylib::layernorm_inplace", mutates_args=("x", "running_mean"))
def layernorm_inplace(
x: torch.Tensor,
running_mean: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
# 修改 x 与 running_mean, 返回新 tensor
...
复杂场景里这套语法要小心写。schema 错了 → AOTAutograd 会在 trace 时报”functionalize 失败”。opcheck 内置 functionalize 检查能在 commit 前发现这类问题(§22.6.5)。
22.6.10 register_kernel:每个 device 单独注册
@custom_op 的函数体是 op 的默认实现(CompositeImplicitAutograd key)。如果你想为不同 device 写专门 kernel,用 register_kernel:
@torch.library.custom_op("mylib::mymul", mutates_args=())
def mymul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# 默认实现 (eager 路径用)
return x * y
@mymul.register_kernel("cuda")
def _(x, y):
# CUDA 专用: 调 Triton kernel
out = torch.empty_like(x)
grid = (triton.cdiv(x.numel(), 128),)
my_triton_kernel[grid](x, y, out, x.numel())
return out
@mymul.register_kernel("cpu")
def _(x, y):
# CPU 专用: 调 OpenMP kernel
out = torch.empty_like(x)
my_cpp_extension.cpu_mul(x, y, out)
return out
@mymul.register_kernel("xpu")
def _(x, y):
# Intel XPU 专用
return x * y # 通用 fallback
dispatcher(§5.x)根据 input device 自动路由。这套机制让”一个 op 多 backend”不需要写 dispatch 逻辑、PyTorch 帮你做。
实战例子:FlashAttention 的 PyTorch 接入:
- 默认实现:调用
F.scaled_dot_product_attention(fallback) - CUDA:调自家 CUDA kernel(Hopper / Ampere 各一份)
- CPU:调 PyTorch 通用 attention(性能差但能跑)
理解 register_kernel 让你看到自定义算子的”多后端”不需要复杂代码 —— 装饰器 + dispatcher 自动协作。
22.6.11 JIT 加载 C++ 扩展:开发期免编译
§22.4 用 setup.py 编译 C++ 扩展,每次改完要重新 build。开发期更方便的方式是 torch.utils.cpp_extension.load:
import torch.utils.cpp_extension as cpp_ext
my_ops = cpp_ext.load(
name='my_ops',
sources=['my_ops.cpp', 'my_kernel.cu'],
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3', '-arch=sm_90'],
verbose=True,
)
# 直接用
out = my_ops.mymul(x, y)
load 内部:
- 把 sources 编译成 .so(首次几十秒)
- 缓存到
~/.cache/torch_extensions/ - 后续相同 sources 命中缓存(毫秒级)
- 改了 source 自动重编
适合开发场景:写 / 改 / 测的循环里不用每次跑 setup.py install。
进阶:load_inline 让你直接传 C++ source 字符串、不用文件:
my_ops = cpp_ext.load_inline(
name='inline_ops',
cpp_sources='''
torch::Tensor add_one(torch::Tensor x) {
return x + 1;
}
''',
functions=['add_one'],
)
适合写小 demo / unit test。生产代码仍用 setup.py + .so 文件(避免每次进程启动都编译)。
实战:研究迭代算法时,load_inline + Jupyter notebook 让你能像写 Python 一样快速迭代 C++ kernel。这套工程便利极大降低了”写 C++ 扩展”的心智门槛。
22.6.12 ABI 兼容性:跨 PyTorch 版本的痛点
C++ 扩展编译出的 .so 对 PyTorch 版本敏感。原因:
- libtorch C++ ABI 不冻结:PyTorch 团队在 v2.x 多次重构内部 API
- CUDA Toolkit 版本:编译用 12.4、运行时 12.5+ OK;但 12.4 → 11.8 不行
- Compiler ABI:gcc 7 编译的 .so 在 gcc 11 系统上可能报 undefined symbol
实战遇到的 ABI 错误:
ImportError: undefined symbol: _ZN3c104impl21py_handle_tdiFEPN10pybind11_4dictE
——pybind11 内部 symbol 在 PyTorch v2.4 与 v2.6 之间改了 mangling。
解决方案:
1. 锁版本 + per-version build
# 用户安装时根据 PyTorch 版本下载对应 wheel
pip install my-extension==0.1.0+pt2.6
pip install my-extension==0.1.0+pt2.4
每个 PyTorch 主版本编一份 wheel。
2. 用 LIBTORCH_USE_GLIBCXX_ABI
# 编译时指定 ABI
TORCH_CUDA_ARCH_LIST="8.0;9.0" \
LIBTORCH_USE_GLIBCXX_ABI=1 \
python setup.py bdist_wheel
让生成的 .so 与 PyTorch 内部 ABI 对齐。
3. JIT load (§22.6.11)
绕过 ABI 问题:用户机器现场编 → 自动用当前 PyTorch 的 ABI。代价是首次启动慢。
4. AOTI 路径
把自定义算子打包进 .pt2(§15.6.21),让 AOTI runtime 加载。AOTI 内部把 ABI 抽象掉,跨版本兼容性更好。
实战:开源 PyTorch 扩展(如 FlashAttention、xformers)维护团队都把”per-PyTorch-version build matrix”放在 CI 里。生产代码部署时锁住 PyTorch + 扩展版本。这是 C++ 扩展不可避免的工程税,优先用纯 Python + Triton(§22.3)能完全避开 ABI 问题。
22.6.13 Composite Implicit Autograd:算子的 decomposition
PyTorch 内置 op 有几类 autograd 处理方式:
| Autograd Key | 含义 |
|---|---|
| Autograd | 显式注册反向规则(如 mm、linear,硬编码反向) |
| CompositeImplicitAutograd | op 内部调其他 op,autograd 自动追踪(不需要写反向) |
| CompositeExplicitAutograd | composite 但显式标 autograd-eligible |
| AutogradPrivateUse1 | 厂商自家硬件的 autograd 实现 |
自定义算子默认是 CompositeImplicitAutograd —— 函数体调其他可微算子,autograd 自动追踪。这种 op 不需要写 register_autograd:
@torch.library.custom_op("mylib::my_attention", mutates_args=())
def my_attention(q, k, v):
# 内部调 ATen 算子, 全部可微
scores = q @ k.transpose(-2, -1)
attn = scores.softmax(-1)
return attn @ v
# 不需要 register_autograd! autograd 自动通过 mm + softmax + mm 追踪
但如果用了 Triton kernel / C++ kernel,autograd 看不到内部 op,必须 register_autograd:
@torch.library.custom_op("mylib::triton_attention", mutates_args=())
def triton_attention(q, k, v):
# Triton kernel 内部 op autograd 看不到
return my_triton_kernel(q, k, v)
# 必须显式注册反向
def backward(ctx, grad_out):
q, k, v = ctx.saved_tensors
return triton_backward_kernel(q, k, v, grad_out)
理解这两套路径让你写自定义算子时知道”何时需要 register_autograd”。简单 Python composite → 不需要;Triton/C++ kernel → 必须。
PyTorch 内部很多算子是 CompositeImplicitAutograd,让 ATen 代码生成不需要为每个 op 写反向。这套设计让 PyTorch 几千算子的反向规则维护成本可控。
22.6.14 Triton autotune:让 kernel 自动找最优配置
写 Triton kernel 时关键参数(block size / num_warps / num_stages)需要为每个硬件 / shape 调优。手动调耗时,Triton 内置 autotune 自动搜索:
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 256}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 256}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_SIZE': 512}, num_warps=8, num_stages=4),
# ... 更多配置
],
key=['n'], # n 不同时重新选 config
)
@triton.jit
def my_kernel(x_ptr, y_ptr, out_ptr, n: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask)
y = tl.load(y_ptr + offsets, mask)
tl.store(out_ptr + offsets, x * y, mask)
工作机制:
- 第一次某个
n调用时,autotune 跑所有 configs、测每个的 GPU 时间 - 选最快的 config
- 缓存到
(kernel, n) → best_config映射 - 后续相同 n 直接用 best_config
cost:第一次跑慢几十 ms(要试几个 config),后续命中缓存零开销。生产代码 warmup 阶段触发 autotune、之后稳态运行。
进阶:
prune_configs_by让你写自定义函数过滤掉 illegal config(如 BLOCK_SIZE 太大超 shared memory)reset_to_zero让某些 input 在每次 autotune trial 后清零(避免累积副作用)do_bench自定义 benchmark 函数
实战:FlashAttention v2 / v3 内部用了 几十个 config × 几十种 shape 的 autotune 矩阵,让单个 kernel 在不同 GPU + 不同 shape 都接近 hardware peak。理解 autotune 让你看到现代 SOTA kernel 的工程实质:不是手写一个完美 kernel,是搜索空间 + 自动调优。
22.6.15 vmap × custom_op:批量化的自动支持
vmap(functorch / torch.func.vmap)让 op 自动批量化:
def add(x, y):
return x + y
batched_add = torch.func.vmap(add)
# batched_add 接受 [B, ...] 输入, 内部 batched 算 add
PyTorch 内置 op 的 vmap 规则已经写好。自定义 op 默认 vmap 会失败:
@torch.library.custom_op("mylib::mymul", mutates_args=())
def mymul(x, y):
return x * y
torch.func.vmap(mymul)(x, y)
# 报错: vmap rule not registered for mylib::mymul
需要 register_vmap:
@mymul.register_vmap
def _(info, in_dims, x, y):
# in_dims: 输入 tensor 沿哪个维度 batch
# 实现: 把 vmap 输入展开成 normal call
x_dim, y_dim = in_dims
if x_dim is not None and y_dim is None:
y = y.unsqueeze(x_dim).expand_as(x)
elif y_dim is not None and x_dim is None:
x = x.unsqueeze(y_dim).expand_as(y)
out = mymul(x, y)
out_dim = x_dim if x_dim is not None else y_dim
return out, out_dim
实战工作量:复杂 op 的 vmap rule 比 forward 还难写。简单做法:默认 register_vmap 不实现,文档说”vmap 不支持”,让用户避开 vmap。生产代码里 vmap 用户少(functorch 主要给研究用),多数自定义 op 不写 vmap rule 也能跑。
理解 vmap 的存在让你知道 PyTorch 的”自动批量化”也是抽象层 + 各 op 单独支持。custom_op 想完整融入 PyTorch 生态需要 fake / autograd / vmap / dispatch 多层注册。
22.6.16 自定义 op 注册到 Inductor lowering
torch.compile 看到自定义 op 时,默认走 fallback 路径:直接调用原 op、不与周围算子 fuse。如果你想让 Inductor 真正编译你的 op(fuse 到 Triton kernel 里),用 register_lowering:
from torch._inductor.lowering import lowerings, register_lowering
from torch._inductor.ir import Pointwise
@register_lowering(torch.ops.mylib.mymul)
def mymul_lowering(x, y):
# 返回 Inductor IR (Pointwise)
return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=lambda idx: x.make_loader()(idx) * y.make_loader()(idx),
ranges=x.get_size(),
)
效果:torch.compile 看到 mymul(a, b) + c 时,不是”调 mymul kernel + 调 add kernel”,而是直接把 mymul 的语义编译进同一个 fused Triton kernel —— 真正的 op fusion。
适用:
- 简单算子(pointwise / reduction):写 lowering 让 Inductor 优化
- 复杂算子(attention / GEMM):保留 fallback,让 Inductor 当黑盒处理
PyTorch 内置 ATen op 都有 lowering,自定义 op 默认没有。写 lowering 是性能极致场景才做的工作 —— FlashAttention 等 SOTA op 已经够快、不需要再 fuse 进周围算子;普通 element-wise op 写 lowering 收益巨大。
理解 lowering 让你看 Inductor 不是”魔法编译器”,是 lowering registry 驱动的代码生成器。每个 op 一行 lowering 让它进入编译路径。
22.6.17 完整 FlashAttention 接入路径
把全章话题合起来看 FlashAttention 这种 SOTA op 怎么完整接入 PyTorch:
graph TB
FA[FlashAttention CUDA kernel]
FA --> CO[custom_op 装饰器<br/>schema: q, k, v -> out]
CO --> Fake[register_fake<br/>shape 推导]
CO --> Auto[register_autograd<br/>反向 = 另一个 FA backward kernel]
CO --> Cuda[register_kernel cuda<br/>调实际 CUDA kernel]
CO --> Cpu[register_kernel cpu<br/>调用 fallback SDPA]
Fake --> Compile[torch.compile 兼容]
Auto --> Backward[autograd Engine 调度反向]
Cuda --> Eager[eager 路径]
style FA fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
style Compile fill:#dcfce7
代码骨架:
@torch.library.custom_op("mylib::flash_attention", mutates_args=())
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# 默认实现 (CPU fallback)
return F.scaled_dot_product_attention(q, k, v)
@flash_attention.register_kernel("cuda")
def _(q, k, v):
# 调真实 CUDA kernel (FlashAttention v3)
return _flash_attn_v3.forward(q, k, v)
@flash_attention.register_fake
def _(q, k, v):
return torch.empty_like(q)
def fa_backward(ctx, grad_out):
q, k, v, out, lse = ctx.saved_tensors # lse = log-sum-exp, FA 内部产物
grad_q, grad_k, grad_v = _flash_attn_v3.backward(grad_out, q, k, v, out, lse)
return grad_q, grad_k, grad_v
def fa_setup_context(ctx, inputs, output):
q, k, v = inputs
out, lse = output_with_lse(q, k, v) # 实际场景里 forward 输出 lse
ctx.save_for_backward(q, k, v, out, lse)
flash_attention.register_autograd(fa_backward, setup_context=fa_setup_context)
# 测试
import torch
opcheck(flash_attention, args=(torch.randn(2, 8, 1024, 64, device='cuda'),) * 3)
部署后用法与内置 op 完全一致:
out = torch.ops.mylib.flash_attention(q, k, v)
# 或者 monkey-patch F.scaled_dot_product_attention 让全局透明用 FA
理解这套接入让你看 FlashAttention 不是”独立库”,是经过 PyTorch custom_op 接入的 first-class 算子。所有 PyTorch 用户能像用 mm 一样用它。custom_op 是 PyTorch 生态吸纳新 SOTA 算子的标准接口。
22.6.18 自家 AI 芯片完整接入 PyTorch 的工程
国产 AI 芯片厂商把硬件接进 PyTorch 是几十人月的系统工程。完整路径:
第 1 阶段:基础 backend
# 1. 注册 PrivateUse1 → "npu"
torch.utils.rename_privateuse1_backend("npu")
# 2. 实现 device guard / stream / event 抽象
class NPUStream(...): ...
class NPUEvent(...): ...
# 3. 注册到 PyTorch
torch._C._jit_register_npu_backend(...)
# 4. tensor.npu() 等 method
torch.utils.generate_methods_for_privateuse1_backend()
第 2 阶段:算子实现
# 给最常用的 200-500 个算子各写 NPU impl
# 用 codegen 减少手写代码
@torch.library.impl("aten::add.Tensor", "PrivateUse1")
def npu_add(self, other, alpha=1):
return _npu_runtime.add(self, other, alpha)
@torch.library.impl("aten::mm", "PrivateUse1")
def npu_mm(self, mat2):
return _npu_runtime.gemm(self, mat2)
# ... 几百个算子 ...
第 3 阶段:CommunicationBackend (NCCL 替代)
class NPUCommBackend(ProcessGroup):
def allreduce(self, tensors, opts): ...
def allgather(self, output_tensors, input_tensors, opts): ...
# ... 实现完整 c10d ProcessGroup 接口 ...
torch.distributed.Backend.register_backend("hccl", create_npu_comm)
第 4 阶段:编译栈集成
# 给 torch.compile 注册自家 backend
@torch._dynamo.register_backend
def npu_compiler(fx_graph, example_inputs):
# 调自家编译器把 fx_graph 编译成 NPU binary
return npu_compile(fx_graph, example_inputs)
# 用法
@torch.compile(backend="npu_compiler")
def model(x): ...
第 5 阶段:训练 / 推理生态
- FSDP / DDP 适配(用 hccl backend)
- AMP / bf16 支持
- safetensors / DCP 集成
- profile + Kineto 自家 backend
整个工程量级:
| 阶段 | 工程量 | 说明 |
|---|---|---|
| 基础 backend | 1-2 人月 | 设备 / 流抽象 |
| 算子实现 | 6-12 人月 | 200+ 算子 |
| 通信 backend | 1-2 人月 | 完整 c10d 接口 |
| 编译集成 | 3-6 人月 | 自家 graph compiler |
| 生态适配 | 2-4 人月 | FSDP / AMP / profile |
| 合计 | 15-25 人月 | 一个团队 5 人 3-5 个月 |
torch_npu(华为)、torch_mlu(寒武纪)、torch_xpu(Intel)都走过这条路。开源代码可以 GitHub 看完整例子。custom_op + PrivateUse1 是国内 AI 芯片厂商生态参与 PyTorch 的核心入口,不需要 fork 主仓代码。
理解这条路径让你看 PyTorch 不是 NVIDIA 专属,是真正”硬件中立”的开放生态。
22.6.19 自定义算子的演进时间线
PyTorch 自定义算子 API 的几个关键节点:
| 版本 | 主流 API | 特点 |
|---|---|---|
| v0.4 (2018) | autograd.Function | 简单但与编译栈不兼容 |
| v1.0 (2018) | + torch.utils.cpp_extension | C++ kernel 接入 |
| v1.5 (2020) | + TORCH_LIBRARY C++ 宏 | 注册到 dispatcher |
| v1.10 (2021) | + torch.library.Library Python API | 替代部分 C++ 宏 |
| v1.13 (2022) | + meta tensor / fake 概念 | 编译路径前置 |
| v2.0 (2023) | + register_fake 等 | torch.compile 兼容 |
| v2.4 (2024) | torch.library.custom_op 装饰器 | 现代标配 |
| v2.4 | + opcheck 自动测试 | 合规性检查 |
| v2.6 (2025) | + register_kernel 优化 + lowering 接口完善 | 与 Inductor 深度集成 |
| v2.10 (2025) | + 完整 functorch/vmap 集成 | 全 PyTorch 生态兼容 |
| v2.11 (2026) | API 稳定 | 生产级别成熟 |
整体趋势:
- v1.x:从
autograd.Function(仅 autograd)到TORCH_LIBRARY(完整 dispatcher) - v2.x:从分散 API 收敛到
custom_op装饰器一站式 - v2.4+:与编译栈、量化、distributed 深度集成
理解时间线让你看到自定义 op 不是一开始就这么好用 —— 经过几年迭代才达到”10 行 Python 装饰器”的体验。生产代码用最新 API(custom_op)能省最多事。
22.6.20 常见 bug 排查 cheat sheet
实战写自定义 op 遇到的报错与解法:
| 报错 | 根因 | 解决 |
|---|---|---|
RuntimeError: ... shape mismatch 在 compile 但不在 eager | fake 函数 shape 推导错 | 检查 fake 返回 shape 是否与真实 kernel 一致 |
Expected at most 0 ... got X | schema 字符串与函数签名不匹配 | type hint 改对 / schema 显式 |
mutates_args 错 | functionalize 假设无副作用、kernel 实际 mutate | 加正确 mutates_args=("x",) |
Dynamo Unsupported: ... graph break | 未注册 fake / Dynamo 看不进 op | register_fake 或 allow_in_graph |
gradcheck 失败 | 反向规则数值不对 | 用 finite-diff 一步步验证、或 torch.autograd.functional.jacobian 比对 |
inductor lowering not registered | 没注册 Inductor lowering(fallback 到 eager) | 写 register_lowering 或接受 fallback |
| ABI undefined symbol | C++ 扩展与 PyTorch 版本不匹配 | 重新编译 / 用 JIT load |
vmap rule 没注册 | functorch 不知道怎么批量化 op | register_vmap 或文档声明不支持 |
| autograd 反向时 saved_tensors 是 None | setup_context 没保存 | 在 setup_context 里调 ctx.save_for_backward |
opcheck test_aot_dispatch_dynamic fail | fake 函数没处理 SymInt 输入 | fake 里所有 shape 操作改用 SymInt-friendly API |
把这张表存到内部 wiki,新人写自定义 op 时遇到报错对照查 → 节省至少 3 天试错时间。
22.6.21 export 与自定义算子
torch.export(§12.8.28)把 model 导成 ExportedProgram,给部署用。自定义算子在 export 路径的处理:
@torch.library.custom_op("mylib::flash_attention", mutates_args=())
def flash_attention(q, k, v): ...
class MyModel(nn.Module):
def forward(self, x):
q, k, v = split(x)
return torch.ops.mylib.flash_attention(q, k, v)
# 导出
exported = torch.export.export(MyModel(), example_inputs)
# ExportedProgram 内部的 fx graph 含 mylib::flash_attention 节点
print(exported.graph)
ExportedProgram 内部用 op 的完整 fqn(mylib::flash_attention)记录,而不是 inline op body。部署时:
- AOTI:把
mylib::flash_attention编译进 .so,runtime 调原 kernel - ExecuTorch:让 op 走 delegate 到目标硬件
- ONNX:自定义 op 没标准化 → 报错(除非用 onnx custom domain)
为让自定义 op 能 export:
- 必须有
register_fake(export 用 FakeTensor 跑) - schema 要稳定(不能动态加参数)
- 不能有 graph break(complex Python logic)
实战:FlashAttention 等 SOTA op 都已 export-friendly。自家研究算子如果要部署,一开始就按 export 兼容写。v2.x 之后”导得出 vs 导不出”是判断 op 工程级别的关键 metric。
22.6.22 自定义 op 性能调优 flow
写完一个 custom_op 跑通后,通常发现”比预期慢”。调优流程:
flowchart TD
Slow[op 慢]
Slow --> P1[1. profile 看 op 在 trace 里占多少]
P1 --> Q1{是 op 内部慢, 还是 op 外部 dispatch 慢?}
Q1 -->|op 内部| Q2{kernel 是否 launch 多次?}
Q1 -->|dispatch 多| FUSE[让 op 接受更大 input<br/>减少 dispatch 次数]
Q2 -->|是| Bundle[bundle 多次小 launch 成一次大 launch]
Q2 -->|否| Q3{Tensor Core 利用率?}
Q3 -->|低| Align[shape padding 到 16 倍数<br/>+ 用 fp16/bf16]
Q3 -->|高| MB[memory bound<br/>看能否减少 read/write]
Slow --> P2[2. 看 with torch.compile 是否能 fuse]
P2 --> Lower[实现 register_lowering<br/>让 op 进入 fusion]
Slow --> P3[3. 比对竞品 baseline]
P3 --> Algo[换更优算法<br/>FA v2 → v3 → ...]
style P1 fill:#fef3c7
style P2 fill:#dcfce7
style P3 fill:#dbeafe
实战 case(自家写的 fused RMSNorm op):
第 1 轮 profile:op 占总时间 8%,但 RMSNorm 数学上只是 mean + rsqrt + scale → 应该 < 1%。 看 trace:op 内部 launch 4 个 kernel(mean / sqrt / rsqrt / scale)→ 应该 fuse 成 1 个。 修复:手写 Triton kernel 把 4 步合到一个 → 1.5%。
第 2 轮:仍比 NVIDIA TransformerEngine 的 RMSNorm 慢 30%。 profile metric:SM 占用率 70%(对方 95%)。 修复:调 BLOCK_SIZE / num_warps(autotune),找到最优配置 → 性能匹配 TE。
整套调优 1-2 天。关键是 profile 驱动——每步看数据找根因,不靠猜。
22.6.23 multi-level dispatch:算子的多层 fallback
dispatcher(§5.x)按 priority 调用算子:从最具体 device 找到最通用 fallback。custom_op 也参与这套机制。
graph TB
Call[mymul x y]
Call --> D[dispatcher]
D --> D1{x is on CUDA?}
D1 -->|是| K1[找 CUDA impl]
K1 -->|找到| RunCuda[运行 CUDA kernel]
K1 -->|没有| K2[找 CompositeImplicitAutograd]
K2 -->|找到| RunComp[运行默认实现]
K2 -->|没有| Fail[报错: 没注册]
D1 -->|是 PrivateUse1| KP[找 PrivateUse1 impl]
KP -->|找到| RunNpu[运行 NPU kernel]
KP -->|没有| K2
style RunCuda fill:#dcfce7
style RunComp fill:#fef3c7
style RunNpu fill:#dbeafe
priority 顺序(精简):
- AutogradXxx(具体 device):训练时优先
- Xxx(具体 device):CPU / CUDA / MPS / PrivateUse1
- CompositeImplicitAutograd:用其他 op 拼出来的默认实现
- CompositeExplicitAutograd:显式标记的 composite
每层都可以注册自家 impl。fallback 链让自定义 op 在缺失某 device 实现时仍能跑(虽然慢):
@torch.library.custom_op("mylib::mymul", mutates_args=())
def mymul(x, y):
return x * y # 默认 (CompositeImplicitAutograd)
@mymul.register_kernel("cuda")
def _(x, y):
return my_cuda_kernel(x, y) # CUDA fast path
# 没注册 cpu impl?
# CPU input 调 mymul → 找不到 CPU impl → 走默认 (composite) → x * y
CPU 用户能跑(虽然慢),CUDA 用户用快路径。优雅 fallback 让自定义 op 通用。
理解 multi-level dispatch 让你看 PyTorch 的”扩展性”——每个 op 可以为 N 个 device 注册 N 份实现,dispatcher 自动选最快的。这是单一 codebase 支持几十种硬件的工程基础。
22.6.24 SOTA op 接入示例:开源生态中的 5 个典型 case
把全章话题落到具体例子,5 个开源 SOTA op 的接入方式:
1. FlashAttention (Tri Dao)
- 路径:CUDA kernel →
flash_attnPython wrapper →custom_op注册到 PyTorch - 全套:fake / autograd / register_kernel(“cuda”) + (“cpu” fallback)
- v2.4+ PyTorch 内置 SDPA 自动用 FA v2/v3
2. xformers
- 路径:CUDA + Triton kernel → 自家 wrapper → 部分注册成 PyTorch op
- 不全用 custom_op(早于 v2.4 出现),有些走
autograd.Function - v2.x 时代逐步迁到 custom_op
3. Liger Kernel (Linkedin)
- 路径:纯 Triton kernel(fused RMSNorm / GeGLU / RoPE 等)
- 全 Python:
@triton.jit+@torch.library.custom_op - 标杆”Triton + custom_op”现代实践
4. bitsandbytes (8-bit / 4-bit ops)
- 路径:自家 CUDA kernel → C++ extension
- 部分注册成 PyTorch op,部分仍是函数式
- 走 PrivateUse1 / 自定义 dtype 路径
5. Apex (NVIDIA)
- 路径:纯 CUDA kernel + setup.py 编译 .so
- 老一代实践,许多 op 是
autograd.Function - 现代被 PyTorch 内置取代(fused LayerNorm 等已进 mainline)
观察:
- 新项目都用 Triton + custom_op:比 CUDA + setup.py 简单 10x
- 老项目逐步迁移:Apex 等老库的功能逐渐进 PyTorch 主仓
- 企业级 (NVIDIA / Meta / Google)仍写 CUDA kernel:性能极致 + 控制 ABI
理解这些案例让你看到 PyTorch 自定义 op 生态的全貌:研究项目 → Triton + Python,生产 SOTA → CUDA + 完整 custom_op,硬件厂商 → PrivateUse1 完整接入。每条路径有自己的 trade-off。
22.6.25 functorch 高阶变换:grad / jacrev / vmap 组合
functorch(v1.13+ 内置 torch.func)提供”函数变换”:把可微函数变成它的梯度、Jacobian、Hessian 等。custom_op 想被这些 transform 用,需要满足条件:
import torch
from torch.func import grad, jacrev, vmap
@torch.library.custom_op("mylib::squared", mutates_args=())
def squared(x: torch.Tensor) -> torch.Tensor:
return x ** 2
@squared.register_fake
def _(x):
return torch.empty_like(x)
def squared_backward(ctx, grad_out):
x, = ctx.saved_tensors
return 2 * x * grad_out
def squared_setup(ctx, inputs, output):
ctx.save_for_backward(inputs[0])
squared.register_autograd(squared_backward, setup_context=squared_setup)
# 现在能用 functorch transforms
gradient_fn = grad(squared)
print(gradient_fn(torch.tensor(3.0))) # 6.0 = 2 × 3
jacobian_fn = jacrev(squared)
print(jacobian_fn(torch.tensor([1.0, 2.0, 3.0]))) # diag([2, 4, 6])
工作机制:functorch 通过 dispatcher 调 register_autograd 注册的反向规则。只要 register_autograd 正确,所有 functorch transform 自动可用 —— 不需要单独 register_grad / register_jacrev。
特殊情况:
vmap(grad(f))这种组合需要 register_vmap(§22.6.15)- 二阶导数 (
grad(grad(f))) 要求反向函数自身可微 ——register_autograd的 backward 函数里调的 op 都得是可微 op,不能是 detached value - forward-mode AD (
jvp) 需要 register_jvp(实验性 API)
实战:研究项目用 functorch 多,custom_op 写正确反向就够。生产 LLM 训练几乎不用 jacrev / hessian(model 太大算不动),functorch 主要给 second-order optimizer / 物理模拟等场景。
理解 functorch 兼容性让你看 custom_op 的”完整生态接入”含义 —— 不只是 forward + backward,还要支持函数变换。
22.6.26 ABI-stable C++ 扩展:v2.6+ 实验性新路径
§22.6.12 提了 ABI 兼容性是 C++ 扩展的痛点。PyTorch v2.6+ 在 torch.csrc.stable namespace 引入 ABI-stable API:
#include <torch/csrc/stable/library.h>
// 用 stable API 而非内部 ABI
TORCH_LIBRARY(mylib, m) {
m.def("mymul(Tensor x, Tensor y) -> Tensor");
}
// stable API 不暴露内部数据结构
torch::stable::Tensor my_mul_cuda(torch::stable::Tensor x, torch::stable::Tensor y) {
return torch::stable::ops::mul(x, y);
}
TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
m.impl("mymul", my_mul_cuda);
}
保证:
- 跨 minor version 兼容:v2.6 编的 .so 在 v2.7 + 加载 OK
- 不暴露内部 type:仅 stable_tensor / stable_scalar 等
- 限制 API 集合:只能用 stable namespace 里的函数(约 200 个,覆盖常用场景)
代价:
- API 比内部 ABI 受限,复杂操作要回退到 unstable
- 性能略低 1-2%(额外 ABI 转换开销)
- 仍在实验,几个版本可能调整
适用场景:长期维护的开源 PyTorch 扩展(如 FlashAttention、xformers)—— 不用每次 PyTorch 升级都 rebuild。
短命扩展(自家研究 prototype)继续用普通 C++ 扩展即可。理解这条新路径让你看 PyTorch 团队对 “ABI 痛点”的工程响应——把痛点收编进框架本身解决,而不是让用户每家自己处理。
22.6.27 distributed 训练里的 custom op
custom_op 在分布式训练里要注意:
1. collective 算子用 functional API
# 错误: 用老 inplace API
@torch.library.custom_op("mylib::ring_attention", mutates_args=())
def ring_attention(q, k, v, group):
out = q @ k.transpose(-2, -1)
dist.all_reduce(out, group=group) # ← inplace, functionalize 会失败
return out @ v
# 正确: 用 functional collectives (§16.7.9)
import torch.distributed._functional_collectives as funcol
@torch.library.custom_op("mylib::ring_attention", mutates_args=())
def ring_attention(q, k, v, group):
out = q @ k.transpose(-2, -1)
out = funcol.all_reduce(out, "sum", group) # ← functional, compile 友好
return out @ v
2. process_group 不能直接放 schema
ProcessGroup 不是 tensor,不能作为 op 输入。变通:用 group_name (str) 在 op 内部 lookup:
@torch.library.custom_op("mylib::ring_attention", mutates_args=())
def ring_attention(q, k, v, group_name: str):
group = dist.distributed_c10d._resolve_process_group(group_name)
...
3. FSDP-2 / DTensor 协作
DTensor(§18.6.6)有 placement 概念。custom_op 默认不支持 DTensor 输入:
@register_dtensor_dispatch(torch.ops.mylib.ring_attention)
def _(q_dt, k_dt, v_dt, group_name):
# 显式处理 DTensor placement
...
实战:如果 custom_op 要在 FSDP-2 / DTensor 模型里用,必须实现 DTensor dispatch,否则 placement 信息丢失。
4. NCCL communicator caching
custom_op 内部如果调 NCCL,要确保用同一个 communicator(§16.7.5)。lookup 一次后 cache:
_comm_cache = {}
def get_comm(group_name):
if group_name not in _comm_cache:
group = dist.distributed_c10d._resolve_process_group(group_name)
_comm_cache[group_name] = init_nccl_comm(group)
return _comm_cache[group_name]
理解分布式 custom_op 的这些坑让你写”适配多卡”的自定义算子时不会踩雷。生产 LLM 训练里 custom_op 与 FSDP / TP / PP 协作是真实需求(如自家 attention 实现要兼容现有训练栈)。
22.6.28 推理引擎中的 custom_op:vLLM / SGLang 实例
LLM 推理引擎 vLLM / SGLang / TensorRT-LLM 都大量用自定义 op。具体实现观察:
vLLM 的 attention kernel:
# vllm/attention/backends/flash_attn.py
@torch.library.custom_op("vllm::flash_attn_varlen", mutates_args=())
def flash_attn_varlen(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor,
max_seqlen_q: int, max_seqlen_k: int,
) -> torch.Tensor:
return _flash_attn_v3.varlen_forward(...)
@flash_attn_varlen.register_fake
def _(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
return torch.empty_like(q)
vLLM 把所有”长度变化的 attention”包成 custom_op,让 torch.compile 能 capture,配合 piecewise CUDA Graph(§15.6.16)实现高吞吐推理。
SGLang 的 paged attention:
@torch.library.custom_op("sglang::paged_attn", mutates_args=("output",))
def paged_attn(
output: torch.Tensor, # mutates 输出 tensor (KV cache 持续累积)
query: torch.Tensor,
key_cache: torch.Tensor, value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
_sglang_kernel.paged_attention(output, query, key_cache, value_cache,
block_tables, seq_lens)
注意 mutates_args=("output",) 让 output 显式标 inplace。这是推理引擎与训练 op 不同的地方:推理时 KV cache 持续累积、必须 inplace 写入,无法走纯函数路径。
实战经验:
- 推理引擎的 op 不需要 register_autograd:推理无反向,省工作
- 必须 register_fake:CUDA Graph 与 torch.compile 都需要
- mutates_args 要写正确:KV cache mutation 必须显式标
- register_kernel(“cuda”) 调真实 CUDA kernel;CPU fallback 可选
理解推理引擎的 custom_op 用法让你看到 LLM 推理优化与 PyTorch 自定义 op 接口深度耦合。理解这套接口能让你看 vLLM / SGLang 源码不困惑,甚至自己往里加新算子。
22.6.29 算子注册的”产品哲学”
把全章合起来看,custom_op 接口的设计反映了 PyTorch 团队的几个产品决策:
1. “扩展是用户体验的一部分”
老 PyTorch(v1.x)扩展接口分散:autograd.Function、TORCH_LIBRARY、Library.impl()……每条路径覆盖一部分场景。结果:用户写自定义 op 痛苦、社区贡献 PR 质量参差不齐。
v2.4 收敛到 @torch.library.custom_op 一站式接口,把”如何扩展”变成产品的核心 UX。这是 PyTorch 从”研究框架”成熟为”工业级 ML 平台”的标志。
2. “fake / shape inference 是底座”
v2.x 把 fake 函数从可选变成”几乎必填”。这看起来增加了用户负担,实际是强制让所有 op 都能进入编译路径。否则 LLM 时代 torch.compile 会被零散的 op 不兼容拖累。
这条决策背后是产品判断:“未来所有人都会用 torch.compile”。所以提前要求 op 注册时声明 fake,保证生态顺滑迁移。
3. “Triton 取代 CUDA”
v1.x 时代写自定义 op 必经 C++ + CUDA。v2.x 推 Triton 作为首选,让 Python 工程师都能写 GPU kernel。降低门槛后社区贡献的高性能 op(Liger Kernel 等)数量爆增。
4. “PrivateUse1 给硬件中立”
不绑死 NVIDIA。提供完整 backend extension API 让国产 / 第三方芯片厂商接进来。这条决策让 PyTorch 在 NVIDIA 之外的硬件市场(华为、寒武纪、Intel Arc)保持竞争力。
理解这些产品决策让你看自定义 op 接口不只是”技术 API”,是 PyTorch 团队对”开放生态”的具体实现。每条接口设计选择背后都有商业 / 战略考量。
22.6.30 一段实战脚本:从零到生产 op
把全章的步骤合并成一个完整的实战脚本,写一个 fused “GeLU + Linear” op:
import torch
import triton
import triton.language as tl
# 第 1 步: Triton kernel
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8),
],
key=['M', 'N', 'K'],
)
@triton.jit
def fused_gelu_linear_kernel(
x_ptr, w_ptr, b_ptr, out_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# ... 实现 GeLU(X @ W + b) 的 fused kernel ...
# 略 (完整实现 30+ 行)
# 第 2 步: custom_op 包装
@torch.library.custom_op("mylib::fused_gelu_linear", mutates_args=())
def fused_gelu_linear(
x: torch.Tensor, w: torch.Tensor, b: torch.Tensor,
) -> torch.Tensor:
M, K = x.shape
K2, N = w.shape
assert K == K2
out = torch.empty(M, N, device=x.device, dtype=x.dtype)
grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))
fused_gelu_linear_kernel[grid](x, w, b, out, M, N, K)
return out
# 第 3 步: register_fake
@fused_gelu_linear.register_fake
def _(x, w, b):
M, K = x.shape
K2, N = w.shape
return torch.empty(M, N, device=x.device, dtype=x.dtype)
# 第 4 步: register_autograd
def fgl_backward(ctx, grad_out):
x, w, b = ctx.saved_tensors
# ... 实现反向 ...
grad_x = grad_out @ w.T * gelu_grad(x @ w + b)
grad_w = x.T @ (grad_out * gelu_grad(x @ w + b))
grad_b = grad_out.sum(0)
return grad_x, grad_w, grad_b
def fgl_setup(ctx, inputs, output):
ctx.save_for_backward(*inputs)
fused_gelu_linear.register_autograd(fgl_backward, setup_context=fgl_setup)
# 第 5 步: opcheck 验证
from torch.library import opcheck
x = torch.randn(64, 256, device='cuda', requires_grad=True)
w = torch.randn(256, 128, device='cuda', requires_grad=True)
b = torch.randn(128, device='cuda', requires_grad=True)
opcheck(fused_gelu_linear, args=(x, w, b)) # 通过
# 第 6 步: 集成到模型
class FusedFFN(torch.nn.Module):
def __init__(self, dim, hidden):
super().__init__()
self.w = torch.nn.Parameter(torch.randn(dim, hidden))
self.b = torch.nn.Parameter(torch.randn(hidden))
def forward(self, x):
return torch.ops.mylib.fused_gelu_linear(x, self.w, self.b)
# 第 7 步: torch.compile 验证
model = FusedFFN(256, 1024).cuda()
compiled = torch.compile(model)
out = compiled(x)
loss = out.sum()
loss.backward() # 反向自动调 fgl_backward, fused 进 inductor graph
整套约 100 行 Python(不含 Triton kernel 实现)。从研究 idea 到生产 op 一周可达:
- Day 1:写 Triton kernel + 跑通 forward
- Day 2:register_fake + register_autograd + opcheck
- Day 3:vmap / Inductor lowering(若需要)
- Day 4-5:性能调优 + autotune + benchmark
- Day 6:集成到模型 + 与 baseline 对比 accuracy
- Day 7:写 unit test + CI 集成
理解这套完整脚本让你看到”自定义 op”在 v2.x 时代不再是几人月的工程,而是一周的开发任务。门槛降低 → 创新加速 —— Triton + custom_op 让大量论文中的新算子能快速进 PyTorch 生态。
22.6.31 自定义 op 的版本兼容性策略
随着 PyTorch / 自家库迭代,自定义 op 的 schema 可能变化。生产代码必须考虑兼容性:
1. schema 演进的安全规则:
| 修改 | 是否 break 兼容 |
|---|---|
| 新增 op | 不 break(旧代码不调用就行) |
| 新增 op 的可选参数(带默认值) | 不 break |
| 新增 op 的必选参数 | break(旧代码不传新参数报错) |
| 重命名 op | break |
| 改 input dtype | break(schema 校验失败) |
| 改 output shape 推导 | 隐性 break(compile 后行为变) |
实战做法:
- 新功能加可选参数:
def my_op(x, y, *, optional_flag: bool = False) -> Tensor - deprecated 老 op,加新 op:保留
mylib::v1_op,新 ckpt 用mylib::v2_op - schema 重大变化:bumping namespace(
mylib::op→mylib_v2::op)
2. 与 PyTorch 版本的兼容:
import torch
if torch.__version__ >= "2.4":
@torch.library.custom_op("mylib::myop", mutates_args=())
def myop(...):
...
else:
# v2.4 之前的 fallback 写法
class MyOp(torch.autograd.Function):
...
或用 try/except 兜底:
try:
from torch.library import custom_op
except ImportError:
# 老 PyTorch 没有这个 API
custom_op = None
3. ckpt 兼容性:
如果 op 是 model 的一部分,model state_dict 没区别(op 的实现不在 state_dict 里)。但用户代码必须能 import 到 op——升级时确保自家 op 库一并升级。
4. 渐进式 deprecation:
import warnings
@torch.library.custom_op("mylib::old_op", mutates_args=())
def old_op(x, y):
warnings.warn(
"mylib::old_op is deprecated, use mylib::new_op instead",
DeprecationWarning, stacklevel=2,
)
return new_op(x, y)
让用户有时间迁移,几个月后正式删除。
理解这些策略让你写自定义 op 时考虑”长期维护”,不是只考虑 v1。生产 op 一旦上线就要支持多年(用户的 ckpt 还在用),向前 / 向后兼容性是必修课。
22.6.32 op 注册的内部数据结构:从 schema 到 dispatcher
把全章话题落到底层数据结构。@torch.library.custom_op("mylib::mymul", ...) 在 PyTorch 内部最终落到几张表:
graph TB
Decorator[custom_op 装饰器]
Decorator --> Lib[Library 对象 mylib<br/>Python 层 wrapper]
Lib --> CppLib[C++ Library<br/>持有 schema list]
CppLib --> Dispatcher[Dispatcher 全局表<br/>OperatorHandle]
Dispatcher --> Schema["schema string<br/>mymul(Tensor, Tensor) -> Tensor"]
Dispatcher --> Kernels{各 dispatch key 实现表}
Kernels --> CPU[CPU: lambda x,y: x*y]
Kernels --> CUDA[CUDA: triton_kernel]
Kernels --> Auto[AutogradCUDA: 自动包 backward]
Kernels --> Fake[Meta/FakeTensor: register_fake fn]
Decorator --> AutogradReg[autograd info<br/>setup_context + backward fn]
AutogradReg --> Auto
style Dispatcher fill:#fef3c7
style Kernels fill:#dcfce7
具体源码位置(v2.x):
- Python wrapper:
torch/library.py:CustomOpDef - C++ Library:
torch/csrc/api/include/torch/library.h:Library - Dispatcher:
aten/src/ATen/core/dispatch/Dispatcher.h:Dispatcher - OperatorHandle:
aten/src/ATen/core/dispatch/OperatorHandle.h
调用 mymul(x, y) 的内部路径:
- Python 调
torch.ops.mylib.mymul(x, y) - C++
OperatorHandle.callBoxed(stack) - Dispatcher 查 dispatch key set(input device + autograd state + …)
- 选最高 priority 的 kernel:典型 AutogradCUDA(如果 input requires_grad + 在 CUDA)
- AutogradCUDA kernel 是 PyTorch 自动生成的 wrapper:调 forward + 注册反向 Node
- forward 调底层 CUDA kernel(用户写的 Triton kernel)
- 反向时 autograd Engine(§8.x)调度 Node、最终调 register_autograd 注册的 backward fn
每一步都用 §5.x dispatcher 章讲过的同一套机制 —— 自定义 op 与内置 op 走完全相同的路径。这就是为什么”扩展与内置无差别”(§22.9 第一条设计启示)。理解这套数据结构让你看 PyTorch 的扩展机制不是黑盒,而是清晰的注册 + 查表系统。
22.7 几条工程经验
1. v2.4+ 用 torch.library.custom_op:替代老 TORCH_LIBRARY 宏 + Library.impl() 等手动调用
2. torch.library.opcheck(my_op, args) 是合规性测试:自动检查 fake / autograd / schema 等是否一致。生产 op 必跑
3. Triton kernel + custom_op 是写新算子的最优组合:性能、灵活性、与 compile 兼容性都好
4. mutates_args= 一定写正确:错了 functionalize 会出问题、torch.compile 编译错代码
5. 不要在 fake 函数里做实际计算:会让 torch.compile / FSDP 内存爆 / 性能崩
6. C++ 扩展跨 PyTorch 版本要重编:libtorch ABI 不保证版本兼容。每升级 PyTorch 重建 .so
7. PrivateUse1 是国产芯片接入路径:注册成新 backend 而非新算子,让所有现有算子都能跑
8. torch._dynamo.allow_in_graph 给某些函数特殊白名单:如果你的代码有 Dynamo 不识别但实际 trace-friendly 的部分,用这个绕过 graph break
9. 推理引擎用的 op 不需要 register_autograd:推理无反向,省一步工作。但 register_fake 仍必须
10. 跨 PyTorch 版本部署用 ABI-stable API(v2.6+):避免每升级 PyTorch 都重新编 .so 的工程税
11. distributed 训练里的 collective 必须用 functional API:torch.distributed._functional_collectives 替代 dist.all_reduce,不然 functionalize 会失败
12. 写 Triton kernel 必加 @triton.autotune:让 BLOCK_SIZE / num_warps 自动搜索,避免手调
22.8 跨书关联
- 第 5 章 dispatcher:自定义 op 注册的底层机制
- 第 6 章 ATen 代码生成:内置 op 是 codegen,自定义 op 是 register —— 两条路殊途同归
- 第 7 章 autograd:
register_autograd与autograd.Function.backward等价语义 - 第 12-14 章 编译器栈:fake 函数让自定义 op 进入编译路径,register_lowering 让 op 真正被 Inductor fuse 而非走 fallback
- 第 16 章 ProcessGroup:分布式训练里 custom_op 与 functional collectives 的协作
- 第 18 章 FSDP-2 / DTensor:DTensor placement 与 custom_op 的 dispatch 协作
- 第 21 章 Profiler:opcheck 与 profile 共同保证 op 正确性 + 性能符合预期
22.9 设计启示
PyTorch 自定义算子接口的核心思想:
第一:让”扩展”与”内置”无差别:自定义 op 一旦注册就和 torch.add 一样工作。所有上层特性(autograd / compile / FSDP)零修改支持
第二:fake 函数是高级特性的入场券:v2.x 之后任何 op 都得能 fake,否则被现代生态边缘化。这条变化看似增加用户负担,实际是 PyTorch 团队对”未来所有 op 都要进编译路径”的产品判断
第三:多种 device 各注册一份 kernel:PrivateUse1 给国产芯片厂商完整的扩展能力,不需要 fork PyTorch 主仓,让硬件中立性成为生态扩展的基础设施
第四:用装饰器替代宏 / Python 替代 C++:现代 API 让”写自定义 op”从需要 C++ + 宏的工程任务,降级到 10 行 Python 装饰器。这种”降低门槛 + 保留性能”的设计思想让 PyTorch 自定义 op 生态空前繁荣
第五:fake / vmap / lowering 是”完整生态接入”的多个维度:每个新维度让 op 与一类 PyTorch 高级特性兼容(compile / functorch / fusion)。理解这种”渐进接入”让你知道 op 想用得上 X 特性需要注册哪个对应 hook
第六:opcheck 把”扩展正确性”自动化:以前自家测 op 行为靠人工写测试,opcheck 自动覆盖 schema/autograd/fake/AOT 多条路径。这种”质量基础设施”的存在让社区能持续贡献高质量 op
22.10 跨章呼应:自定义 op 是这本书的”集大成”
把全章合起来看,自定义 op 几乎需要全书前面所有章节的知识:
| 写自定义 op 时用到 | 对应章节 |
|---|---|
| schema / IValue / ATen | §6 ATen 代码生成 |
| dispatcher 注册 | §5 dispatcher |
| TensorImpl / Storage | §2 Tensor 数据结构 |
| autograd Function / Engine | §7-8 autograd |
| AOTAutograd functionalize | §13 AOT Autograd |
| FakeTensor / register_fake | §5.7 + §13 |
| Inductor lowering / fusion | §14 Inductor |
| torch.compile 协作 | §12-15 编译栈 |
| AMP custom_fwd | §20.5.19 |
| FSDP / DTensor / collective | §16-18 分布式 |
| profile + opcheck | §21 Profiler |
写一个生产级 custom_op = 整本书的综合实践。这就是为什么把它放在最后一章(除 23 章哲学收束外)—— 它是检验前面知识掌握程度的”期末考试”。
新人写自定义 op 卡在哪一步,对应回去复习对应章节。这是本书的内部 cross-reference 网络的最后一环。
下一章是收官章 —— 拆 PyTorch 整体设计哲学与未来演进,把 22 章的内容串成一条主线,看从 Tensor 到 custom_op 这条 trace 上 PyTorch 团队留下了什么共通的设计原则。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。