AI、强化学习、推理·

复现 DeepSeek「顿悟时刻」:10 美元打造 AI 推理突破 [译]

能否用 48GB RTX6000 和 10 美元,让 3B 模型具备回溯、反思、推理能力?强化学习屡创奇迹,本文揭示轻量级实现之道。

我们能否在仅使用约 48GB RTX6000 和 10 美元的有限计算资源下,为一个 3B 模型注入回溯、自我反思、逻辑推理等能力?这个问题引发了我们的探索。这篇文章将介绍我们使用轻量级强化学习算法 —— Reinforce-Lite 的突破性发现。通过接下来的示例,我们将看到经过端到端 RL 微调的模型如何展示出智能、回溯、自我反思和逻辑推理等迹象。

强化学习的魅力与挑战

强化学习是一种屡次带来惊艳结果的强大算法,例如 Deepmind 的 AlphaGo、OpenAI 的 DOTA5、Mujocu 和 Atari 实验、用于对齐 LLM 的 RLHF,以及最近 DeepSeek 在 RL 上的全力投入。

然而,强化学习因其多个活动部件而存在若干复杂性,需要对重要元素进行精细设计,例如适当的信用分配、演员/评论家的超参数调整、RL 算法类型(基于模型或无模型)等,因此其在更广泛范围内的应用受到限制。

在 LLM 环境中使用强化学习可能涉及多达 5 个模型:

  • 策略模型 —— 正在训练的模型
  • 旧策略模型 —— 用于计算代理
  • 参考模型 —— 用于计算 KL 散度
  • 奖励模型 —— 用于学习奖励函数
  • 评论家模型 —— 用于计算价值函数

这些复杂的组件不仅带来了巨大的计算负担,还带来了训练稳定性的挑战。这促使我们思考:是否存在一种更简单、更高效的方法?

Reinforce-Lite:一个更简单的解决方案

为了应对这些挑战,我们从头开始重新构想整个算法,回归基本原理,提出了一个更简单的替代方案:Reinforce-Lite。这个方案通过单一策略网络实现稳定训练,同时消除了对代理比率/旧策略模型的需求。

为什么传统 PPO/GRPO 的代理比率是一种过度设计?

在传统强化学习环境(如 Mujoco、Atari、Dota 等)中,PPO 需要对每个批次进行多次更新,因为数据收集成本高昂,重用样本可以提高样本效率。然而,在 LLM 的场景下,这种方法反而成为了累赘:

  • LLM 可以并行生成多样化的响应,自然创建丰富的数据集,无需重复更新
  • 所有响应都可以使用相同的策略网络生成,并在序列生成结束时一次性反向传播梯度
  • 在高维文本生成空间中,每个批次的多次更新可能导致过拟合,而非有意义的策略改进

通过每个批次只进行一次更新,结合组标准化等技术,我们可以实现稳定的训练,同时显著降低计算成本。这种简化的优化过程不仅高效,还消除了跟踪旧策略模型以计算代理比率的需要。

Reinforce-Lite 的核心设计

我们的算法做出了以下关键简化:

  • KL 散度移除,不再需要参考模型 ❌ —— 改用梯度裁剪替代,虽然不那么自适应但足够有效
  • 代理比率移除,不再需要旧策略模型 ❌ —— 使优化流程更加直接
  • 使用组相对奖励(借鉴 DeepSeek 的 GRPO 风格)计算优势,不再需要评论家模型 ❌

这些简化让我们得到了一个轻量级的强化学习算法,将优化问题简化为经典的 Reinforce 形式。在优势计算中,我们使用组相对策略优化的标准化技术,每组包含 10 个问题响应,利用标准化来减少梯度更新的方差。

PyTorch 实现

让我们来看看 Reinforce-Lite 的核心代码实现:

def reinforce_lite(batch, policy_model, tokenizer, device, step, save_dir):  
    policy_model.train()  
    prompts, targets = zip(*batch)  
    batch_size = len(prompts)  
    evaluated_group = 0  

    all_logprobs = []  
    all_rewards = []  
    all_responses = []  
    all_lengths = []  

    for group_idx in range(config.GROUP_SIZE):  
        formatted_prompts = [format_prompt(p, tokenizer) for p in prompts]  
        inputs = tokenizer(  
            formatted_prompts,  
            return_tensors="pt",  
            padding=True,  
            truncation=True,  
            max_length=config.MAX_SEQ_LENGTH  
        ).to(device)  

        generate_kwargs = {  
            **inputs,  
            "max_new_tokens": config.MAX_NEW_TOKENS,  
            "do_sample": True,  
            "temperature": 0.7,  
            "top_p": 0.9,  
            "pad_token_id": tokenizer.pad_token_id,  
            "return_dict_in_generate": True,  
        }  

        if group_idx == evaluated_group:  
            generated = policy_model.generate(**generate_kwargs)  
            generated_ids = generated.sequences  
            outputs = policy_model(  
                generated_ids,  
                attention_mask=(generated_ids != tokenizer.pad_token_id).long()  
            )  
            prompt_length = inputs.input_ids.shape[1]  
            response_length = generated_ids.shape[1] - prompt_length  
            if response_length > 0:  
                logits = outputs.logits[:, prompt_length-1:-1, :]  
                response_tokens = generated_ids[:, prompt_length:]  
                log_probs = torch.log_softmax(logits, dim=-1)  
                token_log_probs = torch.gather(log_probs, -1, response_tokens.unsqueeze(-1)).squeeze(-1)  
                sequence_log_probs = token_log_probs.sum(dim=1)  
            else:  
                sequence_log_probs = torch.zeros(batch_size, device=device)  
        else:  
            with torch.no_grad():  
                generated = policy_model.generate(**generate_kwargs)  
                sequence_log_probs = torch.zeros(batch_size, device=device)  

        responses = tokenizer.batch_decode(  
            generated.sequences[:, inputs.input_ids.shape[1]:],  
            skip_special_tokens=True  
        )  
        rewards = torch.tensor([get_reward(resp, tgt) for resp, tgt in zip(responses, targets)], device=device)  

        all_responses.extend(responses)  
        all_rewards.append(rewards)  
        all_logprobs.append(sequence_log_probs)  
        all_lengths.extend([len(r.split()) for r in responses])  

    rewards_tensor = torch.stack(all_rewards)  
    logprobs_tensor = torch.stack(all_logprobs)  

    evaluated_rewards = rewards_tensor[evaluated_group]  
    others_rewards = torch.cat([  
        rewards_tensor[:evaluated_group],   
        rewards_tensor[evaluated_group+1:]  
    ], dim=0)  

    baseline = others_rewards.mean(dim=0)  
    advantages = (evaluated_rewards - baseline) / (others_rewards.std(dim=0) + 1e-8)  
    advantages = torch.clamp(advantages, -2.0, 2.0)  

    policy_loss = -(logprobs_tensor[evaluated_group] * advantages.detach()).mean()  

    return policy_loss, rewards_tensor.mean().item(), policy_loss.item(), 0.0, all_responses[0], all_lengths

训练过程

我们的训练过程包含以下步骤:

  1. 初始化指令调整的 LLM,通过适当的提示使其在 标签中包含推理步骤
  2. 定义奖励函数,用于评估模型输出(例如 GSM8K 数学推理任务中的正确性)
  3. 直接优化策略,根据奖励计算梯度,无需代理损失
  4. 计算优势,使用组相对标准化消除对评论家模型的需求
  5. 使用标准的对数概率梯度技巧更新模型

GSM 8K 数据集实验

为了验证我们的方法,我们使用了 GSM8K 小学数学数据集。这个数据集的问题格式如下:

问题:Natalia 在四月向 48 个朋友卖了夹子,五月卖了四月一半数量的夹子。Natalia 在四月和五月总共卖了多少夹子? 答案:Natalia 在五月卖了 48/2 = <<48/2=24>>24 个夹子。Natalia 在四月和五月总共卖了 48+24 = <<48+24=72>>72 个夹子。#### 72

我们只关注 ### 之后的最终答案,要求模型在 标签中输出答案。这更像是一个蒙特卡洛问题,我们在回合结束时获得奖励。

数据处理和奖励机制

处理输入数据的代码:

def format_prompt(question: str) -> str:  
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>  
Solve this math problem: {question}  
Show your reasoning first in <th>, then put the final answer in \\boxed{{}}."""

我们的奖励机制简单明了:

  • 答案错误:奖励 -1
  • 答案正确:奖励 +1

具体实现如下:

def get_reward(completion: str, target: str) -> float:  
    reward = -1.0  
    try:  
        completion = completion.strip()  
        start_tag = "<answer>"  
        end_tag = "</answer>"  
        start_idx = completion.rfind(start_tag)  
        if start_idx != -1:  
            substring_after_start = completion[start_idx + len(start_tag):]  
            end_idx = substring_after_start.find(end_tag)  
            if end_idx != -1:  
                answer = substring_after_start[:end_idx].strip()  
                if not answer and end_idx > 0:  
                    answer = substring_after_start[:end_idx].strip()  
                numbers = ''.join(char for char in answer if char.isdigit() or char == '.')  
                if numbers:  
                    generated_num = float(numbers)  
                    target_num = float(str(target).strip())  
                    if abs(generated_num - target_num) < 1e-6:  
                        reward = 1.0  
    except Exception as e:  
        pass  
    return reward

实验结果

训练细节

我们在 RTX A6000 上使用 Reinforce-Lite 算法训练 3B 模型 12 小时:

  • 使用大小为 10 的组,适应计算限制
  • 在训练初期,模型不断尝试增加输出长度
  • 由于 48GB 显存限制,超过 1024 token 时常遇到 OOM
  • 训练数百次迭代后,观察到策略网络在探索不同策略时的值波动

性能对比

在 GSM8K 上,Reinforce-Lite 相比指令模型取得了小幅提升:

  • Meta Llama 3.2:提升 2.0%(70.5 -> 72.5)
  • Phi3.5 Instruct:提升 0.6%(83.4 -> 84.0)

所有实验都在 FP16 环境下运行。

推理能力提升

Reinforce-Lite 调整后的模型展现出了多项改进:

  • 逻辑思维能力增强
  • 具备搜索和验证能力
  • 能够创建表格进行粗略计算
  • 展现试错和自我纠正能力
  • 表现出系统性思维方式

这些能力在原始的指令模型中都未被观察到。

关键发现

技术突破

  • 突破推理瓶颈 —— RL 微调显著提升了模型的结构化推理能力
  • 简化架构 —— 证明单一策略网络足以完成 LLM 微调任务
  • 计算效率 —— Reinforce-Lite 在保持性能的同时大幅降低了训练复杂度
  • 模型能动性 —— 观察到模型主动尝试不同策略以获取更高奖励

技术限制

  • 内存约束 —— 在 48GB GPU 上训练 3B 模型(FP16)时,超过 1024 token 容易发生 OOM
  • 训练稳定性 —— 虽然去除了 KL 散度,但通过梯度裁剪仍能保持策略稳定
  • 探索平衡 —— 模型在探索与利用之间的平衡还需要进一步优化

未来展望

Reinforce-Lite 的实现即将开源!我们期待:

  • 社区能够进一步优化和改进算法
  • 探索更多应用场景和任务类型
  • 解决当前存在的技术限制
  • 提高模型推理能力的上限

让我们一起探讨这个激动人心的方向。欢迎在评论区分享您的想法和建议!🚀


原文链接:Overnight End-to-End RL Training a 3B Model on a Grade School Math Dataset Leads to Reasoning


© 2025 智人飞扬