实现教学版 PagedAttention:从 block table 读取历史 KV cache
前面两篇文章,我们已经把问题推进到了 vLLM 最核心的部分。
第九篇里,我们解释了为什么 KV cache 需要分页式管理。在线 LLM serving 中,请求会不断进入、增长和结束。如果每个请求都持有一整段连续 cache,显存浪费会非常严重。如果按需申请连续空间,又会遇到碎片化。PagedAttention 的核心思想是:让 Sequence 在逻辑上拥有连续上下文,但物理上使用不连续的 KV cache blocks。
第十篇里,我们实现了 Block Manager。它负责维护全局 free blocks,为 prefill 分配 block,为 decode 追加 slot,在请求结束后释放 block。到这一步,我们已经有了 block table,也有了 physical KV cache pool 的概念。
但是,还有一个关键问题没有解决:
模型做 attention 时,怎么根据 block table 找到正确的 K 和 V?
这就是这一篇要讲的内容。
我们会实现一个教学版 PagedAttention。它不会追求高性能,也不会一上来写 CUDA kernel。它的目标只有一个:把 PagedAttention 的寻址过程讲清楚,让读者真正理解从 logical token index 到 physical KV cache address 的映射。
高性能可以以后再做,但正确性必须先看懂。
PagedAttention 改变的是 KV cache 的访问方式
先强调一个很重要的点。
PagedAttention 没有改变 attention 的数学含义。
对于一个正在 decode 的 token,它仍然会拿自己的 Q 去和历史所有 K 做匹配,然后用 softmax 得到 attention weights,再用这些 weights 加权求和历史 V。
概念上还是:
scores = Q @ K.T
weights = softmax(scores)
output = weights @ V
变化发生在 K 和 V 从哪里来。
普通 attention 假设历史 K 和 V 是连续存放的。
K cache:
token 0, token 1, token 2, token 3, ...
V cache:
token 0, token 1, token 2, token 3, ...
所以访问第 i 个历史 token 的 K 和 V,直接读第 i 个位置就行。
PagedAttention 下,历史 K 和 V 被切成多个 physical block。Sequence 的逻辑上下文仍然是连续的,但它对应的 physical block 可能是离散的。
例如 block size 是 4,某个 Sequence 的 block table 是:
[7, 2, 9]
这表示:
logical block 0 -> physical block 7
logical block 1 -> physical block 2
logical block 2 -> physical block 9
那么这个 Sequence 的逻辑 token 到物理位置的映射就是:
token 0 -> physical block 7, offset 0
token 1 -> physical block 7, offset 1
token 2 -> physical block 7, offset 2
token 3 -> physical block 7, offset 3
token 4 -> physical block 2, offset 0
token 5 -> physical block 2, offset 1
token 6 -> physical block 2, offset 2
token 7 -> physical block 2, offset 3
token 8 -> physical block 9, offset 0
...
所以 PagedAttention 的第一步不是计算 attention,而是做地址映射。
它要把逻辑上的历史 token 序列,还原成 attention 可以使用的 K 和 V。
从逻辑 token 到物理 block
PagedAttention 的核心寻址关系非常简单。
给定一个历史 token 的逻辑位置 token_index,以及 block size,可以先算出它属于哪个 logical block:
logical_block_id = token_index // block_size
再算出它在 block 内部的偏移:
block_offset = token_index % block_size
然后通过 block table 找到对应的 physical block:
physical_block_id = block_table[logical_block_id]
最终,K 和 V 的物理位置就是:
physical_block_id, block_offset
这条映射链路就是 PagedAttention 的核心:
logical token index
-> logical block id
-> physical block id
-> block offset
-> K/V cache address
只要这条链路是正确的,PagedAttention 在语义上就能得到和连续 KV cache 一样的结果。
区别只是底层存储不再连续。
这也是我们实现教学版 PagedAttention 的目标。
先不要考虑 kernel fusion,也不要考虑 shared memory、warp、memory coalescing 这些底层优化。先保证逻辑 token 序列能正确读回历史 K 和 V。
先定义一个简单的 KV cache 布局
为了方便理解,我们先定义一个简单的 physical KV cache pool。
假设我们只关注单层 attention,先忽略多层结构。
可以把 key cache 和 value cache 看成两个张量:
key_cache:
[num_blocks, block_size, num_heads, head_dim]
value_cache:
[num_blocks, block_size, num_heads, head_dim]
其中:
num_blocks 是全局 physical block 数量
block_size 是每个 block 能存多少 token
num_heads 是 attention head 数
head_dim 是每个 head 的维度
在真实模型里,每一层都需要自己的 K 和 V,所以通常还会多一个 layer 维度。
例如可以组织成:
[num_layers, num_blocks, block_size, num_heads, head_dim]
或者把每层的 cache 分开存。
教学版本里,为了减少干扰,我们可以先只讲单层。多层只是对每一层重复同样的过程。
现在,假设某个 Sequence 的 block table 是:
[7, 2, 9]
如果我们要读取逻辑 token 5 的 key,那么:
logical_block_id = 5 // 4 = 1
block_offset = 5 % 4 = 1
physical_block_id = block_table[1] = 2
所以它对应:
key_cache[2, 1]
value_cache[2, 1]
这个位置里保存的就是逻辑 token 5 的 K 和 V。
如果 block table 没错,Block Manager 写入 slot 没错,那么这里读出来的值就应该等价于连续 cache 里第 5 个位置的值。
教学版实现的思路:先 gather,再 attention
高性能 PagedAttention 不会真的把所有历史 K 和 V gather 成一个连续张量再算 attention。那样会增加额外拷贝,效率不高。
但教学版可以这么做。
因为我们的目标不是最高性能,而是让逻辑更清楚。
对于一个 Sequence,我们可以先根据 block table 把它的历史 K 和 V 收集出来,恢复成逻辑上连续的形式:
logical K:
[token 0, token 1, token 2, ..., token context_len - 1]
logical V:
[token 0, token 1, token 2, ..., token context_len - 1]
然后再用普通 attention 公式计算。
这个过程可以拆成两步。
第一步,paged gather。
根据 block table 从 physical KV cache 中取出历史 K 和 V
第二步,standard attention。
用当前 query 对 gather 后的 K 和 V 做普通 attention
这种实现虽然慢,但非常适合验证正确性。
我们甚至可以写两个版本做对比。
一个版本使用普通连续 KV cache。
一个版本使用 paged KV cache,先 gather 再 attention。
只要输入相同、cache 内容相同、mask 相同,两者输出应该一致。
这就是教学版 PagedAttention 最重要的验证方式。
单个 Sequence 的 paged gather
先考虑最简单情况:batch size 等于 1,只处理一个 Sequence。
输入包括:
query: 当前 token 的 Q
key_cache: 全局 physical key cache
value_cache: 全局 physical value cache
block_table: 当前 Sequence 的 block table
context_len: 当前可见历史长度
block_size: 每个 block 的 token 数
我们要做的是,把逻辑位置从 0 到 context_len 减 1 的 K 和 V 读出来。
伪代码大概是:
keys = []
values = []
for token_index in range(context_len):
logical_block_id = token_index // block_size
block_offset = token_index % block_size
physical_block_id = block_table[logical_block_id]
k = key_cache[physical_block_id, block_offset]
v = value_cache[physical_block_id, block_offset]
keys.append(k)
values.append(v)
最后把 keys 和 values 拼成连续张量:
keys: [context_len, num_heads, head_dim]
values: [context_len, num_heads, head_dim]
这个过程就是最直观的 paged gather。
它把分散在 physical blocks 中的 K 和 V,按照逻辑 token 顺序重新收集起来。
接下来,就可以做普通 attention。
从 gather 后的 K/V 做 attention
假设当前 decode token 的 query 形状是:
query: [num_heads, head_dim]
gather 后的 keys 和 values 是:
keys: [context_len, num_heads, head_dim]
values: [context_len, num_heads, head_dim]
为了方便计算,可以把维度调整成按 head 组织:
query: [num_heads, head_dim]
keys: [num_heads, context_len, head_dim]
values: [num_heads, context_len, head_dim]
然后对每个 head 做 attention。
scores = query @ keys.T
weights = softmax(scores)
output = weights @ values
输出形状是:
[num_heads, head_dim]
最后再把多个 head 合并回 hidden size。
这个过程和普通 attention 没有本质区别。
区别只在于 keys 和 values 是通过 block table gather 出来的。
所以可以这样理解:
PagedAttention = paged KV lookup + standard attention
高性能实现会把这两个阶段融合起来,不显式生成中间连续 K/V。
但教学版本可以先把它拆开,这样最容易看懂。
扩展到 batch:每个 Sequence 都有自己的 block table
真实推理里,decode batch 通常会同时处理多个 running Sequence。
每个 Sequence 都有自己的 block table,也有自己的 context length。
例如:
Sequence A:
block_table = [7, 2, 9]
context_len = 10
Sequence B:
block_table = [4, 1]
context_len = 6
Sequence C:
block_table = [8, 5, 3, 6]
context_len = 15
decode batch 中每个 Sequence 都输入一个 token,每个 token 都有自己的 query。
queries:
[batch_size, num_heads, head_dim]
PagedAttention 需要分别为每个 Sequence 读取自己的历史 K 和 V。
伪代码会变成:
outputs = []
for seq_idx in range(batch_size):
query = queries[seq_idx]
block_table = block_tables[seq_idx]
context_len = context_lens[seq_idx]
keys, values = gather_kv(
key_cache,
value_cache,
block_table,
context_len,
)
output = attention(query, keys, values)
outputs.append(output)
最后 outputs 拼成:
[batch_size, num_heads, head_dim]
这个实现非常直观,但性能很差。
原因是 Python 循环很多,gather 很慢,而且每个 Sequence 的 context length 不同,难以形成高效的大矩阵计算。
但作为教学版,它足够清楚。
它能帮助我们验证 block table 和 cache 写入逻辑是否正确。
只有当这个版本正确,我们才有资格去写更复杂的 Triton 或 CUDA kernel。
attention mask 在哪里
在 decode 阶段,如果 context_len 表示当前 token 可以看到的历史长度,那么 gather 出来的 K 和 V 本身就已经只包含可见 token。
此时不一定需要额外的 causal mask。
因为当前 token只对过去的 token 做 attention,不存在未来 token。
但在 prefill 阶段,情况不同。
prefill 一次性处理完整 prompt。prompt 内部每个 token 都只能看自己和之前的 token,不能看未来 token。这时需要 causal mask。
教学版 PagedAttention 可以先主要关注 decode。
原因是 PagedAttention 在 vLLM 中最核心的场景就是 decode 阶段高效访问历史 KV cache。
prefill 阶段也可以使用 block 化 cache 写入,但计算完整 prompt attention 时,很多实现会采用不同路径,例如普通 attention 或其他优化 backend。
所以这一篇我们主要实现 decode 侧的 PagedAttention。
这不是说 prefill 不重要,而是为了让主线更清晰。
prefill 负责把 prompt tokens 的 K 和 V 写入 blocks。
decode 负责通过 block table 读取这些 blocks,并生成后续 token。
写入 slot 和读取 block table 必须一致
这里有一个极其重要的正确性条件:
Block Manager 写入 cache 的位置,必须和 PagedAttention 读取 cache 的位置完全一致。
decode 时,Block Manager 会为当前输入 token 分配一个 slot。
physical_block_id, block_offset
ModelRunner 执行当前 token 的 forward 后,会把这个 token 的 K 和 V 写入这个 slot。
下一轮 decode 时,PagedAttention 会把这个 token 当作历史 token 读取出来。
如果写入时使用的 token index 和读取时使用的 token index 不一致,模型就会读错历史。
这类 bug 很难排查。
因为程序可能不会报 shape 错误,也不会报显存错误。它只是生成结果变差,或者在长输出时逐渐跑偏。
所以实现时必须明确两件事。
第一,Sequence 的 cached_token_count 表示已经写入 KV cache 的 token 数。
第二,Sequence 的 generated_token_ids 表示已经采样出来、可以返回给用户的 token 数。
这两个数并不总是同步。
为什么?
因为 prefill 处理 prompt 后,会采样出第一个 generated token。但这个 generated token 的 K 和 V 还没有写入 cache。它会作为下一轮 decode 的输入。只有下一轮 decode 执行完,它的 K 和 V 才被写入 cache。
所以一种清晰的语义是:
prefill:
输入 prompt tokens
写入 prompt tokens 的 KV
采样 token_1
generated_token_ids = [token_1]
cached_token_count = prompt_len
decode 第 1 轮:
输入 token_1
写入 token_1 的 KV
采样 token_2
generated_token_ids = [token_1, token_2]
cached_token_count = prompt_len + 1
decode 第 2 轮:
输入 token_2
写入 token_2 的 KV
采样 token_3
generated_token_ids = [token_1, token_2, token_3]
cached_token_count = prompt_len + 2
这里的 cached_token_count 是 PagedAttention 读取历史 KV 的重要依据。
decode 第 1 轮输入 token_1 时,它应该能看到 prompt tokens 的 KV。
decode 第 2 轮输入 token_2 时,它应该能看到 prompt tokens 加 token_1 的 KV。
如果 cached_token_count 提前加了,attention 可能读到还没写入的 token_2 的位置。
如果 cached_token_count 晚加了,attention 可能看不到 token_1。
这就是为什么状态语义比代码本身更重要。
和连续 KV cache 对齐验证
实现教学版 PagedAttention 时,我建议一定要做一个对齐测试。
思路是这样的:
先构造一份连续 KV cache。
contiguous_key_cache:
[context_len, num_heads, head_dim]
contiguous_value_cache:
[context_len, num_heads, head_dim]
再构造一份 paged KV cache。
把同样的 K 和 V 按照 block table 写入 physical blocks。
然后用同一个 query 分别计算两种 attention。
连续版本:
output_contiguous = attention(query, contiguous_k, contiguous_v)
Paged 版本:
paged_k, paged_v = gather_by_block_table(...)
output_paged = attention(query, paged_k, paged_v)
最后比较:
output_contiguous ≈ output_paged
如果这一步不一致,说明问题一定出在 block table、slot 写入、gather 顺序或 context_len 上。
这个测试非常重要。
因为 PagedAttention 的数学结果应该和普通 attention 一致。
它只是改变 KV cache 存储和访问方式,不应该改变模型输出。
在教学实现中,先用随机张量验证,再接入真实模型验证。
随机张量测试更容易定位问题,因为没有 tokenizer、采样、模型层数这些干扰。
为什么教学版会慢
教学版 PagedAttention 先 gather 再 attention,清晰但慢。
慢在哪里?
第一,它会显式构造连续 K/V。
这相当于多做了一次内存拷贝。
第二,它对每个 Sequence 单独循环。
batch 中每个请求 context_len 不同,Python 层循环会很慢。
第三,它没有利用 GPU 的高效内存访问模式。
真实 kernel 会尽量让线程以合适的方式读取 block,减少随机访问带来的开销。
第四,它没有融合 softmax、value 聚合和 block 读取。
高性能 attention kernel 通常会把多个步骤融合,避免中间结果反复写回显存。
所以教学版不能代表最终性能。
但它有一个巨大优势:可读。
它能让我们清楚地看到 PagedAttention 到底在做什么。
工程学习里,这一步很重要。
如果直接从 CUDA kernel 开始看,很容易被线程块、warp、shared memory、向量化加载这些细节淹没,反而看不清系统设计。
先写一个慢但正确的版本,再逐步优化,是理解复杂系统的更好路径。
从教学版到高性能版的方向
教学版正确之后,优化方向就清楚了。
第一个方向是去掉显式 gather。
不要先把 K/V 收集成连续张量,而是在 attention 计算时直接通过 block table 读取 physical cache。
也就是说,kernel 在遍历历史 token 时,直接完成:
token_index -> block_id -> offset -> load K/V
第二个方向是减少 Python 循环。
batch 内多个 Sequence 的 attention 应该交给 GPU kernel 并行处理,而不是在 Python 里逐个处理。
第三个方向是优化内存访问。
block size、cache layout、head 维度排列都会影响读取效率。高性能实现会尽量让相邻线程读取相邻内存,提高带宽利用率。
第四个方向是融合计算。
softmax、score 计算、value 聚合可以在 kernel 内部融合,减少中间张量的读写。
第五个方向是针对 decode 场景优化。
decode 每个 Sequence 通常只有一个 query,但 context_len 可能很长。这和 prefill 阶段的大矩阵 attention 不一样。kernel 设计需要围绕这个形态优化。
这些优化都很重要,但它们不改变 PagedAttention 的核心抽象。
核心仍然是:
每个 Sequence 通过 block table 访问自己的逻辑上下文
PagedAttention 如何接入 mini vLLM
接入 mini vLLM 后,几个模块的关系会变得更完整。
Block Manager 负责分配 blocks 和 slot。
Sequence 保存 block table、cached token 数和生成状态。
Scheduler 在调度前检查 token budget 和 block budget。
ModelRunner 在执行 decode 时,拿到当前 token、block table、context length、slot mapping 和 physical KV cache pool。
PagedAttention 根据 block table 读取历史 K/V,根据 slot mapping 写入当前 token 的 K/V。
Engine 在采样后更新 Sequence 状态,并在请求结束时释放 block。
这条链路可以概括成:
Scheduler 选择谁运行
Block Manager 决定 cache 写到哪里
ModelRunner 执行模型
PagedAttention 读取历史 KV
Sampler 生成新 token
Sequence 更新状态
Block Manager 回收 finished 请求的 blocks
到这里,mini vLLM 的核心形态就已经非常清楚了。
它不再只是一个调用 past_key_values 的外壳,而是开始拥有自己的 cache 管理和 attention 访问方式。
这正是从“使用模型推理”走向“实现推理系统”的分界线。
小结
这一篇我们实现了教学版 PagedAttention 的核心思想。
PagedAttention 不改变 attention 的数学语义。它改变的是 KV cache 的存储和读取方式。
在普通 attention 中,历史 K 和 V 通常被看作连续数组。
在 PagedAttention 中,Sequence 的上下文逻辑上连续,但物理上分散在多个 KV cache blocks 中。Sequence 通过 block table 记录逻辑 block 到 physical block 的映射。
读取某个历史 token 的 KV 时,需要经过这样的映射链路:
logical token index
-> logical block id
-> physical block id
-> block offset
-> K/V cache address
教学版实现可以先根据 block table 把历史 K/V gather 成连续张量,再调用普通 attention。这个版本性能不好,但非常适合理解和验证正确性。
实现时最需要注意的是写入 slot 和读取 block table 必须一致,尤其要区分 generated tokens 和已经写入 KV cache 的 cached tokens。很多 PagedAttention 的隐蔽 bug 都来自这个状态语义不清。
到这里,我们已经完成了一个正确性优先的 PagedAttention 版本。
下一篇文章,我们会讨论如何从正确走向更快。
我们会分析教学版为什么慢,高性能 attention kernel 要解决哪些问题,以及 Triton、FlashAttention、FlashInfer 这类 backend 在推理系统中分别扮演什么角色。