复现 DeepSeek「顿悟时刻」:10 美元打造 AI 推理突破 [译]
我们能否在仅使用约 48GB RTX6000 和 10 美元的有限计算资源下,为一个 3B 模型注入回溯、自我反思、逻辑推理等能力?这个问题引发了我们的探索。这篇文章将介绍我们使用轻量级强化学习算法 —— Reinforce-Lite 的突破性发现。通过接下来的示例,我们将看到经过端到端 RL 微调的模型如何展示出智能、回溯、自我反思和逻辑推理等迹象。
强化学习的魅力与挑战
强化学习是一种屡次带来惊艳结果的强大算法,例如 Deepmind 的 AlphaGo、OpenAI 的 DOTA5、Mujocu 和 Atari 实验、用于对齐 LLM 的 RLHF,以及最近 DeepSeek 在 RL 上的全力投入。
然而,强化学习因其多个活动部件而存在若干复杂性,需要对重要元素进行精细设计,例如适当的信用分配、演员/评论家的超参数调整、RL 算法类型(基于模型或无模型)等,因此其在更广泛范围内的应用受到限制。
在 LLM 环境中使用强化学习可能涉及多达 5 个模型:
- 策略模型 —— 正在训练的模型
- 旧策略模型 —— 用于计算代理
- 参考模型 —— 用于计算 KL 散度
- 奖励模型 —— 用于学习奖励函数
- 评论家模型 —— 用于计算价值函数
这些复杂的组件不仅带来了巨大的计算负担,还带来了训练稳定性的挑战。这促使我们思考:是否存在一种更简单、更高效的方法?
Reinforce-Lite:一个更简单的解决方案
为了应对这些挑战,我们从头开始重新构想整个算法,回归基本原理,提出了一个更简单的替代方案:Reinforce-Lite。这个方案通过单一策略网络实现稳定训练,同时消除了对代理比率/旧策略模型的需求。
为什么传统 PPO/GRPO 的代理比率是一种过度设计?
在传统强化学习环境(如 Mujoco、Atari、Dota 等)中,PPO 需要对每个批次进行多次更新,因为数据收集成本高昂,重用样本可以提高样本效率。然而,在 LLM 的场景下,这种方法反而成为了累赘:
- LLM 可以并行生成多样化的响应,自然创建丰富的数据集,无需重复更新
- 所有响应都可以使用相同的策略网络生成,并在序列生成结束时一次性反向传播梯度
- 在高维文本生成空间中,每个批次的多次更新可能导致过拟合,而非有意义的策略改进
通过每个批次只进行一次更新,结合组标准化等技术,我们可以实现稳定的训练,同时显著降低计算成本。这种简化的优化过程不仅高效,还消除了跟踪旧策略模型以计算代理比率的需要。
Reinforce-Lite 的核心设计
我们的算法做出了以下关键简化:
- KL 散度移除,不再需要参考模型 ❌ —— 改用梯度裁剪替代,虽然不那么自适应但足够有效
- 代理比率移除,不再需要旧策略模型 ❌ —— 使优化流程更加直接
- 使用组相对奖励(借鉴 DeepSeek 的 GRPO 风格)计算优势,不再需要评论家模型 ❌
这些简化让我们得到了一个轻量级的强化学习算法,将优化问题简化为经典的 Reinforce 形式。在优势计算中,我们使用组相对策略优化的标准化技术,每组包含 10 个问题响应,利用标准化来减少梯度更新的方差。
PyTorch 实现
让我们来看看 Reinforce-Lite 的核心代码实现:
def reinforce_lite(batch, policy_model, tokenizer, device, step, save_dir):
policy_model.train()
prompts, targets = zip(*batch)
batch_size = len(prompts)
evaluated_group = 0
all_logprobs = []
all_rewards = []
all_responses = []
all_lengths = []
for group_idx in range(config.GROUP_SIZE):
formatted_prompts = [format_prompt(p, tokenizer) for p in prompts]
inputs = tokenizer(
formatted_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=config.MAX_SEQ_LENGTH
).to(device)
generate_kwargs = {
**inputs,
"max_new_tokens": config.MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"pad_token_id": tokenizer.pad_token_id,
"return_dict_in_generate": True,
}
if group_idx == evaluated_group:
generated = policy_model.generate(**generate_kwargs)
generated_ids = generated.sequences
outputs = policy_model(
generated_ids,
attention_mask=(generated_ids != tokenizer.pad_token_id).long()
)
prompt_length = inputs.input_ids.shape[1]
response_length = generated_ids.shape[1] - prompt_length
if response_length > 0:
logits = outputs.logits[:, prompt_length-1:-1, :]
response_tokens = generated_ids[:, prompt_length:]
log_probs = torch.log_softmax(logits, dim=-1)
token_log_probs = torch.gather(log_probs, -1, response_tokens.unsqueeze(-1)).squeeze(-1)
sequence_log_probs = token_log_probs.sum(dim=1)
else:
sequence_log_probs = torch.zeros(batch_size, device=device)
else:
with torch.no_grad():
generated = policy_model.generate(**generate_kwargs)
sequence_log_probs = torch.zeros(batch_size, device=device)
responses = tokenizer.batch_decode(
generated.sequences[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
rewards = torch.tensor([get_reward(resp, tgt) for resp, tgt in zip(responses, targets)], device=device)
all_responses.extend(responses)
all_rewards.append(rewards)
all_logprobs.append(sequence_log_probs)
all_lengths.extend([len(r.split()) for r in responses])
rewards_tensor = torch.stack(all_rewards)
logprobs_tensor = torch.stack(all_logprobs)
evaluated_rewards = rewards_tensor[evaluated_group]
others_rewards = torch.cat([
rewards_tensor[:evaluated_group],
rewards_tensor[evaluated_group+1:]
], dim=0)
baseline = others_rewards.mean(dim=0)
advantages = (evaluated_rewards - baseline) / (others_rewards.std(dim=0) + 1e-8)
advantages = torch.clamp(advantages, -2.0, 2.0)
policy_loss = -(logprobs_tensor[evaluated_group] * advantages.detach()).mean()
return policy_loss, rewards_tensor.mean().item(), policy_loss.item(), 0.0, all_responses[0], all_lengths
训练过程
我们的训练过程包含以下步骤:
- 初始化指令调整的 LLM,通过适当的提示使其在 标签中包含推理步骤
- 定义奖励函数,用于评估模型输出(例如 GSM8K 数学推理任务中的正确性)
- 直接优化策略,根据奖励计算梯度,无需代理损失
- 计算优势,使用组相对标准化消除对评论家模型的需求
- 使用标准的对数概率梯度技巧更新模型
GSM 8K 数据集实验
为了验证我们的方法,我们使用了 GSM8K 小学数学数据集。这个数据集的问题格式如下:
问题:Natalia 在四月向 48 个朋友卖了夹子,五月卖了四月一半数量的夹子。Natalia 在四月和五月总共卖了多少夹子? 答案:Natalia 在五月卖了 48/2 = <<48/2=24>>24 个夹子。Natalia 在四月和五月总共卖了 48+24 = <<48+24=72>>72 个夹子。#### 72
我们只关注 ### 之后的最终答案,要求模型在
数据处理和奖励机制
处理输入数据的代码:
def format_prompt(question: str) -> str:
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Solve this math problem: {question}
Show your reasoning first in <th>, then put the final answer in \\boxed{{}}."""
我们的奖励机制简单明了:
- 答案错误:奖励 -1
- 答案正确:奖励 +1
具体实现如下:
def get_reward(completion: str, target: str) -> float:
reward = -1.0
try:
completion = completion.strip()
start_tag = "<answer>"
end_tag = "</answer>"
start_idx = completion.rfind(start_tag)
if start_idx != -1:
substring_after_start = completion[start_idx + len(start_tag):]
end_idx = substring_after_start.find(end_tag)
if end_idx != -1:
answer = substring_after_start[:end_idx].strip()
if not answer and end_idx > 0:
answer = substring_after_start[:end_idx].strip()
numbers = ''.join(char for char in answer if char.isdigit() or char == '.')
if numbers:
generated_num = float(numbers)
target_num = float(str(target).strip())
if abs(generated_num - target_num) < 1e-6:
reward = 1.0
except Exception as e:
pass
return reward
实验结果
训练细节
我们在 RTX A6000 上使用 Reinforce-Lite 算法训练 3B 模型 12 小时:
- 使用大小为 10 的组,适应计算限制
- 在训练初期,模型不断尝试增加输出长度
- 由于 48GB 显存限制,超过 1024 token 时常遇到 OOM
- 训练数百次迭代后,观察到策略网络在探索不同策略时的值波动
性能对比
在 GSM8K 上,Reinforce-Lite 相比指令模型取得了小幅提升:
- Meta Llama 3.2:提升 2.0%(70.5 -> 72.5)
- Phi3.5 Instruct:提升 0.6%(83.4 -> 84.0)
所有实验都在 FP16 环境下运行。
推理能力提升
Reinforce-Lite 调整后的模型展现出了多项改进:
- 逻辑思维能力增强
- 具备搜索和验证能力
- 能够创建表格进行粗略计算
- 展现试错和自我纠正能力
- 表现出系统性思维方式
这些能力在原始的指令模型中都未被观察到。
关键发现
技术突破
- 突破推理瓶颈 —— RL 微调显著提升了模型的结构化推理能力
- 简化架构 —— 证明单一策略网络足以完成 LLM 微调任务
- 计算效率 —— Reinforce-Lite 在保持性能的同时大幅降低了训练复杂度
- 模型能动性 —— 观察到模型主动尝试不同策略以获取更高奖励
技术限制
- 内存约束 —— 在 48GB GPU 上训练 3B 模型(FP16)时,超过 1024 token 容易发生 OOM
- 训练稳定性 —— 虽然去除了 KL 散度,但通过梯度裁剪仍能保持策略稳定
- 探索平衡 —— 模型在探索与利用之间的平衡还需要进一步优化
未来展望
Reinforce-Lite 的实现即将开源!我们期待:
- 社区能够进一步优化和改进算法
- 探索更多应用场景和任务类型
- 解决当前存在的技术限制
- 提高模型推理能力的上限
让我们一起探讨这个激动人心的方向。欢迎在评论区分享您的想法和建议!🚀
原文链接:Overnight End-to-End RL Training a 3B Model on a Grade School Math Dataset Leads to Reasoning