第11章 DataLoader 与多进程数据流水线

“If your GPU is idle waiting for data, your DataLoader is the slowest thing in your training. Master it or lose half your throughput.”

—— PyTorch DataLoader 文档(节录改编)

本章要点

  • DataLoader 是 4 件套Dataset(怎么读一条数据)、Sampler(决定读哪些)、collate_fn(怎么把多条合成 batch)、DataLoader(编排整个流水线)
  • num_workers > 0 启动 worker 子进程:每个 worker 跑一个独立 Python 解释器,通过 multiprocessing.Queue 与主进程通信
  • 两条 Queue 协作index_queue(主进程→worker,发送”取第几个 sample”指令)、data_queue(worker→主进程,返回数据)
  • pin_memory=True 启动一个独立的 pin memory 线程:从 worker 收到的 CPU 数据搬到锁页内存,让后续 H2D 拷贝能 async + 跳 page fault
  • persistent_workers=True 让 worker 进程跨 epoch 复用:避免每个 epoch 重启 worker 的几秒~几十秒开销
  • 每个 worker 的 RNG seed = base_seed + worker_id:保证多 worker 之间数据不重复但可复现
  • GPU 训练里 DataLoader 不能成为瓶颈num_workers 应让 worker 池吞吐 ≥ GPU 训练吞吐

11.1 一行 for 循环背后的系统

每个 PyTorch 用户都写过:

loader = DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)

for batch in loader:
    out = model(batch.cuda(non_blocking=True))
    loss = compute_loss(out)
    loss.backward()
    optimizer.step()

for batch in loader 这一行,背后是 4 个 worker 子进程 + 1 个 pin_memory 线程 + 2 条 multiprocessing queue 协同工作的流水线。如果 dataset 是从磁盘读图(IO bound)或者要做 augment(CPU bound),这套机器把 数据准备并行化,让 GPU 不用等。

本章拆这套流水线。源码主要在 torch/utils/data/dataloader.py(1709 行)和 torch/utils/data/_utils/

11.2 4 件套数据抽象

PyTorch 数据加载的核心抽象:

graph LR
    Sam[Sampler<br/>决定读哪些索引]
    Ds["Dataset<br/>__getitem__(idx) → 一条样本"]
    Co["collate_fn<br/>list[sample] → batch"]
    Dl[DataLoader<br/>编排所有]

    Sam --> Dl
    Ds --> Dl
    Co --> Dl

    style Dl fill:#fef3c7,stroke:#f59e0b,stroke-width:2px

逐个介绍:

11.2.1 Dataset

class Dataset:
    def __getitem__(self, index): ...
    def __len__(self): ...

Map-style Dataset:实现 __getitem____len__。典型用例:图片数据集(每个 idx 对应一个文件)。

class ImageDataset(Dataset):
    def __init__(self, root, transform):
        self.files = sorted(os.listdir(root))
        self.transform = transform

    def __getitem__(self, idx):
        img = Image.open(self.files[idx])
        return self.transform(img)

    def __len__(self):
        return len(self.files)

Iterable-style DatasetIterableDataset):实现 __iter__。用于流式数据(如 Kafka stream、巨量 webdataset)。它的特殊性是没有 len,长度要么不知道、要么靠 worker 自己处理。

11.2.2 Sampler

Sampler 决定”按什么顺序遍历 dataset”。最常用的几种:

  • SequentialSampler:按顺序 0, 1, 2, …, N-1
  • RandomSampler:每个 epoch 随机打乱
  • WeightedRandomSampler:按权重采样(用于不平衡数据集)
  • BatchSampler:把 sample 索引打包成 batch
  • DistributedSampler:多卡训练时按 rank 分片

DataLoader(shuffle=True) 实际就是用 RandomSampler 包装。DistributedSampler 是 DDP 训练的关键 —— 它保证 rank 0 拿到 [0, N, 2N, ...]、rank 1 拿到 [1, N+1, 2N+1, ...],不重复不遗漏。第 17 章 DDP 章会详细讲。

11.2.3 collate_fn

collate_fn 把 list of samples 合成一个 batch。默认 default_collate 处理常见类型:

  • 张量 list → torch.stack
  • numpy list → torch.as_tensor + stack
  • list of tuple → tuple of stacked tensors
  • list of dict → dict of stacked tensors
  • str / int / float → list(不堆叠,作为 batch metadata)

但变长数据(如 NLP 句子)要写自定义 collate_fn 做 padding:

def my_collate(batch):
    # batch: list of (input_ids, label)
    inputs = [b[0] for b in batch]
    labels = torch.tensor([b[1] for b in batch])
    inputs_padded = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    return inputs_padded, labels

collate_fn 的执行位置很关键 —— 它在 worker 进程里跑(不在主进程)。这意味着 padding / 转 tensor 等 CPU 工作能并行,主进程只接收已经合好的 batch。

11.2.4 DataLoader

DataLoader 是这 4 件套的”总编排器”:

class DataLoader(Generic[_T_co]):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        num_workers=0,
        collate_fn=None,
        pin_memory=False,
        persistent_workers=False,
        prefetch_factor=None,
        ...
    ):
        ...

源码在 torch/utils/data/dataloader.py:142。它根据 num_workers 选择两套实现:_SingleProcessDataLoaderIter(同步)或 _MultiProcessingDataLoaderIter(异步)。

11.3 单进程模式:最简实现

num_workers=0 时 DataLoader 在主进程跑:

# 简化版
class _SingleProcessDataLoaderIter:
    def __next__(self):
        index = next(self._sampler_iter)        # 取下一个 batch 索引
        data = self._dataset_fetcher.fetch(index)  # 直接调 dataset.__getitem__ + collate
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
        return data

——拿索引 → 读数据 → 合 batch → pin → return。主进程顺序跑,下一个 batch 等当前 batch 训练完才开始读。

适用场景:

  • Dataset 是内存里的张量,读取是 O(1)
  • 调试时想避免多进程导致 stack trace 丢失
  • 单卡 + 小数据集 + GPU 比 CPU 快得多

但绝大多数训练场景下 num_workers=0 会让 GPU 大部分时间空等。这就是多进程模式的存在意义。

11.4 多进程模式:核心架构

num_workers > 0 时 DataLoader 启动一组 worker 子进程,主进程通过 queue 与它们通信:

graph TB
    Main[主进程<br/>训练循环]
    Sam[Sampler]
    PMT[pin_memory thread<br/>主进程内]

    subgraph Workers["worker 子进程 (4 个)"]
        W0[Worker 0<br/>_worker_loop]
        W1[Worker 1<br/>_worker_loop]
        W2[Worker 2<br/>_worker_loop]
        W3[Worker 3<br/>_worker_loop]
    end

    IQ[index_queue<br/>主→worker]
    DQ[data_queue<br/>worker→pin_thread]
    OQ[输出 queue<br/>pin_thread→主]

    Main -->|"取下一个 batch 索引"| Sam
    Sam -->|索引| IQ
    IQ --> W0
    IQ --> W1
    IQ --> W2
    IQ --> W3

    W0 -->|读取 + collate| DQ
    W1 --> DQ
    W2 --> DQ
    W3 --> DQ

    DQ --> PMT
    PMT -->|pin 后| OQ
    OQ -->|"next() 返回"| Main

    style Main fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
    style PMT fill:#dbeafe,stroke:#3b82f6
    style Workers fill:#dcfce7,stroke:#22c55e

工作流:

  1. 启动时主进程 fork N 个 worker,每个跑 _worker_loop
  2. 主进程预先往 index_queue push 几条 batch 索引(prefetch_factor 决定预取多少)
  3. worker pop 索引、调 dataset.__getitem__ 读取、collate、把结果 put 到 data_queue
  4. 主进程从 output queue(如果开 pin_memory,是 pin_thread 的输出 queue)pop 出 batch
  5. 每 pop 一个,主进程 push 一个新索引到 index_queue 维持 pipeline

整套机制让 worker 与训练真正并发 —— GPU 跑 step N 时 worker 在准备 step N+1、N+2 的数据。

11.5 worker 主循环

打开 torch/utils/data/_utils/worker.py:228_worker_loop(精简版):

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, ...):
    # 1. 设置 signal handlers 防止 segfault 进程死锁
    signal_handling._set_worker_signal_handlers()

    # 2. 单线程模式 (avoid CPU thrashing)
    torch.set_num_threads(1)

    # 3. 设置 RNG seed
    seed = base_seed + worker_id
    random.seed(seed)
    torch.manual_seed(seed)
    if HAS_NUMPY:
        np.random.seed(np_seed)

    # 4. 暴露 worker info 给用户回调
    global _worker_info
    _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset)

    # 5. 创建 fetcher (Dataset 包装器, 处理 collate)
    fetcher = _DatasetKind.create_fetcher(...)

    # 6. 主循环
    while watchdog.is_alive():
        r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        if r is None:
            break    # 收到关闭信号
        idx, index = r
        try:
            data = fetcher.fetch(index)        # 读取 + collate
        except Exception:
            data = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}")
        data_queue.put((idx, data))            # 发回主进程

几个细节:

11.5.1 torch.set_num_threads(1) 防 CPU 抢占

每个 worker 进程默认会让 PyTorch C++ 后端开多线程跑算子(如 torch.einsum),但 num_workers=4 加上每个 worker 4 个 thread = 16 thread 抢 CPU 核 → 严重 thrashing。torch.set_num_threads(1) 让每个 worker 单线程,4 个 worker 加起来正好用 4 核。

如果你的 dataset transform 用了大量 PyTorch 算子(如 vision transform),这条限制能避免 30%-50% 性能损失。

11.5.2 RNG seed:reproducible 但不重复

worker 0 的 seed = base_seed + 0、worker 1 = base_seed + 1,etc.。这保证:

  • 不同 worker 的随机性不同(augment 不会重复)
  • 每次运行一致(用同一个 base_seed 重跑能复现)
  • 跨 epoch 不重复:第 7 章我们看到 base_seed 每次 reset 时由 PyTorch 自动从主进程的 RNG 推导

如果你的 dataset 用 numpy.random 而不是 torch.rand,要小心。numpy seed 也由 PyTorch 自动设置(worker.py:262),但用户自定义的全局变量 RNG 不受管。生产代码里建议优先用 torch.rand 与 dataset 内部的 torch.Generator 实例。

11.5.3 ExceptionWrapper:跨进程异常传递

worker 抛异常时不能直接 raise(多进程间异常无法跨进程传播),用 ExceptionWrapper 把异常 + traceback 序列化成普通对象,发回主进程后由主进程 reraise。这种”序列化异常”模式让 worker 的报错在主进程能看到完整 stack trace。

11.6 pin_memory:异步 H2D 的关键

pin_memory=True 启动一个主进程内的线程(不是子进程!),把 worker 返回的 CPU tensor 搬到 pinned (page-locked) memory:

# torch/utils/data/_utils/pin_memory.py:18
def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
    torch.set_num_threads(1)
    torch.accelerator.set_device_index(device_id)

    while not done_event.is_set():
        r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        idx, data = r
        if not isinstance(data, ExceptionWrapper):
            data = pin_memory(data, device)
        out_queue.put((idx, data))

pin_memory(data, device) 递归遍历 data(dict / list / tensor)调 tensor.pin_memory()。pinned memory 的好处:

  • H2D 拷贝可以 async(CPU 不用等 GPU)
  • CUDA driver 直接 DMA,不走 page fault
  • 配合 tensor.cuda(non_blocking=True) 拿到完整异步性能

为什么用线程而不是子进程?因为 pin_memory() 调用要访问 CUDA context —— 子进程要自己 init CUDA 太重。线程在主进程里直接用主进程的 CUDA context,简单高效。

实测:开启 pin_memory=True + tensor.cuda(non_blocking=True) 配合,H2D 阶段通常能完全 overlap 到训练里,节省 5-15% 训练时间。

11.7 persistent_workers:跨 epoch 复用 worker

默认每个 epoch 结束 DataLoader 关闭所有 worker,下个 epoch 重启。重启代价:

  • fork 进程几百毫秒到几秒
  • 每个 worker 重新 torch.manual_seed、加载 transform 模块、初始化 dataset
  • 总 epoch 切换时间可能几秒到几十秒

如果 dataset.init 里要读大量元数据(如 ImageNet 1.28M 图片清单),这个重启成本就尤其高。

persistent_workers=True 让 worker 跨 epoch 不退出,每个 epoch 重发新 sampler 状态即可:

loader = DataLoader(dataset, num_workers=4, persistent_workers=True)
for epoch in range(100):
    for batch in loader:    # 第 2 个 epoch 起 worker 不重启
        ...

代价是 worker 持续占着内存,dataset 里的 cached 数据(如预读到内存的 ImageNet)一直占用 RAM。生产代码里几乎所有训练循环都该开 persistent_workers,除非 worker 占内存极大。

11.8 prefetch_factor:预取深度

prefetch_factor 默认 2,意思是每个 worker 预取 2 个 batch 在 queue 里。总 prefetch = num_workers * prefetch_factor

为什么要预取?因为 worker 跑 dataset.__getitem__ 有一定延迟(如读图、解码、augment)。如果训练消费数据的速度比 worker 快,就会有”GPU 等数据”的 stall。预取让 queue 里始终有几个 batch 待取,吸收 worker 间的速率波动。

调优建议:

  • 默认 2 适合大多数场景
  • IO 抖动大(如从 S3 读)→ 提升到 4-8
  • 内存紧张 → 降到 1,但要确保 worker 数足够多

11.9 BatchSampler:用 batch 索引省 sampler 开销

默认 DataLoader 内部把 Sampler + batch_size 组合成 BatchSampler。BatchSampler 一次返回 batch_size 个索引:

# 默认行为
sampler = RandomSampler(dataset)            # 给单个索引
batch_sampler = BatchSampler(sampler, batch_size, drop_last)  # 给 batch 索引

for indices in batch_sampler:    # indices 是 list of 32 个索引
    samples = [dataset[i] for i in indices]
    batch = collate_fn(samples)

用户可以传自定义 batch_sampler 替代默认。常见用法:动态 batch size(如按序列长度排序后凑满 token 预算的 batch):

class DynamicBatchSampler:
    """按 token 数构造 batch, 长序列少配, 短序列多配, 让每个 batch 总 token 数接近"""
    def __iter__(self):
        ...

这种”按 token 而非按 sample 数”在 NLP 训练里(特别是变长序列)能显著提升 GPU 利用率。

11.9.5 fork vs spawn:跨平台陷阱

multiprocessing 启动子进程有两种模式:

  • fork(Linux 默认):子进程继承父进程内存。Dataset 对象不需要 picklable,速度快
  • spawn(macOS / Windows 默认 from Python 3.8+):子进程从零启动 + pickle 传递所有参数。Dataset 必须 picklable(不能含 lambda / 闭包 / unpicklable 对象)

后果:在 Mac 调通的 DataLoader 在 Linux 上 hang,或者反过来 —— 都是真实碰到的工程坑。

# 强制指定启动方式
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

CUDA 训练强制要求 spawn。原因:fork 出来的子进程继承父进程的 CUDA context 但无法正常使用,会抛 “Cannot re-initialize CUDA in forked subprocess” 错误。Linux 上跑 GPU 训练时如果你显式 mp.set_start_method('fork') 一定会崩。

PyTorch 2.x 开始 DataLoader 在 CUDA 场景自动选 spawn,避免新手踩坑。但理解这条差异让你看到 _pickling not supported 报错时不会困惑 —— 那是 spawn 在序列化 Dataset 时失败了。

11.9.6 worker shutdown 的精细同步

worker.py 顶部有一段 200+ 行注释 “NOTE [Data Loader Multiprocessing Shutdown Logic]“,描述了关闭流程。简化版:

  1. 主进程退出迭代 → 发 None 到每个 worker 的 index_queue(关闭信号)
  2. 主进程 set done_event
  3. worker 在每次 index_queue.get() 超时(MP_STATUS_CHECK_INTERVAL)时检查 done_event_workers_status
  4. worker 收到 Nonedone_event.is_set() → break 主循环 → 进程退出
  5. 主进程 join 所有 worker(等子进程退出)
  6. 关闭 queue + signal handlers

为什么这么复杂?因为有几种异常退出场景要处理:

  • 主进程被 Ctrl+C 杀掉(worker 必须在父进程死亡时也死,否则成僵尸)
  • worker 自己崩了(main 必须感知到、报错而不是 hang)
  • IterableDataset 提前 StopIteration(必须告诉 main “我没数据了,别再发任务给我”)

signal_handling._set_worker_signal_handlers()(worker.py:253)注册 SIGBUS / SIGSEGV handler,让 worker 崩溃时主进程能立即感知。ManagerWatchdog(line 314)持续 ping 父进程,父进程死亡时 worker 自杀。

如果你写训练代码遇到 “ctrl-c 后 python 进程 hang 不退”,几乎都是 DataLoader shutdown 流程出了问题。persistent_workers=True 引入的复杂度让这套同步逻辑更微妙。

11.9.7 IterableDataset 的 worker 分片

Dataset(map-style)有 __len__,DataLoader 能给每个 worker 分配不同索引。但 IterableDataset 没有 __len__每个 worker 默认会重复跑同一份数据流 —— 4 个 worker = 4 份重复 batch!

正确做法:dataset 内部按 worker_id 分片

class StreamDataset(IterableDataset):
    def __init__(self, files):
        self.files = files

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            files = self.files                    # 单进程, 全部
        else:
            n = worker_info.num_workers
            i = worker_info.id
            files = self.files[i::n]              # worker i 取第 i, i+n, i+2n, ... 份
        for f in files:
            yield from read_file(f)

torch.utils.data.get_worker_info() 在主进程返回 None,在 worker 里返回 WorkerInfo(id, num_workers, seed, dataset)。这是 IterableDataset 多 worker 训练的关键 API。

第 17 章 DDP 章会再叠一层:DistributedSampler 在 rank 间分片 + 上面这段在 worker 间分片,两层分片合起来让 N rank × M worker = N×M 份数据并行加载,互不重复。

11.9.8 DistributedSampler:多卡训练的数据切片

第 17 章 DDP 章会把这玩意儿用到底,这里先讲它的内部机制。torch/utils/data/distributed.py:17DistributedSampler(v2.11 实测):

class DistributedSampler(Sampler):
    def __init__(self, dataset, num_replicas=None, rank=None,
                 shuffle=True, seed=0, drop_last=False):
        ...

    def __iter__(self):
        # 1. 生成全局索引序列
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)    # 每个 epoch 不同 seed
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # 2. padding 让总长度能被 num_replicas 整除
        if not self.drop_last:
            padding_size = self.total_size - len(indices)
            indices += indices[:padding_size]
        else:
            indices = indices[:self.total_size]

        # 3. 切片:rank i 拿 [i, i+N, i+2N, ...]
        indices = indices[self.rank:self.total_size:self.num_replicas]
        return iter(indices)

几个工程细节:

  • 每个 epoch 必须调 sampler.set_epoch(epoch):否则每个 epoch shuffle 顺序一样、训练 loss 曲线会异常。这是 DDP 训练新手常见 bug
  • drop_last vs padding:drop_last 把不够整除的 sample 扔掉、padding 把它们重复一份补齐。padding 让 epoch 内所有 rank batch 数一致(DDP 要求),但有重复样本
  • seed = 用户 seed + epoch:所有 rank 用同一个 base seed → 所有 rank 算出的全局 indices 一致 → 切片后 rank i 拿到第 i, i+N, … 个样本,不重复不遗漏

这套机制让 N 个 rank 的训练等同于单 rank 训练 N×batch_size,loss 曲线、收敛速度都与单卡一致(前提:用 sync BN、scale lr)。是 DDP 训练正确性的基础。

11.9.9 forkserver:fork 与 spawn 之间的折中

§11.9.5 提到 fork(快但 CUDA 不安全)vs spawn(慢但安全)。还有第三种 forkserver

mp.set_start_method('forkserver')

机制:启动一个轻量的 server 进程,每次要 fork 子进程时由 server 来 fork(而不是从主进程)。好处:

  • server 在 CUDA / 大量库 import 之前 fork → 子进程不继承 CUDA context
  • 比 spawn 快(仍然是 fork,不是从零启动)
  • 比 fork 安全(不继承主进程复杂状态)

代价:dataset 仍要 picklable(与 spawn 一样),不能含 lambda。Linux 上用 forkserver 启动 DataLoader 比 spawn 快 2-3x。

实测对比(启动 8 个 worker 的时间,Linux):

启动方式时间
fork50ms
forkserver100ms
spawn600ms

生产 GPU 训练推荐:优先 forkserver,回退 spawn。Mac / Windows 没 forkserver,自动用 spawn。

11.9.10 DataPipes 与 DataLoader2:新一代实验

PyTorch 在 v1.12 引入 DataPipes(位于 torchdata 仓库),思路是把数据加载用函数式管道表达:

from torchdata.datapipes.iter import FileLister, FileOpener, Mapper

dp = FileLister("/data/imagenet")
dp = dp.shuffle(buffer_size=10000)
dp = dp.sharding_filter()                  # 按 worker / rank 分片
dp = FileOpener(dp, mode="b")
dp = dp.map(decode_image)                   # 每条样本应用 decode_image
dp = dp.batch(batch_size=32)
dp = dp.collate()

每个 DataPipe 都是 lazy iterator。组合起来比 class MyDataset(IterableDataset) 更可读、复用度更高。

DataLoader2 是新版加载器(v1.13 实验、目前在 torchdata 维护),把 reading service(如 prefetching、distributed sharding)从 dataset 抽离:

loader = DataLoader2(
    datapipe=dp,
    reading_service=MultiProcessingReadingService(num_workers=4)
)

reading service 负责 worker 池、prefetching;datapipe 负责数据流水。职责分离让用户能更灵活组合(如换成 DistributedReadingService 一行切到 DDP 模式)。

实战现状(2026 年):DataLoader(v1)仍是主流;DataLoader2 + DataPipes 在 torchdata 仓库,社区采用度较低、API 还在迭代。除非你已经熟悉、否则生产代码继续用 v1 即可。理解新一代设计能帮你预判 PyTorch 数据加载未来的走向。

11.9.11 timeout:worker 健康检查

DataLoader(timeout=N)next(iter(loader)) 等 N 秒拿不到数据就抛 RuntimeError,避免训练 hang 在数据加载上。机制:

  • 主进程从 data_queue.get(timeout=N) 等待
  • 超时后检查每个 worker 是否还活着(worker._popen.poll()
  • 如果有 worker 死了 → raise RuntimeError("DataLoader worker (pid xxx) is killed by signal: ...")
  • 如果都活着但仍超时 → raise RuntimeError("DataLoader timed out after {timeout} seconds")

默认 timeout=0 表示无限等待。生产代码建议设个保险值(如 60 秒),避免 NFS 卡死、worker silently 卡住等场景让训练 hang 几小时。

MP_STATUS_CHECK_INTERVAL(worker.py:32,默认 5 秒)是 worker 内部 index_queue.get 的超时 —— worker 每 5 秒检查一次”是否被叫退出”,让 shutdown 能在 < 5 秒内完成。

11.9.12 跨进程 tensor 传输:共享内存绕开 pickle

worker 把 batch 发回主进程,不是 pickle 整个 tensor 再 send。torch/multiprocessing/reductions.py 注册了 tensor 的特殊序列化逻辑:

graph LR
    W[worker 进程<br/>collate 出 batch tensor]
    W --> SH[把 tensor 数据放到<br/>shared memory file]
    W --> META[把 metadata<br/>shape/dtype/storage_handle<br/>送到 queue]
    META --> Q[multiprocessing.Queue]
    Q --> M[主进程拿到 metadata]
    M --> RB[根据 storage_handle 重建<br/>指向同一块 shm 的 tensor]
    SH -.同一块物理内存.-> RB

    style SH fill:#fef3c7,stroke:#f59e0b
    style RB fill:#dbeafe,stroke:#3b82f6

机制:

  1. worker 创建 tensor 时自动用 shared memory storage(torch.multiprocessingset_sharing_strategy('file_system')'file_descriptor'
  2. 传输:只把 storage 的 file handle / fd 通过 queue 送过去(几十字节)
  3. 主进程:用 file handle 打开同一块 shm,重建 tensor —— 零拷贝

为什么不直接 pickle?因为 ImageNet batch 几百 MB,pickle + 序列化 + 反序列化要几百毫秒,远慢于 GPU 计算。共享内存让传输时间降到微秒级。

代价:每次创建的 shm file 都要在数据消费后释放。torch.multiprocessing 用引用计数自动清理;但 Ctrl+C 杀进程偶尔留下”/dev/shm/torch_xxxx” 文件 —— Linux 系统 reboot 才清。生产监控可以加一条 “/dev/shm 占用 > 50%” 告警。

11.9.13 Fetcher:dataset 与 sampler 的中间层

§11.5 主循环里的 fetcher.fetch(index) 是什么?源码 torch/utils/data/_utils/fetch.py

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)


class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset_iter, ...):
        self.dataset_iter = dataset_iter

    def fetch(self, possibly_batched_index):
        # IterableDataset 没有 __getitem__, 只能 next()
        data = []
        for _ in range(len(possibly_batched_index)):
            try:
                data.append(next(self.dataset_iter))
            except StopIteration:
                self.ended = True
                break
        return self.collate_fn(data) if data else None

两种 fetcher 反映 Map vs Iterable 的本质差异:

特性_MapDatasetFetcher_IterableDatasetFetcher
index 来源主进程 sampler 推送dataset 自己迭代
是否能随机访问是(__getitem__否(仅 next
BatchSampler 作用把单个 idx 打成 batch仅决定 batch 大小
StopIteration不会发生数据耗尽时触发

理解 Fetcher 让你看到 DataLoader 一套 worker 主循环代码能同时处理两类 dataset 的工程巧思 —— 通过抽象 fetch 接口屏蔽差异。

11.9.14 Sample 数据类型谱与 default_collate

default_collate 处理 batch 合并的核心逻辑(torch/utils/data/_utils/collate.py:294):

def default_collate(batch):
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        return torch.stack(batch, 0)
    elif elem_type.__module__ == "numpy" and ...:
        return torch.as_tensor(np.stack(batch))
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch                           # 字符串保持原样
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        return [default_collate(samples) for samples in zip(*batch)]
    raise TypeError(default_collate_err_msg_format.format(elem_type))

设计要点:

  • 递归处理嵌套结构{"image": tensor, "label": int, "meta": {"path": str}}{"image": stacked_tensor, "label": int_tensor, "meta": {"path": [strs]}}
  • 保持 namedtuple 类型:自定义 namedtuple 输入仍输出同类型,不退化成 tuple
  • str 不堆叠:字符串作为 batch metadata,保留 list

如果你的 dataset 返回了 default_collate 不认识的类型(如自定义 dataclass),就会触发 TypeError。要么写自定义 collate_fn、要么把数据转成 dict / tuple / tensortorch.utils.data.default_collate 本身可以被 monkey-patch 但不推荐 —— 影响其他代码的行为。

11.9.15 _try_put_index:prefetching 的具体实现

§11.8 讲 prefetch_factor。具体怎么实现?看 dataloader.py_MultiProcessingDataLoaderIter._try_put_index

def _try_put_index(self):
    assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
    try:
        index = self._next_index()             # sampler 给的下一个 batch idx
    except StopIteration:
        return

    for _ in range(self._num_workers):
        worker_queue_idx = next(self._worker_queue_idx_cycle)
        if self._workers_status[worker_queue_idx]:
            break
    else:
        return    # 所有 worker 都死了, 不再发任务

    self._index_queues[worker_queue_idx].put((self._send_idx, index))
    self._task_info[self._send_idx] = (worker_queue_idx,)
    self._tasks_outstanding += 1
    self._send_idx += 1

关键逻辑:

  • _tasks_outstanding 跟踪有多少任务在 worker 那里 pending。上限 = prefetch_factor * num_workers
  • 轮转分配:用 _worker_queue_idx_cycle 把任务平均分给所有 worker
  • 跳过死掉的 worker:检查 _workers_status[i]
  • 每次 next() 拿到 batch 后立刻 put 一个新 index:维持 outstanding 任务数稳定

启动时主进程会预先 putprefetch_factor * num_workers 个任务。后续每消费一个 batch 就 put 一个新任务,让 queue 始终满 —— 这是 sustained throughput 的关键。

这个机制的极致:prefetch_factor=4, num_workers=8 时主进程预 put 32 个任务,worker 池始终有 32 个待执行 batch 在 pipeline 里。GPU 训练慢一点 worker 就能打满;GPU 训练快了 worker 就成瓶颈 —— outstanding 数是 GPU/CPU 速率匹配的关键参数

11.9.16 Pin memory 的页锁定原理

§11.6 提到 pinned memory 让 H2D 可以 async + 跳 page fault。底层是怎么实现的?

正常 malloc 出的内存是可被换页(pageable):操作系统可能把这块内存换到 swap 区。CUDA driver 做 H2D 拷贝时:

  1. driver 不能直接 DMA 到可换页内存(DMA 控制器不知道虚拟地址,OS 也可能在拷贝时换页)
  2. driver 必须先把这块内存”锁”起来(防止换页)→ 拷贝 → 解锁
  3. 这个”锁 / 解锁” 路径是 同步阻塞,CPU 必须等

tensor.pin_memory() 调用 cudaMallocHost / cudaHostRegister 直接申请永远不会被换页的内存:

  • driver 知道这块内存物理地址固定 → 可以 setup DMA 后直接返回
  • CPU 立刻继续做别的事,DMA 在后台跑
  • 这就是为什么”pin_memory + non_blocking=True”才能 async

代价:pinned memory 是 OS 级稀缺资源(每台机器有上限,几 GB)。开太多 pin_memory 会失败、或拖累系统其他进程。DataLoader 默认只对 batch tensor pin,不对整个 dataset pin。

11.9.17 worker 内存增长:fork CoW 失效问题

Linux fork 用 Copy-on-Write:fork 时父子进程共享物理内存页,只在写入时复制。理论上 worker 启动几乎不占新内存。

有一个坑:CPython 的引用计数会写入对象的 ob_refcnt。worker 即使只是”读”一个 Python 对象(如遍历一个 list),也会修改它的 refcnt → 触发 CoW → 该内存页整个被复制。

实战后果:dataset 里 cache 了大量 Python 对象(如 1.28M 个文件路径 list),worker 启动后内存看起来正常,跑一段时间后 worker 内存爆炸 —— 全部从父进程 CoW 过来了。

解决方法:

  1. 把 cache 数据用 numpy / tensor 存(refcnt 操作不会蔓延到数据本身)
  2. 用 spawn 启动:避免 CoW 误伤(虽然启动慢)
  3. gc.freeze()(Python 3.7+):让 GC 把启动时的对象标记为永生,refcnt 不再变化

PyTorch 文档明确建议大型 dataset 用 numpy 存 metadata、避免 list of dict 这种结构。这是把”操作系统 + Python 实现”的细节落到使用建议的典型例子。

11.9.18 训练 resume:sampler 状态的保存与恢复

长跑训练(几天几周)必须支持中断恢复。光保存 model + optimizer state 不够 —— dataloader 进度也要保存,否则恢复时从头开始遍历 dataset。

PyTorch v2.0+ 给 Sampler / DataLoader 加了 state_dict() / load_state_dict()

# 训练时定期 checkpoint
checkpoint = {
    "model": model.state_dict(),
    "optim": optim.state_dict(),
    "sampler": sampler.state_dict(),     # 含已迭代到的位置
    "epoch": current_epoch,
    "step": current_step,
}
torch.save(checkpoint, "ckpt.pt")

# 恢复时
sampler.load_state_dict(checkpoint["sampler"])
loader = DataLoader(dataset, sampler=sampler, ...)
# 从 sampler 中断的位置继续

具体实现:RandomSampler.state_dict 保存 generator 的状态 + 当前 yield 到的 idx。load_state_dict 把 generator 恢复 + 跳过已 yield 的 idx(用 __iter__ 内的 index 计数器)。

更复杂的:多进程 worker 的状态。worker 子进程也在迭代 dataset,恢复时怎么把”worker 0 已经处理到 idx 537”这种信息记下?v2.4 引入的 StatefulDataLoader(在 torchdata)解决这个:每个 worker 周期性把自己的 progress 报告回主进程,主进程汇总成 dataloader 状态。

实战:所有 > 1 小时的训练任务都该实现 dataloader resume。否则每次中断都让训练吐出几小时进度。这是 v2.x 时代生产训练流水线的必备组件。

11.9.19 一次 step 的端到端时序

把所有机制合起来看 一次 for batch in loader: 实际发生了什么

sequenceDiagram
    participant T as 主进程训练循环
    participant Q as output queue
    participant P as pin_memory thread
    participant W as worker 子进程
    participant DS as Dataset

    T->>+Q: next() 拿 batch N
    Q-->>T: 已有 batch N (pin 完成)
    T->>T: model(batch).backward()<br/>optim.step()
    Note over T: GPU 训练 step N 进行中

    par 同时, worker 在准备 batch N+1
        W->>+DS: dataset[idx]<br/>(磁盘读 + decode + augment)
        DS-->>-W: tensor
        W->>P: data_queue.put((idx, tensor))
        P->>P: tensor.pin_memory()
        P->>Q: output_queue.put((idx, pinned))
    and 主进程做 prefetch
        T->>W: index_queue.put(idx N+M)
    end

    T->>Q: 训练完, next() 拿 batch N+1
    Q-->>T: 已有 batch N+1

整套机制让 worker 数据准备 + pin memory + 主进程 GPU 训练 三阶段流水线。理想状态下三个阶段都在不同硬件上并发:worker 用 CPU、pin thread 用 PCIe DMA、训练用 GPU。GPU 利用率能稳定在 95%+。

这张时序图是 §11.10 工程经验的根本依据 —— 任何一个阶段慢了都会让 pipeline stall:

  • worker 慢 → output queue 空 → 主进程 stall
  • pin thread 慢(很少见)→ 数据卡在中间
  • 主进程慢 → output queue 满 → worker 阻塞在 put

诊断瓶颈的标准方法:用 torch.profiler 看主进程的 “data fetching” 耗时占比、用 py-spy 看每个 worker 的 CPU 利用率。两边数据合起来精准定位。

11.9.20 DataLoader × 混合精度训练

第 20 章会讲 mixed precision(FP16/BF16)。DataLoader 这层有什么协作?

直觉做法是 worker 直接返回 fp16 tensor 节省传输带宽(pin 内存量减半)。但实战不推荐

  • transform 在 fp16 下精度损失大(特别是图像归一化)
  • pin → cuda 后再 cast 到 fp16 比 worker 内 cast 更可控
  • worker 跑 transform 时 fp16 的算子反而比 fp32 慢(CPU 没 fp16 加速)

推荐做法:worker 输出 fp32 tensor,主进程在 cuda 后用 tensor.to(torch.float16) 转。前者占数据传输带宽多 2x,但通常被 PCIe 隐藏(H100 PCIe 带宽 64 GB/s,远大于训练实际数据吞吐)。

例外场景:A100 + ImageNet + 大 batch。这种场景 PCIe 真的能成瓶颈,可以让 worker 输出 uint8(不是 fp16),主进程用 cuda kernel 做 normalize + cast 到 fp16。NVIDIA DALI 就用这个方案。

11.9.21 variable batch size 与 max_tokens

NLP 训练常见诉求:按 token 数构造 batch(变长序列)而非按 sample 数。让短句多凑、长句少凑、每 batch 总 token 数稳定。

实现方式:自定义 BatchSampler:

class TokenBudgetSampler(Sampler):
    def __init__(self, lengths, max_tokens=4096):
        self.lengths = lengths              # 每个 sample 的 token 数
        self.max_tokens = max_tokens

    def __iter__(self):
        # 按长度排序后, 滚动凑满 max_tokens
        sorted_idx = sorted(range(len(self.lengths)), key=lambda i: self.lengths[i])
        batch = []
        max_len_in_batch = 0
        for idx in sorted_idx:
            new_max = max(max_len_in_batch, self.lengths[idx])
            new_total = new_max * (len(batch) + 1)
            if new_total > self.max_tokens and batch:
                yield batch
                batch = [idx]
                max_len_in_batch = self.lengths[idx]
            else:
                batch.append(idx)
                max_len_in_batch = new_max
        if batch:
            yield batch

这套 sampler 让 GPU 利用率显著提升(相比固定 batch_size,能多 30%+ 吞吐)。fairseq、flash-attention、DeepSpeed 等训练库都内置类似实现。理解 PyTorch 的 BatchSampler 接口让你看到这些上层框架是怎么实现的 —— 没有黑魔法,就是自定义 sampler + collate

11.9.22 常见报错信息逐条拆解

生产 DataLoader 报错谱系(按出现频率):

1. RuntimeError: DataLoader worker (pid xxx) is killed by signal: Killed

worker 被 OOM killer 干掉。通常是 dataset / transform 内存爆。检查:

  • 单个 sample 是否占内存过大(如 4K 图片解码到 numpy)
  • 是否在 __getitem__ 里 cache 了不释放的对象
  • num_workers 是否过多(每个 worker 有独立内存)

2. RuntimeError: DataLoader worker (pid xxx) exited unexpectedly

worker 崩溃但没有信号(如 segfault 但被 catch)。通常是 C 扩展 bug:

  • OpenCV cv2.imread 偶发 segfault(特定图片格式)
  • numba JIT 在 worker 里编译失败
  • HDF5 跨进程读取的线程问题

排查:用 num_workers=0 跑相同代码看是否报同样错;或单独 import 触发 segfault 的库到主进程隔离测试。

3. RuntimeError: Cannot re-initialize CUDA in forked subprocess

§11.9.5 提过:fork 启动 + CUDA 不兼容。解法:mp.set_start_method('spawn')'forkserver'

4. _pickle.PicklingError: Can't pickle <class 'X.<lambda>'>

spawn 启动需要 dataset / transform / sampler picklable。把 lambda 改成命名函数:

# 不行
loader = DataLoader(ds, collate_fn=lambda batch: ...)

# 改为
def my_collate(batch): ...
loader = DataLoader(ds, collate_fn=my_collate)

5. RuntimeError: received 0 items of ancdata

Linux 上多进程 + 大量 file descriptor 共享,触发 ulimit 限制。解法:ulimit -n 65536 提高上限,或 torch.multiprocessing.set_sharing_strategy('file_system') 切到 file system 共享(不用 fd)。

6. Stop iteration 异常吞噬

IterableDataset 在 worker 里抛 StopIteration 后,PyTorch 会确保所有 worker 都耗尽数据再返回 main。如果你的 dataset 内部 __iter__ 提前 return / break 而没 raise StopIteration,可能让 dataloader 永远等不到结束 —— 训练 hang。修复:始终用 yield 而非显式 return。

逐条理解这些错误能让你 5 分钟定位 90% 的 DataLoader 问题。生产代码遇到上面任意一个先按这个清单排查。

11.9.23 N×M 分片:DDP × 多 worker 的合作

第 17 章 DDP 章会展开,这里先给”分片合作”的全景:

graph TB
    DS[全局 dataset<br/>1.28M sample]
    DS --> R0[rank 0 分片<br/>320K sample]
    DS --> R1[rank 1 分片<br/>320K sample]
    DS --> R2[rank 2 分片<br/>320K sample]
    DS --> R3[rank 3 分片<br/>320K sample]

    R0 --> W00[worker 0-0]
    R0 --> W01[worker 0-1]
    R0 --> W02[worker 0-2]
    R0 --> W03[worker 0-3]
    R1 --> W10[...]

    style DS fill:#fef3c7,stroke:#f59e0b

两层分片机制

  • 第一层 DistributedSampler:rank i 拿 [i, i+N, i+2N, ...] 索引(§11.9.8)
  • 第二层 worker_info:rank i 内的 worker j 处理 [j, j+M, j+2M, ...] 索引

合起来:4 rank × 4 worker = 16 个进程并行处理全局 dataset,互不重复。

特殊情况:IterableDataset + DDP。第一层不能直接套 DistributedSampler(IterableDataset 没 __getitem__),要在 dataset 内部按 (rank, num_workers) 双重分片:

class StreamDataset(IterableDataset):
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        rank = dist.get_rank() if dist.is_initialized() else 0
        world = dist.get_world_size() if dist.is_initialized() else 1
        n_workers = worker_info.num_workers if worker_info else 1
        worker_id = worker_info.id if worker_info else 0

        # 全局有 world * n_workers 个数据消费者
        global_id = rank * n_workers + worker_id
        global_n = world * n_workers
        files = self.files[global_id::global_n]
        ...

这是流式数据训练(webdataset、kafka 流)必写的分片模板。理解它让你跑 DDP 时数据不会重复或遗漏。

11.9.24 ManagerWatchdog:父子进程死亡感知

worker.py:30/62ManagerWatchdog 类(Linux/Windows 各一份实现,运行时实例化在 line 314):

class ManagerWatchdog:
    def __init__(self):
        self.manager_pid = os.getppid()
        self.manager_dead = False

    def is_alive(self):
        if not self.manager_dead:
            self.manager_dead = os.getppid() != self.manager_pid
        return not self.manager_dead

机制:worker 启动时记下父进程 PID(manager_pid)。每次 worker 主循环迭代检查 os.getppid(),如果不等于启动时记的 PID → 父进程已死、worker 自己也该退出。

为什么这样可以检测父死亡?Linux 下父进程死后子进程的”父”会变成 init(PID 1)—— getppid() 返回值改变就是父死亡信号。比 signal handler 更鲁棒(信号可能被 mask、可能竞争)。

Linux 版本还用 prctl(PR_SET_PDEATHSIG, SIGKILL):父死亡时内核自动发 SIGKILL 给子进程。两套机制叠加 = worker 不可能成为僵尸。

理解这个让你在 Ctrl+C 杀训练后看到 worker 都干净退出 —— 不是巧合,是工程设计。生产部署训练任务 OOM kill 主进程时,所有 worker 也立刻死,不留泄漏。

11.9.25 DataLoader 演进时间线

PyTorch DataLoader 的迭代轨迹:

版本关键变化
v0.3 (2018)最初的 multi-process DataLoader
v0.4 (2018)pin_memory + num_workers 趋于稳定
v1.0 (2018)API 冻结,开始大规模生产使用
v1.2 (2019)IterableDataset 引入
v1.7 (2020)persistent_workers 加入
v1.10 (2021)prefetch_factor 暴露成参数
v1.12 (2022)DataPipes / DataLoader2 实验性引入(torchdata)
v2.0 (2023)sampler state_dict 支持,配合 long-running training
v2.4 (2024)StatefulDataLoader (torchdata) — 多 worker resume
v2.6 (2025)shared memory 的 file_descriptor 模式优化
v2.11 (2026)API 稳定,生态成熟

整体趋势:

  • v1.x 阶段:把 multi-process 的鲁棒性磨好(错误处理、shutdown、CoW、CUDA 兼容)
  • v2.x 阶段:增强长跑训练(resume、stateful、distributed shard)+ 函数式数据流(DataPipes)

理解这条时间线让你看 PyTorch issue 列表里 “为什么 X feature 直到 v1.10 才加” 不困惑。每个特性都是真实生产场景反馈打磨出来的。

11.9.26 DataLoader vs DALI vs FFCV

NVIDIA DALI 与 FFCV 是 ImageNet 训练里常被提到的”替代品”。它们与 PyTorch DataLoader 的核心区别:

维度PyTorch DataLoaderNVIDIA DALIFFCV
transform 跑在哪CPU(worker 进程)GPU(CUDA kernel)CPU(精心优化的 SIMD)
数据格式任意(用户写 Dataset)jpg / mp4 / TFRecord自家 .beton 格式(预处理)
集成成本中(需重写 pipeline)中(需先转 .beton)
加速幅度基线1.5-3x2-4x
灵活性最高低(需预处理)

何时选谁

  • 多数训练:PyTorch DataLoader(足够好、生态最广)
  • ImageNet / 视频训练 + GPU 富裕:DALI(GPU 解码 + augment 让 CPU 不再是瓶颈)
  • 极致 throughput + 数据格式可控:FFCV(预处理过的二进制让 IO 接近 PCIe 极限)

PyTorch DataLoader 的核心优势是通用。生产 80% 训练用它就够,剩下 20% 极致性能场景才考虑切。理解这个生态位让你不会盲目迁移到 DALI 反而踩坑(DALI 不支持自定义 Python 算子,数据 augment 受限)。

11.9.27 一段总结:DataLoader 是个并行编排引擎

把全章看到的拼起来:DataLoader 表面是”数据加载工具”,骨子里是个并行编排引擎

  • worker pool(进程级并发)
  • pin_memory thread(线程级 + DMA 异步)
  • index queue / output queue(生产消费解耦)
  • sampler 状态机(next index 决定)
  • ManagerWatchdog(错误恢复)
  • ExceptionWrapper(异常跨进程传播)
  • shared memory storage(零拷贝传输)

PyTorch 把这些工程零件封装在一个 1700 行的类里,对用户暴露的就是 for batch in loader。理解每个零件的存在意义、能让你在这个抽象层之外做高级优化(如自己写 streaming pipeline)也不会丢核心思想。

11.9.28 DataLoader 性能 profiling 实战

诊断 DataLoader 是否成为瓶颈的标准方法,按精度从粗到细:

第 1 步:nvidia-smi 看 GPU 利用率

watch -n 0.5 nvidia-smi --query-gpu=utilization.gpu --format=csv

如果数字反复跳 0% / 100% / 0% / 100%(sawtooth pattern)→ GPU 在等数据。

第 2 步:看主进程是不是 CPU 瓶颈

top -p $(pgrep -f train.py | head -1)

主进程 CPU > 90% 说明它自己处理数据成了瓶颈(如做大量 .cuda() 调用)。

第 3 步:torch.profiler 量化 data loading 耗时

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=5),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
) as prof:
    for batch in loader:
        with torch.profiler.record_function("training_step"):
            train_step(batch)
        prof.step()

打开 TensorBoard / Perfetto 看 trace,“Iter#0 next”(拿 batch)耗时与 “training_step” 耗时对比:

  • 拿 batch 耗时 > 训练耗时 50% → DataLoader 瓶颈
  • 拿 batch 耗时 < 5% → DataLoader 不是瓶颈,找别处

第 4 步:py-spy 看 worker 在干啥

py-spy dump --pid <worker_pid>

看 worker stack:

  • 卡在 read() / decode_jpeg() → IO 或 decode 瓶颈
  • 卡在 pickle.loads → 数据格式问题(用了大量 Python 对象)
  • 卡在 numpy 函数 → CPU augment 慢

第 5 步:profile transform

from time import perf_counter

def timed_getitem(self, idx):
    t0 = perf_counter()
    sample = self.original_getitem(idx)
    print(f"getitem took {(perf_counter()-t0)*1000:.1f}ms")
    return sample

逐层定位到具体瓶颈,再针对性优化(换库、加 num_workers、预处理到磁盘等)。

理解这套 profiling 流程让你5 分钟定位瓶颈,避免盲目调 num_workers 凭运气。这是生产 GPU 训练成本控制的关键技能 —— H100 一天租金几百美元,DataLoader 让 GPU 闲一半就是直接烧钱。

11.9.29 transform 库性能基准

ImageNet 训练里一张图的标准 pipeline:“JPEG 解码 → resize → crop → normalize → tensor”。不同库实现 5 步的耗时差距巨大:

单图耗时适用场景
PIL(默认)12 ms兼容好,速度差
PIL-SIMD4 msPIL 的 SIMD 优化版,drop-in 替代
OpenCV (cv2)5 ms比 PIL 快 + 功能更广
Albumentations3 ms用 cv2 + 高度优化 augment 链
NVIDIA DALI1 ms(GPU)GPU 解码 + augment
torchvision.transforms.v24 msv2 用 cv2/PIL 后端,比 v1 快 30%
FFCV< 1 ms数据预处理过的 .beton 格式

升级路径建议(成本递增):

  1. torchvision.transformstorchvision.transforms.v2:零成本,自动加速 30%
  2. PIL → PIL-SIMDpip install pillow-simd 替代 PIL,零代码改动
  3. torchvisionAlbumentations:要重写 transform 链,但 augment 选择更广
  4. CPU augment → DALI:要重写整条 pipeline,但 GPU 训练时收益最大
  5. 运行时 augment → FFCV 预处理:需预先做一次 dataset 转换,长跑训练摊销成本

实战经验:H100 + ImageNet 训练里,默认 torchvision PIL 让 GPU 利用率只有 60%;切到 Albumentations 直接到 92%。这是单点优化收益最大的一个动作 —— 比换 GPU 便宜得多。

11.9.30 远程存储训练:NFS / S3 / 对象存储的特殊问题

云上训练 dataset 常放远程存储(NFS、Lustre、S3、GCS)。DataLoader 这层的注意点:

NFS / Lustre

  • worker 起来时每个进程都打开同一文件会触发”client cache invalidation storm” → 训练前几秒卡顿
  • 解决:用 os.O_DIRECT 跳过 client cache,或预热(训练前用 cat dataset/* > /dev/null 让 server 端加载)
  • 文件元信息查询慢(os.path.getsize 等),把 dataset 元数据预先存成 numpy / parquet 避免逐文件 stat

S3 / GCS / OSS(对象存储)

  • 单个 GET 请求几十到几百 ms 延迟,远高于 顺序读
  • num_workers 必须开大(16-32)+ prefetch_factor 8-16 才能 hide 延迟
  • s3fsaiobotocore 的并发下载,单 batch 内多文件并行 GET
  • 大数据集应该预 shard 成 webdataset / TFRecord 格式(每个文件几百 MB,少 GET 多吞吐)

实战数据(同一份 ImageNet 训练):

存储num_workers=8 吞吐调优后吞吐
本地 NVMe SSD1500 img/s2000 img/s
NFS(本机房)800 img/s1500 img/s
S3(同 region)200 img/s1200 img/s

这表说明:远程存储不是 DataLoader 的死穴,但需要正确配置(多 worker + prefetch + webdataset shard)才能逼近本地性能。

11.9.31 一道真实排查题

最后举一个真实碰到的案例:DDP 训练 Llama,4 机 32 卡,第一个 epoch 只有 50% GPU 利用率,第二个 epoch 起 95%。怎么解释?

诊断思路:第一与后续 epoch 的差异。能想到的原因:

  • persistent_workers=False:每个 epoch worker 重启 → 但用户说开了 persistent,排除
  • OS file cache 预热:第一 epoch 数据全从磁盘读 → 后续从 OS page cache 读。这是最常见原因
  • CUDA kernel 预热:第一次见到的 kernel 形状要 JIT compile / autotune → 第一 batch 慢。但只影响开头,不解释整个 epoch

最终原因:OS file cache 预热。该 dataset 总共 50GB,机器 RAM 256GB,第一 epoch 跑完后整个 dataset 已被 page cache 缓存 → 后续 epoch 全 in-memory 速度。

修复方法:训练前 warmupfind /data -type f -exec cat {} > /dev/null \; 让 OS 提前缓存),或者降低 dataset 大小让缓存命中率高。这是远程 / 大数据集训练的隐藏 quirk —— 看一眼 free -m 的 cached 列就能确认。

理解 DataLoader 的工作机制 + OS 文件系统行为,能让你解释训练曲线的细节、做出正确干预。这是一切性能调优的起点。

11.10 几条工程经验

实战 DataLoader 相关:

1. num_workers 经验值:通常设为 CPU 核心数的 1/2 到 1(如 16 核机器设 8-16)。再多就 thrashing

2. 监控 GPU 利用率:如果训练时 GPU 利用率 < 90%,多半是 DataLoader 跟不上。增加 num_workers 或 prefetch_factor

3. 第一次 epoch 慢:worker 初次启动 + 第一次 sample 触发 page fault。persistent_workers=True 后续 epoch 就快

4. pin_memory + non_blocking=True 必须配对:单开 pin_memory 但 cuda 调用没用 non_blocking,等同于没开

5. transform 应该在 worker 里跑:把 augment 写到 dataset 的 __getitem__ 里,而非主循环里。主循环只做 GPU 计算

6. NumPy in worker 的 seed 坑:fork 启动时 numpy 全局 RNG 在所有 worker 里相同(Python 复制了 main 的状态)。PyTorch 在 worker 启动时会 reseed numpy,但如果你在 worker 里 import 了某个用 np.random 的库,要确保它在 worker 启动之后 import

7. torch.utils.data.get_worker_info():在 dataset 内可以查当前 worker id / num_workers,用于”按 worker 分片”(IterableDataset 必备)

8. 子进程 dataset 状态独立:worker 是 fork 出来的,每个 worker 有独立的 dataset 副本。如果你想在 dataset 里 cache 某些计算结果(如解析 json),每个 worker 各自缓存一份,不共享

11.11 一个典型优化案例:DataLoader 跟不上 GPU

社区里反复出现的最经典 DataLoader 瓶颈案例(PyTorch forum / GitHub issue 反复讨论的问题模式):

症状:H100 训练 ImageNet, GPU 利用率只有 40%
诊断:
  1. 用 nvidia-smi -l 1 看 GPU 利用率波动: 100% → 0% → 100% → 0% (sawtooth)
  2. 这种 sawtooth 是经典 "GPU 等数据" 模式
  3. 看 nvtop 主进程 CPU 利用率几乎 100% (主进程瓶颈)
原因诊断:
  - num_workers=4, 每个 worker 单线程 (set_num_threads(1)) 跑 transform
  - transform 里有耗时的 PIL.Image.open + JPEG 解码
  - 4 个 worker 不够,prefetch 队列经常空
解决方案:
  1. num_workers 从 4 → 16 (机器有 32 核)
  2. 用 PIL-SIMD 替代默认 PIL (JPEG 解码 5x 加速)
  3. transform 改用 Albumentations + cv2 (比 torchvision transforms 快 3x)
结果: GPU 利用率 40% → 92%

这种”逐层定位 + 多管齐下”是 DataLoader 性能调优的标准流程。理解整章讲的”worker 池 + queue + pin_memory thread”机制,能让你精准识别瓶颈所在。

11.12 跨书关联

  • 《vLLM 内核探秘》第 17 章 API 服务器:vLLM 推理时的 batched request 处理与 DataLoader 的 batch 思想一致 —— 把多个独立请求合成一个 GPU 友好的 batch
  • 《Tokio 异步运行时》第 X 章 channel 与 Future:DataLoader 的 multiprocessing.Queue 是同步阻塞 channel,思想与 Tokio 的 mpsc::channel 相通 —— 都是 producer-consumer 模式
  • 《Rust 编译器之路》第 X 章 增量编译:编译器的 dependency graph build 也用 worker pool + queue 处理独立任务

11.13 几条 DataLoader 设计的”通用启示”

把 DataLoader 思想抽象到任何”生产 / 消费数据流水线”系统:

第一生产与消费独立 + queue 解耦:让两者不用互相等待。Queue 长度(prefetch_factor)调节缓冲深度

第二worker pool 而非每次 fork:进程启动有几百毫秒成本,长跑任务必须 pool。persistent_workers 是这条原则的体现

第三异常通过序列化跨进程传递:直接 throw 在多进程里没用,必须 wrap 成可序列化对象。ExceptionWrapper 是模板

第四用 thread 处理需要主进程上下文的工作:pin_memory 用 thread 不用 process,因为它要访问 CUDA context

第五RNG seed 每个 worker 不同 + 可复现base_seed + worker_id 是黄金组合 —— 既保证不重复又保证可复现

第六torch.set_num_threads(1) 防多 worker × 多 thread 抢核:任何 worker pool 都要决定每个 worker 自己开多少 thread。设错了性能崩盘

把这六条记住,写自己的多进程数据流水线(如 ETL 系统、爬虫、批处理任务)能避开几乎所有大坑。

下一章拆 TorchDynamo —— PyTorch 2.0 编译器栈的入口,看它怎么用 PEP 523 帧拦截 + 字节码分析把动态 Python 代码捕获成 FX Graph。

评论 0