TRL 文档

Supervised Fine-tuning Trainer

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Supervised Fine-tuning Trainer

监督式微调(或简称 SFT)是 RLHF 中的关键步骤。在 TRL 中,我们提供了一个易于使用的 API,用于创建您的 SFT 模型,并用几行代码在您的数据集上对其进行训练。

请查看 trl/scripts/sft.py 中的完整灵活示例。视觉语言模型的实验性支持也包含在 examples/scripts/sft_vlm.py 示例中。

Quickstart

如果您的数据集托管在 🤗 Hub 上,您可以使用 TRL 的 SFTTrainer 轻松微调您的 SFT 模型。假设您的数据集是 imdb,您要预测的文本在数据集的 text 字段中,并且您想要微调 facebook/opt-350m 模型。以下代码段负责为您处理所有数据预处理和训练

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

dataset = load_dataset("stanfordnlp/imdb", split="train")

training_args = SFTConfig(
    max_length=512,
    output_dir="/tmp",
)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    args=training_args,
)
trainer.train()

请确保为 max_length 传递正确的值,因为默认值将设置为 min(tokenizer.model_max_length, 1024)

您还可以在训练器外部构建模型并按如下方式传递它

from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

dataset = load_dataset("stanfordnlp/imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

training_args = SFTConfig(output_dir="/tmp")

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=training_args,
)

trainer.train()

以上代码段将使用 SFTConfig 类的默认训练参数。如果您想修改默认值,请将您的修改传递给 SFTConfig 构造函数,并通过 args 参数将其传递给训练器。

Advanced usage

仅在补全上训练

您可以使用 DataCollatorForCompletionOnlyLM 仅在生成的提示上训练您的模型。请注意,这仅在 packing=False 的情况下有效。要为指令数据实例化该整理器,请传递响应模板和分词器。以下是如何在 CodeAlpaca 数据集上仅在补全上微调 opt-350m 的示例

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=SFTConfig(output_dir="/tmp"),
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

trainer.train()

要为助手风格的对话数据实例化该整理器,请传递响应模板、指令模板和分词器。以下是如何在 Open Assistant Guanaco 数据集上仅在助手补全上微调 opt-350m 的示例

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
    model,
    args=SFTConfig(output_dir="/tmp"),
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()

请确保 pad_token_ideos_token_id 不同,这可能会导致模型在生成过程中无法正确预测 EOS(句子结尾)标记。

直接为 response_template 使用 token_ids

某些分词器(如 Llama 2 (meta-llama/Llama-2-XXb-hf))根据序列是否具有上下文以不同的方式对序列进行分词。例如

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

def print_tokens_with_ids(txt):
    tokens = tokenizer.tokenize(txt, add_special_tokens=False)
    token_ids = tokenizer.encode(txt, add_special_tokens=False)
    print(list(zip(tokens, token_ids)))

prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt)  # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]

response_template = "### Assistant:"
print_tokens_with_ids(response_template)  # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]

在这种情况下,并且由于 response_template 中缺少上下文,因此相同的字符串(“### Assistant:”)以不同的方式进行分词

  • 文本(带上下文):[2277, 29937, 4007, 22137, 29901]
  • response_template(不带上下文):[835, 4007, 22137, 29901]

DataCollatorForCompletionOnlyLM 在数据集示例文本中找不到 response_template 时,这将导致错误

RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([    1,   835,  ...])

为了解决这个问题,您可以使用与数据集中相同的上下文对 response_template 进行分词,根据需要截断它,并将 token_ids 直接传递给 DataCollatorForCompletionOnlyLM 类的 response_template 参数。例如

response_template_with_context = "\n### Assistant:"  # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:]  # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`

data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)

为聊天格式添加特殊标记

向语言模型添加特殊标记对于训练聊天模型至关重要。这些标记添加到对话中不同角色(如用户、助手和系统)之间,并帮助模型识别对话的结构和流程。这种设置对于使模型能够在聊天环境中生成连贯且上下文相关的响应至关重要。trl 中的 setup_chat_format() 函数可以轻松设置用于对话式 AI 任务的模型和分词器。此函数

  • 向分词器添加特殊标记,例如 <|im_start|><|im_end|>,以指示对话的开始和结束。
  • 调整模型嵌入层的大小以适应新标记。
  • 设置分词器的 chat_template,该模板用于将输入数据格式化为类似聊天的格式。默认值是 OpenAI 的 chatml
  • 可选地,您可以传递 resize_to_multiple_of 以将嵌入层的大小调整为 resize_to_multiple_of 参数的倍数,例如 64。如果您希望在将来看到支持更多格式,请在 trl 上打开 GitHub 问题
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)

设置好模型和分词器后,我们现在可以在对话数据集上微调我们的模型。以下是如何格式化数据集以进行微调的示例。

数据集格式支持

SFTTrainer 支持流行的数据集格式。这允许您直接将数据集传递给训练器,而无需任何预处理。支持以下格式

  • 对话格式
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
  • 指令格式
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}

如果您的数据集使用上述格式之一,您可以直接将其传递给训练器,而无需进行预处理。SFTTrainer 然后将使用模型的分词器中定义的格式和 apply_chat_template 方法为您格式化数据集。

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

...

# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

...

training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
    "facebook/opt-350m",
    args=training_args,
    train_dataset=dataset,
)

如果数据集不是这些格式之一,您可以预处理数据集以匹配格式,或者将格式化函数传递给 SFTTrainer 为您完成此操作。让我们来看一下。

格式化您的输入提示

对于指令微调,在数据集中有两个列非常常见:一个用于提示,另一个用于响应。这允许人们像 Stanford-Alpaca 那样格式化示例,如下所示

Below is an instruction ...

### Instruction
{prompt}

### Response:
{completion}

假设您的数据集有两个字段 questionanswer。因此,您可以直接运行

...
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

trainer = SFTTrainer(
    model,
    args=training_args,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)

trainer.train()

为了正确格式化您的输入,请确保通过循环处理所有示例并返回处理后的文本列表。查看有关如何在 alpaca 数据集上使用 SFTTrainer 的完整示例 here

打包数据集 ( ConstantLengthDataset )

SFTTrainer 支持示例打包,其中多个短示例打包在同一输入序列中,以提高训练效率。这是通过 ConstantLengthDataset 实用程序类完成的,该类从示例流返回恒定长度的标记块。要启用此数据集类的使用,只需将 packing=True 传递给 SFTConfig 构造函数。

...
training_args = SFTConfig(packing=True)

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    args=training_args
)

trainer.train()

请注意,如果您使用打包数据集,并且如果您在训练参数中传递 max_steps,您可能会训练模型超过几个 epoch,具体取决于您配置打包数据集和训练协议的方式。仔细检查您是否知道并理解您在做什么。如果您不想打包您的 eval_dataset,您可以将 eval_packing=False 传递给 SFTConfig init 方法。

使用打包数据集自定义您的提示

如果您的数据集有多个您想要组合的字段,例如,如果数据集有 questionanswer 字段,并且您想要组合它们,您可以将格式化函数传递给训练器,该函数将处理该问题。例如

def formatting_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text

training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    args=training_args,
    formatting_func=formatting_func
)

trainer.train()

您还可以通过直接将参数传递给 SFTConfig 构造函数来更多地自定义 ConstantLengthDataset。有关更多信息,请参阅该类的签名。

控制预训练模型

您可以将 from_pretrained() 方法的 kwargs 直接传递给 SFTConfig。例如,如果您想以不同的精度加载模型,类似于

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)

...

training_args = SFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
    },
    output_dir="/tmp",
)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    args=training_args,
)

trainer.train()

请注意,支持 from_pretrained() 的所有关键字参数。

训练适配器

我们还支持与 🤗 PEFT 库的紧密集成,以便任何用户都可以方便地训练适配器并在 Hub 上共享它们,而不是训练整个模型。

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

dataset = load_dataset("trl-lib/Capybara", split="train")

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules="all-linear",
    modules_to_save=["lm_head", "embed_token"],
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    "Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
    args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"),
    peft_config=peft_config
)

trainer.train()

如果聊天模板包含特殊标记,如 <|im_start|> (ChatML) 或 <|eot_id|> (Llama),则嵌入层和 LM 头必须通过 modules_to_save 参数包含在可训练参数中。否则,微调后的模型将产生无界限或无意义的生成。如果聊天模板不包含特殊标记(例如 Alpaca),则可以忽略 modules_to_save 参数或将其设置为 None

您还可以继续训练您的 PeftModel。为此,首先在 SFTTrainer 外部加载 PeftModel,并将其直接传递给训练器,而无需传递 peft_config 参数。

使用基础 8 位模型训练适配器

为此,您需要首先在训练器外部加载您的 8 位模型,并将 PeftConfig 传递给训练器。例如

...

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neo-125m",
    load_in_8bit=True,
    device_map="auto",
)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=SFTConfig(),
    peft_config=peft_config,
)

trainer.train()

使用 Flash Attention 和 Flash Attention 2

您可以从 Flash Attention 1 和 2 中受益,开箱即用地使用 SFTTrainer,只需进行最少的代码更改。首先,为了确保您拥有 transformers 的所有最新功能,请从源代码安装 transformers

pip install -U git+https://github.com/huggingface/transformers.git

请注意,Flash Attention 现在仅在 GPU 上和半精度方案下工作(当使用适配器时,基础模型以半精度加载)另请注意,这两个功能与其他工具(如量化)完全兼容。

BetterTransformer API 并强制调度 API 以使用 Flash Attention 内核。首先,安装最新的 optimum 包

pip install -U optimum

加载模型后,将 trainer.train() 调用包装在 with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): 上下文管理器下

...

+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    trainer.train()

请注意,你不能在任意数据集上使用 Flash Attention 1 训练模型,因为如果你使用 Flash Attention 内核,torch.scaled_dot_product_attention 不支持使用填充 tokens 进行训练。因此,你只能将此功能与 packing=True 一起使用。如果你的数据集包含填充 tokens,请考虑切换到 Flash Attention 2 集成。

以下是使用 Flash Attention 1 在单个 NVIDIA-T4 16GB 上获得的一些速度和内存效率数据。

use_flash_attn_1 model_name max_seq_len batch_size time per training step
x facebook/opt-350m 2048 8 ~59.1秒
facebook/opt-350m 2048 8 OOM (内存溢出)
x facebook/opt-350m 2048 4 ~30.3秒
facebook/opt-350m 2048 4 ~148.9秒

使用 Flash Attention-2

要使用 Flash Attention 2,首先安装最新的 flash-attn

pip install -U flash-attn

并在调用 from_pretrained 时添加 attn_implementation="flash_attention_2"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_4bit=True,
    attn_implementation="flash_attention_2"
)

如果你不使用量化,请确保你的模型以半精度加载,并将你的模型分派到支持的 GPU 设备上。加载模型后,你可以按原样训练它,或者连接适配器并在其上训练适配器,以防你的模型被量化。

与 Flash Attention 1 相比,此集成使得可以在也包含填充 tokens 的任意数据集上训练模型。

使用模型创建实用程序

我们包含了一个实用程序函数来创建你的模型。

class trl.ModelConfig

< >

( model_name_or_path: typing.Optional[str] = None model_revision: str = 'main' torch_dtype: typing.Optional[str] = None trust_remote_code: bool = False attn_implementation: typing.Optional[str] = None use_peft: bool = False lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: typing.Optional[list[str]] = None lora_modules_to_save: typing.Optional[list[str]] = None lora_task_type: str = 'CAUSAL_LM' use_rslora: bool = False use_dora: bool = False load_in_8bit: bool = False load_in_4bit: bool = False bnb_4bit_quant_type: str = 'nf4' use_bnb_nested_quant: bool = False )

参数

  • model_name_or_path (strNone, 可选, 默认为 None) — 用于权重初始化的模型检查点。
  • model_revision (str, 可选, 默认为 "main") — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID。
  • torch_dtype (Literal["auto", "bfloat16", "float16", "float32"]None, 可选, 默认为 None) — 覆盖默认的 torch.dtype 并在该 dtype 下加载模型。可能的值为:

    • "bfloat16": torch.bfloat16
    • "float16": torch.float16
    • "float32": torch.float32
    • "auto": 从模型的权重自动推导 dtype。
  • trust_remote_code (bool, 可选, 默认为 False) — 是否允许在 Hub 上自定义模型及其自己的建模文件中定义。此选项仅应针对你信任且已阅读代码的存储库设置为 True,因为它将在你的本地计算机上执行 Hub 上存在的代码。
  • attn_implementation (strNone, 可选, 默认为 None) — 要使用的注意力实现。你可以运行 --attn_implementation=flash_attention_2,在这种情况下,你必须手动安装它,通过运行 pip install flash-attn --no-build-isolation
  • use_peft (bool, 可选, 默认为 False) — 是否使用 PEFT 进行训练。
  • lora_r (int, 可选, 默认为 16) — LoRA R 值。
  • lora_alpha (int, 可选, 默认为 32) — LoRA alpha 值。
  • lora_dropout (float, 可选, 默认为 0.05) — LoRA dropout 率。
  • lora_target_modules (Union[str, list[str]]None, 可选, 默认为 None) — LoRA 目标模块。
  • lora_modules_to_save (list[str]None, 可选, 默认为 None) — 要解冻和训练的模型层。
  • lora_task_type (str, 可选, 默认为 "CAUSAL_LM") — 要为 LoRA 传递的任务类型(奖励建模使用 "SEQ_CLS")。
  • use_rslora (bool, 可选, 默认为 False) — 是否使用 Rank-Stabilized LoRA,它将适配器缩放因子设置为 lora_alpha/√r,而不是原始默认值 lora_alpha/r
  • use_dora (bool, 可选, 默认为 False) — 启用 Weight-Decomposed Low-Rank Adaptation (DoRA) (权重分解低秩适配)。此技术将权重的更新分解为两部分:幅值和方向。方向由普通的 LoRA 处理,而幅值由单独的可学习参数处理。这可以提高 LoRA 的性能,尤其是在低秩时。目前,DoRA 仅支持线性层和 Conv2D 层。DoRA 引入的比纯 LoRA 更大的开销,因此建议合并权重以进行推理。
  • load_in_8bit (bool, 可选, 默认为 False) — 是否对基础模型使用 8 位精度。仅适用于 LoRA。
  • load_in_4bit (bool, 可选, 默认为 False) — 是否对基础模型使用 4 位精度。仅适用于 LoRA。
  • bnb_4bit_quant_type (str, 可选, 默认为 "nf4") — 量化类型 ("fp4""nf4")。
  • use_bnb_nested_quant (bool, 可选, 默认为 False) — 是否使用嵌套量化。

模型的配置类。

使用 HfArgumentParser,我们可以将此类转换为可以在命令行上指定的 argparse 参数。

from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
model_args = ModelConfig(
    model_name_or_path="facebook/opt-350m"
    attn_implementation=None, # or "flash_attention_2"
)
torch_dtype = (
    model_args.torch_dtype
    if model_args.torch_dtype in ["auto", None]
    else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    attn_implementation=model_args.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
trainer = SFTTrainer(
    ...,
    model=model_args.model_name_or_path,
    peft_config=get_peft_config(model_args),
)

使用 NEFTune 增强模型性能

NEFTune 是一种提升聊天模型性能的技术,由 Jain 等人在论文 “NEFTune: Noisy Embeddings Improve Instruction Finetuning” 中介绍。它包括在训练期间向嵌入向量添加噪声。根据该论文的摘要:

使用 Alpaca 对 LLaMA-2-7B 进行标准微调,在 AlpacaEval 上达到 29.79% 的准确率,而使用 noisy embeddings(噪声嵌入)则升至 64.69%。NEFTune 还在现代指令数据集上优于强大的基线模型。使用 Evol-Instruct 训练的模型性能提升 10%,使用 ShareGPT 提升 8%,使用 OpenPlatypus 提升 8%。即使是使用 RLHF 进一步改进的强大模型(如 LLaMA-2-Chat)也受益于使用 NEFTune 进行的额外训练。

要在 SFTTrainer 中使用它,只需在创建 SFTConfig 实例时传递 neftune_noise_alpha。请注意,为避免任何意外行为,NEFTune 在训练后被禁用,以恢复嵌入层的原始行为。

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

dataset = load_dataset("stanfordnlp/imdb", split="train")

training_args = SFTConfig(
    neftune_noise_alpha=5,
)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    args=training_args,
)
trainer.train()

我们通过在 OpenAssistant 数据集上训练 mistralai/Mistral-7B-v0.1 测试了 NEFTune,并验证了使用 NEFTune 可以使 MT Bench 的性能提升约 25%。

但请注意,性能提升的幅度取决于数据集,特别是,在 UltraChat 等合成数据集上应用 NEFTune 通常产生的收益较小。

使用 unsloth 加速微调 2 倍

你可以使用 unsloth 库进一步加速 QLoRA / LoRA(速度快 2 倍,内存减少 60%),该库与 SFTTrainer 完全兼容。目前,unsloth 仅支持 Llama(Yi、TinyLlama、Qwen、Deepseek 等)和 Mistral 架构。下面列出了一些在 1x A100 上的基准测试:

1 A100 40GB 数据集 🤗 🤗 + Flash Attention 2 🦥 Unsloth 🦥 VRAM 节省
Code Llama 34b Slim Orca 1x 1.01x 1.94x -22.7%
Llama-2 7b Slim Orca 1x 0.96x 1.87x -39.3%
Mistral 7b Slim Orca 1x 1.17x 1.88x -65.9%
Tiny Llama 1.1b Alpaca 1x 1.55x 2.74x -57.8%

首先根据官方文档安装 unsloth。安装完成后,您可以非常简单地将 unsloth 融入您的工作流程中;只需加载 FastLanguageModel,而不是加载 AutoModelForCausalLM,如下所示

import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel

max_length = 2048 # Supports automatic RoPE Scaling, so choose any number

# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/mistral-7b",
    max_seq_length=max_length,
    dtype=None,  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit=True,  # Use 4bit quantization to reduce memory usage. Can be False
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,  # Dropout = 0 is currently optimized
    bias="none",  # Bias = "none" is currently optimized
    use_gradient_checkpointing=True,
    random_state=3407,
)

training_args = SFTConfig(output_dir="./output", max_length=max_length)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

保存的模型与 Hugging Face 的 transformers 库完全兼容。请在其官方仓库中了解更多关于 unsloth 的信息。

Liger-Kernel:将多 GPU 训练吞吐量提高 20%,内存减少 60%

Liger Kernel 是一系列专为 LLM 训练设计的 Triton 内核的集合。它可以有效地将多 GPU 训练吞吐量提高 20%,并将内存使用量减少 60%。这样,我们可以将上下文长度增加 4 倍,如下面的基准测试所述。他们已经实现了 Hugging Face 兼容的 RMSNormRoPESwiGLUCrossEntropyFusedLinearCrossEntropy 等,并且未来还会添加更多。该内核可以与 Flash AttentionPyTorch FSDPMicrosoft DeepSpeed 开箱即用。

通过大幅减少内存使用,您可以潜在地关闭 cpu_offloading 或梯度检查点,以进一步提升性能。

加速 内存减少
Speed up Memory
  1. 要在 SFTTrainer 中使用 Liger-Kernel,请先通过以下方式安装
pip install liger-kernel
  1. 安装完成后,在 SFTConfig 中设置 use_liger_kernel。无需其他更改!
training_args = SFTConfig(
  use_liger_kernel=True
)

要了解更多关于 Liger-Kernel 的信息,请访问其官方仓库

最佳实践

当使用该 trainer 训练模型时,请注意以下最佳实践

  • SFTTrainer 默认情况下始终将序列截断为 SFTConfigmax_length 参数。如果没有传递任何值,trainer 将从 tokenizer 中检索该值。某些 tokenizer 不提供默认值,因此会进行检查以检索 1024 和该值之间的最小值。请务必在训练前检查它。
  • 对于以 8 位训练 adapters,您可能需要调整 PEFT 中 prepare_model_for_kbit_training 方法的参数,因此我们建议用户使用 prepare_in_int8_kwargs 字段,或者在 SFTTrainer 外部创建 PeftModel 并传递它。
  • 为了使用 adapters 进行更节省内存的训练,您可以加载 8 位的基础模型,为此,只需在创建 SFTTrainer 时添加 load_in_8bit 参数,或者在 trainer 外部创建 8 位的基本模型并传递它。
  • 如果您在 trainer 外部创建模型,请确保不要向 trainer 传递任何与 from_pretrained() 方法相关的额外关键字参数。

多 GPU 训练

Trainer(以及 SFTTrainer)支持多 GPU 训练。如果您使用 python script.py 运行脚本,它将默认使用 DP 作为策略,这可能比预期慢。要使用 DDP(通常建议使用,有关更多信息,请参阅此处),您必须使用 python -m torch.distributed.launch script.pyaccelerate launch script.py 启动脚本。为了使 DDP 工作,您还必须检查以下内容

  • 如果您正在使用 gradient_checkpointing,请将以下内容添加到 TrainingArguments 中:gradient_checkpointing_kwargs={'use_reentrant':False}(更多信息请参阅此处
  • 确保模型放置在正确的设备上
from accelerate import PartialState
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
     ...
    device_map={'':device_string}
)

GPTQ 转换

在完成训练后,您可能会遇到一些 GPTQ 量化问题。将 gradient_accumulation_steps 降低到 4 将解决量化为 GPTQ 格式过程中的大多数问题。

为视觉语言模型扩展 SFTTrainer

SFTTrainer 本身不支持视觉语言数据。但是,我们提供了一个指南,说明如何调整 trainer 以支持视觉语言数据。具体来说,您需要使用与视觉语言数据兼容的自定义数据 collator。本指南概述了进行这些调整的步骤。有关具体示例,请参阅脚本 examples/scripts/sft_vlm.py,该脚本演示了如何在 HuggingFaceH4/llava-instruct-mix-vsft 数据集上微调 LLaVA 1.5 模型。

准备数据

数据格式是灵活的,只要它与我们稍后定义的自定义 collator 兼容即可。一种常见的方法是使用对话数据。鉴于数据包括文本和图像,因此需要相应地调整格式。以下是涉及文本和图像的对话数据格式示例

images = ["obama.png"]
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Who is this?"},
            {"type": "image"}
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "Barack Obama"}
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What is he famous for?"}
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "He is the 44th President of the United States."}
        ]
    }
]

为了说明如何使用 LLaVA 模型处理此数据格式,您可以使用以下代码

from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))

输出格式如下

Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States. 

用于处理多模态数据的自定义 collator

SFTTrainer 的默认行为不同,多模态数据的处理是在数据 collation 过程中即时完成的。为此,您需要定义一个自定义 collator,该 collator 处理文本和图像。此 collator 必须将示例列表作为输入(有关数据格式的示例,请参见上一节),并返回一批已处理的数据。以下是此类 collator 的示例

def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"][0] for example in examples]

    # Tokenize the texts and process the images
    batch = processor(texts, images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch["labels"] = labels

    return batch

我们可以通过运行以下代码来验证 collator 是否按预期工作

from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]]  # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys())  # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])

训练视觉语言模型

现在我们已经准备好了数据并定义了 collator,我们可以继续训练模型。为了确保数据不被视为纯文本处理,我们需要在 SFTConfig 中设置几个参数,特别是将 remove_unused_columnsskip_prepare_dataset 设置为 True,以避免对数据集进行默认处理。以下是如何设置 SFTTrainer 的示例。

training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_dataset,
    processing_class=processor.tokenizer,
)

HuggingFaceH4/llava-instruct-mix-vsft 数据集上训练 LLaVa 1.5 的完整示例可以在脚本 examples/scripts/sft_vlm.py 中找到。

SFTTrainer

class trl.SFTTrainer

< >

( model: typing.Union[str, torch.nn.modules.module.Module, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.sft_config.SFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None compute_loss_func: typing.Optional[typing.Callable] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) optimizer_cls_and_kwargs: typing.Optional[tuple[typing.Type[torch.optim.optimizer.Optimizer], dict[str, typing.Any]]] = None preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None formatting_func: typing.Union[typing.Callable[[dict], str], typing.Callable[[dict], list[str]], NoneType] = None )

参数

  • model (Union[str, PreTrainedModel]) — 要训练的模型。可以是:

    • 一个字符串,即托管在 huggingface.co 模型仓库中的预训练模型的模型 ID,或者是包含使用 save_pretrained 保存的模型权重的目录的路径,例如 './my_model_directory/'。该模型使用 from_pretrainedargs.model_init_kwargs 中的关键字参数加载。
    • 一个 PreTrainedModel 对象。仅支持因果语言模型。
  • args (SFTConfig可选,默认为 None)— 此 trainer 的配置。如果为 None,则使用默认配置。
  • data_collator (DataCollator可选) — 用于从已处理的 train_dataseteval_dataset 的元素列表中形成批次的函数。如果没有提供 processing_class,则默认为 default_data_collator;如果 processing_class 是特征提取器或 tokenizer,则默认为 DataCollatorWithPadding 的实例。
  • train_dataset (DatasetIterableDataset) — 用于训练的数据集。SFT 支持语言建模类型和提示完成类型。样本的格式可以是:

    • 标准:每个样本都包含纯文本。
    • 对话式:每个样本都包含结构化消息(例如,角色和内容)。

    trainer 也支持已处理的数据集(已分词),只要它们包含 input_ids 字段。

  • eval_dataset (DatasetIterableDatasetdict[str, Union[Dataset, IterableDataset]]) — 用于评估的数据集。它必须满足与 train_dataset 相同的要求。
  • processing_class (PreTrainedTokenizerBase可选,默认为 None)— 用于处理数据的处理类。如果为 None,则处理类从模型的名称加载,使用 from_pretrained
  • callbacks (TrainerCallback 列表,可选,默认为 None)— 用于自定义训练循环的回调列表。将这些添加到此处详细介绍的默认回调列表中。

    如果您想删除使用的默认回调之一,请使用 remove_callback 方法。

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]可选,默认为 (None, None))— 包含要使用的优化器和调度器的元组。将默认为模型上的 AdamW 实例和由 args 控制的 get_linear_schedule_with_warmup 给出的调度器。
  • optimizer_cls_and_kwargs (Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]可选,默认为 None)— 包含要使用的优化器类和关键字参数的元组。覆盖 args 中的 optimoptim_args。与 optimizers 参数不兼容。

    optimizers 不同,此参数避免了在初始化 Trainer 之前将模型参数放置在正确设备上的需要。

  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 可选, 默认为 None) — 在每个评估步骤缓存 logits 之前预处理 logits 的函数。必须接受两个 tensors,即 logits 和 labels,并返回处理后的 logits。此函数所做的修改将反映在 compute_metrics 收到的预测中。

    请注意,如果数据集没有 labels,则 labels(第二个参数)将为 None

  • peft_config (~peft.PeftConfig, 可选, 默认为 None) — 用于包装模型的 PEFT 配置。如果为 None,则模型不会被包装。
  • formatting_func (Optional[Callable]) — 应用于数据集的格式化函数,在分词之前使用。

用于监督式微调 (SFT) 方法的 Trainer。

此类是 transformers.Trainer 类的包装器,并继承了它的所有属性和方法。

示例

from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
trainer.train()

compute_loss

< >

( model inputs return_outputs = False num_items_in_batch = None )

计算训练损失,并额外计算 token 准确率

create_model_card

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

参数

  • model_name (strNone, 可选, 默认为 None) — 模型名称。
  • dataset_name (strNone, 可选, 默认为 None) — 用于训练的数据集名称。
  • tags (str, list[str]None, 可选, 默认为 None) — 要与模型卡关联的标签。

使用 Trainer 可用的信息创建模型卡的草稿。

SFTConfig

class trl.SFTConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 2e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None tp_size: typing.Optional[int] = 0 fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_text_field: str = 'text' dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None pad_token: typing.Optional[str] = None max_length: typing.Optional[int] = 1024 packing: bool = False padding_free: bool = False eval_packing: typing.Optional[bool] = None dataset_batch_size: typing.Optional[int] = None num_of_sequences: typing.Optional[int] = None chars_per_token: typing.Optional[float] = None max_seq_length: typing.Optional[int] = None use_liger: typing.Optional[bool] = None )

控制模型的参数

  • model_init_kwargs (dict[str, Any]None, 可选, 默认为 None) — 用于 from_pretrained 的关键字参数,当 SFTTrainermodel 参数以字符串形式提供时使用。

控制数据预处理的参数

  • dataset_text_field (str, 可选, 默认为 "text") — 数据集中包含文本数据的列名。
  • dataset_kwargs (dict[str, Any]None, 可选, 默认为 None) — 用于数据集准备的可选关键字参数字典。唯一支持的键是 skip_prepare_dataset
  • dataset_num_proc (intNone, 可选, 默认为 None) — 用于处理数据集的进程数。
  • pad_token (strNone, 可选, 默认为 None) — 用于填充的 Token。 如果为 None,则默认为 processing_class.pad_token,如果也为 None,则回退到 processing_class.eos_token
  • max_length (intNone, 可选, 默认为 1024) — 分词序列的最大长度。 超过 max_length 的序列将从右侧截断。 如果为 None,则不应用截断。 启用 packing 时,此值设置序列长度。
  • packing (bool, 可选, 默认为 False) — 是否将多个序列打包成固定长度格式。 使用 max_length 定义序列长度。
  • padding_free (bool, 可选, 默认为 False) — 是否执行无填充的前向传播,方法是将批次中的所有序列展平为单个连续序列。这通过消除填充开销来减少内存使用。目前,这仅在使用 flash_attention_2 注意力实现时受支持,该实现可以有效地处理展平的批次结构。
  • eval_packing (boolNone, 可选, 默认为 None) — 是否打包评估数据集。如果为 None, 则使用与 packing 相同的值。

控制训练的参数

  • learning_rate (float, 可选, 默认为 2e-5) — AdamW 优化器的初始学习率。默认值替换了 TrainingArguments 的默认值。

用于 SFTTrainer 的配置类。

此处仅列出了特定于 SFT 训练的参数。有关其他参数的详细信息,请参阅 TrainingArguments 文档。

使用 HfArgumentParser,我们可以将此类转换为可以在命令行上指定的 argparse 参数。

数据集

在 SFTTrainer 中,除了其他样式的数据集外,我们还巧妙地支持 datasets.IterableDataset。如果您正在使用不想全部保存到磁盘的大型语料库,这将非常有用。即使启用打包,数据也将在运行时被标记化和处理。

此外,在 SFTTrainer 中,如果预标记化的数据集是 datasets.Datasetdatasets.IterableDataset,我们也支持它们。换句话说,如果这样的数据集有一列 input_ids,则不会进行进一步的处理(标记化或打包),并且数据集将按原样使用。如果您已在此脚本之外预标记化了数据集并希望直接重用它,这将非常有用。

ConstantLengthDataset

class trl.trainer.ConstantLengthDataset

< >

( tokenizer dataset dataset_text_field = None formatting_func = None infinite = False seq_length = 1024 num_of_sequences = 1024 chars_per_token = 3.6 eos_token_id = 0 shuffle = True append_concat_token = True add_special_tokens = True )

参数

  • tokenizer (transformers.PreTrainedTokenizer) — 用于处理数据的处理器。
  • dataset (dataset.Dataset) — 包含文本文件的数据集。
  • dataset_text_field (strNone, 可选, 默认为 None) — 数据集中包含文本的字段名称。dataset_text_fieldformatting_func 只能提供一个。
  • formatting_func (Callable, 可选) — 在标记化之前格式化文本的函数。通常建议遵循某种模式,例如 "### Question: {question} ### Answer: {answer}"dataset_text_fieldformatting_func 只能提供一个。
  • infinite (bool, 可选, 默认为 False) — 如果为 True,则在数据集到达末尾后重置迭代器,否则停止。
  • seq_length (int, 可选, 默认为 1024) — 要返回的 token 序列的长度。
  • num_of_sequences (int, 可选, 默认为 1024) — 要在缓冲区中保留的 token 序列的数量。
  • chars_per_token (int, 可选, 默认为 3.6) — 每个 token 的字符数,用于估计文本缓冲区中的 token 数量。
  • eos_token_id (int, 可选, 默认为 0) — 如果传递的 tokenizer 没有 EOS token,则为序列结束 token 的 ID。
  • shuffle (bool, 可选, 默认为 True) — 在返回示例之前对其进行洗牌
  • append_concat_token (bool, 可选, 默认为 True) — 如果为 true,则在每个正在打包的样本末尾追加 eos_token_id
  • add_special_tokens (bool, 可选, 默认为 True) — 如果为 true,tokenizer 会将特殊 token 添加到每个正在打包的样本中。

可迭代数据集,从文本文件流中返回恒定长度的 token 块。该数据集还在标记化之前使用用户提供的特定格式格式化文本。

< > 在 GitHub 上更新