Skip to content

第3章 StateGraph 图构建 API

本章基于 LangGraph 1.1.6 / langgraph-checkpoint 4.0.1 源码分析。源码路径:libs/langgraph/langgraph/graph/ 目录。

StateGraph 是开发者与 LangGraph 交互的首要入口。它提供了一套声明式的 API,让开发者可以用自然直觉的方式定义节点、边和条件分支,然后通过 compile() 一键转换为可执行的 Pregel 运行时。本章将深入 graph/state.py 源码,逐行剖析 StateGraph 的构建 API、节点类型系统、边的三种形态,以及编译过程中发生的关键转换。

本章要点

  • StateGraph 类的完整解剖:构造器、状态模式解析、内部数据结构
  • add_node 的五重重载:名称推断、输入模式推断、Command 返回类型解析
  • 三种边的实现差异:add_edge(直接边)、add_conditional_edges(条件边)、waiting_edges(汇聚边)
  • START/END 常量的本质:它们不是真正的节点,而是 Channel 触发机制
  • StateNodeSpec 与节点类型协议:理解节点函数的多种合法签名
  • MessageGraph 与 MessagesState:消息状态的便捷封装

3.1 StateGraph 类的构造

StateGraph 是开发者构建 LangGraph 应用的第一个接触点。它的设计目标是让图的定义尽可能直观和声明式——开发者只需要关注"做什么"(定义节点和边),而不需要关注"怎么执行"(Pregel 循环、Channel 管理等底层细节)。但要真正理解 StateGraph 的行为,特别是在遇到边界情况和错误时,我们需要深入其构造过程。

3.1.1 构造器签名

StateGraph 的构造器接受状态模式作为核心参数,并可选地指定输入/输出模式和上下文模式:

python
# 源码位置:langgraph/graph/state.py
class StateGraph(Generic[StateT, ContextT, InputT, OutputT]):
    def __init__(
        self,
        state_schema: type[StateT],
        context_schema: type[ContextT] | None = None,
        *,
        input_schema: type[InputT] | None = None,
        output_schema: type[OutputT] | None = None,
        **kwargs: Unpack[DeprecatedKwargs],
    ) -> None:

这里的泛型参数 StateT, ContextT, InputT, OutputT 为类型检查器提供了推断依据,但在运行时并不强制约束。构造器的核心逻辑分为两步:

第一步:初始化内部数据结构

python
# 源码位置:langgraph/graph/state.py,__init__ 方法
self.nodes = {}           # dict[str, StateNodeSpec] 节点注册表
self.edges = set()        # set[tuple[str, str]] 直接边集合
self.branches = defaultdict(dict)  # 条件边注册表
self.schemas = {}         # schema -> channel 映射缓存
self.channels = {}        # 全局 Channel 注册表
self.managed = {}         # 托管值注册表
self.compiled = False     # 是否已编译的标记
self.waiting_edges = set() # 多源汇聚边集合

第二步:解析状态模式

python
self.state_schema = state_schema
self.input_schema = cast(type[InputT], input_schema or state_schema)
self.output_schema = cast(type[OutputT], output_schema or state_schema)
self.context_schema = context_schema

# 核心:解析每个 schema 的字段,创建对应的 Channel
self._add_schema(self.state_schema)
self._add_schema(self.input_schema, allow_managed=False)
self._add_schema(self.output_schema, allow_managed=False)

注意 input_schemaoutput_schema 默认为 state_schema。这意味着如果不显式指定,图的输入和输出与状态具有相同的模式。但当你需要"输入只接受部分字段"或"输出只暴露部分字段"时,可以指定不同的 schema。

3.1.2 状态模式到 Channel 的转换

这是 StateGraph 构造过程中最重要的环节。LangGraph 的核心理念之一是让开发者用熟悉的 Python 类型系统来定义状态,然后由框架自动将类型注解转换为内部的 Channel 表示。这意味着开发者不需要直接与 Channel 打交道——他们只需要定义一个普通的 TypedDict 或 Pydantic 模型,框架就能理解每个字段应该使用什么样的存储和更新策略。

_add_schema 方法是理解 StateGraph 的关键。它调用 _get_channels 函数,将 Python 类型注解转换为 Channel 实例:

python
# 源码位置:langgraph/graph/state.py
def _get_channels(
    schema: type[dict],
) -> tuple[dict[str, BaseChannel], dict[str, ManagedValueSpec], dict[str, Any]]:
    if not hasattr(schema, "__annotations__"):
        # 没有字段注解的类型(如 Annotated[list, add_messages])
        # 使用 __root__ 作为单一 Channel
        return (
            {"__root__": _get_channel("__root__", schema, allow_managed=False)},
            {},
            {},
        )

    # 有字段注解的类型(TypedDict, dataclass, Pydantic BaseModel)
    type_hints = get_type_hints(schema, include_extras=True)
    all_keys = {
        name: _get_channel(name, typ)
        for name, typ in type_hints.items()
        if name != "__slots__"
    }
    return (
        {k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)},
        {k: v for k, v in all_keys.items() if is_managed_value(v)},
        type_hints,
    )

对于每个字段,_get_channel 函数按优先级检查三种情况:

对应的源码逻辑:

python
# 源码位置:langgraph/graph/state.py
def _get_channel(name, annotation, *, allow_managed=True):
    # 1. 检查是否为 ManagedValue(如 IsLastStep)
    if manager := _is_field_managed_value(name, annotation):
        return manager

    # 2. 检查是否有显式 Channel 注解
    #    如 Annotated[str, EphemeralValue]
    elif channel := _is_field_channel(annotation):
        channel.key = name
        return channel

    # 3. 检查是否有 Reducer 函数注解
    #    如 Annotated[list, operator.add]
    elif channel := _is_field_binop(annotation):
        channel.key = name
        return channel

    # 4. 默认:创建 LastValue Channel
    fallback: LastValue = LastValue(annotation)
    fallback.key = name
    return fallback

让我们看几个具体的映射例子:

python
from typing import Annotated
import operator

class State(TypedDict):
    # 情况1:无注解 -> LastValue
    name: str                              # LastValue(str)

    # 情况2:Reducer 函数 -> BinaryOperatorAggregate
    items: Annotated[list, operator.add]   # BinaryOperatorAggregate(list, operator.add)

    # 情况3:自定义 Reducer
    messages: Annotated[list, add_messages] # BinaryOperatorAggregate(list, add_messages)

    # 情况4:显式 Channel 类型
    temp: Annotated[str, EphemeralValue]   # EphemeralValue(str)

3.1.3 _is_field_binop 的 Reducer 检测

Reducer 检测逻辑值得特别关注。_is_field_binop 函数检查 Annotated 的最后一个元数据是否为接受两个参数的可调用对象:

python
# 源码位置:langgraph/graph/state.py
def _is_field_binop(typ):
    if hasattr(typ, "__metadata__"):
        meta = typ.__metadata__
        if len(meta) >= 1 and callable(meta[-1]):
            sig = signature(meta[-1])
            params = list(sig.parameters.values())
            if (
                sum(
                    p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
                    for p in params
                ) == 2
            ):
                return BinaryOperatorAggregate(typ, meta[-1])
            else:
                raise ValueError(
                    f"Invalid reducer signature. Expected (a, b) -> c. Got {sig}"
                )
    return None

这段代码做了两件事:(1) 检查最后一个元数据是否可调用;(2) 检查它是否恰好接受两个位置参数。如果签名不匹配(比如只有一个参数),会抛出明确的错误信息。这个设计确保了 Reducer 函数的签名正确性。

3.2 add_node:节点注册的五重重载

节点是 LangGraph 工作流中最基本的计算单元。每个节点封装了一段独立的处理逻辑——可以是调用大语言模型、执行外部工具、进行数据转换处理,或者任何 Python 可调用对象。add_node 是 StateGraph 最常用的方法,它被设计得足够灵活,以适应从简单函数到复杂 Runnable 对象的各种场景。它有五种重载签名来适应不同的调用方式:

3.2.1 名称推断机制

当第一个参数不是字符串时,LangGraph 会自动推断节点名称:

python
# 源码位置:langgraph/graph/state.py,add_node 方法
if not isinstance(node, str):
    action = node
    if isinstance(action, Runnable):
        node = action.get_name()      # Runnable 使用 get_name()
    else:
        node = getattr(action, "__name__", action.__class__.__name__)
        # 函数使用 __name__,类实例使用类名

这意味着:

python
def my_agent(state): ...
builder.add_node(my_agent)         # 名称自动推断为 "my_agent"

class MyAgent:
    def __call__(self, state): ...
builder.add_node(MyAgent())        # 名称自动推断为 "MyAgent"

builder.add_node("custom_name", my_agent)  # 显式指定名称

3.2.2 输入模式推断

输入模式推断是 StateGraph 的一个强大特性。它允许不同的节点看到不同的状态子集——一个只需要 query 字段的检索节点不需要接收完整的包含 messagestools 等字段的状态。这种选择性输入不仅减少了数据传递的开销,更重要的是明确了节点的依赖关系,使得代码的意图更加清晰。

add_node 不仅推断名称,还会尝试从函数签名中推断输入模式:

python
# 源码位置:langgraph/graph/state.py,add_node 方法(简化)
inferred_input_schema = None

if isfunction(action) or ismethod(action):
    hints = get_type_hints(action)
    if input_schema is None:
        first_parameter_name = next(
            iter(inspect.signature(action).parameters.keys())
        )
        if input_hint := hints.get(first_parameter_name):
            if isinstance(input_hint, type) and get_type_hints(input_hint):
                inferred_input_schema = input_hint

这段代码的含义是:如果节点函数的第一个参数有类型注解,且该注解是一个带有字段的类型(TypedDict/dataclass/Pydantic),则将其作为该节点的输入模式。这使得以下用法成立:

python
class FullState(TypedDict):
    x: int
    y: str
    z: list

class AgentInput(TypedDict):
    x: int
    y: str

# 通过类型注解推断,agent 只会接收 x 和 y 字段
def agent(state: AgentInput) -> dict:
    return {"z": [state["x"]]}

builder = StateGraph(FullState)
builder.add_node(agent)  # input_schema 自动推断为 AgentInput

3.2.3 Command 返回类型解析

LangGraph 1.x 引入了 Command 类型,允许节点通过返回值来控制流程。add_node 会解析函数的返回类型注解来提取目的地信息:

python
# 源码位置:langgraph/graph/state.py,add_node 方法(简化)
if rtn := hints.get("return"):
    rtn_origin = get_origin(rtn)
    # 处理 Union 类型:Union[dict, Command[Literal["a", "b"]]]
    if rtn_origin is Union:
        for arg in get_args(rtn):
            if get_origin(arg) is Command:
                rtn = arg
                break

    # 提取 Command[Literal["a", "b"]] 中的目的地
    if get_origin(rtn) is Command and (rargs := get_args(rtn)):
        if get_origin(rargs[0]) is Literal:
            ends = get_args(rargs[0])  # ("a", "b")

这使得以下模式可以正确渲染图的边:

python
def route_node(state: State) -> Command[Literal["agent", "tools"]]:
    if should_use_tools(state):
        return Command(goto="tools", update={"step": "tools"})
    return Command(goto="agent", update={"step": "agent"})

builder.add_node(route_node, destinations=("agent", "tools"))
# 或者依赖返回类型注解自动推断

基于 VitePress 构建