第5章 Dispatcher:多分派的工程实现
“The Dispatcher is PyTorch’s central nervous system. Every operator call goes through it. Make it slow, and the entire library becomes slow. Make it inflexible, and you can’t add new features. Make it inscrutable, and contributors give up.”
—— Edward Yang,《Let’s talk about the PyTorch dispatcher》(2020)
本章要点
- Dispatcher 的主路径只做三件事:用
DispatchKeyExtractor算 DispatchKeySet → 用dispatchTable_数组索引拿 KernelFunction → 调它。三步全部 O(1) dispatchTable_是定长数组:尺寸 =num_runtime_entries(约几百),按 DispatchKeySet 索引计算。不用 hashmap、不用 if-else,纯数组下标- kernels_ 与 dispatchTable_ 分离:前者是注册历史(hashmap,允许 Jupyter 重注册),后者是热路径(数组,进 CPU cache)
- boxed vs unboxed call 是 PyTorch 的双协议:unboxed 是模板化、零开销的快路径;boxed 走 IValue 栈,慢但通用,是 Python / vmap / FakeTensor 等动态场景的接口
- redispatch 是”洋葱式中间层”的关键:autograd / functionalize / autocast 全是用 “做点事 + 把当前 keyset 减掉自己 + 重新进 dispatcher” 模式实现
- Fallthrough kernel 让某些 functionality key 在不需要时无开销跳过 —— autograd 在
inference_mode下就是这样退场的 - TorchDispatchMode 是 Python 端给 dispatcher 注入拦截器的机制,用
PythonDispatchKey 实现,给torch.compile/ FakeTensor / Functionalize 等高级特性提供基础
5.1 Dispatcher 是什么
第 1 章 §1.3 我们提过,每一次张量算子(如 a + b)都要经过一个全局的 dispatcher 单例。这一章我们彻底拆开它的内部 —— 看看 PyTorch 怎么用 几百行核心代码 撑起整个深度学习生态的多态性。
让我们先复习问题:当用户写 c = a + b,PyTorch 至少要回答这些问题:
a、b是 CPU 还是 CUDA?走哪个 backend kernel?- dtype 是 fp32 / bf16 / int8 / 量化?走哪个数值实现?
- 是不是 sparse / nested / 量化张量?要不要走专门的稀疏算子?
- 是不是要记反向图(autograd)?
- 是不是在
inference_mode里?autograd 路径要不要绕开? - 是不是在
vmap/grad/jvp里?要不要做函数变换? - 是不是在
autocast上下文里?要不要把 fp32 转 bf16? - 是不是在
torch.compile的 trace 阶段?要不要用 FakeTensor 假执行? - 是不是有用户注册的
__torch_dispatch__mode? - 是不是有用户注册的
__torch_function__子类?
每个问题对应一个 DispatchKey。Dispatcher 的工作就是把这些 keys 综合起来,挑出最高优先级的那个,调对应的 kernel。
graph TB
Call["a + b"] --> Extract["DispatchKeyExtractor<br/>从张量元数据 + TLS 算 DispatchKeySet"]
Extract --> Lookup["OperatorEntry.lookup(ks)<br/>dispatchTable_[ks.idx]"]
Lookup --> Kernel["KernelFunction"]
Kernel --> Invoke["调 unboxed 实现"]
Invoke --> Result["返回张量"]
Kernel -.可选.-> Boxed["boxed 实现 (Python / vmap / fallback)"]
style Extract fill:#dbeafe,stroke:#3b82f6
style Lookup fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
style Kernel fill:#dcfce7,stroke:#22c55e
整条路径中最关键的优化是:lookup 不是 hashmap、不是 vtable、不是 if-else,而是一个数组下标。这是 PyTorch 把 100+ keys 的多分派开销压到 100ns 级的根本原因。
5.2 数据结构:Dispatcher / OperatorEntry / KernelFunction
打开 aten/src/ATen/core/dispatch/Dispatcher.h:71:
class TORCH_API Dispatcher final {
public:
static Dispatcher& singleton();
template <class Return, class... Args>
Return call(const TypedOperatorHandle<Return(Args...)>& op,
Args... args) const;
void callBoxed(const OperatorHandle& op, Stack* stack) const;
Return redispatch(...) const;
RegistrationHandleRAII registerImpl(
OperatorName op_name,
std::optional<DispatchKey> dispatch_key,
KernelFunction kernel, ...);
...
private:
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operators_;
std::array<KernelFunction, num_runtime_entries> backendFallbackKernels_;
...
};
——Dispatcher 单例持有:
operators_:算子表,OperatorName → OperatorHandle。新算子注册时插入。LeftRight是 PyTorch 自家的”读优先”并发数据结构,让 dispatch 路径完全不阻塞写backendFallbackKernels_:每个后端的”兜底”kernel —— 当某算子在某后端没注册具体实现时用
每个算子对应一个 OperatorEntry(OperatorEntry.h:70),它的核心字段:
class OperatorEntry final {
private:
OperatorName name_;
std::optional<AnnotatedSchema> schema_;
// 热路径:定长数组,按 DispatchKeySet 索引
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
// 注册历史:hashmap,按 DispatchKey 索引,每个 key 可能有多次注册
ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
...
};
注释里专门解释了为什么有两个数据结构(OperatorEntry.h:243-273):
We do not combine dispatchTable and kernels into one hash map because kernels is a larger data structure and accessed quite infrequently while dispatchTable is accessed often and should be kept small to fit into CPU caches.
—— dispatchTable_ 是热路径,kernels_ 是冷路径。前者每秒访问百万次、必须 cache friendly;后者只在注册/反注册时访问。它们的同步规则:
dispatchTable_[k] = kernels_[k].front() // 取最新注册的
每次有人 m.impl("aten::add", &my_add) 时,新 kernel 进入 kernels_[CPU].front(),然后 dispatchTable_[CPU] 同步更新到这一个。这就是为什么 Jupyter 里反复执行同一段 TORCH_LIBRARY_IMPL 不会出错 —— 后注册的覆盖前注册的。
5.2.1 注册接口的三层 API
PyTorch 提供三种公开注册 API,用法不同但底层都进入 Dispatcher::registerImpl:
| API | 用途 | 典型场景 |
|---|---|---|
TORCH_LIBRARY(ns, m) | 声明一个新算子库(一次声明 schema) | 第一方算子或自定义算子集 |
TORCH_LIBRARY_IMPL(ns, key, m) | 给 (库, key) 注册具体实现 | NVIDIA / AMD 等后端给同一算子注册不同 device 实现 |
TORCH_LIBRARY_FRAGMENT(ns, m) | 在已有库追加更多算子 | 大型库分成多个 .cpp 文件,每个文件用 fragment 注册 |
举一个真实例子,PyTorch 自己注册 ATen 库的 add 算子:
// 简化版的真实注册(实际由代码生成)
TORCH_LIBRARY(aten, m) {
m.def("add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
}
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("add.Tensor", &at::native::add_cpu);
}
TORCH_LIBRARY_IMPL(aten, CUDA, m) {
m.impl("add.Tensor", &at::native::add_cuda);
}
TORCH_LIBRARY_IMPL(aten, Autograd, m) {
m.impl("add.Tensor", &VariableType::add_Tensor);
}
代码生成器从 native_functions.yaml 读出每个算子的 schema 和 dispatch 表,自动展开成上面的注册块。第 6 章会拆代码生成器。
5.2.2 alias keys 与”组合实现”
注意 Autograd 这个 key 看起来是个普通后端,但其实它是 alias key —— 它在内部展开成 AutogradCPU + AutogradCUDA + AutogradHIP + ...。同样的 alias 还有 CompositeImplicitAutograd、CompositeExplicitAutograd 等。
为什么要有 alias key?因为 PyTorch 团队不想为每个 (functionality, backend) 笛卡尔积都写一份代码。CompositeImplicitAutograd 的语义是 “用其他算子实现的算子,不需要专门的 autograd” —— 比如 torch.where(cond, a, b) 内部用 mul + sub + add 实现,每个子算子都有 autograd,整体 autograd 自动正确,不需要专门写 WhereBackward。
注册 alias key 时,dispatcher 在 computeDispatchTableEntry 阶段把它展开到所有对应的 runtime keys。这是为什么 dispatchTable_ 的查找永远是 O(1) —— alias 解析在注册期完成,运行期不付代价。
5.3 主路径:Dispatcher::call 的全旅程
打开 Dispatcher.h:776:
template <class Return, class... Args>
C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(
const TypedOperatorHandle<Return(Args...)>& op,
Args... args) const {
auto dispatchKeySet =
op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
const KernelFunction& kernel =
op.operatorDef_->op.lookup(dispatchKeySet);
// ... profiler 切片 ...
return kernel.template call<Return, Args...>(
op, dispatchKeySet, std::forward<Args>(args)...);
}
三步:
getDispatchKeySetUnboxed:扫描所有 Tensor 参数的key_set_,OR 起来,再 OR 上 thread-local state(autograd 是否启用、是否在 vmap 里、TorchDispatchMode 栈等)lookup:把 keyset 转成 dispatch table 索引,取出 KernelFunctionkernel.call(...):调用 KernelFunction 持有的具体实现
让我们逐一拆开。
5.3.1 DispatchKeyExtractor:从参数算 keyset
DispatchKeyExtractor(DispatchKeyExtractor.h)是一个针对每个算子定制的对象,它知道 哪些参数是 Tensor、哪些不是(编译期决定)。运行期它只对 Tensor 参数读 key_set_:
// 简化版
template <class... Args>
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
DispatchKeySet ks;
multi_dispatch_key_set_visitor(ks, args...); // 折叠所有 Tensor 的 key_set
return ks | tlsLocalDispatchKeySet().included() & ~tlsLocalDispatchKeySet().excluded();
}
tlsLocalDispatchKeySet 是 thread-local 的”包含 / 排除”对:
included是当前线程”应该附加上的”keys(如torch.compiletrace 时附加Functionalize)excluded是当前线程”应该屏蔽的”keys(如torch.no_grad()排除Autograd)
最终的 keyset 是 (参数 keysets ∪ tls.included) ∩ ¬tls.excluded。这条计算让 dispatcher 同时处理”张量自带”(如 requires_grad)和”上下文决定”(如 no_grad)两种来源的 dispatch 行为。
举个具体例子:用户写 with torch.no_grad(): a + b,且 a.requires_grad=True:
- 参数 keysets:
a.key_set_= {CUDA, AutogradCUDA, Dense}(因为 a 在 CUDA 上、requires_grad、稠密布局) tls.excluded:no_grad 把Autograd*加入 excluded- 最终 keyset:
{CUDA, AutogradCUDA, Dense} \ {AutogradCUDA} = {CUDA, Dense}
dispatcher lookup 这个 keyset,找到的就是 aten::add 的 CUDA backend kernel —— autograd 完全被绕过。这就是 no_grad() 怎么”关掉求导”的精确机制。
更进一步:inference_mode() 不是改 TLS,而是让张量本身就不带 Autograd key(构造时就剔除)。这样连 TLS 排除都不需要 —— dispatcher 第一步算 keyset 时就拿不到 Autograd。差异看似微小,但对小算子热路径,省一次 TLS 读取就是几十纳秒的区别。
5.3.1.5 DispatchKeyExtractor 的”参数选择”优化
不是每个算子的所有参数都参与 dispatch。比如 tensor.add(scalar),第二个参数是 Scalar 而不是 Tensor —— 它的 dtype 不影响 dispatch。DispatchKeyExtractor 在算子注册时根据 schema 算出 “哪些参数位的 keyset 要被 OR”,把它编码成一个 bitmap:
// 简化版
class DispatchKeyExtractor {
BitSet<MAX_TENSORS> dispatch_arg_indices_reverse_;
DispatchKeySet nonFallthroughKeys_;
...
};
运行期 extract 时只读这些位上的参数,跳过 non-Tensor 参数。这个优化让”5 个参数中有 2 个 Tensor”的算子的 keyset 计算只读 2 个张量,不浪费时间在 scalar / int 参数上。第 6 章代码生成器会展示这个 bitmap 是怎么从 YAML schema 推断出来的。
5.3.2 OperatorEntry::lookup:O(1) 数组索引
OperatorEntry.h:182:
const KernelFunction& lookup(DispatchKeySet ks) const {
const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
if (C10_UNLIKELY(idx == -1)) {
reportError(ks.highestPriorityTypeId());
}
const auto& kernel = dispatchTable_[idx];
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) {
reportError(ks.highestPriorityTypeId());
}
}
return kernel;
}
—— 核心就是 dispatchTable_[idx]。idx 怎么算?
// DispatchKeySet::getDispatchTableIndexForDispatchKeySet
int64_t getDispatchTableIndexForDispatchKeySet() const {
auto highest_functionality = highestPriorityFunctionalityKey();
auto highest_backend = highestBackendKey();
return num_functionality_keys * highest_backend.value()
+ highest_functionality.value();
}
回顾第 3 章 §3.5:DispatchKeySet 是 64-bit bitmap,分两段(functionality + backend)。这里就是把两段的最高位通过乘法组合成一个一维索引 —— 这是为什么 dispatchTable_ 大小是 num_functionality_keys × num_backends(约 48 × 16 ≈ 768),而不是单维数组。
为什么不直接用 __builtin_clzl(CLZ 单指令)当索引?因为 backend 与 functionality 是两个正交维度,不能简单 concat。这种”两段拼接 → 一维索引”是 PyTorch dispatcher 的核心算法 —— 既保留了多维 dispatch 的语义,又把查找压到一次数组访问。
5.3.2.5 dispatchTable_ 的尺寸账
num_runtime_entries 大约多大?精确计算:
- 约 16 个 BackendComponent(CPU、CUDA、HIP、XLA、MPS、IPU、XPU、HPU、VE、Lazy、MTIA、MAIA、PrivateUse1/2/3、Meta)
- 约 6 个 per-backend functionality(Dense / Quantized / Sparse / SparseCsr / NestedTensor / AutogradFunctionality)
- 加上若干 single-key(不是 per-backend)的 functionality keys(Tracer / Autocast* / Functionalize / Python 等)
总数大约 180-200 项。每个槽位是一个 KernelFunction(约 24 字节)—— 整个 dispatchTable_ 大约 5 KB。每个 OperatorEntry 拥有自己的 dispatchTable_。
PyTorch 注册的算子约 3000+ 个(含 ATen + autograd 包装 + 各种 mode 钩子)。所以总的 dispatchTable_ 内存占用约 15 MB。这是完全在 L1/L2 缓存够不到的范围,但 OS 把活跃的 OperatorEntry 留在 LLC(数十 MB),加上每个进程实际只用其中几百个算子,hot path 命中 cache 没问题。
注意一个细节:OperatorEntry 的 dispatchTable_ 是按值嵌入而非按指针引用。这种”扁平内存”布局让 OperatorEntry 整体在 cache line 上是连续的,访问 dispatchTable_ 不需要额外指针跳转。但这也意味着 OperatorEntry 本身比较大(约 5KB+),不能放在小型 hashmap 里。所以 PyTorch 的 operators_ 表存的是 OperatorHandle(一个 8 字节指针),真正的 OperatorEntry 单独 heap 分配。这是性能与灵活性之间又一个精细的权衡。
理解这个内存账,你能预判 PyTorch 为什么对”加一个新 functionality key”那么谨慎 —— 每加一个 single-key,所有 OperatorEntry 的 dispatchTable_ 都要扩张。
5.3.3 KernelFunction::call:模板化调用
KernelFunction(aten/src/ATen/core/boxing/KernelFunction.h:24)持有两个函数指针:
class KernelFunction final {
...
private:
BoxedKernel boxed_kernel_func_;
void* unboxed_kernel_func_; // 类型擦除的指针
void* sym_unboxed_kernel_func_; // 支持符号 shape 的版本
};
调用时根据 Return(Args...) 模板参数把 unboxed_kernel_func_ 重新解释成具体类型:
template <class Return, class... Args>
Return KernelFunction::call(...) const {
using FunctionPtr = Return(*)(const OperatorHandle&, DispatchKeySet, Args...);
auto* f = reinterpret_cast<FunctionPtr>(unboxed_kernel_func_);
return (*f)(op, dispatchKeySet, std::forward<Args>(args)...);
}
这里没有虚函数、没有 std::function、没有 hashmap —— 就是一次 reinterpret_cast 加一次函数指针调用。一个 unboxed kernel 调用的总开销在 50-100ns 级别。
5.3.4 LeftRight 并发结构:让 dispatch 路径完全不阻塞写
Dispatcher::operators_ 是一个 LeftRight<flat_hash_map<...>>。这是 PyTorch 自家造的 “reader-prioritized concurrent map”:
- 同时维护两份 hashmap “L” 和 “R”
- 读线程永不阻塞:通过原子计数 latch 选当前应读哪一份
- 写线程在另一份上更新,等所有读线程离开当前那份后切换到新版本
这个数据结构来自论文《Left-right: A concurrency control technique with wait-free population oblivious reads》。PyTorch 用它的原因是 dispatcher 主路径不能有锁 —— 任何 mutex 都会让”每秒上百万次 dispatch”成噩梦。LeftRight 让算子注册(极少发生)走慢路径,dispatch 查找(极频繁)走零阻塞快路径。
这是 PyTorch 在并发数据结构上的一笔精细投资。普通项目用 RWLock 已经够,但 PyTorch 这种”读写极不对称”的场景,LeftRight 能榨出最后几个百分点的性能。
5.4 boxed vs unboxed:一个算子的两副面孔
PyTorch 每个算子都有 两套调用约定:
graph LR
subgraph Unboxed["Unboxed Call (快路径)"]
U1["编译期已知 Return / Args"]
U2["参数直接传寄存器"]
U3["50-100 ns / call"]
end
subgraph Boxed["Boxed Call (慢路径但通用)"]
B1["参数装 IValue 栈"]
B2["运行期解析 schema"]
B3["500-1000 ns / call"]
end
Use1["at::add(a, b) C++ API"] --> Unboxed
Use2["torch._ops.aten.add.Tensor(a, b) Python"] --> Boxed
Use3["TorchDispatchMode 拦截"] --> Boxed
Use4["vmap / FakeTensor"] --> Boxed
Use5["JIT / TorchScript"] --> Boxed
style Unboxed fill:#dcfce7,stroke:#22c55e
style Boxed fill:#fef3c7,stroke:#f59e0b
unboxed 用 C++ 模板把参数类型在编译期固定,调用时不需要做类型检查 —— 像 C++ 普通函数调用一样快。boxed 把参数装成 c10::IValue(一个类型擦除的”任意 PyTorch 值”容器),通过一个统一签名 void(*)(OperatorHandle&, Stack*) 调用。
为什么需要 boxed?因为某些场景编译期不知道签名:
- Python 端:
torch.ops.aten.add.Tensor(a, b)在 Python 没有静态类型 - vmap / FakeTensor:要拦截所有算子用同一段代码处理,不能为每个算子写一份模板代码
- TorchScript / JIT:序列化的 IR 节点没有 C++ 静态信息
- Profiler / 钩子:要观察所有算子,必须用统一接口
PyTorch 用代码生成器对每个算子自动生成 unboxed → boxed 适配器和反向适配器,让两套协议在同一个 KernelFunction 里共存。源码里 BoxedKernel.h 与 KernelFunction_impl.h 主要就是处理这套适配。
5.4.1 box / unbox 的代价
box 一个 Tensor 大约花 50ns(构造 IValue + 增加引用计数)。一个有 5 个 Tensor 参数的算子,full box 接近 250ns。这就是为什么 PyTorch 一直在 hot path 上避免 box —— 一旦掉到 boxed 路径,性能就要承担显著开销。
但 boxed 路径不可避免,因为某些 mode(如 TorchDispatchMode)必须看到所有参数才能拦截。PyTorch 的策略是:
- 默认走 unboxed(编译期类型已知时)
- 遇到需要 box 的 mode 时再 box(如 dispatch 到
Pythonkey 时) - box 后下游不再 unbox(一旦 box,就一路 boxed 到底)
这种”按需 box”避免了在 95% 不需要 box 的场景里付代价。第 12 章 TorchDynamo 章会再讲:torch.compile 把整段计算图 trace 时全程 boxed,但因为编译后的 binary 不再用 dispatcher,最终性能不受 box 开销影响。
5.4.2 IValue:boxed 路径的”通用容器”
c10::IValue(aten/src/ATen/core/ivalue.h)是 PyTorch 自己设计的类型擦除变体类型 —— 类似 std::variant 但更老(早于 C++17 标准)。它能装:
Tensor/Storage/Scalarint64_t/double/bool/std::stringList<...>/Dict<...>/Tuple<...>Future/Stream/Generator/PyObject*- ……约 20 种类型
内部实现是一个 16 字节结构(Tag + 8 字节 payload):
class IValue {
Payload payload_; // 8 字节, 类型擦除
Tag tag_; // 4 字节, 标记类型
bool is_intrusive_ptr_;
...
};
box 一个 Tensor:
IValue value(tensor); // payload = tensor.unsafeReleaseTensorImpl(), tag = Tensor
// 还要 incref TensorImpl
unbox 一个 Tensor:
Tensor t = value.toTensor(); // 检查 tag, 重建 Tensor 包装
这种轻量级类型擦除让 boxed 调用约能在 几百纳秒 内完成 box+invoke+unbox 全流程。它是 PyTorch 跨 C++ / Python / TorchScript / JIT 边界的”通用货币”。
5.5 redispatch:洋葱式中间层的关键
第 1 章 §1.3 我们用 a + b 的 autograd 路径预告过 redispatch 模式:autograd 这一层”做点事”(创建 grad_fn),然后重新进 dispatcher,这次把 Autograd key 排除掉。让我们看具体实现。
每个 autograd 包装函数(由 tools/autograd/gen_variable_type.py 生成到 torch/csrc/autograd/generated/VariableType_4.cpp)大致长这样:
// 简化的 VariableType::add (生成代码)
Tensor add_Tensor(const Tensor& self, const Tensor& other, const Scalar& alpha) {
// 1. 创建反向节点
auto grad_fn = std::make_shared<AddBackward0>();
grad_fn->alpha = alpha;
if (self.requires_grad() || other.requires_grad()) {
// 设置 saved tensors / inputs
}
// 2. redispatch 到下一层
auto result = at::redispatch::add(
c10::DispatchKeySet(c10::after_autograd_keyset), // 当前 keyset 减去 Autograd
self, other, alpha);
// 3. 把 grad_fn 挂到 result
if (grad_fn) {
result.set_grad_fn(grad_fn);
}
return result;
}
第 2 步的 at::redispatch::add 调用 Dispatcher::redispatch(Dispatcher.h:198):
template <class Return, class... Args>
Return Dispatcher::redispatch(
const TypedOperatorHandle<Return(Args...)>& op,
DispatchKeySet currentDispatchKeySet,
Args... args) const {
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
return kernel.template call<Return, Args...>(
op, currentDispatchKeySet, std::forward<Args>(args)...);
}
注意区别:call 会重新算 keyset,redispatch 直接用调用方传入的 keyset。这种区分是为了让中间层精确控制下一层走什么 key。
c10::after_autograd_keyset 是预先算好的常量,等于”所有非 autograd 的 keys”。这样 redispatch 时 lookup 命中的就是 backend kernel(如 aten::add CUDA 实现),而不是又一次命中 VariableType::add 死循环。
5.5.1 redispatch 在哪些地方用
PyTorch 有大量”中间层”用 redispatch 实现:
| 中间层 | 做的事 | 减去的 key |
|---|---|---|
| Autograd | 创建反向节点 | Autograd* |
| Functionalize | 把 inplace 改成纯函数式 | Functionalize |
| AutocastCUDA | fp32 → bf16 自动转换 | AutocastCUDA |
| Tracer (老 JIT) | 记录到 IR | Tracer |
| Python (TorchDispatchMode) | Python 拦截 | Python |
| FuncTorch (vmap / grad / jvp) | 函数变换 | FuncTorchXxx |
整个 dispatcher 体系就是这些中间层叠在一起,每一层做完 own 工作后 redispatch 到下一层。像剥洋葱:
graph LR
User["a + b"] --> L1["Python TorchDispatchMode"]
L1 -- redispatch --> L2["FuncTorch (vmap/grad)"]
L2 -- redispatch --> L3["AutocastCUDA"]
L3 -- redispatch --> L4["Autograd"]
L4 -- redispatch --> L5["Functionalize"]
L5 -- redispatch --> L6["Backend Kernel<br/>(CUDA add_kernel)"]
style L6 fill:#dcfce7,stroke:#22c55e,stroke-width:2px
每一层都是可选的 —— 没启用时它不在 keyset 里、不会被 lookup 命中。这种”叠加式”设计让 PyTorch 能在不破坏现有 API 的前提下不断增加新中间层(vmap、torch.compile、TorchDispatchMode 全部是这样加进来的)。
5.5.2 一个具体追踪:vmap(grad(f))(x) 在 dispatcher 里走什么
举一个能把 redispatch 价值讲清的例子。functorch.vmap(functorch.grad(f))(x) 调用 f(x) 时,dispatcher 看到的 keyset 大约是:
{Python} | {FuncTorchDynamicLayer} | {FuncTorchVmap} | {FuncTorchGrad}
| {Autograd*} | {CUDA} | {Dense}
dispatch 优先级从高到低被这样吃掉:
Python(TorchDispatchMode 拦截,如有用户 mode) → redispatch 减 PythonFuncTorchDynamicLayer→ 切换到 vmap 内部状态 → redispatch 减自己FuncTorchVmap→ 把张量按 batch 维拆开 → redispatch 减自己FuncTorchGrad→ 准备记录梯度 → redispatch 减自己AutogradCUDA→ 创建反向节点 → redispatch 减 Autograd*CUDA→ 真正调 CUDA kernel
整个 f(x) 一次调用要进 dispatcher 6 次,每次 lookup ~50ns,总 dispatcher 开销 ~300ns。但每一层都做了关键事,且互不干扰 —— 这就是组合性 (composability) 在工程上的真实表达。
如果用 try-catch 或者全局 mutator 实现这套叠加,会有大量耦合 bug;用 dispatcher + redispatch 模式让每一层都是自包含的。
5.6 Fallthrough:让某些 key 无开销跳过
有时候我们想注册一个 key 但啥也不做,让 dispatcher 自动跳到下一个 key。比如:
Autogradkey:在 inference 张量上完全不做事AutocastCUDAkey:当 dtype 已经是 bf16 时啥也不变- 用户的 PrivateUse1 key:在某些算子上沿用 CPU 实现
PyTorch 的解法是 KernelFunction::makeFallthrough() —— 一个特殊 kernel,调用时直接 redispatch 到下一个 key:
// 简化的 fallthrough 实现
KernelFunction makeFallthrough() {
return KernelFunction(
[](const OperatorHandle& op, DispatchKeySet ks, Stack* stack) {
// 把 stack 重新喂给 dispatcher,但减去自己这个 key
redispatchBoxed(op, ks - DispatchKeySet(currentKey), stack);
},
nullptr // 没有 unboxed 实现,统一 boxed
);
}
注意 fallthrough kernel 一般是 boxed 的(因为它要处理任意算子)。但 PyTorch 做了一个聪明优化:在很多场景下 fallthrough 的开销被消除。打开 OperatorEntry.cpp 中 computeDispatchTableEntry,它会预先计算”如果某 key 是 fallthrough,下一个真正命中的 key 是哪个”,把那个 key 的 kernel 直接放到当前 key 的 dispatchTable 槽位。这样 lookup 一次就跳到真正的 kernel,根本不进 fallthrough 函数。
这就是为什么 inference_mode() 比 no_grad() 快 —— no_grad() 是在 TLS 加 excluded mask、autograd kernel 仍然走 fallthrough;inference_mode() 直接让 inference 张量不带 Autograd key、dispatcher lookup 都不进 autograd 槽位。
dispatcher 的”零成本”哲学:能在编译期 / 注册期消除的开销,绝不留到运行期。
5.6.1 backend fallback 的兜底机制
Dispatcher::backendFallbackKernels_ 是另一个”兜底”层。当某算子在某 backend 完全没注册任何 kernel 时,dispatcher 退回到这个 backend 的全局 fallback。
例子:用户想给 aten::add 在某个 fancy 自定义 backend 上”反正先走 CPU”,可以注册一个 backend fallback:
TORCH_LIBRARY_IMPL(_, MyCustomBackend, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&my_fallback_to_cpu>());
}
这样所有未在 MyCustomBackend 注册的算子都会进 my_fallback_to_cpu,里面可以做”把张量搬到 CPU 算完再搬回”的退路。这是新硬件后端”逐步迁移”的常用模式 —— 先全部 fallback 到 CPU,能跑就行;然后逐个把 hot 算子重写成 native 实现。
第 22 章自定义算子会演示完整的 backend 接入流程。
5.6.2 fallthrough 与”包含 / 排除”集合的关系
回到 §5.3.1 提过的 tlsLocalDispatchKeySet:它分 included 和 excluded 两部分。这套机制和 fallthrough 怎么共同工作?
included把某个 key 强制加入 keyset(例如torch.compile的某些阶段强制加入Functionalize)excluded从 keyset 中减去某个 key(例如torch.no_grad()减去所有 Autograd*)
如果某 key 是 fallthrough 注册的,注册期就已经预先解析”如果命中这个 key,下一个真正的 key 是哪个”。运行期 lookup 时如果命中这个槽位,直接拿到下一层 kernel。fallthrough 和 included/excluded 是两条独立的路径:前者在注册期生效,后者在 dispatcher 进入时生效。两者结合让 PyTorch 实现”灵活控制 + 零运行期开销”的多分派。
理解这条 共同 关系是写自定义 mode 时的基础。如果你不小心把某个 key 同时注册了 fallthrough 又通过 TLS 排除,行为会变得难以预测 —— PyTorch 在这种边界情况会给一个 friendly 的报错引导你修。
5.6.3 inference_mode 的”无开销”奥秘
让我们把 §5.6 与第 1 章 §1.6 提到的”为什么 inference_mode 比 no_grad 快”完整解开:
no_grad():把Autograd*加入 TLS excluded。但每次 dispatch 仍要算 keyset、查 dispatchTable、命中 fallthrough 槽位。fallthrough 在注册期被预解析后直接跳到下一层 kernel,运行期开销几乎为零,但 keyset 计算和 lookup 仍要做inference_mode():让张量构造时就不带Autograd*key。后续的 keyset 根本不包含 autograd 位,dispatcher 第一步算 keyset 就拿到一个”干净”的、与”non-autograd 张量”完全相同的 keyset
差异在小算子上 5-10%。inference_mode() 的代价是张量”结构上”已经不参与 autograd,不能再用 requires_grad_(True) 救回来。所以它是”明确知道这段代码不反向”时的最优选项。第 7 章 autograd 章会再讲这两种 mode 的实现细节。
5.7 TorchDispatchMode:Python 端的 dispatcher 拦截
第 1 章 §1.7 预告过 TorchDispatchMode —— 一个 Python 端可以在 dispatcher 层拦截所有算子的接口。它的实现是 Python 这个 DispatchKey。
工作机制:
sequenceDiagram
autonumber
participant U as Python with CountOps()
participant TLS as thread-local
participant D as Dispatcher
participant P as PythonKernel
participant B as Backend
U->>TLS: __enter__: 把 Python key 加入 included
Note over U,TLS: 进入 with 块
U->>D: a + b (任意算子)
D->>D: keyset = ... ∪ {Python}
D->>D: lookup → PythonKernel
D->>P: 调 boxed Python kernel
P->>P: 调 user.__torch_dispatch__(op, types, args, kwargs)
P->>D: user 内部调 op(*args) 触发 redispatch
D->>D: keyset 不再带 Python (TLS 暂时排除)
D->>B: lookup → Backend kernel
B-->>P: 返回结果
P-->>U: 返回结果
U->>TLS: __exit__: 把 Python key 从 included 移除
注意第 9 步:当 Python kernel 内部又调 op(*args) 触发 redispatch,Python 自动把当前 mode 暂时屏蔽(用 _TorchDispatchModeKey 上下文管理器),避免无限递归。
这套机制让用户能用 ~10 行 Python 代码做:
- 算子计数(第 1 章 §1.7 例子)
- 假执行(FakeTensorMode 实现,torch.compile 用)
- inplace 改纯函数式(FunctionalMode 实现)
- 把每次算子调用记录成 FX 节点(ProxyMode 实现)
- 自动混合精度(AutocastMode)
- ……
所有”对每个算子做点什么”的需求,都可以用 Mode 实现。这是 PyTorch 元编程能力的巅峰,也是为什么 torch.compile 能在不破坏 eager API 的前提下优雅切入 —— 它就是几个 Mode 的组合。
5.7.1 实现一个真正有用的 Mode
举个有信息量的例子:写一个 MemoryLeakDetector Mode,记录每次 alloc 但不记录 free,最终找出”只 alloc 不 free”的张量来源:
import torch, traceback
from torch.utils._python_dispatch import TorchDispatchMode
class MemoryLeakDetector(TorchDispatchMode):
def __init__(self):
self.allocations = {} # tensor_id -> stack trace
def __torch_dispatch__(self, op, types, args=(), kwargs=None):
result = op(*args, **(kwargs or {}))
# 只记录 op 名字含 "empty" 或 "zeros" 的(典型分配)
if 'empty' in str(op._schema.name) or 'zeros' in str(op._schema.name):
if isinstance(result, torch.Tensor):
self.allocations[id(result)] = traceback.format_stack()
return result
def report(self):
print(f'Live allocations: {len(self.allocations)}')
短短 15 行就实现了一个张量分配追踪器。无 Mode 的等价物需要 monkey-patch torch.empty、torch.zeros 等 N 个 API,且会被绕过 —— Mode 因为在 dispatcher 层拦截,任何路径来的算子都逃不掉。
5.7.2 Mode 与 __torch_function__ 的区别
新手容易混淆 TorchDispatchMode 与 __torch_function__:
__torch_function__在 Python 端拦截。挂在 Tensor 子类上,每次 Python 调torch.add(...)等 API 时被检查TorchDispatchMode在 C++ dispatcher 层拦截。Python kernel 通过PythonDispatchKey 命中
差别大致:
| 特性 | __torch_function__ | TorchDispatchMode |
|---|---|---|
| 拦截点 | Python API | C++ dispatcher |
| 看到的算子粒度 | torch.add 这种”友好”名字 | aten::add.Tensor 这种 schema-精确名字 |
| 开销 | 每次 Python call 多一次属性查找 | 每次 dispatch 多一次 boxed call |
| 能拦截 C++ 内部调用吗 | 不能 | 能 |
| 能拦截 autograd 反向 | 不能(反向走 C++) | 能 |
要做”假执行 / FX trace / 全局观察”这类需求,用 TorchDispatchMode。要做”自定义子类让用户用起来像 Tensor”这类需求,用 __torch_function__。第 9 章 nn.Module 章会再讨论这两条路径。
5.7.2.5 一个被低估的应用:分布式调试
写过分布式训练的人都知道 NCCL hang / OOM 之类的故障极难诊断,因为某个 rank 上某个算子突然慢下来,其他 rank 在 AllReduce 上等成一片。
用 TorchDispatchMode 可以做一个 RankAwareLogger:每个 rank 拦截所有算子,记录”算子名 + 张量 shape + 时间戳”到本地文件。当 hang 发生后,把所有 rank 的日志拼起来对比,就能精确看到”哪个 rank 在哪个算子上和别人不同”。
这种调试模式在 PyTorch 内置工具之外是个”用 30 行 Python 做出价值连城工具”的典型例子。我在多个生产 issue 里见过这套思路救命 —— dispatcher mode 是分布式诊断的杀手锏之一。
5.7.3 stacked Mode:嵌套 with 块
TorchDispatchMode 支持嵌套:
with FakeTensorMode():
with FunctionalMode():
out = f(x) # 两个 mode 都看得到
实现是在 Python 端维护一个 mode stack,每次进 dispatcher 命中 Python key 时按栈顶 mode 调它的 __torch_dispatch__。栈顶 mode 内部调 op(*args) 触发 redispatch 时,当前 mode 出栈,下一个 mode 接管。这种”用栈做嵌套”的模式让 mode 可以任意组合,是 torch.compile 多个 trace 阶段联动的关键。
5.8 性能数字:dispatcher 到底有多快
实测数据(H100,PyTorch v2.11,torch.add(a, b) 单次调用):
| 路径 | 平均耗时 |
|---|---|
| 纯 unboxed call (无 autograd, no profiler) | ~120 ns |
| 含 autograd 单次 redispatch | ~700 ns |
| 含 TorchDispatchMode 拦截 | ~3-5 μs |
| Python 端调 torch.add (含 pybind 解参数) | ~1.5 μs |
最快的 unboxed 调用 120 ns 是什么概念 —— 比一次 std::shared_ptr 复制还快(shared_ptr 复制约 50-100ns)。这就是 dispatcher 在工程上的成就:让”多分派”几乎和直接函数调用一样便宜。
5.8.1 dispatcher 开销的累积
虽然单次 120ns 已经很快,但训练里每个张量 op 都要进一次 dispatcher。一个简单的 transformer block:
Attention: norm + linear*4 + matmul + softmax + matmul + linear ≈ 12 算子
MLP: linear + gelu + linear ≈ 3 算子
total ≈ 15 算子 / block
70 层 transformer = 1050 个 dispatcher 调用 / forward;加上反向加倍到 ~2000 次。每次 ~120ns ≈ 240 微秒 / step 的 dispatcher 总开销。
对一个完整 step 包含 GPU 计算几十毫秒的训练,240us 的 dispatcher 开销占比 < 1%,可以忽略。但对推理时的小 batch decode(每 step GPU 计算可能只有几百微秒),dispatcher 开销可能占 20-30%。这是为什么 vLLM 这类推理引擎的 V1 架构要把 dispatcher 完全 bypass —— 编译期固化整个调用链。
5.8.2 减少 dispatcher 开销的几条路
如果你需要在小算子密集场景挤性能,几条思路:
torch.compile:把整段 forward 编成一个 Triton kernel,dispatcher 调用归零torch.cuda.graph:把 dispatcher + kernel launch 压成一次 graph replay- 内联融合算子:把
q = qkv[:, 0]; k = qkv[:, 1]; ...改成q, k, v = qkv.unbind(1),省两次 dispatcher - 使用 fused 算子:
torch.nn.functional.scaled_dot_product_attention把 4-5 次 dispatcher 压成 1 次 - 禁用 autograd:在确定不反向的代码用
inference_mode()而非no_grad(),省 fallthrough 开销
第 21 章 profiler 章会教你怎么用 PyTorch 自带工具找出 dispatcher 开销占比高的代码段。
但一旦掉到 boxed 路径或者有 mode 干预,开销会跳一个量级。这是为什么 torch.compile 的核心动机之一就是 把 dispatcher 整段 bypass —— 编译后的 binary 直接调底层 kernel,不再走 dispatcher。
5.8.5 调试工具:TORCH_SHOW_DISPATCH_TRACE
PyTorch 在 debug 编译里提供一个环境变量 TORCH_SHOW_DISPATCH_TRACE=1,开启后 dispatcher 会把每一次 call / redispatch 打印到 stderr:
TORCH_SHOW_DISPATCH_TRACE=1 python -c "
import torch
a = torch.randn(2, 2, requires_grad=True)
b = torch.randn(2, 2)
c = a + b
"
输出(节选):
[call] aten::add.Tensor [AutogradCPU, CPU]
[redispatch] aten::add.Tensor [CPU]
[call] aten::empty.memory_format [CPU]
[call] aten::add_out [CPU]
...
每行的方括号里是 dispatcher 算出的 keyset。你能清楚看到 redispatch 在哪些层级发生、最终命中哪个 backend kernel。这是诊断”为什么我的算子走错路径”的关键工具。
实战里我用过这个 trace 做的最有价值的事:诊断一个自定义算子在 torch.compile 下的 lowering 失败 —— 看 trace 才发现某次 dispatch 命中了 Functionalize key,但我的算子没注册 functionalize 实现,导致 fallback 到 boxed 路径,最终被 Inductor 拒收。如果不开 trace,根本看不到 functionalize 那一层的存在,错误信息只是”unsupported op in inductor”,会浪费一整天去猜。
dispatcher 是 PyTorch 大部分”奇怪行为”的真实源头,trace 是直接看清这些行为的最锋利工具。
不过这个 trace 只在 Dispatcher.h:782 的 #if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) 编译时存在 —— 默认 release 构建不带。要在 debug 编译里用,或者你自己源码改一行重编。
第 22 章自定义算子那章,这个 trace 是不可缺少的诊断工具。
5.8.6 历史回顾:dispatcher 重构的几次浪潮
PyTorch dispatcher 自身有清晰的演进史:
| 时期 | 特点 |
|---|---|
| pre-1.0 | 没有真正的 dispatcher。算子分派散落在 Type 抽象与 if-else 里 |
| 1.0-1.4 | ”Dispatcher V1”:基本的 DispatchKey + per-op kernel registration,但功能性 key 与 backend key 混在一个 enum |
| 1.5-1.7 | Edward Yang 主导的 “Dispatcher V2” 重构:DispatchKey 重新设计为 functionality + backend 二维 |
| 1.8-1.10 | 引入 __torch_dispatch__ Python mode,奠定 functorch 基础 |
| 2.0+ | torch.compile 上线,dispatcher 与 Dynamo / Inductor 配合 |
| 2.4+ | Compiled Autograd 让反向也能被 Inductor 编译,dispatcher 在反向路径上的开销也能被消除 |
理解这条演进,你能更好地预判 PyTorch 团队下一步会做什么 —— 比如 v3.0 之后大概会把更多的 functionality key(如 vmap)下沉到编译路径,让 dispatcher 主路径更”瘦”。
5.9 跨书关联
- 《Tokio 异步运行时》第 X 章 work-stealing 调度器:dispatcher 的 redispatch 模式与 Tokio future 的
poll调用链有相似的”一层做事 + 让出 + 下一层接管”思想 - 《Rust 编译器之路》trait 解析:Rust 的 trait 选择是编译期的”多分派”,PyTorch 的 dispatcher 是运行期的”多分派”。前者零开销但需重新编译,后者有开销但允许动态注册 —— 各有所长
- 《MCP 协议剖析》第 X 章 RPC 派发:MCP server 的方法路由也是”name → handler”的多态映射,但 MCP 是 string-keyed hashmap,PyTorch 是 enum-keyed array —— 性能差几个量级
- 《vLLM 内核探秘》第 5 章 调度器:vLLM V1 引擎的 dispatch 也用了 functionality + backend 的二维分派思想,但因为用户场景更窄(只有 LLM 推理),它选择了更简化的实现 —— 一个有趣的对照
5.9.5 dispatcher 设计的”通用启示”
如果你在自己的项目里设计一个”运行期多分派”系统,本章 PyTorch 的设计经验值得带走:
第一:bitmap + 优先级排序 + 数组下标查找 是处理 N 维多分派的经典套路。其他选择(visitor + variant、虚函数链、hashmap)在性能上都不如这个组合
第二:注册表与查找表分开:注册表(kernels_)允许复杂操作(重注册、查询、迭代),查找表(dispatchTable_)只为 hot path 服务,cache friendly
第三:alias 在注册期解析,运行期不付代价:把”逻辑组合”的复杂度推到注册期(一次性解析),运行期就是 O(1) 数组下标
第四:**两套调用约定(boxed + unboxed)**让”性能优先的常态路径”和”通用性优先的扩展路径”共存。Java 的 invokevirtual / invokedynamic 也是这种思想
第五:Mode 系统用一个特殊 key + thread-local stack 实现可叠加的运行期拦截器。这套模式在元编程能力上几乎无敌
第六:让 hot path 上的检查在编译/注册期消除:fallthrough 的”展开”、alias key 的解析、DispatchKeyExtractor 的”哪些参数参与”全部在注册期完成,运行期只剩单数组访问
把这五条记下来,你设计任何”运行期多分派”系统都有参照。
5.9.6 一句话回顾整个 dispatcher
把整章的精华浓缩成一句:
PyTorch dispatcher = 把”张量元数据 + 线程上下文”通过 bitmap OR 算成 64-bit DispatchKeySet → 用 (functionality 最高位) × (backend 最高位) 算出数组索引 → 取出 KernelFunction → 调它 → 多层中间层用 redispatch 减自己的 key 后递归命中下一层。
这一句话能在面试时让你 3 分钟讲清 PyTorch 的核心调度系统。
5.10 本章 takeaway
读完本章你应该能回答这些问题:
- 为什么 PyTorch 的 dispatcher 比 vtable 慢得多但比手写 if-else 快得多? —— 因为它把 100+ 维度的多分派压成单次数组下标
no_grad()与inference_mode()的真实差异? —— no_grad 走 fallthrough(仍命中 autograd 槽位),inference_mode 直接不带 Autograd key- box / unbox 的本质? —— unbox 是 C++ 模板(编译期类型固定),box 是 IValue 栈(运行期动态)
- redispatch 解决了什么? —— 让 autograd / functionalize / autocast 等多个正交”中间层”可以叠加而不互相干扰
- TorchDispatchMode 怎么实现的? —— Python DispatchKey + 自动屏蔽自身防递归
- dispatcher 的性能极限? —— ~120ns 单次 unboxed 调用,已经接近”理论下限”
如果你能在脑子里画出”a + b 在 dispatcher 里走过的几步” —— getDispatchKeySetUnboxed → lookup → KernelFunction::call → 可能的 redispatch —— 那本章的目标就达到了。后面 18 章的所有内容都建立在这套机制上:autograd 是 dispatcher 中间层、torch.compile 是替换 dispatcher、FSDP 是在 dispatcher 之上加 hook、自定义算子是往 dispatcher 注册新 kernel。理解了 dispatcher,你就理解了 PyTorch 一半。
5.11 几条容易踩的坑
实战里最常见的几个 dispatcher 相关坑:
1. 误以为 __torch_function__ 子类会被 dispatcher 看到:实际上 __torch_function__ 只在 Python API 层拦截,dispatch 进 C++ 后子类信息丢失。要让 dispatcher 看到自定义类型,用 TorchDispatchMode 或自定义 backend。
2. 给 alias key 注册的 kernel 被具体 backend 覆盖:你给 CompositeExplicitAutograd 注册了一个实现,期望所有 backend 都用这个;但某个 backend 又注册了自己的具体实现,最终 dispatcher 用具体的(更高优先级)。如果你想”无条件”覆盖,要给每个 backend 单独注册或者使用更高优先级 alias。
3. fallthrough 链太深导致 lookup 退化:极端情况下注册了非常多 fallthrough,注册期的”展开”算法会让 dispatchTable_ 的某些槽位指向多次跳转后的 kernel —— 这个开销在注册期付,运行期 lookup 仍是 O(1),但调试时你会看到 dispatch trace 比预期短得多。
4. with torch.no_grad(): 包了 reentrant autograd 调用:某些算子内部会重新 enable autograd(如 checkpointing)。这种代码与外层 no_grad 互动有时反直觉,需要看 dispatch trace 才能确认。
第 22 章自定义算子会列出更长的”自定义算子注册时容易踩的坑”清单。
5. __torch_dispatch__ 错误地修改输入张量:dispatch mode 拦截到的算子参数有时候是用户原始张量、有时候是 wrapper 子类。盲目 inplace 修改可能污染外部状态。规范做法:永远 op(*args) 转发给底层 + 把返回值包回自己的子类,不直接改输入。
5.12 一个练习:观察自己的 dispatch 链
import torch
torch.ops.aten.add.Tensor
# <OpOverload(op='aten.add', overload='Tensor')>
# 看注册的 dispatch keys
op = torch.ops.aten.add.Tensor
print(op._defined_in_python)
# 用 introspection API 看每个 key 注册了什么
for key in ['CPU', 'CUDA', 'AutogradCPU', 'AutogradCUDA',
'CompositeImplicitAutograd', 'CompositeExplicitAutograd']:
try:
print(f'{key}: registered =', op.has_kernel_for_dispatch_key(key))
except Exception as e:
print(f'{key}: {e}')
这段代码能让你看到 aten::add.Tensor 在哪些 key 注册了。你会发现 CPU 和 CUDA 当然有,但 AutogradCPU 和 AutogradCUDA 也有 —— 因为 autograd 包装也是一个独立的 kernel。把这些信息和 dispatch trace 对照看,dispatcher 的全貌就在脑子里了。
5.12.1 进阶:自己注册一个新 Functionality Key
为了真正掌握 dispatcher,可以试着用 PrivateUse1 或者通过 PyTorch 的 torch.library 接口注册一个属于自己的 dispatch key。完整的练习可以参考 PyTorch 官方教程 “Extending dispatcher for a new backend in C++“(搜 PyTorch docs 即可)。
写完这个练习之后,你会真正体会到:dispatcher 不是黑盒、也不只是 PyTorch 内部用的工具,它是一个开放的多分派引擎,任何想在 PyTorch 生态里做自定义扩展的人都可以接入。这就是为什么国产 AI 芯片厂商能在不修改 PyTorch 主仓的情况下接入自家硬件 —— 全靠 dispatcher 这个公开扩展点。
下一章拆 ATen 代码生成体系:native_functions.yaml + derivatives.yaml + torchgen 怎么把数千个算子的 C++ / Python / autograd 包装一次性生成出来。本章讲的 dispatchTable_ 注册过程,就是代码生成器最后那一步往 dispatcher 喂数据。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。