从 model.generate 到手写 generate
1. 这一篇要解决什么问题
在上一篇文章里,我们从整体上看了一个 LLM 请求在推理系统里的生命周期。
一个请求会从 API 层进入系统,经过 tokenizer,变成内部 Request 对象,再进入 waiting 队列。调度器会决定它什么时候做 prefill,什么时候做 decode。模型执行完成后,sampler 会从 logits 中选出下一个 token。如果请求还没结束,它会继续进入下一轮 decode;如果已经结束,系统会释放它占用的 KV cache。
不过在真正开始写调度器、KV cache 和 PagedAttention 之前,我们还需要先回答一个更基础的问题:
一个 token 到底是怎么被生成出来的?
很多人在使用大语言模型时,最常写的是这样的代码:
outputs = model.generate(**inputs, max_new_tokens=128)
这行代码非常方便,但它隐藏了很多细节。
它隐藏了 tokenizer 如何把文本变成 token id。
它隐藏了模型 forward 之后 logits 是什么。
它隐藏了 greedy search、temperature、top k、top p 这些采样策略。
它也隐藏了停止条件是怎么判断的。
如果我们想实现自己的 mini vLLM,就不能一直依赖这个黑盒接口。我们需要把 model.generate() 拆开,用自己的代码完成一个最小推理闭环。
这一篇文章的目标很简单:
写一个最小版本的 generate 函数。
它暂时不考虑 KV cache,不考虑 batching,不考虑 continuous batching,也不考虑流式服务。我们只关心一件事:给定一个 prompt,如何一步一步生成新的 token。
完成这一篇后,我们会得到一个这样的接口:
text = simple_generate(
model=model,
tokenizer=tokenizer,
prompt="Explain KV cache in LLM inference.",
max_new_tokens=128,
temperature=0.8,
)
它看起来和 model.generate() 很像,但内部逻辑完全由我们自己控制。
这就是从“调用模型”走向“实现推理引擎”的第一步。
2. 准备模型和 tokenizer
为了降低理解成本,这一篇先直接使用 Hugging Face Transformers 加载模型。
后续我们不会长期依赖 model.generate(),但会继续复用 tokenizer 和模型权重。因为 mini vLLM 的重点不是从零训练一个模型,而是从零实现推理调度系统。
先写一个最小加载代码:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cuda",
)
model.eval()
这里有几个点需要注意。
第一,AutoTokenizer 负责文本和 token id 之间的转换。
第二,AutoModelForCausalLM 加载的是 causal language model,也就是常见的 decoder only 生成式语言模型。
第三,model.eval() 会把模型切换到推理模式,关闭 dropout 等训练相关行为。
第四,推理时我们通常会配合 torch.inference_mode(),这样可以避免构建梯度图,减少显存占用。
如果本地没有 GPU,也可以先用 CPU 跑一个小模型。代码会慢一些,但不影响理解主流程。
3. 文本是怎么变成 token 的
模型不能直接处理字符串,它只能处理整数形式的 token id。
例如输入:
Explain KV cache in LLM inference.
经过 tokenizer 后,会变成类似这样的整数序列:
input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_ids 的 shape 通常是:
[batch_size, sequence_length]
对于单请求推理来说,batch_size 等于 1。
我们可以写一个小例子观察它:
prompt = "Explain KV cache in LLM inference."
input_ids = tokenizer.encode(prompt, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(input_ids)
print(tokens)
你会看到一段文本被切成了多个 token。每个 token 对应一个整数 id。
从这一刻开始,模型关心的就不再是原始字符串,而是这一串 token id。
4. 模型 forward 会输出什么
接下来,我们把 input_ids 送进模型:
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device))
对于 causal language model 来说,outputs.logits 是最重要的结果。
它的 shape 通常是:
[batch_size, sequence_length, vocab_size]
其中:
batch_size 表示 batch 中有几个请求
sequence_length 表示当前输入序列长度
vocab_size 表示词表大小
如果我们输入了 10 个 token,模型就会输出 10 个位置上的 logits。
每个位置的 logits 都是一个长度为 vocab_size 的向量。它表示模型认为“这个位置的下一个 token 应该是什么”的原始分数。
对于文本生成来说,我们通常只关心最后一个位置的 logits。
next_token_logits = outputs.logits[:, -1, :]
它的 shape 是:
[batch_size, vocab_size]
对于单请求来说,就是:
[1, vocab_size]
这个向量里,每个元素对应词表中一个 token 的分数。分数越高,说明模型越倾向于选择这个 token 作为下一个 token。
5. 最简单的采样:greedy search
有了 next_token_logits 之后,最简单的生成方式就是选分数最高的 token。
这叫 greedy search。
next_token_id = torch.argmax(next_token_logits, dim=-1)
然后把这个 token 拼回输入序列后面:
input_ids = torch.cat([input_ids, next_token_id[:, None]], dim=-1)
这样,下一轮模型就会基于“原始 prompt 加新生成 token”的完整上下文继续生成。
我们可以把这个过程写成一个循环:
def greedy_generate(model, tokenizer, prompt, max_new_tokens=128):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
for _ in range(max_new_tokens):
with torch.inference_mode():
outputs = model(input_ids=input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_token_id[:, None]], dim=-1)
if next_token_id.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
这就是一个最小版本的 generate。
它已经包含了完整的生成闭环:
prompt
tokenize
model forward
取最后一个位置的 logits
选择下一个 token
拼回输入序列
判断是否结束
decode 成文本
这段代码很慢,也很朴素,但它非常重要。
因为后面的所有复杂系统,本质上都没有逃离这个循环。vLLM 做的事情,是把这个循环变成一个可以高效服务大量请求的推理引擎。
6. greedy search 的问题
greedy search 的优点是简单、稳定、可复现。
但它也有明显的问题:每一步都选当前概率最高的 token,容易导致输出过于保守,甚至重复。
比如模型在某一步认为:
token A 的概率是 0.31
token B 的概率是 0.29
token C 的概率是 0.28
greedy search 一定会选择 token A。
但实际上 token B 和 token C 也都很合理。对于开放式生成来说,我们往往希望模型保留一定随机性。
这时就需要 sampling。
sampling 的基本思路是:把 logits 转成概率分布,然后从这个概率分布里随机抽取一个 token。
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
不过,直接从完整词表中采样也会有问题。因为词表里有很多低概率 token,它们虽然概率很小,但仍然可能被抽中,从而导致输出质量变差。
所以我们通常会结合 temperature、top k 和 top p。
7. temperature:控制随机性
temperature 用来控制概率分布的平滑程度。
代码上非常简单:
logits = logits / temperature
当 temperature 较低时,分布会变得更尖锐,高概率 token 更容易被选中。
当 temperature 较高时,分布会变得更平滑,随机性更强。
如果 temperature 接近 0,效果会越来越接近 greedy search。
我们可以写一个函数:
def apply_temperature(logits, temperature):
if temperature is None or temperature <= 0:
return logits
return logits / temperature
实际工程里,temperature 等于 0 通常会被视为 greedy 模式,不再走随机采样。
8. top k:只从前 k 个 token 中采样
top k 的思路是:每一步只保留 logits 最高的 k 个 token,其余 token 全部屏蔽掉。
def apply_top_k(logits, top_k):
if top_k is None or top_k <= 0:
return logits
top_k = min(top_k, logits.size(-1))
values, _ = torch.topk(logits, top_k)
min_values = values[:, -1].unsqueeze(-1)
logits = torch.where(
logits < min_values,
torch.full_like(logits, float("-inf")),
logits,
)
return logits
这里的核心是把不在 top k 内的 token logits 设置为负无穷。这样经过 softmax 后,它们的概率就会变成 0。
9. top p:只从累计概率前 p 的 token 中采样
top p 也叫 nucleus sampling。
它不是固定保留前 k 个 token,而是先按概率从高到低排序,然后保留累计概率不超过 p 的那一部分 token。
例如:
token A: 0.40
token B: 0.25
token C: 0.15
token D: 0.10
token E: 0.10
如果 top p 等于 0.8,那么可能会保留 A、B、C,因为它们累计概率正好达到 0.8。
相比 top k,top p 更灵活。因为有些时候模型非常确定,只需要保留少量 token;有些时候模型不确定,就会保留更多候选 token。
一个简单实现如下:
def apply_top_p(logits, top_p):
if top_p is None or top_p <= 0 or top_p >= 1:
return logits
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_mask = cumulative_probs > top_p
sorted_mask[:, 1:] = sorted_mask[:, :-1].clone()
sorted_mask[:, 0] = False
sorted_logits = sorted_logits.masked_fill(sorted_mask, float("-inf"))
new_logits = torch.full_like(logits, float("-inf"))
new_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)
return new_logits
这段代码做了三件事。
首先,把 logits 从大到小排序。
然后,计算排序后 token 的累计概率。
最后,把累计概率超过 top p 的 token 屏蔽掉,再恢复到原来的词表顺序。
10. 写一个 Sampler
现在我们可以把 temperature、top k、top p 封装成一个 Sampler。
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class SamplingParams:
max_new_tokens: int = 128
temperature: float = 1.0
top_k: Optional[int] = None
top_p: Optional[float] = None
class Sampler:
def __init__(self, params: SamplingParams):
self.params = params
def sample(self, logits: torch.Tensor) -> torch.Tensor:
if self.params.temperature == 0:
return torch.argmax(logits, dim=-1)
logits = apply_temperature(logits, self.params.temperature)
logits = apply_top_k(logits, self.params.top_k)
logits = apply_top_p(logits, self.params.top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id.squeeze(-1)
有了 Sampler 之后,generate 函数就不需要关心具体采样策略了。它只需要把 logits 交给 sampler,然后拿到下一个 token。
这正是我们后面做 mini vLLM 时会反复使用的设计思路:
让 engine 负责流程,让 sampler 负责采样,让 scheduler 负责调度,让 cache manager 负责缓存管理。
每个模块只做自己的事情。
11. 实现 simple_generate
现在我们可以组合出一个稍微完整一点的生成函数。
def simple_generate(
model,
tokenizer,
prompt: str,
sampling_params: SamplingParams,
) -> str:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
sampler = Sampler(sampling_params)
prompt_length = input_ids.size(1)
for _ in range(sampling_params.max_new_tokens):
with torch.inference_mode():
outputs = model(input_ids=input_ids)
logits = outputs.logits[:, -1, :]
next_token_id = sampler.sample(logits)
input_ids = torch.cat([input_ids, next_token_id[:, None]], dim=-1)
if tokenizer.eos_token_id is not None:
if next_token_id.item() == tokenizer.eos_token_id:
break
generated_ids = input_ids[0, prompt_length:]
generated_text = tokenizer.decode(
generated_ids,
skip_special_tokens=True,
)
return generated_text
这里我们没有返回完整文本,而是只返回新生成的部分。
这是有意为之。
在推理服务里,用户通常已经知道自己的 prompt,更关心模型新增的输出。后面做流式输出时,我们也会每次返回新增 token 对应的文本片段。
调用方式如下:
params = SamplingParams(
max_new_tokens=128,
temperature=0.8,
top_k=50,
top_p=0.95,
)
text = simple_generate(
model=model,
tokenizer=tokenizer,
prompt="Explain KV cache in LLM inference.",
sampling_params=params,
)
print(text)
到这里,我们已经完成了一个手写版本的 generate。
12. 这个版本为什么很慢
虽然 simple_generate 可以工作,但它有一个非常大的问题:
每生成一个 token,它都会把完整的 input_ids 再送进模型算一遍。
例如 prompt 有 100 个 token,我们生成 50 个 token。
第 1 轮,模型处理 100 个 token。
第 2 轮,模型处理 101 个 token。
第 3 轮,模型处理 102 个 token。
一直到第 50 轮,模型处理 149 个 token。
这意味着模型在不断重复计算历史 token。
而这些历史 token 对应的 key 和 value 其实已经算过了。理论上,我们不应该每次都重新算。
这就是 KV cache 要解决的问题。
KV cache 的核心思想是:在 attention 里,把历史 token 的 key 和 value 保存下来。下一轮 decode 时,只需要计算新 token 的 key 和 value,然后和历史 cache 一起做 attention。
如果没有 KV cache,decode 会越来越慢。
如果有 KV cache,decode 每一轮主要只处理新增 token,同时访问历史 cache。
这也是 LLM 推理系统的第一个核心优化点。
不过在这一篇里,我们先不实现 KV cache。我们先把最小生成闭环彻底弄清楚。
下一篇,我们会在这个 generate 函数的基础上加入 KV cache。
13. 把 generate 抽象成推理引擎的 step
现在再往 mini vLLM 的方向思考一步。
simple_generate 是一个完整函数,它会一直循环,直到生成结束。
但真正的推理引擎通常不会这样写。
推理引擎更像是一个可以不断执行的状态机。每调用一次 step(),它就向前推进一小步。
比如:
engine.add_request(prompt="Explain KV cache.", sampling_params=params)
while engine.has_unfinished_requests():
outputs = engine.step()
for output in outputs:
print(output.text, end="", flush=True)
为什么要这么设计?
因为一旦我们进入多请求场景,就不能让某一个请求独占整个 generate 循环。
假设请求 A 要生成 1000 个 token,请求 B 此时刚刚到达。如果请求 A 一直占着循环不放,请求 B 就只能等待。
更合理的方式是:每一轮 step,系统从多个请求中挑选一批,一起执行一次模型 forward。执行完后,每个请求各自追加一个 token。然后系统进入下一轮 step。
这就是后面 continuous batching 的基础。
所以,从这一篇开始,我们要逐渐转变思维:
不要把 generate 看成一个一次性函数。
要把 generate 看成一个可以被调度器逐步推进的过程。
这一点非常重要。
14. 当前代码和 mini vLLM 的关系
这一篇的代码虽然简单,但它已经对应了 mini vLLM 里的几个关键模块。
SamplingParams 对应请求中的采样参数。
Sampler 对应后续的采样模块。
input_ids 对应一个请求当前维护的 token 序列。
model(input_ids=...) 对应后续的 Model Runner。
for 循环对应后续 engine 的多轮 step。
eos_token_id 判断对应请求完成条件。
也就是说,我们现在写的不是一次性的 demo,而是后续推理引擎的雏形。
后面我们会一步步把它拆开。
首先,把 token 序列和请求状态封装成 Sequence。
然后,把单请求循环改造成 engine.step。
再之后,加入多个 Sequence,让它们共享同一个调度循环。
最后,加入 KV cache 和 block manager,让每个请求不再反复计算历史 token。
15. 小结
这一篇我们完成了一个非常重要的基础工作:手写了一个最小版本的 generate。
它做了这些事情:
把 prompt 转成 token id
调用模型 forward
取最后一个位置的 logits
使用 greedy search 或 sampling 选择下一个 token
把新 token 拼回输入序列
判断 EOS 和 max_new_tokens
把生成结果 decode 成文本
这个版本还很慢,因为每一轮都会重新计算完整上下文。但它已经把 LLM 生成的核心循环暴露了出来。
下一篇文章,我们会开始优化这个循环。
我们会引入 KV cache,解释 attention 里的 key 和 value 到底是什么,为什么它们可以被缓存,以及如何在代码里使用 past_key_values 避免重复计算历史 token。
到那时,我们的 generate 函数会从“能跑”变成“更接近真实推理引擎”。