我们能否在仅使用约 48GB RTX6000 和 10 美元的有限计算资源下,为一个 3B 模型注入回溯、自我反思、逻辑推理等能力?这个问题引发了我们的探索。这篇文章将介绍我们使用轻量级强化学习算法 —— Reinforce-Lite 的突破性发现。通过接下来的示例,我们将看到经过端到端 RL 微调的模型如何展示出智能、回溯、自我反思和逻辑推理等迹象。
强化学习是一种屡次带来惊艳结果的强大算法,例如 Deepmind 的 AlphaGo、OpenAI 的 DOTA5、Mujocu 和 Atari 实验、用于对齐 LLM 的 RLHF,以及最近 DeepSeek 在 RL 上的全力投入。
然而,强化学习因其多个活动部件而存在若干复杂性,需要对重要元素进行精细设计,例如适当的信用分配、演员/评论家的超参数调整、RL 算法类型(基于模型或无模型)等,因此其在更广泛范围内的应用受到限制。
在 LLM 环境中使用强化学习可能涉及多达 5 个模型:
这些复杂的组件不仅带来了巨大的计算负担,还带来了训练稳定性的挑战。这促使我们思考:是否存在一种更简单、更高效的方法?
为了应对这些挑战,我们从头开始重新构想整个算法,回归基本原理,提出了一个更简单的替代方案:Reinforce-Lite。这个方案通过单一策略网络实现稳定训练,同时消除了对代理比率/旧策略模型的需求。
在传统强化学习环境(如 Mujoco、Atari、Dota 等)中,PPO 需要对每个批次进行多次更新,因为数据收集成本高昂,重用样本可以提高样本效率。然而,在 LLM 的场景下,这种方法反而成为了累赘:
通过每个批次只进行一次更新,结合组标准化等技术,我们可以实现稳定的训练,同时显著降低计算成本。这种简化的优化过程不仅高效,还消除了跟踪旧策略模型以计算代理比率的需要。
我们的算法做出了以下关键简化:
这些简化让我们得到了一个轻量级的强化学习算法,将优化问题简化为经典的 Reinforce 形式。在优势计算中,我们使用组相对策略优化的标准化技术,每组包含 10 个问题响应,利用标准化来减少梯度更新的方差。
让我们来看看 Reinforce-Lite 的核心代码实现:
def reinforce_lite(batch, policy_model, tokenizer, device, step, save_dir):
policy_model.train()
prompts, targets = zip(*batch)
batch_size = len(prompts)
evaluated_group = 0
all_logprobs = []
all_rewards = []
all_responses = []
all_lengths = []
for group_idx in range(config.GROUP_SIZE):
formatted_prompts = [format_prompt(p, tokenizer) for p in prompts]
inputs = tokenizer(
formatted_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=config.MAX_SEQ_LENGTH
).to(device)
generate_kwargs = {
**inputs,
"max_new_tokens": config.MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"pad_token_id": tokenizer.pad_token_id,
"return_dict_in_generate": True,
}
if group_idx == evaluated_group:
generated = policy_model.generate(**generate_kwargs)
generated_ids = generated.sequences
outputs = policy_model(
generated_ids,
attention_mask=(generated_ids != tokenizer.pad_token_id).long()
)
prompt_length = inputs.input_ids.shape[1]
response_length = generated_ids.shape[1] - prompt_length
if response_length > 0:
logits = outputs.logits[:, prompt_length-1:-1, :]
response_tokens = generated_ids[:, prompt_length:]
log_probs = torch.log_softmax(logits, dim=-1)
token_log_probs = torch.gather(log_probs, -1, response_tokens.unsqueeze(-1)).squeeze(-1)
sequence_log_probs = token_log_probs.sum(dim=1)
else:
sequence_log_probs = torch.zeros(batch_size, device=device)
else:
with torch.no_grad():
generated = policy_model.generate(**generate_kwargs)
sequence_log_probs = torch.zeros(batch_size, device=device)
responses = tokenizer.batch_decode(
generated.sequences[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
rewards = torch.tensor([get_reward(resp, tgt) for resp, tgt in zip(responses, targets)], device=device)
all_responses.extend(responses)
all_rewards.append(rewards)
all_logprobs.append(sequence_log_probs)
all_lengths.extend([len(r.split()) for r in responses])
rewards_tensor = torch.stack(all_rewards)
logprobs_tensor = torch.stack(all_logprobs)
evaluated_rewards = rewards_tensor[evaluated_group]
others_rewards = torch.cat([
rewards_tensor[:evaluated_group],
rewards_tensor[evaluated_group+1:]
], dim=0)
baseline = others_rewards.mean(dim=0)
advantages = (evaluated_rewards - baseline) / (others_rewards.std(dim=0) + 1e-8)
advantages = torch.clamp(advantages, -2.0, 2.0)
policy_loss = -(logprobs_tensor[evaluated_group] * advantages.detach()).mean()
return policy_loss, rewards_tensor.mean().item(), policy_loss.item(), 0.0, all_responses[0], all_lengths
我们的训练过程包含以下步骤:
为了验证我们的方法,我们使用了 GSM8K 小学数学数据集。这个数据集的问题格式如下:
问题:Natalia 在四月向 48 个朋友卖了夹子,五月卖了四月一半数量的夹子。Natalia 在四月和五月总共卖了多少夹子? 答案:Natalia 在五月卖了 48/2 = <<48/2=24>>24 个夹子。Natalia 在四月和五月总共卖了 48+24 = <<48+24=72>>72 个夹子。#### 72
我们只关注 ### 之后的最终答案,要求模型在
处理输入数据的代码:
def format_prompt(question: str) -> str:
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Solve this math problem: {question}
Show your reasoning first in <th>, then put the final answer in \\boxed{{}}."""
我们的奖励机制简单明了:
具体实现如下:
def get_reward(completion: str, target: str) -> float:
reward = -1.0
try:
completion = completion.strip()
start_tag = "<answer>"
end_tag = "</answer>"
start_idx = completion.rfind(start_tag)
if start_idx != -1:
substring_after_start = completion[start_idx + len(start_tag):]
end_idx = substring_after_start.find(end_tag)
if end_idx != -1:
answer = substring_after_start[:end_idx].strip()
if not answer and end_idx > 0:
answer = substring_after_start[:end_idx].strip()
numbers = ''.join(char for char in answer if char.isdigit() or char == '.')
if numbers:
generated_num = float(numbers)
target_num = float(str(target).strip())
if abs(generated_num - target_num) < 1e-6:
reward = 1.0
except Exception as e:
pass
return reward
我们在 RTX A6000 上使用 Reinforce-Lite 算法训练 3B 模型 12 小时:
在 GSM8K 上,Reinforce-Lite 相比指令模型取得了小幅提升:
所有实验都在 FP16 环境下运行。
Reinforce-Lite 调整后的模型展现出了多项改进:
这些能力在原始的指令模型中都未被观察到。
Reinforce-Lite 的实现即将开源!我们期待:
让我们一起探讨这个激动人心的方向。欢迎在评论区分享您的想法和建议!🚀
原文链接:Overnight End-to-End RL Training a 3B Model on a Grade School Math Dataset Leads to Reasoning