从 model.generate 到手写 generate

1. 这一篇要解决什么问题

在上一篇文章里,我们从整体上看了一个 LLM 请求在推理系统里的生命周期。

一个请求会从 API 层进入系统,经过 tokenizer,变成内部 Request 对象,再进入 waiting 队列。调度器会决定它什么时候做 prefill,什么时候做 decode。模型执行完成后,sampler 会从 logits 中选出下一个 token。如果请求还没结束,它会继续进入下一轮 decode;如果已经结束,系统会释放它占用的 KV cache。

不过在真正开始写调度器、KV cache 和 PagedAttention 之前,我们还需要先回答一个更基础的问题:

一个 token 到底是怎么被生成出来的?

很多人在使用大语言模型时,最常写的是这样的代码:

outputs = model.generate(**inputs, max_new_tokens=128)

这行代码非常方便,但它隐藏了很多细节。

它隐藏了 tokenizer 如何把文本变成 token id。

它隐藏了模型 forward 之后 logits 是什么。

它隐藏了 greedy search、temperature、top k、top p 这些采样策略。

它也隐藏了停止条件是怎么判断的。

如果我们想实现自己的 mini vLLM,就不能一直依赖这个黑盒接口。我们需要把 model.generate() 拆开,用自己的代码完成一个最小推理闭环。

这一篇文章的目标很简单:

写一个最小版本的 generate 函数。

它暂时不考虑 KV cache,不考虑 batching,不考虑 continuous batching,也不考虑流式服务。我们只关心一件事:给定一个 prompt,如何一步一步生成新的 token。

完成这一篇后,我们会得到一个这样的接口:

text = simple_generate(
    model=model,
    tokenizer=tokenizer,
    prompt="Explain KV cache in LLM inference.",
    max_new_tokens=128,
    temperature=0.8,
)

它看起来和 model.generate() 很像,但内部逻辑完全由我们自己控制。

这就是从“调用模型”走向“实现推理引擎”的第一步。

2. 准备模型和 tokenizer

为了降低理解成本,这一篇先直接使用 Hugging Face Transformers 加载模型。

后续我们不会长期依赖 model.generate(),但会继续复用 tokenizer 和模型权重。因为 mini vLLM 的重点不是从零训练一个模型,而是从零实现推理调度系统。

先写一个最小加载代码:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="cuda",
)

model.eval()

这里有几个点需要注意。

第一,AutoTokenizer 负责文本和 token id 之间的转换。

第二,AutoModelForCausalLM 加载的是 causal language model,也就是常见的 decoder only 生成式语言模型。

第三,model.eval() 会把模型切换到推理模式,关闭 dropout 等训练相关行为。

第四,推理时我们通常会配合 torch.inference_mode(),这样可以避免构建梯度图,减少显存占用。

如果本地没有 GPU,也可以先用 CPU 跑一个小模型。代码会慢一些,但不影响理解主流程。

3. 文本是怎么变成 token 的

模型不能直接处理字符串,它只能处理整数形式的 token id。

例如输入:

Explain KV cache in LLM inference.

经过 tokenizer 后,会变成类似这样的整数序列:

input_ids = tokenizer.encode(prompt, return_tensors="pt")

input_ids 的 shape 通常是:

[batch_size, sequence_length]

对于单请求推理来说,batch_size 等于 1。

我们可以写一个小例子观察它:

prompt = "Explain KV cache in LLM inference."

input_ids = tokenizer.encode(prompt, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

print(input_ids)
print(tokens)

你会看到一段文本被切成了多个 token。每个 token 对应一个整数 id。

从这一刻开始,模型关心的就不再是原始字符串,而是这一串 token id。

4. 模型 forward 会输出什么

接下来,我们把 input_ids 送进模型:

with torch.inference_mode():
    outputs = model(input_ids=input_ids.to(model.device))

对于 causal language model 来说,outputs.logits 是最重要的结果。

它的 shape 通常是:

[batch_size, sequence_length, vocab_size]

其中:

batch_size 表示 batch 中有几个请求
sequence_length 表示当前输入序列长度
vocab_size 表示词表大小

如果我们输入了 10 个 token,模型就会输出 10 个位置上的 logits。

每个位置的 logits 都是一个长度为 vocab_size 的向量。它表示模型认为“这个位置的下一个 token 应该是什么”的原始分数。

对于文本生成来说,我们通常只关心最后一个位置的 logits。

next_token_logits = outputs.logits[:, -1, :]

它的 shape 是:

[batch_size, vocab_size]

对于单请求来说,就是:

[1, vocab_size]

这个向量里,每个元素对应词表中一个 token 的分数。分数越高,说明模型越倾向于选择这个 token 作为下一个 token。

有了 next_token_logits 之后,最简单的生成方式就是选分数最高的 token。

这叫 greedy search。

next_token_id = torch.argmax(next_token_logits, dim=-1)

然后把这个 token 拼回输入序列后面:

input_ids = torch.cat([input_ids, next_token_id[:, None]], dim=-1)

这样,下一轮模型就会基于“原始 prompt 加新生成 token”的完整上下文继续生成。

我们可以把这个过程写成一个循环:

def greedy_generate(model, tokenizer, prompt, max_new_tokens=128):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    for _ in range(max_new_tokens):
        with torch.inference_mode():
            outputs = model(input_ids=input_ids)

        next_token_logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1)

        input_ids = torch.cat([input_ids, next_token_id[:, None]], dim=-1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

这就是一个最小版本的 generate。

它已经包含了完整的生成闭环:

prompt
tokenize
model forward
取最后一个位置的 logits
选择下一个 token
拼回输入序列
判断是否结束
decode 成文本

这段代码很慢,也很朴素,但它非常重要。

因为后面的所有复杂系统,本质上都没有逃离这个循环。vLLM 做的事情,是把这个循环变成一个可以高效服务大量请求的推理引擎。

6. greedy search 的问题

greedy search 的优点是简单、稳定、可复现。

但它也有明显的问题:每一步都选当前概率最高的 token,容易导致输出过于保守,甚至重复。

比如模型在某一步认为:

token A 的概率是 0.31
token B 的概率是 0.29
token C 的概率是 0.28

greedy search 一定会选择 token A。

但实际上 token B 和 token C 也都很合理。对于开放式生成来说,我们往往希望模型保留一定随机性。

这时就需要 sampling。

sampling 的基本思路是:把 logits 转成概率分布,然后从这个概率分布里随机抽取一个 token。

probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)

不过,直接从完整词表中采样也会有问题。因为词表里有很多低概率 token,它们虽然概率很小,但仍然可能被抽中,从而导致输出质量变差。

所以我们通常会结合 temperature、top k 和 top p。

7. temperature:控制随机性

temperature 用来控制概率分布的平滑程度。

代码上非常简单:

logits = logits / temperature

当 temperature 较低时,分布会变得更尖锐,高概率 token 更容易被选中。

当 temperature 较高时,分布会变得更平滑,随机性更强。

如果 temperature 接近 0,效果会越来越接近 greedy search。

我们可以写一个函数:

def apply_temperature(logits, temperature):
    if temperature is None or temperature <= 0:
        return logits
    return logits / temperature

实际工程里,temperature 等于 0 通常会被视为 greedy 模式,不再走随机采样。

8. top k:只从前 k 个 token 中采样

top k 的思路是:每一步只保留 logits 最高的 k 个 token,其余 token 全部屏蔽掉。

def apply_top_k(logits, top_k):
    if top_k is None or top_k <= 0:
        return logits

    top_k = min(top_k, logits.size(-1))
    values, _ = torch.topk(logits, top_k)

    min_values = values[:, -1].unsqueeze(-1)
    logits = torch.where(
        logits < min_values,
        torch.full_like(logits, float("-inf")),
        logits,
    )
    return logits

这里的核心是把不在 top k 内的 token logits 设置为负无穷。这样经过 softmax 后,它们的概率就会变成 0。

9. top p:只从累计概率前 p 的 token 中采样

top p 也叫 nucleus sampling。

它不是固定保留前 k 个 token,而是先按概率从高到低排序,然后保留累计概率不超过 p 的那一部分 token。

例如:

token A: 0.40
token B: 0.25
token C: 0.15
token D: 0.10
token E: 0.10

如果 top p 等于 0.8,那么可能会保留 A、B、C,因为它们累计概率正好达到 0.8。

相比 top k,top p 更灵活。因为有些时候模型非常确定,只需要保留少量 token;有些时候模型不确定,就会保留更多候选 token。

一个简单实现如下:

def apply_top_p(logits, top_p):
    if top_p is None or top_p <= 0 or top_p >= 1:
        return logits

    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    sorted_mask = cumulative_probs > top_p

    sorted_mask[:, 1:] = sorted_mask[:, :-1].clone()
    sorted_mask[:, 0] = False

    sorted_logits = sorted_logits.masked_fill(sorted_mask, float("-inf"))

    new_logits = torch.full_like(logits, float("-inf"))
    new_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)

    return new_logits

这段代码做了三件事。

首先,把 logits 从大到小排序。

然后,计算排序后 token 的累计概率。

最后,把累计概率超过 top p 的 token 屏蔽掉,再恢复到原来的词表顺序。

10. 写一个 Sampler

现在我们可以把 temperature、top k、top p 封装成一个 Sampler。

from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class SamplingParams:
    max_new_tokens: int = 128
    temperature: float = 1.0
    top_k: Optional[int] = None
    top_p: Optional[float] = None


class Sampler:
    def __init__(self, params: SamplingParams):
        self.params = params

    def sample(self, logits: torch.Tensor) -> torch.Tensor:
        if self.params.temperature == 0:
            return torch.argmax(logits, dim=-1)

        logits = apply_temperature(logits, self.params.temperature)
        logits = apply_top_k(logits, self.params.top_k)
        logits = apply_top_p(logits, self.params.top_p)

        probs = torch.softmax(logits, dim=-1)
        next_token_id = torch.multinomial(probs, num_samples=1)

        return next_token_id.squeeze(-1)

有了 Sampler 之后,generate 函数就不需要关心具体采样策略了。它只需要把 logits 交给 sampler,然后拿到下一个 token。

这正是我们后面做 mini vLLM 时会反复使用的设计思路:

让 engine 负责流程,让 sampler 负责采样,让 scheduler 负责调度,让 cache manager 负责缓存管理。

每个模块只做自己的事情。

11. 实现 simple_generate

现在我们可以组合出一个稍微完整一点的生成函数。

def simple_generate(
    model,
    tokenizer,
    prompt: str,
    sampling_params: SamplingParams,
) -> str:
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    sampler = Sampler(sampling_params)

    prompt_length = input_ids.size(1)

    for _ in range(sampling_params.max_new_tokens):
        with torch.inference_mode():
            outputs = model(input_ids=input_ids)

        logits = outputs.logits[:, -1, :]
        next_token_id = sampler.sample(logits)

        input_ids = torch.cat([input_ids, next_token_id[:, None]], dim=-1)

        if tokenizer.eos_token_id is not None:
            if next_token_id.item() == tokenizer.eos_token_id:
                break

    generated_ids = input_ids[0, prompt_length:]
    generated_text = tokenizer.decode(
        generated_ids,
        skip_special_tokens=True,
    )

    return generated_text

这里我们没有返回完整文本,而是只返回新生成的部分。

这是有意为之。

在推理服务里,用户通常已经知道自己的 prompt,更关心模型新增的输出。后面做流式输出时,我们也会每次返回新增 token 对应的文本片段。

调用方式如下:

params = SamplingParams(
    max_new_tokens=128,
    temperature=0.8,
    top_k=50,
    top_p=0.95,
)

text = simple_generate(
    model=model,
    tokenizer=tokenizer,
    prompt="Explain KV cache in LLM inference.",
    sampling_params=params,
)

print(text)

到这里,我们已经完成了一个手写版本的 generate。

12. 这个版本为什么很慢

虽然 simple_generate 可以工作,但它有一个非常大的问题:

每生成一个 token,它都会把完整的 input_ids 再送进模型算一遍。

例如 prompt 有 100 个 token,我们生成 50 个 token。

第 1 轮,模型处理 100 个 token。

第 2 轮,模型处理 101 个 token。

第 3 轮,模型处理 102 个 token。

一直到第 50 轮,模型处理 149 个 token。

这意味着模型在不断重复计算历史 token。

而这些历史 token 对应的 key 和 value 其实已经算过了。理论上,我们不应该每次都重新算。

这就是 KV cache 要解决的问题。

KV cache 的核心思想是:在 attention 里,把历史 token 的 key 和 value 保存下来。下一轮 decode 时,只需要计算新 token 的 key 和 value,然后和历史 cache 一起做 attention。

如果没有 KV cache,decode 会越来越慢。

如果有 KV cache,decode 每一轮主要只处理新增 token,同时访问历史 cache。

这也是 LLM 推理系统的第一个核心优化点。

不过在这一篇里,我们先不实现 KV cache。我们先把最小生成闭环彻底弄清楚。

下一篇,我们会在这个 generate 函数的基础上加入 KV cache。

13. 把 generate 抽象成推理引擎的 step

现在再往 mini vLLM 的方向思考一步。

simple_generate 是一个完整函数,它会一直循环,直到生成结束。

但真正的推理引擎通常不会这样写。

推理引擎更像是一个可以不断执行的状态机。每调用一次 step(),它就向前推进一小步。

比如:

engine.add_request(prompt="Explain KV cache.", sampling_params=params)

while engine.has_unfinished_requests():
    outputs = engine.step()
    for output in outputs:
        print(output.text, end="", flush=True)

为什么要这么设计?

因为一旦我们进入多请求场景,就不能让某一个请求独占整个 generate 循环。

假设请求 A 要生成 1000 个 token,请求 B 此时刚刚到达。如果请求 A 一直占着循环不放,请求 B 就只能等待。

更合理的方式是:每一轮 step,系统从多个请求中挑选一批,一起执行一次模型 forward。执行完后,每个请求各自追加一个 token。然后系统进入下一轮 step。

这就是后面 continuous batching 的基础。

所以,从这一篇开始,我们要逐渐转变思维:

不要把 generate 看成一个一次性函数。

要把 generate 看成一个可以被调度器逐步推进的过程。

这一点非常重要。

14. 当前代码和 mini vLLM 的关系

这一篇的代码虽然简单,但它已经对应了 mini vLLM 里的几个关键模块。

SamplingParams 对应请求中的采样参数。

Sampler 对应后续的采样模块。

input_ids 对应一个请求当前维护的 token 序列。

model(input_ids=...) 对应后续的 Model Runner。

for 循环对应后续 engine 的多轮 step。

eos_token_id 判断对应请求完成条件。

也就是说,我们现在写的不是一次性的 demo,而是后续推理引擎的雏形。

后面我们会一步步把它拆开。

首先,把 token 序列和请求状态封装成 Sequence。

然后,把单请求循环改造成 engine.step。

再之后,加入多个 Sequence,让它们共享同一个调度循环。

最后,加入 KV cache 和 block manager,让每个请求不再反复计算历史 token。

15. 小结

这一篇我们完成了一个非常重要的基础工作:手写了一个最小版本的 generate。

它做了这些事情:

  1. 把 prompt 转成 token id

  2. 调用模型 forward

  3. 取最后一个位置的 logits

  4. 使用 greedy search 或 sampling 选择下一个 token

  5. 把新 token 拼回输入序列

  6. 判断 EOS 和 max_new_tokens

  7. 把生成结果 decode 成文本

这个版本还很慢,因为每一轮都会重新计算完整上下文。但它已经把 LLM 生成的核心循环暴露了出来。

下一篇文章,我们会开始优化这个循环。

我们会引入 KV cache,解释 attention 里的 key 和 value 到底是什么,为什么它们可以被缓存,以及如何在代码里使用 past_key_values 避免重复计算历史 token。

到那时,我们的 generate 函数会从“能跑”变成“更接近真实推理引擎”。