Skip to content

第7章 任务调度与并行执行

7.1 引言

上一章我们剖析了 Pregel 执行循环的宏观架构——tick()after_tick() 和 BSP 超步模型。但在每个超步内部,还有一个同样复杂的世界:多个任务如何被并行调度?任务失败时如何重试?缓存如何避免重复计算?PUSH 任务和 PULL 任务在运行时有何不同?

本章将深入 LangGraph 的任务执行层,涉及以下核心组件:

  • PregelExecutableTasktypes.py)—— 可执行任务的数据结构
  • PregelRunnerpregel/_runner.py)—— 任务调度器,管理并行执行和结果收集
  • BackgroundExecutor / AsyncBackgroundExecutorpregel/_executor.py)—— 线程池和 asyncio 并行原语
  • run_with_retry / arun_with_retrypregel/_retry.py)—— 重试逻辑
  • 缓存匹配机制 —— cache_policyCacheKey 的协作

这些组件共同实现了一个高效的并行执行框架,在保证正确性的前提下最大化吞吐量。

本章要点

  1. PregelExecutableTask 是任务执行的最小单元,包含输入、处理器、写入缓冲、配置等全部信息
  2. PregelRunner 通过 FuturesDict 管理并发任务,支持"任一失败则全部停止"的语义
  3. PULL 任务由 Channel 版本变更触发,输入从 Channel 读取;PUSH 任务由 Send API 创建,输入由调用者指定
  4. BackgroundExecutor 使用线程池实现同步并行,AsyncBackgroundExecutor 使用 asyncio 任务实现异步并行
  5. 重试策略支持指数退避、抖动、最大重试次数,以及按异常类型匹配的多策略组合
  6. 缓存策略通过 CacheKey 关联节点身份和输入哈希,支持 TTL 过期

7.2 PregelExecutableTask:任务的全貌

PregelExecutableTask 定义在 types.py 中,是一个不可变的 dataclass:

python
@dataclass(frozen=True)
class PregelExecutableTask:
    name: str                          # 节点名称
    input: Any                         # 任务输入
    proc: Runnable                     # 可执行处理器(bound + writers 的组合)
    writes: deque[tuple[str, Any]]     # 写入缓冲区
    config: RunnableConfig             # 完整的运行配置
    triggers: Sequence[str]            # 触发此任务的 Channel 列表
    retry_policy: Sequence[RetryPolicy] # 重试策略
    cache_key: CacheKey | None         # 缓存键(如果启用了缓存)
    id: str                            # 全局唯一的任务 ID
    path: tuple[str | int | tuple, ...] # 任务路径(用于排序和标识)
    writers: Sequence[Runnable] = ()    # 写入器引用
    subgraphs: Sequence[PregelProtocol] = ()  # 子图引用

虽然标记为 frozen=True(不可变),但 writes 字段是一个 deque——它的引用不可变,但内容可变。这个设计使得任务执行过程中可以向 writes 追加数据,同时防止意外替换整个 writes 对象。

7.2.1 任务 ID 的生成

任务 ID 是通过确定性哈希函数生成的,确保同一个 Checkpoint 状态下,相同的任务总是获得相同的 ID:

python
# 对于 PULL 任务
task_id = task_id_func(
    checkpoint_id_bytes,    # Checkpoint ID 的字节表示
    checkpoint_ns,          # 命名空间(如 "parent|agent")
    str(step),              # 步数
    name,                   # 节点名称
    PULL,                   # 任务类型
    *triggers,              # 触发 Channel
)

# 对于 PUSH 任务(Send API)
task_id = task_id_func(
    checkpoint_id_bytes,
    checkpoint_ns,
    str(step),
    name,
    PUSH,
    task_path_str(parent_path),  # 父任务路径
    str(idx),                    # 在父任务写入中的索引
)

LangGraph 1.1.6 支持两种哈希函数:xxhash(v2 Checkpoint 格式,更快)和 uuid5(v1 格式,兼容旧版)。确定性的 ID 是 Checkpoint 恢复的关键——恢复后重新计算的任务 ID 与保存的 pending writes 中的 task ID 必须匹配,这样 _match_writes 才能正确地将已保存的写入结果关联到重建的任务。

7.2.2 proc 的构成

PregelExecutableTask.proc 是一个 RunnableSeq,它将用户逻辑和写入器串联:

执行 task.proc.invoke(task.input, task.config) 时:

  1. 首先调用用户函数,传入从 Channel 读取的状态
  2. 用户函数返回状态更新(如 {"count": 5}
  3. 第一个 ChannelWrite 将更新转化为 Channel 写入元组,通过 CONFIG_KEY_SEND 发送
  4. 第二个 ChannelWrite(如果有边)将路由信号写入目标节点的触发 Channel

7.2.3 config 中注入的关键函数

每个任务的 config 中注入了几个关键回调,使得任务执行过程中能与 PregelLoop 交互:

python
config = patch_config(
    config,
    configurable={
        CONFIG_KEY_TASK_ID: task_id,
        CONFIG_KEY_SEND: writes.extend,     # 写入收集器
        CONFIG_KEY_READ: partial(            # 状态读取器
            local_read, scratchpad, channels, managed,
            PregelTaskWrites(path, name, writes, triggers),
        ),
        CONFIG_KEY_CHECKPOINTER: checkpointer,
        CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
        CONFIG_KEY_SCRATCHPAD: scratchpad,
        CONFIG_KEY_RUNTIME: runtime,
    },
)
  • CONFIG_KEY_SEND:绑定到 writes.extend——当 ChannelWrite.do_write 被调用时,写入元组被追加到任务的 writes deque。deque.extend 是线程安全的。
  • CONFIG_KEY_READ:绑定到 local_read 函数——条件边通过此函数读取"应用了当前任务写入后"的状态快照。这确保条件判断基于最新状态。

7.3 PULL 任务 vs PUSH 任务

LangGraph 中有两种根本不同的任务触发方式:

PULL 任务

PULL 任务是标准的 BSP 触发方式。在 prepare_single_task 中,对于 (PULL, name) 路径:

python
if task_path[0] == PULL:
    name = task_path[1]
    proc = processes[name]
    # 检查触发条件
    if _triggers(channels, checkpoint["channel_versions"],
                 checkpoint["versions_seen"].get(name),
                 null_version, proc):
        # 读取输入
        val = _proc_input(proc, managed, channels,
                          for_execution=True, ...)
        if val is MISSING:
            return  # Channel 为空,跳过
        # 创建任务
        return PregelExecutableTask(name, val, node, writes, ...)

PULL 任务的输入来自 Channel:_proc_input 根据 proc.channels 配置读取指定的 Channel 值,如果有 mapper 则进行类型转换。

PUSH 任务

PUSH 任务通过两种途径创建:

  1. Send API(prepare_push_task_send:当 __pregel_tasks Topic Channel 中有 Send 对象时
  2. Functional API(prepare_push_task_functional:当任务路径以 Call 对象结尾时

对于 Send API 的 PUSH 任务:

python
if task_path[0] == PUSH:
    # 获取 Send 对象
    send = tasks_channel.get()[task_path[1]]
    name = send.node
    val = send.arg  # 直接使用 Send 的参数作为输入
    proc = processes[name]
    # 创建任务(不检查 _triggers)
    return PregelExecutableTask(name, val, node, writes, ...)

关键区别:PUSH 任务不检查 _triggers——它们总是被执行。输入直接来自 Send.arg,而非从 Channel 读取。这使得同一个节点可以被多次调用,每次使用不同的输入。

7.4 PregelRunner:并行调度器

PregelRunner 定义在 pregel/_runner.py 中,负责在每个超步中并行执行所有任务:

python
class PregelRunner:
    def __init__(self, *, submit, put_writes,
                 use_astream=False, node_finished=None):
        self.submit = submit          # 提交函数(弱引用)
        self.put_writes = put_writes  # 写入保存函数(弱引用)
        self.use_astream = use_astream
        self.node_finished = node_finished

7.4.1 同步 tick 的执行流程

python
def tick(self, tasks, *, reraise=True, timeout=None,
         retry_policy=None, get_waiter=None, schedule_task):
    tasks = tuple(tasks)
    futures = FuturesDict(
        callback=weakref.WeakMethod(self.commit),
        event=threading.Event(),
        future_type=concurrent.futures.Future,
    )
    # 让出控制权给调用者
    yield

    # 快速路径:单任务无超时
    if len(tasks) == 1 and timeout is None and get_waiter is None:
        t = tasks[0]
        try:
            run_with_retry(t, retry_policy, ...)
            self.commit(t, None)
        except Exception as exc:
            self.commit(t, exc)
            ...
        return

    # 调度所有任务到线程池
    for t in tasks:
        fut = self.submit()(
            run_with_retry, t, retry_policy, ...
        )
        futures[fut] = t

    # 等待任务完成,逐个处理
    while len(futures) > 0:
        done, inflight = concurrent.futures.wait(
            futures,
            return_when=concurrent.futures.FIRST_COMPLETED,
            timeout=...,
        )
        for fut in done:
            futures.pop(fut)
        if _should_stop_others(done):
            break
        yield  # 让出控制权给调用者处理流式输出

    # 等待所有回调完成
    futures.event.wait(timeout=...)
    yield

    # 检查异常
    _panic_or_proceed(futures.done, panic=reraise)

7.4.2 FuturesDict:智能的并发管理

FuturesDict 是一个自定义的 dict,它在 Future 完成时自动调用回调并管理计数器:

python
class FuturesDict(dict):
    event: threading.Event  # 所有任务完成时设置
    callback: weakref.ref   # commit 回调
    counter: int            # 活跃任务计数
    done: set              # 已完成的 Future 集合

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        if value is not None:
            self.event.clear()
            self.counter += 1
            key.add_done_callback(partial(self.on_done, value))

    def on_done(self, task, fut):
        try:
            if cb := self.callback():
                cb(task, _exception(fut))
        finally:
            self.done.add(fut)
            self.counter -= 1
            if self.counter == 0 or _should_stop_others(self.done):
                self.event.set()

基于 VitePress 构建