AI教母李飞飞团队训练出媲美DeepSeek R1推理模型,云计算费用不到50美元附代码

【论文标题】 s1: Simple test-time scaling

【论文链接】 https://arxiv.org/abs/2501.19393v2

【代码链接】 https://github.com/simplescaling/s1

【论文单位/作者】斯坦福大学/李飞飞团队

注:“不到50美元”仅为云计算服务费用,不包括服务器、显卡等硬件投入费用,因这部分已经由云厂商承担。

图片

【摘要】✨

2025 年 1 月,李飞飞团队提出了一种极简的测试时扩展(test-time scaling)方法,仅需对预训练模型进行少量监督微调(SFT)并结合动态推理控制技术,即可显著提升语言模型的数学推理能力。团队精心构建了包含 1000 个问题及对应推理过程的小型数据集 s1K,该数据集依据难度、多样性和质量三大标准进行严格筛选。同时,开发了预算强制(budget forcing)技术,此技术可在测试阶段精准控制模型的思考时间,灵活地强制终止或延长模型的推理过程。

【技术亮点/创新点】✨

  1. s1K数据集设计
    从59K初始数据中筛选,确保问题难度高(基于模型错误率和推理长度)、领域多样(覆盖数学、物理、化学等)、格式规范; 仅用1K样本微调即接近全量数据性能,验证“少而精”的数据价值。
  2. 预算强制(Budget Forcing)
    预算强制是一种在测试时控制模型计算资源使用的方法,旨在通过调节模型的“思考”时间来优化其性能。具体来说:
    • 设定步骤限制:为模型设置一个固定的步骤数限制,比如16、32等。每个步骤可以包含一定数量的标记(tokens),这些标记代表了模型的推理过程。
    • 控制结束条件:当模型尝试结束其推理过程时,如果尚未达到预设的步骤限制,则通过追加“Wait”标记的方式强迫模型继续思考。这实际上是在告诉模型:“你还没有完成思考,请再考虑一下你的答案。”

      如果达到了步骤限制,即使模型还想继续生成新的标记,系统也会强行终止其思考过程,并促使模型进入回答模式。

    • 结果分析:实验表明,这种方法有助于提高模型的准确性。例如,在AIME24任务中,当允许模型进行更长时间的思考(即增加步骤数)时,其表现从23.3%提升到了36.7%。
图片
在这里插入图片描述
  1. 并行扩展尝试:并行扩展尝试指的是利用多种策略来增加模型在测试时的计算负担以期获得更好的性能。以下是具体的实现方式之一:
    • 多数投票机制:对于每一个输入问题,运行多个独立的模型实例(如64次),然后根据大多数模型的选择来决定最终的答案。这种做法类似于民主投票,认为多数意见往往更接近正确答案。
    • 例如,给定一个样本,执行64次评估,温度参数设为1(这意味着较高的随机性)。然后,比较不同次数(2, 4, 8, 16, 32, 和 64)下的多数投票结果,观察随着参与投票的模型数量增加,准确性的变化情况。
    • 尽管增加了计算成本,但并行扩展尝试并不总是能带来预期中的性能提升。相比之下,预算强制因其简单且有效的特性而被证明是一个更为成功的策略。

【工作原理/方法】🔍

  1. 数据蒸馏
    • 使用 Gemini API 生成推理链:研究人员借助 Google Gemini Flash Thinking API,为收集到的大量问题生成推理过程和解决方案,以此获得高质量的推理链。这一过程依赖于 Gemini API 强大的推理能力,能够为后续构建数据集提供丰富的素材。
    • 结合人工筛选构建 s1K:在生成推理链后,并非直接使用所有数据,而是进行了人工筛选。通过严格的筛选过程,去除低质量数据,如存在 API 错误、格式问题的数据,确保数据的高质量。
    • 通过模型错误率和领域分类器确保数据多样性与挑战性:为保证数据集的质量,利用模型错误率来衡量问题难度。通过让 Qwen2.5-7B-Instruct 和 Qwen2.532B-Instruct 两个模型对每个问题进行解答,并由 Claude 3.5 Sonnet 评估正确性,去除模型能轻易解答的问题,留下更具挑战性的题目。在多样性方面,使用 Claude 3.5 Sonnet 基于美国数学学会的数学学科分类(MSC)系统对问题进行领域分类,然后从不同领域中选取问题,确保数据集涵盖多个领域,具有多样性。最终构建出了包含 1000 个高质量、多样且有难度问题的 s1K 数据集。
  2. 预算强制解码
    • 最大标记限制:在测试时,当模型生成的思考标记数量超过预先设定的阈值时,就插入终止符(如 “Final Answer:” 和思考结束令牌分隔符),使模型提前结束思考阶段,并给出当前的最佳答案。这种方式可以控制模型在有限的计算资源下及时输出结果,避免模型过度思考,浪费计算资源。
    • 最小标记扩展:如果希望模型在某个问题上花费更多的测试时计算资源,即希望模型进行更深入的推理,可以抑制模型生成终止符,并且在模型当前的推理轨迹后追加 “Wait”。这样做是为了鼓励模型继续探索和思考,通过延长推理步骤,让模型有机会修正之前的推理错误或进行更全面的分析,从而得到更好的答案。
    • 支持线性扩展性能:预算强制方法能够使模型的性能随着测试时计算量的增加而呈现线性扩展的趋势。以 AIME24 任务为例,随着分配给模型更多的思考令牌(即增加计算量),模型的准确率稳定提升。这表明预算强制方法有效地利用了测试时计算资源,提升了模型在推理任务中的表现。
  3. 对比实验
    • 验证并行扩展效果弱于顺序扩展:研究中对比了并行扩展和顺序扩展两种方式。并行扩展以多数投票为例,通过多次生成不同的解决方案,然后选取出现频率最高的结果作为最终答案。而顺序扩展则是基于之前的推理结果逐步生成后续的解决方案,预算强制属于顺序扩展的一种。实验结果显示,并行扩展的效果不如顺序扩展,凸显了预算强制这种顺序扩展方法在提升模型性能方面的优势。例如,在对 Qwen2.532B-Instruct 模型的实验中,通过多数投票进行并行扩展的性能无法与使用预算强制的 s1-32B 模型相比。
    • 展示模型在有限上下文窗口下的性能饱和现象:实验发现,当不断增加测试时计算资源(如通过预算强制让模型思考更多令牌)时,模型的性能最终会趋于平稳,不再提升,即出现性能饱和现象。这是因为模型的上下文窗口有限,随着推理的进行,模型可能会陷入重复思考或超出上下文窗口的限制,导致无法进一步提升性能。
    • 提出结合树搜索的混合方案:为了解决模型在有限上下文窗口下的性能限制问题,研究提出结合树搜索的混合方案。例如使用 REBASE(一种基于树搜索的方法),它利用过程奖励模型来平衡搜索过程中的探索和剪枝。通过这种方式,在遇到复杂问题时,模型可以通过树搜索更有效地探索解决方案空间,一定程度上突破上下文窗口的限制,提升模型的性能。实验表明,这种结合树搜索的方法在扩展测试时计算资源方面比多数投票更有效,甚至在某些情况下优于单纯的顺序扩展。

【实验结果】📈

  1. 基准测试
    AIME24:56.7%准确率(超越o1-preview的44.6%);
    MATH500:93.0%(接近o1-preview的94.8%);
    GPQA Diamond:59.6%(接近专家水平的69.7%)。
    图片
  2. 效率对比
    仅需1K样本微调,性能超越使用800K样本的DeepSeek R1蒸馏模型;训练成本仅7 H100 GPU小时,远低于全量数据训练的394小时。

【总结】

这篇文章在某种程度上证明了极简数据+动态控制可大幅提升语言模型推理能力,挑战了传统大规模训练范式。开源模型与代码为社区提供了透明高效的解决方案,未来可探索与强化学习结合进一步突破性能上限。核心启示:高质量数据筛选测试时计算优化是解锁模型潜力的关键。

【代码】

Structure
eval/: Evaluation scripts
data/: Synthetic data creation scripts & co
train/: Training scripts
Inference
vLLM
Install the vllm library and run:

from vllm import LLM, SamplingParams
model = LLM(
    "simplescaling/s1-32B",
    tensor_parallel_size=2,
)
tok = AutoTokenizer.from_pretrained("simplescaling/s1-32B")

stop_token_ids = tok("<|im_end|>")["input_ids"]

sampling_params = SamplingParams(
    max_tokens=32768,
    min_tokens=0,
    stop_token_ids=stop_token_ids,
)

prompt = "How many r in raspberry"
prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"

o = model.generate(prompt, sampling_params=sampling_params)
print(o[0].outputs[0].text)
vLLM with budget forcing
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

# Decide on a token limit for thinking; As the model's max tokens is 32768, 32000 usually ensures there is enough space for the model to still answer
MAX_TOKENS_THINKING = 32000
# Decide how often to ignore end-of-thinking token
NUM_IGNORE = 1

model = LLM(
    "simplescaling/s1-32B",
    tensor_parallel_size=2,
)
tok = AutoTokenizer.from_pretrained(
    "simplescaling/s1-32B"
)

stop_token_ids = tok("<|im_end|>")["input_ids"]
sampling_params = SamplingParams(
    max_tokens=32768,
    min_tokens=0,
    stop_token_ids=stop_token_ids,
    skip_special_tokens=False,
    temperature=0.0,
)

# For the exact raspberry sample in the paper, change
# model to `qfq/1k_qr_bt_dm_po_steps` (an earlier version of s1)
# & prompt to `How many r in raspberry?`
prompts = [
    "How many r in raspberry",
]

for i, p in enumerate(prompts):
    prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + p + "<|im_end|>\n<|im_start|>assistant\n"
    stop_token_ids = tok("<|im_start|><|im_end|>")["input_ids"]
    sampling_params = SamplingParams(
        max_tokens=MAX_TOKENS_THINKING,
        min_tokens=0,
        stop_token_ids=stop_token_ids,
        skip_special_tokens=False,
        temperature=0.0,
    )
    prompt += "<|im_start|>think"
    o = model.generate(
        prompt,
        sampling_params=sampling_params
    )
    ignore_str = "Wait"
    max_tokens_thinking_tmp = MAX_TOKENS_THINKING
    # Num of times to skip stop token
    for i in range(NUM_IGNORE):
        max_tokens_thinking_tmp -= len(o[0].outputs[0].token_ids)
        prompt += o[0].outputs[0].text + ignore_str
        sampling_params = SamplingParams(
            max_tokens=max_tokens_thinking_tmp,
            min_tokens=1,
            stop_token_ids=stop_token_ids,
            skip_special_tokens=False,
            temperature=0.0,
        )
        o = model.generate(
            prompt,
            sampling_params=sampling_params
        )
    ### Final answer ###
    prompt += o[0].outputs[0].text
    stop_token_ids = tok("<|im_end|>")["input_ids"]
    sampling_params = SamplingParams(
        max_tokens=32768,
        min_tokens=0,
        stop_token_ids=stop_token_ids,
        skip_special_tokens=False,
        temperature=0.0,
    )
    o = model.generate(
        prompt,
        sampling_params=sampling_params,
    )
    print("With budget forcing:")
    print(prompt + o[0].outputs[0].text)
transformers
Install the transformers & torch libraries and run:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "simplescaling/s1-32B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "How many r in raspberry"
messages = [
    {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
Training
To run training, you can find our script at train/sft.py which you can invoke via one of the train/sft*sh scripts which in turn you can launch via train/launch.sh if you are on a SLURM cluster (requires editing the file for your cluster setup).

来源:AI前沿速递

THE END