TRL 文档
Supervised Fine-tuning Trainer
and get access to the augmented documentation experience
to get started
Supervised Fine-tuning Trainer
监督式微调(或简称 SFT)是 RLHF 中的关键步骤。在 TRL 中,我们提供了一个易于使用的 API,用于创建您的 SFT 模型,并用几行代码在您的数据集上对其进行训练。
请查看 trl/scripts/sft.py
中的完整灵活示例。视觉语言模型的实验性支持也包含在 examples/scripts/sft_vlm.py
示例中。
Quickstart
如果您的数据集托管在 🤗 Hub 上,您可以使用 TRL 的 SFTTrainer 轻松微调您的 SFT 模型。假设您的数据集是 imdb
,您要预测的文本在数据集的 text
字段中,并且您想要微调 facebook/opt-350m
模型。以下代码段负责为您处理所有数据预处理和训练
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
max_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
请确保为 max_length
传递正确的值,因为默认值将设置为 min(tokenizer.model_max_length, 1024)
。
您还可以在训练器外部构建模型并按如下方式传递它
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
training_args = SFTConfig(output_dir="/tmp")
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=training_args,
)
trainer.train()
以上代码段将使用 SFTConfig 类的默认训练参数。如果您想修改默认值,请将您的修改传递给 SFTConfig
构造函数,并通过 args
参数将其传递给训练器。
Advanced usage
仅在补全上训练
您可以使用 DataCollatorForCompletionOnlyLM
仅在生成的提示上训练您的模型。请注意,这仅在 packing=False
的情况下有效。要为指令数据实例化该整理器,请传递响应模板和分词器。以下是如何在 CodeAlpaca 数据集上仅在补全上微调 opt-350m
的示例
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
要为助手风格的对话数据实例化该整理器,请传递响应模板、指令模板和分词器。以下是如何在 Open Assistant Guanaco 数据集上仅在助手补全上微调 opt-350m
的示例
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
trainer = SFTTrainer(
model,
args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
data_collator=collator,
)
trainer.train()
请确保 pad_token_id
与 eos_token_id
不同,这可能会导致模型在生成过程中无法正确预测 EOS(句子结尾)标记。
直接为 response_template 使用 token_ids
某些分词器(如 Llama 2 (meta-llama/Llama-2-XXb-hf
))根据序列是否具有上下文以不同的方式对序列进行分词。例如
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def print_tokens_with_ids(txt):
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
token_ids = tokenizer.encode(txt, add_special_tokens=False)
print(list(zip(tokens, token_ids)))
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
response_template = "### Assistant:"
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
在这种情况下,并且由于 response_template
中缺少上下文,因此相同的字符串(“### Assistant:”)以不同的方式进行分词
- 文本(带上下文):
[2277, 29937, 4007, 22137, 29901]
response_template
(不带上下文):[835, 4007, 22137, 29901]
当 DataCollatorForCompletionOnlyLM
在数据集示例文本中找不到 response_template
时,这将导致错误
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
为了解决这个问题,您可以使用与数据集中相同的上下文对 response_template
进行分词,根据需要截断它,并将 token_ids
直接传递给 DataCollatorForCompletionOnlyLM
类的 response_template
参数。例如
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
为聊天格式添加特殊标记
向语言模型添加特殊标记对于训练聊天模型至关重要。这些标记添加到对话中不同角色(如用户、助手和系统)之间,并帮助模型识别对话的结构和流程。这种设置对于使模型能够在聊天环境中生成连贯且上下文相关的响应至关重要。trl 中的 setup_chat_format()
函数可以轻松设置用于对话式 AI 任务的模型和分词器。此函数
- 向分词器添加特殊标记,例如
<|im_start|>
和<|im_end|>
,以指示对话的开始和结束。 - 调整模型嵌入层的大小以适应新标记。
- 设置分词器的
chat_template
,该模板用于将输入数据格式化为类似聊天的格式。默认值是 OpenAI 的chatml
。 - 可选地,您可以传递
resize_to_multiple_of
以将嵌入层的大小调整为resize_to_multiple_of
参数的倍数,例如 64。如果您希望在将来看到支持更多格式,请在 trl 上打开 GitHub 问题
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
设置好模型和分词器后,我们现在可以在对话数据集上微调我们的模型。以下是如何格式化数据集以进行微调的示例。
数据集格式支持
SFTTrainer 支持流行的数据集格式。这允许您直接将数据集传递给训练器,而无需任何预处理。支持以下格式
- 对话格式
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
- 指令格式
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
如果您的数据集使用上述格式之一,您可以直接将其传递给训练器,而无需进行预处理。SFTTrainer 然后将使用模型的分词器中定义的格式和 apply_chat_template 方法为您格式化数据集。
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
...
# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
...
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
args=training_args,
train_dataset=dataset,
)
如果数据集不是这些格式之一,您可以预处理数据集以匹配格式,或者将格式化函数传递给 SFTTrainer 为您完成此操作。让我们来看一下。
格式化您的输入提示
对于指令微调,在数据集中有两个列非常常见:一个用于提示,另一个用于响应。这允许人们像 Stanford-Alpaca 那样格式化示例,如下所示
Below is an instruction ...
### Instruction
{prompt}
### Response:
{completion}
假设您的数据集有两个字段 question
和 answer
。因此,您可以直接运行
...
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts
trainer = SFTTrainer(
model,
args=training_args,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
为了正确格式化您的输入,请确保通过循环处理所有示例并返回处理后的文本列表。查看有关如何在 alpaca 数据集上使用 SFTTrainer 的完整示例 here
打包数据集 ( ConstantLengthDataset )
SFTTrainer 支持示例打包,其中多个短示例打包在同一输入序列中,以提高训练效率。这是通过 ConstantLengthDataset
实用程序类完成的,该类从示例流返回恒定长度的标记块。要启用此数据集类的使用,只需将 packing=True
传递给 SFTConfig 构造函数。
...
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args
)
trainer.train()
请注意,如果您使用打包数据集,并且如果您在训练参数中传递 max_steps
,您可能会训练模型超过几个 epoch,具体取决于您配置打包数据集和训练协议的方式。仔细检查您是否知道并理解您在做什么。如果您不想打包您的 eval_dataset
,您可以将 eval_packing=False
传递给 SFTConfig
init 方法。
使用打包数据集自定义您的提示
如果您的数据集有多个您想要组合的字段,例如,如果数据集有 question
和 answer
字段,并且您想要组合它们,您可以将格式化函数传递给训练器,该函数将处理该问题。例如
def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
formatting_func=formatting_func
)
trainer.train()
您还可以通过直接将参数传递给 SFTConfig 构造函数来更多地自定义 ConstantLengthDataset
。有关更多信息,请参阅该类的签名。
控制预训练模型
您可以将 from_pretrained()
方法的 kwargs 直接传递给 SFTConfig。例如,如果您想以不同的精度加载模型,类似于
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
...
training_args = SFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
},
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
请注意,支持 from_pretrained()
的所有关键字参数。
训练适配器
我们还支持与 🤗 PEFT 库的紧密集成,以便任何用户都可以方便地训练适配器并在 Hub 上共享它们,而不是训练整个模型。
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/Capybara", split="train")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules="all-linear",
modules_to_save=["lm_head", "embed_token"],
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
"Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"),
peft_config=peft_config
)
trainer.train()
如果聊天模板包含特殊标记,如 <|im_start|>
(ChatML) 或 <|eot_id|>
(Llama),则嵌入层和 LM 头必须通过 modules_to_save
参数包含在可训练参数中。否则,微调后的模型将产生无界限或无意义的生成。如果聊天模板不包含特殊标记(例如 Alpaca),则可以忽略 modules_to_save
参数或将其设置为 None
。
您还可以继续训练您的 PeftModel
。为此,首先在 SFTTrainer
外部加载 PeftModel
,并将其直接传递给训练器,而无需传递 peft_config
参数。
使用基础 8 位模型训练适配器
为此,您需要首先在训练器外部加载您的 8 位模型,并将 PeftConfig
传递给训练器。例如
...
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-125m",
load_in_8bit=True,
device_map="auto",
)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(),
peft_config=peft_config,
)
trainer.train()
使用 Flash Attention 和 Flash Attention 2
您可以从 Flash Attention 1 和 2 中受益,开箱即用地使用 SFTTrainer,只需进行最少的代码更改。首先,为了确保您拥有 transformers 的所有最新功能,请从源代码安装 transformers
pip install -U git+https://github.com/huggingface/transformers.git
请注意,Flash Attention 现在仅在 GPU 上和半精度方案下工作(当使用适配器时,基础模型以半精度加载)另请注意,这两个功能与其他工具(如量化)完全兼容。