Skip to content

第2章 EngineCore:引擎的心脏

"A conductor does not make a sound. He depends on his ability to make other people powerful." — Benjamin Zander

本章要点

  • 理解 EngineCore 作为"指挥者"的职责边界:它协调所有人,却不替任何人干活
  • 跟随 EngineCoreProc.run_busy_loop 逐行走完主循环的 6 个阶段
  • 掌握 add_request / abort_request 的数据契约:跨进程传的是什么,为什么这么传
  • 读懂 ZMQ 三 socket 通信协议:PULL / PUSH / PAIR 各自承担哪种语义
  • 看懂 EngineCoreClient 的三种形态(Sync / Async / MultiprocClient)以及调用方如何无感切换
  • 理解请求生命周期状态机的每一个迁移背后的触发条件
  • 分析 V1 如何用"异步流水线 + 后台线程"把 CPU 调度与 GPU 计算叠到一起
  • 掌握多数据并行(DP)下 EngineCore 的跨 rank 协调机制
  • 学会读 EngineCore 的信号处理与优雅退出路径,避免生产环境僵尸进程
  • 知道在哪些观测点挂钩子来排查"吞吐上不去"类问题

2.1 指挥者模式:不演奏,却不可或缺

在一个交响乐团中,指挥不演奏任何乐器。他不吹长笛,不拉小提琴,不敲定音鼓。但如果指挥离开,乐团会在几小节内陷入混乱——节奏失序,声部冲突,音乐瓦解。

EngineCore 就是 vLLM 的指挥。

它不做分词——那是 API Server 的工作。它不做 GPU 计算——那是 Worker 的工作。它甚至不做具体的调度决策——那是 Scheduler 的工作。但它协调所有这些组件的节奏,确保它们在正确的时刻做正确的事。

这种"不做具体事,只做协调"的设计,是典型的中介者模式(Mediator Pattern)的工程落地。放到 V1 重构的语境下看,它解决了 V0 时代一个非常痛苦的问题:Scheduler 和 Executor 之间的直接耦合让两者无法独立演化。把它们都挂到 EngineCore 下,用明确的 SchedulerOutput / ExecutorOutput 数据契约串起来,调度算法改了不影响执行路径,执行后端(GPU / TPU / CPU)换了不影响调度逻辑。

我们先从最宏观的视角看 EngineCore 在一次推理步骤中指挥的"声部":

注意这张图里 EngineCore 出现了 6 次,但没有任何一条箭头是"EngineCore 做 XXX"——它永远是"让别人做 XXX"的那一方。这就是指挥者模式的精髓:最重要的组件恰恰是最"无所事事"的那一个

让我们打开 vllm/v1/engine/core.py,看看这位指挥是怎么指挥的。

2.2 主循环:逐行拆解 run_busy_loop

EngineCore 有两个核心类:EngineCore(基类)与 EngineCoreProc(多进程形态下的子进程入口)。生产环境几乎总是走 EngineCoreProc,它把 EngineCore 的功能包在一个独立进程里,通过 ZMQ 和前端 API Server 通信。我们直接从 EngineCoreProc.run_busy_loop 开始:

python
# vllm/v1/engine/core.py(精简并加注释)
class EngineCoreProc(EngineCore):
    def run_busy_loop(self) -> None:
        """EngineCore 子进程的主循环。这个函数一跑起来就不会退出,
        除非收到 SHUTDOWN 信号或遇到未捕获异常。"""

        # (1) 等待所有 Worker 启动就绪 —— DP / TP 场景下需要 rendezvous
        self._ensure_workers_ready()

        while True:
            # (2) 批量拉取输入:把 ZMQ input_socket 上累积的所有消息一次性读光
            #     (非阻塞;如果没有新消息就立即返回空列表)
            self._process_input_queue()

            # (3) 如果没有正在进行的请求,阻塞等待下一个输入
            if not self.scheduler.has_requests():
                self._wait_for_work()
                continue

            # (4) 核心一步:调度 + 执行 + 处理输出
            outputs = self.step()

            # (5) 发回 API Server
            if outputs:
                self._send_outputs(outputs)

寥寥 10 行伪代码,却承载了整个 V1 推理引擎的节拍。我们逐步骤拆解。

(1) Worker 就绪同步

_ensure_workers_ready 内部做了一件听起来简单但实现很复杂的事:等所有 Worker 进程(在 TP / DP 场景下是多个 Python 进程、每个绑一张 GPU)都初始化完模型权重并进入可执行状态,再让主循环开始跑。为什么重要?因为 ZMQ 是一个不会"抗拒"连接的协议——input_socket 一创建就能收消息,如果 Worker 还没就绪就收到请求,EngineCore 会立刻把调度结果发下去,触发 Worker 的 CUDA 内核,但权重还没 load 完——结果就是一串莫名其妙的报错。

V1 用的是 Barrier 语义:主进程创建一个 multiprocessing.Event,每个 Worker 初始化完毕后 set 它,_ensure_workers_ready 阻塞 wait 直到所有 Worker 都 set。这种显式的屏障比 V0 时代的"睡一会儿试试看"优雅得多,也彻底消除了时序竞态。

(2) 批量拉取输入:减少系统调用

_process_input_queue 是一个贪心消费:

python
def _process_input_queue(self) -> None:
    """把 ZMQ socket 上当前可读的所有消息一次性处理完。"""
    while True:
        try:
            # DONTWAIT 让 recv 立即返回而不是阻塞
            frames = self.input_socket.recv_multipart(flags=zmq.DONTWAIT)
        except zmq.Again:
            # 当前没有更多消息了 —— 退出内层 while,回到主循环
            return

        msg_type, payload = self._decode(frames)
        if msg_type == EngineCoreRequestType.ADD:
            self._handle_add_request(payload)
        elif msg_type == EngineCoreRequestType.ABORT:
            self._handle_abort(payload)
        elif msg_type == EngineCoreRequestType.SHUTDOWN:
            self._handle_shutdown()

为什么不是每次主循环只处理一条消息?因为在高 QPS 场景下,前端可能短时间内塞进来几十条 add_request。如果每次主循环只处理一条,那处理完一条就要跑一整轮调度+GPU 计算,剩下的消息要积压若干个 step 才能入队——这会显著拖慢请求的首 token 延迟。批量拉取把"接收"和"调度"解耦,一次把所有等待中的消息吸进来,然后在后续的 step 中一起调度。

(3) 无请求时的等待策略

self.scheduler.has_requests() 返回 False 意味着目前没有任何 WAITING 或 RUNNING 的请求。这时候 GPU 完全空闲,如果还继续跑 step() 就会空转浪费 CPU。V1 的做法是调用 _wait_for_work()

python
def _wait_for_work(self) -> None:
    """阻塞直到 input_socket 上有消息。配合 ZMQ poller 使用。"""
    poller = zmq.Poller()
    poller.register(self.input_socket, zmq.POLLIN)
    poller.poll()  # 阻塞等待
    # 返回后立即回到主循环顶部重新 _process_input_queue

这里的关键是 poller.poll() 会阻塞当前线程,把 CPU 让给操作系统调度器。跟 time.sleep(0.01) 之类的轮询相比,这种 epoll-based 等待既没有轮询延迟,也不会占着 CPU 干等——典型的 Linux 事件驱动设计。

(4) step():引擎的一拍

python
def step(self) -> list[EngineCoreOutput]:
    # 4.1 调度:决定本步执行哪些请求、每个请求执行多少 Token
    scheduler_output = self.scheduler.schedule()

    # 4.2 执行:把调度结果交给 Executor 跑前向传播
    executor_output = self.model_executor.execute_model(scheduler_output)

    # 4.3 处理输出:更新请求状态、检查完成条件、构造输出包
    engine_core_outputs = self.scheduler.update_from_output(
        scheduler_output, executor_output
    )
    return engine_core_outputs

这三行是整个 V1 推理引擎的灵魂。注意它们的数据流:

注意 update_from_output 虽然方法名在 Scheduler 上,但它修改的是 Scheduler 内部状态(RUNNING / FINISHED 队列、KV 块分配表)然后把需要发给前端的数据打包成 list[EngineCoreOutput] 返回。这个方法是 V0 → V1 重构的受益者之一:V0 里这块逻辑散落在 LLMEngineBlockManagerOutputProcessor 三个类里,难以追踪状态变化;V1 把所有"一步结束后的状态更新"集中到 Scheduler,让一致性检查变得可行。

(5) 发回 API Server

_send_outputsoutput_socket(PUSH),把 EngineCoreOutputs 序列化后扔到 ZMQ buffer:

python
def _send_outputs(self, outputs: list[EngineCoreOutput]) -> None:
    # msgpack 比 pickle 快 3-5 倍,体积也小 30%+,是 V1 的默认序列化
    msg = msgpack.packb({
        "outputs": [o.to_dict() for o in outputs],
        "scheduler_stats": self.scheduler.make_stats(),
    })
    self.output_socket.send(msg, flags=zmq.DONTWAIT)

这里的 DONTWAIT 是一个值得关注的设计选择:如果 API Server 消费得不够快,ZMQ buffer 满了,这里会立即抛 zmq.Again 而不是阻塞 EngineCore。V1 把这种情况视为"背压信号"——EngineCore 不应该为了等 API Server 消费而停下来,而是应该继续推进计算,让 ZMQ 的 HWM(high water mark)自动丢弃最老的消息(实际上更常见的是把 HWM 设得很大,让这个"丢"永远不发生)。

2.3 请求的注入:add_request 全链路

从 API Server 发起 add_request 到请求真正被 Scheduler 看到,中间要穿过进程边界、经过数据结构转换、更新内部索引。我们跟着一个请求走完这段路:

关键点有三个。

第一,API Server 侧就完成分词。V0 时代分词是在 LLMEngine 里做的,导致引擎必须知道 tokenizer 路径。V1 把分词前移到 API Server,让 EngineCore 只处理 Token IDs——这带来两个好处:引擎进程不需要加载 tokenizer(节约几百 MB 内存),也可以用不同 tokenizer 前端同时接入同一个引擎。

第二,跨进程传递的是 EngineCoreRequest 不是 Request。这是一个精心设计的分离:

python
# vllm/v1/engine/__init__.py
@dataclass
class EngineCoreRequest:
    """最小化的跨进程传输对象。只包含不可变的请求信息。"""
    request_id: str
    prompt_token_ids: list[int]
    mm_inputs: Optional[list[MultiModalKwargs]]
    mm_hashes: Optional[list[str]]
    mm_placeholders: Optional[list[PlaceholderRange]]
    sampling_params: SamplingParams
    pooling_params: Optional[PoolingParams]
    eos_token_id: Optional[int]
    arrival_time: float
    lora_request: Optional[LoRARequest]
    cache_salt: Optional[str]

# vllm/v1/request.py
class Request:
    """EngineCore 内部使用的带状态对象。包含 KV 块引用、当前 token 数、
    状态标志(WAITING / RUNNING / FINISHED / ...)等 runtime 信息。"""
    def __init__(self, core_req: EngineCoreRequest, ...):
        self.request_id = core_req.request_id
        self.prompt_token_ids = core_req.prompt_token_ids
        self.num_computed_tokens = 0
        self.output_token_ids: list[int] = []
        self.status = RequestStatus.WAITING
        self.kv_block_ids: list[int] = []
        # ... 还有十几个 runtime 字段

跨进程传最小的那个,是因为每多传一个字节都要走 pickle/msgpack 序列化 + ZMQ 发送 + 反序列化。Runtime 状态(比如 KV 块分配)根本不需要让 API Server 知道,所以干脆不放进去。这种"数据契约最小化"的原则贯穿 V1:任何跨进程、跨网络的结构体都应该只携带接收方真正需要的字段。

第三,scheduler.add_request 在主循环外部被调用。这点在并发安全上很微妙——严格说来,add_request 是在 _process_input_queue 里调用的,而 _process_input_queue 在主循环的顶部运行,和 step() 不重叠。换句话说,Scheduler 的状态修改要么发生在处理输入阶段,要么发生在 step 阶段,两者是串行的。整个 EngineCore 进程其实是单线程的(只有一个主循环线程),所以不存在锁问题。这也是 V1 能把实现写得如此紧凑的底层原因——复杂性被并发安全约束限制得很死。

2.4 取消请求:abort_request 的三种时机

abort_request 是另一个必须实现正确才能在生产环境活下去的机制。用户关浏览器、HTTP 连接超时、应用层主动取消——这些都要能干净地把请求从引擎里撤走、释放资源。

取消可能发生在三个时刻:

对应的源码:

python
def abort_requests(self, request_ids: list[str]) -> None:
    """从引擎中移除给定的请求。支持同时取消多个。"""
    # 注意:V1 里 scheduler 维护了 requests 字典,用 request_id 直接查
    self.scheduler.finish_requests(
        request_ids=request_ids,
        finished_status=RequestStatus.ABORTED,
    )

# vllm/v1/core/sched/scheduler.py
def finish_requests(self, request_ids, finished_status):
    for rid in request_ids:
        req = self.requests.pop(rid, None)
        if req is None:
            continue  # 已经不存在(可能已完成或从未注册)
        if req.status == RequestStatus.WAITING:
            self.waiting.remove(req)
        elif req.status == RequestStatus.RUNNING:
            self.running.remove(req)
            # RUNNING 状态下必须释放 KV 块
            self.kv_cache_manager.free(req)
        req.status = finished_status

注意 pop 配合 if req is None: continue 的防御式写法。这是因为 abort_requests 可能收到一个"在飞行中"的请求 id——前端刚发 abort 时,引擎可能刚好把这个请求标记为 FINISHED 并清理掉了。这种 TOCTOU(Time Of Check Time Of Use)竞态在分布式系统里很常见,V1 的解决方案是让 abort 做到幂等:取消一个不存在的请求不应该报错。

还有一个不明显的细节:abort 通过普通的 ZMQ 消息传递,而不是走"紧急通道"。这意味着如果 _process_input_queue 前面排了很多 add_request,abort 要等它们全部处理完才会轮到。在极端高负载下这可能带来几毫秒延迟——但对于一个已经运行了若干秒的请求来说,几毫秒的取消延迟完全可以接受。

2.5 ZMQ 通信协议:三条 Socket 各司其职

MultiprocClient 和 EngineCoreProc 之间用 ZMQ 通信。具体架构如下:

三条 Socket 各自承担不同语义:

inputs (PUSH → PULL):前端 → 引擎的单向队列。MultiprocClient 的 add_request / abort_request 都走这条。PUSH/PULL 是 ZMQ 的"任务分发"模式,多生产者/多消费者天然 round-robin——不过我们这里永远是一对一。

outputs (PUSH → PULL):引擎 → 前端的单向队列。EngineCore 产生的 EngineCoreOutputs 都走这条。这里特别强调"单向"——API Server 不会往这条上发任何东西,引擎也不会从上面读。这种方向性约束让协议实现起来极其简单:每条消息读完就处理,不需要维护请求-响应关联。

ready (PAIR):握手通道。子进程启动完成后发送一个 ready 消息,让前端知道可以开始发请求了。PAIR 是 ZMQ 的双向 1-1 通道,适合这种"握手"场景。

为什么不把所有流量合并到一条 PAIR?因为 PUSH/PULL 的单向语义让背压控制更清晰。比如我们可以把 outputs socket 的 SNDHWM(发送方高水位)设置得很大(比如 10000 条),让引擎永远不会因为前端消费慢而阻塞;而 inputs socket 的 RCVHWM(接收方高水位)可以设小一些,避免积压过多未处理请求。这种精细化控制在 PAIR 上做不到。

消息编码:V1 默认用 msgpack。相比 pickle,msgpack 有三大优势:(1) 跨语言,未来如果想用 Rust / Go 写前端,协议不需要改;(2) 速度快 3-5 倍;(3) 序列化后体积小 30%+,减少 IPC 吞吐压力。

python
# vllm/v1/engine/core_client.py
import msgspec

class _SerializedEngineCoreRequest:
    """msgpack 友好的 EngineCoreRequest 版本,用 msgspec.Struct 定义。"""
    type: int  # EngineCoreRequestType
    data: bytes  # msgpack 序列化后的 payload

encoder = msgspec.msgpack.Encoder()
decoder = msgspec.msgpack.Decoder(_SerializedEngineCoreRequest)

msgspec 是 Python 生态里性能最好的 msgpack 库之一,它在 C 层做 schema 绑定的序列化,不走 Python 字典反射,比 msgpack-python 还要快 2 倍左右。

2.6 EngineCoreClient:同一套接口,三张面孔

API Server 不直接和 EngineCore 打交道——它拿到一个 EngineCoreClient,按这个接口调用:

python
class EngineCoreClient(ABC):
    @abstractmethod
    def add_request(self, request: EngineCoreRequest) -> None: ...
    @abstractmethod
    def abort_requests(self, request_ids: list[str]) -> None: ...
    @abstractmethod
    def get_output(self) -> EngineCoreOutputs: ...
    @abstractmethod
    def shutdown(self) -> None: ...
    # ... 还有十几个方法

三种具体实现:

InprocClient——最简单的实现。EngineCore 在同一进程里,add_request 就是一个普通的 Python 方法调用,get_output 触发一次 step()。适合离线批处理场景(llm = LLM(model="...") 这种用法)。

AsyncMPClient——同进程但全 async。add_request / get_output 返回 Future,调用方不需要阻塞等 GPU 计算。其实内部也是单线程的,只是把 step 包在一个 asyncio 任务里跑。

SyncMPClient——子进程 + ZMQ 的同步形态。add_request 是发 ZMQ 消息然后立即返回,get_output 是阻塞地从 ZMQ 读消息。主要给某些需要"同步感觉但要跨进程"的测试工具用。

DPAsyncMPClient——生产环境用得最多的那个。N 个 DP rank,每个 rank 一个 EngineCoreProc 子进程,每个子进程自己绑若干张 GPU(一般是 TP=8 或 16 张)。这个 Client 内部维护 N 组 ZMQ socket,按负载把请求路由到合适的 rank。

关键是这四种 Client 的外部 API 完全一致。上层 API Server 代码写的时候只 client.add_request(...),切换部署模式只是 engine_core_client_from_engine_args(...) 里换一个分支:

python
def make_engine_core_client(engine_args):
    if engine_args.distributed_executor_backend is None:
        return InprocClient(engine_args)
    elif engine_args.distributed_executor_backend == "mp":
        if engine_args.data_parallel_size > 1:
            return DPAsyncMPClient(engine_args)
        else:
            return AsyncMPClient(engine_args)

一个接口、四种实现,这正是"依赖倒置"的实战价值——上层代码面向抽象,切底层不需要动。

2.7 请求的一生:状态机与时序

把请求在 EngineCore 中经历的所有状态画成状态机:

每一个迁移都有明确的触发条件,我们逐个过一遍。

WAITING → RUNNING(被调度):这是最常见的迁移,发生在 Scheduler.schedule() 选中这个请求时。进入 RUNNING 前必须完成两件事:给请求分配 KV 块(prefill 阶段一次性分配、decode 阶段按需增量分配)、把请求加入 running 队列。V1 里这一步的代码在 Scheduler._schedule_running_and_waiting() 中,后续章节会详细讲。

RUNNING → PREEMPTED(抢占):当 KV 池用满、新请求又到了、Scheduler 的抢占策略(FCFS / Priority)决定牺牲某个 RUNNING 请求时触发。V1 的抢占非常轻量:直接释放它的 KV 块,把状态改回 WAITING。重要的是 num_computed_tokens 不清零——下次被调度时可以增量恢复。

RUNNING → FINISHED_STOPPED:遇到 EOS token 或用户指定的 stop 字符串时触发。Scheduler.update_from_output() 每一步都会检查新生成的 token 是否触发了停止条件:

python
# 简化自 vllm/v1/core/sched/scheduler.py
def _check_stop(self, req, new_token_id):
    if new_token_id == req.eos_token_id and not req.ignore_eos:
        return FinishReason.STOP
    if len(req.output_token_ids) >= req.sampling_params.max_tokens:
        return FinishReason.LENGTH
    if req.sampling_params.stop_token_ids and new_token_id in req.sampling_params.stop_token_ids:
        return FinishReason.STOP
    return None

任何状态 → FINISHED_ABORTED:上一节已经讲过。

为什么要把 FINISHED 细分成三种?因为前端需要知道"为什么停止"来决定下一步行为——OpenAI 兼容 API 的 finish_reason 字段直接映射到这个枚举:stop / length / abort。把这个信息从 EngineCore 带到 API Server 非常重要,否则前端无法正确填充响应。

2.8 异步流水线:让 CPU 和 GPU 并行跑

EngineCore 的一个精妙之处是它怎么把**调度(CPU)执行(GPU)**重叠起来,不让任一方干等。

观察一个典型的 step 时间分布:

调度 (Scheduler.schedule)                    : ~2-5 ms CPU
构建输入张量 (prepare_inputs)                 : ~3-8 ms CPU
    ├─ 拼接 block_ids、input_ids、position_ids
    ├─ build slot_mapping
    └─ copy 到 GPU (cudaMemcpyAsync)
Worker 前向传播 (execute_model)               : ~20-200 ms GPU
    ├─ embedding lookup
    ├─ N 层 transformer block
    ├─ LM head
    └─ sampling
update_from_output                           : ~1-3 ms CPU

如果纯串行跑,一个 step 就是这些阶段的总和。V1 利用 CUDA Stream 的异步性质:

Step N:
  CPU:  [schedule][prepare][launch kernels][update_output]
                                  ↓(kernel 启动后立即返回)
  GPU:                            [execute]

  CPU(Step N+1): [schedule][prepare][launch]
  GPU(Step N+1):                   [execute]

关键在于 execute_model 并不真的等 GPU 跑完。它启动 CUDA kernel 后立刻返回(kernel 在 stream 里异步执行),CPU 就可以开始准备下一步了。下一步的 schedule 不依赖 GPU 结果(因为 Scheduler 看的是 token 数量而不是 token 值),所以可以先跑起来。这种"投机"式调度让 V1 的每 step 开销几乎被 GPU 计算本身吸收掉了。

但也有例外——如果采样时要做 stop token 检查,就必须等 GPU 出来新 token 才知道要不要继续。V1 的解法是"延迟一拍":step N 采样出的 token,在 step N+1 开始时再检查停止条件。这样 step N+1 的调度阶段不阻塞 step N 的 GPU 执行。代价是"overshoot 一个 token"——明明 step N 已经命中 stop 了,但 step N+1 又多算了一个 token 才发现。这一个 token 的浪费,换回了整个流水线的重叠,完全值得。

2.9 数据并行(DP)协调:多个 EngineCore 的共舞

当部署模式是 data_parallel_size > 1 时,会有多个独立的 EngineCore 子进程,每个绑一组 GPU。这时候 EngineCore 之间需要协调吗?

一般情况下不需要。DP 的核心理念就是"各干各的",每个 rank 独立地处理分配给它的请求。DPAsyncMPClient 在前端把新请求路由到某个 rank,然后这个 rank 的 EngineCore 独立走完 add → schedule → execute → output 的全流程。

但有一种情况需要:当某些 rank 没有请求、某些 rank 很忙时,忙的 rank 跑 all_reduce 类的集合通信会卡住(因为要等所有 rank 到齐)。V1 用了一个"dummy forward"的巧妙方案解决这个:

python
# 简化自 vllm/v1/engine/core.py
def _ensure_dp_barrier(self, scheduler_output):
    """DP 场景下,确保所有 rank 同步地调用 execute_model,即使
    某个 rank 这一步没有实际请求。"""
    num_reqs = scheduler_output.total_num_scheduled_tokens
    # 通过一个小 all_reduce 让所有 rank 知道彼此当前的负载情况
    all_num_reqs = dp_comm.all_gather(num_reqs)
    if num_reqs == 0 and any(n > 0 for n in all_num_reqs):
        # 别的 rank 在跑,我也得跑,否则它们的 all_reduce 会等我
        # 跑一个空的"占位"step
        return self._run_dummy_step()
    return self._run_real_step(scheduler_output)

这是一个挺反直觉的设计:一个 rank 明明没活干,还得跑个空 step 假装在干活。但这是使用集合通信的必然代价——torch.distributed 的 all_reduce 是同步 barrier,缺一个 rank 就全卡住。V1 把这部分复杂性收敛在 EngineCore 这一层,让上层完全无感。

2.10 容错与优雅退出

生产环境中,引擎不能随便崩溃,也不能死得难看。EngineCore 实现了几层保障:

信号处理——子进程启动时注册 SIGTERM / SIGINT handler:

python
def _install_signal_handlers(self):
    def _handle(signum, frame):
        logger.info("EngineCore received signal %d, initiating shutdown", signum)
        # 不直接退出 —— 往输入队列塞一条 SHUTDOWN 消息
        # 让主循环走正常的退出路径
        self._shutdown_requested = True

    signal.signal(signal.SIGTERM, _handle)
    signal.signal(signal.SIGINT, _handle)

注意这里没有调用 sys.exit 而是设置一个标志。为什么?因为信号可能打断正在进行的任何系统调用(包括 CUDA 内核启动),这时候直接退出会把 GPU 驱动留在一个不一致状态。V1 的做法是让信号处理器只做最小动作(设置标志 + log),真正的退出动作由主循环在"安全点"完成。

优雅退出流程

步骤 F 很关键:不让当前 step 跑完就强退,GPU 上正在跑的内核可能没有 sync,后续进程起来发现显存没释放干净。V1 用 torch.cuda.synchronize() 确保当前 step 的所有 GPU 操作都完成再往下走。

Worker 崩溃恢复——MultiprocExecutor 启动 Worker 时会 fork 子进程并监控它们。如果一个 Worker 意外退出(OOM、内核 bug),MultiprocExecutor 的 waitpid 回调会触发:

python
def _on_worker_exit(self, pid, exit_code):
    if exit_code != 0:
        logger.error("Worker %d died with exit code %d", pid, exit_code)
        # 把该 Worker 负责的所有请求都标记为 FAILED
        for req_id in self.worker_req_map[pid]:
            self.scheduler.finish_requests([req_id], RequestStatus.FINISHED_ABORTED)
        # 触发整个引擎的降级:要么重启 Worker,要么让整个 EngineCore 退出
        self._trigger_engine_shutdown()

这里选择了保守策略——不尝试单独重启 Worker 进程,而是让整个 EngineCore 退出,让外部编排系统(Kubernetes、systemd)重启整个服务。为什么这么保守?因为模型权重在内存里,重启 Worker 要重新 load 权重(几十秒),期间无法服务。倒不如直接死,让 K8s 起一个新 Pod,健康检查期间流量自动切走。这是一种典型的"fail fast + 外部恢复"哲学。

2.11 可观测性:EngineCore 在默默发出的信号

运维 vLLM 生产环境时,你会想知道这些事:现在有多少请求在 RUNNING?KV 块使用率多少?平均每步生成多少 token?EngineCore 的每一步都会产出一批统计数据:

python
# vllm/v1/engine/core.py 中 step 返回时附带 stats
class SchedulerStats:
    num_running_reqs: int
    num_waiting_reqs: int
    kv_cache_usage: float  # 0.0 – 1.0
    num_prompt_tokens_this_step: int
    num_generation_tokens_this_step: int
    preempted_reqs_this_step: int
    prefix_cache_hit_rate: float

这些指标通过 _send_outputsscheduler_stats 字段带回到 API Server,由 API Server 暴露为 Prometheus metric:

vllm_num_requests_running{...}        : gauge
vllm_num_requests_waiting{...}        : gauge
vllm_kv_cache_usage_ratio{...}        : gauge
vllm_prompt_tokens_per_step{...}      : histogram
vllm_generation_tokens_per_step{...}  : histogram
vllm_preempted_requests_total{...}    : counter
vllm_prefix_cache_hit_ratio{...}      : gauge

生产调优最常看的三个:

  1. kv_cache_usage_ratio——持续接近 1.0 意味着 KV 池快满,将触发抢占。要么加 GPU,要么降低 max_num_seqs
  2. num_requests_waiting——持续 > 0 意味着 Scheduler 吃不下了,请求开始排队。通常伴随首 token 延迟上升
  3. preempted_requests_total——持续增长意味着显存频繁紧张,被抢占的请求要重新 prefill,吞吐和延迟都会受损

此外 EngineCore 里有大量 tracing span(通过 vllm/tracing.py 的 OpenTelemetry 封装),可以追踪单个请求从 add 到 finish 的全链路耗时。生产环境建议接 Jaeger 或 Tempo 保留几天的 trace,发生"某个请求莫名其妙特别慢"的问题时能直接定位。

2.12 本章小结

EngineCore 是 vLLM 的心脏,但它的力量不在于"做什么",而在于"让正确的组件在正确的时刻做正确的事":

  • 指挥者模式——EngineCore 本身几乎没有业务逻辑,所有重活都交给 Scheduler、KV Cache Manager、Executor,自己只负责按节拍协调
  • 主循环 6 阶段——worker 就绪 → 批量拉输入 → 空闲等待 → 调度+执行+处理 → 发回输出,周而复始
  • 数据契约最小化——跨进程只传 EngineCoreRequest(不可变),进程内用 Request(带状态),界限清晰
  • 三条 ZMQ Socket——inputs (PUSH/PULL) + outputs (PUSH/PULL) + ready (PAIR),职责分离,背压可控
  • 四种 Client 面孔——Inproc / AsyncMP / SyncMP / DPAsyncMP,同一接口适配 4 种部署拓扑
  • 请求状态机——WAITING / RUNNING / PREEMPTED / FINISHED_{STOPPED,LENGTH,ABORTED},每个迁移都有明确触发条件
  • 异步流水线——CPU 调度与 GPU 计算通过 CUDA Stream 异步重叠,必要时"延迟一拍"避免阻塞
  • DP 协调——多 rank 通过 dummy step 保持集合通信同步
  • 容错——信号处理走"设置标志 + 主循环安全退出",Worker 崩溃走 fail-fast + 外部重启
  • 可观测——SchedulerStats 通过 scheduler_stats 携带到 API Server,暴露为 Prometheus 指标

下一章,我们将深入调度器 Scheduler——那个决定"谁先谁后、每人发多少 token"的裁判。它的每一个决策都直接影响吞吐量和延迟,是 vLLM 性能优化的核心战场。


源码导航

  • EngineCore 主类:vllm/v1/engine/core.pyEngineCore / EngineCoreProc
  • Client 实现:vllm/v1/engine/core_client.pyInprocClient / AsyncMPClient / SyncMPClient / DPAsyncMPClient
  • Request 数据结构:vllm/v1/request.py / vllm/v1/engine/__init__.py
  • MultiprocExecutor:vllm/v1/executor/multiproc_executor.py
  • Scheduler 入口:vllm/v1/core/sched/scheduler.py
  • Metrics:vllm/v1/metrics/

基于 VitePress 构建