通过遵循方法论使用群组相对策略优化 (GRPO) 微调 SmolLM

群组相对策略优化 (GRPO) 是一种强化学习技术,旨在通过利用基于群组的奖励和策略优化来微调语言模型。它建立在近端策略优化 (PPO) 的概念之上,但通过考虑群组内生成输出的相对性能,引入了一种新颖的奖励计算和策略更新方法。
- 使用 GRPO 微调 SmolLM 模型涉及优化基于推理、准确性和格式等关键因素的奖励所推导的替代损失。微调过程遵循以下步骤:
- 安装所需包
- 加载和测试基础模型
- 定义辅助函数
- 定义奖励函数
- 设置GRPO配置和训练器
- 准备GSM8K数据集
- 配置训练器和模型
- 训练模型
- 使用微调模型进行推理
- 将模型上传到Hugging Face Hub
SmolLM2 135M Grpo 微调
资源 | 链接 |
---|---|
微调脚本1 | SmolLM_x_Grpo.ipynb |
微调脚本2 | SmolLM_x_Grpo.ipynb |
微调模型 | SmolLM2_135M_Grpo_Gsm8k |
微调检查点 | SmolLM2_135M_Grpo_Checkpoint |
方法一
步骤1:安装所需库
首先,安装必要的软件包。我们将从 Hugging Face 的 GitHub 仓库安装最新版本的 Transformers、Accelerate、Datasets 和 TRL 库。我们还安装 PEFT 用于参数高效微调。
!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git
!pip install -q datasets huggingface-hub trl
!pip install -q git+https://github.com/huggingface/peft.git
#!pip install flash-attn --no-build-isolation
步骤2:导入库并定义辅助函数
导入所有必需的库并定义辅助函数,用于解析和格式化模型输出。这些函数将帮助从模型的 XML 格式响应中提取答案。
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfiga # Note: This might be a typo. In your code, you later call LoraConfig.
from trl import GRPOConfig, GRPOTrainer
# System prompt that instructs the model to use a specific XML format.
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# XML chain-of-thought format template.
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# Function to extract the answer part from the XML response.
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
# Function to extract an answer if it is provided with a "####" delimiter.
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
步骤3:准备GSM8K数据集
我们使用 Hugging Face Hub 中的 GSM8K 数据集(一个小学数学问题集合)。在 `get_gsm8k_questions` 函数中,我们将每个示例转换为包含系统指令和用户问题的提示。(一例样本被注释掉,但如果需要可以启用。)
# Function to load and process the GSM8K dataset.
def get_gsm8k_questions(split="train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # Load the GSM8K dataset.
data = data.map(lambda x: { # Process each example.
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
# Uncomment the following lines to include a one-shot example.
# {'role': 'user', 'content': 'What is the largest single-digit prime number?'},
# {'role': 'assistant', 'content': XML_COT_FORMAT.format(
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
# answer="7"
# )},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
# Load the processed dataset.
dataset = get_gsm8k_questions()
步骤4:定义奖励函数
定义了几个奖励函数来指导训练过程。这些函数评估模型输出的不同方面,例如正确性、格式和对 XML 格式的结构遵守。
# Reward function to check correctness: compares the extracted answer from the response with the known answer.
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# Reward function that checks if the response is a digit.
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# Reward function that checks if the response strictly follows the desired XML format.
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# Reward function with a softer check for the XML format.
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in responses]
# Function to count specific XML tokens and award a small reward for each.
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
# Reward function that uses the XML token count.
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
步骤5:设置模型和分词器
我们从 Hugging Face Hub 中选择 SmolLM 模型(`HuggingFaceTB/SmolLM2-135M-Instruct`)。模型以 `bfloat16` 数据类型加载并移至 GPU。分词器也已加载,其填充标记设置为序列结束标记。
# Choose the model name.
model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
# Alternatively, you can use:
# model_name = "Qwen/Qwen2.5-1.5B-Instruct"
# Set output directories and run name based on the chosen model.
if "SmolLM2" in model_name:
output_dir = "outputs/SmolLM2-135M-GRPO"
run_name = "SmolLM2-135M-GRPO"
else:
output_dir = "outputs/Qwen-1.5B-GRPO"
run_name = "Qwen-1.5B-GRPO-gsm8k"
# Load the model.
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
device_map=None
).to("cuda")
# Load the tokenizer and ensure that the pad token is set.
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
步骤6:配置GRPO和PEFT
接下来,我们定义 GRPO 的训练配置以及使用 LoRA(低秩适应)的 PEFT(参数高效微调)配置。请注意,在下面的代码中,PEFT 配置已创建但未传递给训练器(它被注释掉了)。您可以通过取消注释相应的参数来启用它。
# GRPO training configuration.
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=16, # Must be divisible by num_generations.
gradient_accumulation_steps=4,
num_generations=16, # Number of generations per prompt.
max_prompt_length=256,
max_completion_length=786,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
report_to="none",
log_on_each_node=False,
)
# PEFT configuration using LoRA.
peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
注意:在导入部分我们使用了 `LoraConfiga`,这可能是一个拼写错误。请确保您从 `peft` 中正确导入并使用 `LoraConfig`。
步骤7:初始化GRPO训练器
现在我们使用模型、分词器(作为 `processing_class` 传入)、奖励函数、训练配置和数据集实例化 GRPOTrainer。奖励函数应用于每个生成,以提供细粒度反馈。
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func
],
args=training_args,
train_dataset=dataset,
# peft_config=peft_config # Uncomment this line to enable LoRA-based parameter-efficient fine-tuning.
)
步骤8:启动训练过程
最后,调用训练器上的 `train()` 方法,开始使用 GRPO 微调模型。
trainer.train()
方法二
1. 安装所需包
首先,安装最新版本的 `transformers` 和 `accelerate`(直接从 GitHub 获取),以及 `datasets` 和 `huggingface-hub`。
!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git
!pip install -q datasets huggingface-hub
2. 加载和测试基础模型
我们加载一个预训练的 SmolLM 模型并进行简单的推理来验证设置。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM2-360M-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
messages = [
{"role": "system", "content": "Please respond in this specific format ONLY:\n<reasoning>\n input your reasoning behind your answer in between these reasoning tags.\n</reasoning>\n<answer>\nyour answer in between these answer tags.\n</answer>\n"},
{"role": "user", "content": "How to add two numbers in Python?\n"},
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=256, temperature=0.2, top_p=0.9, do_sample=True, use_cache=False)
print(tokenizer.decode(outputs[0]))
3. 定义辅助函数
我们定义了用于处理提示和响应的辅助函数,包括提取用户查询、助手响应和 XML 格式的答案。这些函数将在训练和推理期间使用。
import re
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
# Reasoning Instruction
SYSTEM_PROMPT = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking> </thinking> and
<answer> </answer> tags, respectively, i.e., <thinking> reasoning process here </thinking><answer> answer here </answer>.
Response Format rules:
- Always start your response with <thinking> tag and end with </answer>.
- Do not include any text or commentary before the opening <thinking> tag or after the closing </answer> tag.
- Do not include any text or commentary between the closing </thinking> tag and the opening <answer> tag.
For example, your response follow this format:
<thinking>
[Your detailed chain-of-thought goes here]
</thinking>
<answer>
[Your final answer goes here]
</answer>
"""
# Helpers
def get_user_prompt(prompt: str) -> str:
match = re.search(r"<\|im_start\|>user\s*(.*?)\s*<\|im_end\|>", prompt, re.DOTALL)
if match:
return match.group(1).strip()
lines = prompt.splitlines()
result = []
for line in lines:
if not line.strip().lower().startswith("system"):
if line.strip().lower().startswith("user"):
result.append(line.strip()[4:].strip())
else:
result.append(line)
return "\n".join(result).strip()
def get_assistant_response(text: str) -> str:
match = re.search(r"<\|im_start\|>assistant\s*(.*?)\s*<\|im_end\|>", text, re.DOTALL)
if match:
return match.group(1).strip()
lines = text.splitlines()
result = []
capture = False
for line in lines:
stripped = line.strip()
if stripped.lower().startswith("assistant"):
capture = True
continue
if capture:
result.append(line)
return "\n".join(result).strip()
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str:
if "####" not in text:
return text.strip()
return text.split("####", 1)[1].strip()
def count_xml(text: str) -> float:
count = 0.0
if text.count("<thinking>\n") == 1:
count += 0.225
if text.count("\n</thinking>\n") == 1:
count += 0.225
if text.count("\n<answer>\n") == 1:
count += 0.225
count -= len(text.split("\n</answer>")[-1]) * 0.001
if text.count("\n</answer>\n") == 1:
count += 0.225
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def inference(prompt: str, model_path: str) -> str:
device = config.device
model_infer = AutoModelForCausalLM.from_pretrained(model_path).to(device)
tokenizer_infer = AutoTokenizer.from_pretrained(model_path)
inputs = tokenizer_infer(prompt, return_tensors="pt", max_length=config.max_prompt_length, truncation=False)
outputs = model_infer.generate(
inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
max_new_tokens=config.max_completion_length,
do_sample=True,
pad_token_id=tokenizer_infer.eos_token_id,
temperature=config.temperature,
num_return_sequences=1,
use_cache=False
)
full_text = tokenizer_infer.decode(outputs[0])
user_question = get_user_prompt(prompt)
assistant_response = get_assistant_response(full_text)
extracted_answer = extract_xml_answer(assistant_response)
return f"{'='*10} Inference {'='*10}\nQuestion:\n{user_question}\n\nModel Response:\n{assistant_response}\n\nExtracted:\n{extracted_answer}\n{'='*12} End {'='*12}\n"
4. 定义奖励函数
这些函数根据各种标准分配奖励:推理的长度和质量(在 `<thinking>` 标签内)、最终答案的准确性、格式遵从性、XML 标签计数以及答案是否为整数。这些奖励将指导 GRPO 更新。
# Rewards
def reasoning_reward(prompts, completions, answer, **kwargs) -> list:
rewards = []
transition_words = ["first", "next", "then", "because", "wait", "aha", "therefore", "finally", "in summary"]
pattern = r"<\s*thinking\s*>(.*?)<\s*/\s*thinking\s*>"
for comp in completions:
match = re.search(pattern, comp, re.DOTALL | re.IGNORECASE)
if match:
reasoning_text = match.group(1).strip()
words = reasoning_text.split()
reward = 0.0
# base reward if at least 25 words in between <thinking> </thinking> tags
if len(words) >= 25:
reward += 0.25
lower_text = reasoning_text.lower()
# transition words reward (case-insensitive)
transition_count = sum(1 for word in transition_words if word in lower_text)
if transition_count > 0:
reward += 0.5
# bonus reward if there are at least 30 words
if len(words) >= 50:
reward += 0.35
rewards.append(reward)
else:
rewards.append(0.0)
return rewards
def accuracy_reward(prompts, completions, answer, num_generated_samples_to_view=False, q_num=None, **kwargs) -> list:
q = prompts[0]
user_question = get_user_prompt(q)
assistant_responses = [get_assistant_response(r) for r in completions]
extracted_responses = [extract_xml_answer(get_assistant_response(r)) for r in completions]
if num_generated_samples_to_view:
print(f"{'='*15} Sample {q_num} {'='*15}\nQuestion:\n{user_question}\n\nAnswer:\n{answer[0]}\n\nResponse:\n{assistant_responses[0]}\n\nExtracted:\n{extracted_responses[0]}\n{'='*18} End {'='*18}\n")
return [2.0 if r.strip() == a.strip() else 0.0 for r, a in zip(extracted_responses, answer)]
def soft_format_reward(completions, **kwargs) -> list:
pattern = r"<thinking>.*?</thinking>\s*<answer>.*?</answer>"
return [0.5 if re.search(pattern, comp, re.DOTALL) else 0.0 for comp in completions]
def strict_format_reward(completions, **kwargs) -> list:
pattern = r"^<thinking>\n.*?\n</thinking>\n<answer>\n.*?\n</answer>\n$"
return [1.0 if re.fullmatch(pattern, comp) else 0.0 for comp in completions]
def xmlcount_reward(prompts, completions, answer, **kwargs) -> list:
return [count_xml(comp) * 0.5 for comp in completions]
def int_reward(completions, **kwargs) -> list:
return [0.5 if get_assistant_response(comp).strip().isdigit() else 0.0 for comp in completions]
5. 设置GRPO配置和训练器
现在我们定义一个配置类(`GRPOConfig`)来保存我们的训练参数,以及一个实现 GRPO 训练循环的 `GRPOTrainer` 类。训练器负责生成补全、计算奖励、计算带有 KL 惩罚的替代损失,并更新模型。
# GRPO Config
class GRPOConfig:
def __init__(self, **kwargs):
self.output_dir = kwargs.get("output_dir", "outputs")
self.run_name = kwargs.get("run_name", "custom_grpo")
self.learning_rate = kwargs.get("learning_rate", 1e-5)
self.weight_decay = kwargs.get("weight_decay", 0.01)
self.warmup_steps = kwargs.get("warmup_steps", 50)
self.num_generations = kwargs.get("num_generations", 1)
self.max_prompt_length = kwargs.get("max_prompt_length", 256)
self.max_completion_length = kwargs.get("max_completion_length", 256)
self.num_train_epochs = kwargs.get("num_train_epochs", 1)
self.gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps", 1)
self.clip_epsilon = kwargs.get("clip_epsilon", 0.2)
self.beta = kwargs.get("beta", 0.01)
self.logging_steps = kwargs.get("logging_steps", 1)
self.save_steps = kwargs.get("save_steps", 50)
self.max_steps = kwargs.get("max_steps", 1000)
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.temperature = kwargs.get("temperature", 0.2)
self.num_generated_samples_to_view = kwargs.get("num_generated_samples_to_view", 10)
self.bf16 = kwargs.get("bf16", True)
self.per_device_train_batch_size = kwargs.get("per_device_train_batch_size", 4)
self.use_flash_attn_2 = kwargs.get("use_flash_attn_2", False)
self.use_vllm = kwargs.get("use_vllm", False)
self.vllm_device = kwargs.get("vllm_device", "cuda:0")
self.vllm_gpu_memory_utilization = kwargs.get("vllm_gpu_memory_utilization", 0.8)
self.vllm_dtype = kwargs.get("vllm_dtype", "float16")
self.vllm_max_model_len = kwargs.get("vllm_max_model_len", 512)
# GRPO Trainer
class GRPOTrainer:
def __init__(self, model, tokenizer, reward_funcs, config, train_dataset):
self.dataloader = DataLoader(train_dataset, batch_size=config.per_device_train_batch_size, shuffle=True, collate_fn=lambda x: x)
self.model = model.to(config.device)
self.tokenizer = tokenizer
self.reward_funcs = reward_funcs
self.config = config
self.train_dataset = train_dataset
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
total_steps = (len(train_dataset) // config.per_device_train_batch_size) * config.num_train_epochs
self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps)
self.ref_model = AutoModelForCausalLM.from_pretrained(model.config._name_or_path)
self.ref_model.to(config.device)
self.ref_model.eval()
for param in self.ref_model.parameters():
param.requires_grad = False
self.step = 0
self._metrics = defaultdict(list)
self.scaler = torch.cuda.amp.GradScaler() if config.device.startswith("cuda") else None
def get_per_token_logps(self, model, full_ids, attention_mask, num_logits_to_keep):
outputs = model(input_ids=full_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1, :] # Exclude the last logit
logits_slice = logits[:, -num_logits_to_keep:, :]
token_ids = full_ids[:, -num_logits_to_keep:]
log_probs = torch.log_softmax(logits_slice, dim=-1)
token_log_probs = log_probs.gather(dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)
return token_log_probs
def compute_loss(self, input_ids, generation_output, advantages, old_logps, attention_mask):
num_logits_to_keep = generation_output.shape[1] - input_ids.shape[1]
full_ids = generation_output
# Compute current log probabilities from the updated model
per_token_logps = self.get_per_token_logps(self.model, full_ids, attention_mask, num_logits_to_keep)
with torch.no_grad():
ref_per_token_logps = self.get_per_token_logps(self.ref_model, full_ids, attention_mask, num_logits_to_keep)
# KL divergence per token (using Schulman et al.'s approximation)
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# Compute mask for valid tokens via EOS detection
completion_ids = full_ids[:, input_ids.shape[1]:]
is_eos = (completion_ids == self.tokenizer.eos_token_id)
batch_size, seq_len = is_eos.size()
device = input_ids.device
eos_idx = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
for i in range(batch_size):
nonzero = torch.nonzero(is_eos[i], as_tuple=False)
if nonzero.numel() > 0:
eos_idx[i] = nonzero[0, 0]
sequence_indices = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
mask = (sequence_indices <= eos_idx.unsqueeze(1)).float()
# Calculate policy ratio using stored old log probabilities
ratio = torch.exp(per_token_logps - old_logps)
clipped_ratio = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon)
# Clipped surrogate objective
surrogate_loss = -torch.min(ratio * advantages.unsqueeze(1), clipped_ratio * advantages.unsqueeze(1))
# Add KL penalty term
per_token_loss = surrogate_loss + self.config.beta * per_token_kl
loss = ((per_token_loss * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)).mean()
mean_kl = (per_token_kl * mask).sum(dim=1).mean().item()
completion_length = mask.sum(dim=1).mean().item()
return loss, mean_kl, completion_length
def evaluate_rewards(self, prompt, completions, gt_answer):
rewards_dict = {}
for func in self.reward_funcs:
if func.__name__ in ["accuracy_reward", "xmlcount_reward", "reasoning_reward"]:
r = func([prompt] * len(completions), completions, [gt_answer] * len(completions))
else:
r = func(completions)
rewards_dict[func.__name__] = r
combined_rewards = [sum(rewards_dict[func_name][i] for func_name in rewards_dict)
for i in range(len(completions))]
return combined_rewards, rewards_dict
def train(self):
self.model.train()
accumulation_counter = 0
for epoch in range(self.config.num_train_epochs):
for batch in self.dataloader:
if self.step >= self.config.max_steps:
break
example = batch[0]
prompts = example["prompts"]
gt_answer = example["answer"]
prompt_text = self.tokenizer.apply_chat_template(prompts, tokenize=False)
inputs = self.tokenizer(prompt_text, return_tensors="pt", max_length=self.config.max_prompt_length, truncation=False)
input_ids = inputs.input_ids.to(self.config.device)
attention_mask = inputs.attention_mask.to(self.config.device)
with torch.autocast(
device_type=self.config.device,
enabled=(self.scaler is not None),
dtype=(torch.bfloat16 if self.config.bf16 else torch.float16)
):
generation_output = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=self.config.max_completion_length,
do_sample=True,
temperature=self.config.temperature,
num_return_sequences=self.config.num_generations,
pad_token_id=self.tokenizer.eos_token_id,
use_cache=False
)
generation_output = generation_output.to(self.config.device)
completions = [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in generation_output]
completions = [c.replace(prompt_text, "").strip() if prompt_text in c else c for c in completions]
num_gens = len(completions)
view_flag = (self.step < self.config.num_generated_samples_to_view)
acc_rewards = accuracy_reward([prompt_text]*num_gens, completions, [gt_answer]*num_gens,
num_generated_samples_to_view=view_flag, q_num=self.step)
combined_rewards, reward_dict = self.evaluate_rewards(prompt_text, completions, gt_answer)
rewards_tensor = torch.tensor(combined_rewards, device=self.config.device, dtype=torch.float)
reward_avg = rewards_tensor.mean().item()
reward_std = rewards_tensor.std().item() if rewards_tensor.numel() > 1 else 0.0
reasoning_rewards = reward_dict.get("reasoning_reward", [0.0]*len(completions))
reasoning_reward_avg = sum(reasoning_rewards) / len(reasoning_rewards)
if self.config.num_generations > 1:
rewards_grouped = rewards_tensor.view(-1, self.config.num_generations)
mean_rewards = rewards_grouped.mean(dim=1)
std_rewards = rewards_grouped.std(dim=1) + 1e-4
advantages = (rewards_tensor - mean_rewards.repeat_interleave(self.config.num_generations)) / std_rewards.repeat_interleave(self.config.num_generations)
else:
advantages = rewards_tensor
advantages = torch.clamp(advantages, -5.0, 5.0)
num_logits_to_keep = generation_output.shape[1] - input_ids.shape[1]
old_logps = self.get_per_token_logps(self.model, generation_output, attention_mask, num_logits_to_keep).detach()
loss, mean_kl, completion_length = self.compute_loss(input_ids, generation_output, advantages, old_logps, attention_mask)
loss = loss / self.config.gradient_accumulation_steps
if self.scaler is not None:
self.scaler.scale(loss).backward()
else:
loss.backward()
accumulation_counter += 1
if accumulation_counter % self.config.gradient_accumulation_steps == 0:
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
accumulation_counter = 0
self._metrics["loss"].append(loss.item() * self.config.gradient_accumulation_steps)
self._metrics["completion_length"].append(completion_length)
self._metrics["reward"].append(reward_avg)
self._metrics["reward_std"].append(reward_std)
self._metrics["accuracy_reward"].append(sum(acc_rewards))
self._metrics["reasoning_reward"].append(reasoning_reward_avg)
self._metrics["kl"].append(mean_kl)
# Print without reasoning reward
print(f"Step {self.step} | Loss: {loss.item()*self.config.gradient_accumulation_steps:.4f} | Reward: {reward_avg:.4f} | Reward Std: {reward_std:.4f} | Completion Length: {completion_length:.4f} | KL: {mean_kl:.4f}\n")
self.step += 1
if self.step % self.config.save_steps == 0:
checkpoint_path = os.path.join(self.config.output_dir, f"checkpoint-{self.step}")
os.makedirs(checkpoint_path, exist_ok=True)
self.model.save_pretrained(checkpoint_path)
self.tokenizer.save_pretrained(checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}\n")
test_messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "Which is heavier 1k of steel or 1kg of wool?"}
]
test_prompt = self.tokenizer.apply_chat_template(test_messages, tokenize=False)
inf_result = inference(test_prompt, checkpoint_path)
print(inf_result)
if self.step >= self.config.max_steps:
break
if self.step >= self.config.max_steps:
break
final_model_path = os.path.join(self.config.output_dir, "final_model")
os.makedirs(final_model_path, exist_ok=True)
self.model.save_pretrained(final_model_path)
self.tokenizer.save_pretrained(final_model_path)
print(f"Final model saved to {final_model_path}")
plt.figure(figsize=(14, 10))
plt.subplot(3, 2, 1)
plt.plot(self._metrics["accuracy_reward"], label="Accuracy", color="blue")
plt.title("Accuracy vs Steps")
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.legend()
plt.subplot(3, 2, 2)
plt.plot(self._metrics["reward"], label="Reward", color="green")
plt.title("Reward vs Steps")
plt.xlabel("Steps")
plt.ylabel("Reward")
plt.legend()
plt.subplot(3, 2, 3)
plt.plot(self._metrics["reward_std"], label="Reward Std", color="orange")
plt.title("Reward Std vs Steps")
plt.xlabel("Steps")
plt.ylabel("Reward Std")
plt.legend()
plt.subplot(3, 2, 4)
plt.plot(self._metrics["kl"], label="KL Penalty", color="red")
plt.title("KL Penalty vs Steps")
plt.xlabel("Steps")
plt.ylabel("KL Penalty")
plt.legend()
plt.subplot(3, 2, 5)
plt.plot(self._metrics["completion_length"], label="Avg Completion Length", color="purple")
plt.title("Avg Completion Length vs Steps")
plt.xlabel("Steps")
plt.ylabel("Completion Length")
plt.legend()
# plt.subplot(3, 2, 6)
# plt.plot(self._metrics["reasoning_reward"], label="Reasoning Reward", color="brown")
# plt.title("Reasoning Reward vs Steps")
# plt.xlabel("Steps")
# plt.ylabel("Reasoning Reward")
# plt.legend()
plt.tight_layout()
plt.show()
6. 准备GSM8K数据集
对于我们的训练数据,我们使用 GSM8K 数据集(数学问题集合)。我们重新格式化数据,使每个示例都包含一个提示(包含系统提示和用户问题)和相应的答案。
# GSM8K Dataset & Chat Temp
def get_gsm8k_data(split="train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split]
data = data.map(lambda x: {
'prompts': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
dataset = get_gsm8k_data()
7. 配置训练器和模型
设置 GRPO 配置并加载 SmolLM 模型(使用较小的 135M 变体)进行微调。另外,指定要使用的奖励函数。
# Trainer and Config
config = GRPOConfig(
output_dir="outputs/SmolLM2_135M_Grpo_Gsm8k",
run_name="smollm2_135m_grpo_gsm8k_reasoner",
learning_rate=5e-6,
weight_decay=0.01,
warmup_steps=100,
num_generations=2,
max_prompt_length=256,
max_completion_length=200,
num_train_epochs=1,
gradient_accumulation_steps=1,
clip_epsilon=0.2,
beta=0.01,
logging_steps=1,
save_steps=250,
max_steps=500,
temperature=0.2,
num_generated_samples_to_view=250,
bf16=True,
per_device_train_batch_size=1,
# use_flash_attn_2=True, # Enable Flash Attention 2 (GPU only)
# use_vllm=True, # use vLLM (GPU only)
# vllm_device="cuda:0", # vLLM device config (GPU only)
# vllm_gpu_memory_utilization=0.3 # vLLM GPU memory utilization (GPU only)
)
model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" if config.use_flash_attn_2 else None,
use_cache=False
).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer.pad_token = tokenizer.eos_token
reward_functions = [reasoning_reward, accuracy_reward, soft_format_reward, strict_format_reward, int_reward, xmlcount_reward]
trainer = GRPOTrainer(model, tokenizer, reward_functions, config, dataset)
8. 训练模型
开始训练过程。训练器将生成多个补全,计算奖励,使用 GRPO 损失(包括 KL 惩罚)更新模型,并沿途保存检查点。
# Train
trainer.train()
9. 使用微调模型进行推理
训练完成后,在新问题上测试微调模型。
sample = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "If there are 12 cookies in a dozen and you have 5 dozen, how many cookies do you have?"}
]
final_prompt = tokenizer.apply_chat_template(sample, tokenize=False)
print(inference(final_prompt, os.path.join(config.output_dir, "final_model")))
10. 将模型上传到 Hugging Face Hub
最后,登录您的 Hugging Face 账户并将训练好的模型推送到仓库。
# Login Hf and Push to Repo
from huggingface_hub import notebook_login
notebook_login()
# Import the HfApi class from the huggingface_hub library.
from huggingface_hub import HfApi
api = HfApi()
repo_id = f"prithivMLmods/SmolLM2_135M_Grpo_Checkpoint"
try:
# Attempt to create a new repository on the Hugging Face Model Hub using the specified repo_id.
api.create_repo(repo_id)
print(f"Repo {repo_id} created")
except:
print(f"Repo {repo_id} already exists")
api.upload_folder(
folder_path="outputs/SmolLM2_135M_Grpo_Gsm8k/final_model", # The path to the folder to be uploaded
path_in_repo=".", # The path where the folder will be stored in the repository
repo_id=repo_id, # The ID of the repository where the folder will be uploaded
repo_type="model", # The type of the repository (in this case, a model repository)
revision="main" # Revision name
)
结论
GRPO 是一种强大的强化学习技术,用于微调语言模型。通过利用基于群组的相对奖励和 KL 散度正则化,它能够实现稳定高效的学习,同时鼓励生成高质量、结构化的输出。其灵活性使其适用于从推理到指令遵循的广泛任务。
通过遵循这些步骤并使用提供的代码,您可以使用 GRPO 方法自定义 SmolLM,以提高推理、格式遵从性和整体答案准确性。
微调愉快!🤗