迭代式训练器
迭代式微调是一种训练方法,它能够在优化步骤之间执行自定义操作(例如生成和过滤)。在 TRL 中,我们提供了一个易于使用的 API,只需几行代码即可以迭代方式微调您的模型。
用法
要快速入门,请实例化一个模型实例和一个分词器。
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = IterativeSFTTrainer(
model,
tokenizer
)
您可以选择向 step 函数提供字符串列表或张量列表。
使用张量列表作为输入:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
trainer.step(**inputs)
使用字符串列表作为输入:
inputs = {
"texts": texts
}
trainer.step(**inputs)
对于因果语言模型,标签将自动从 input_ids 或文本创建。当使用序列到序列模型时,您需要提供自己的标签或 text_labels。
IterativeTrainer
class trl.IterativeSFTTrainer
< source >( model: Optional = None args: Optional = None tokenizer: Optional = None optimizers: Tuple = (None, None) data_collator: Optional = None eval_dataset: Union = None max_length: Optional = None truncation_mode: Optional = 'keep_end' preprocess_logits_for_metrics: Optional = None compute_metrics: Optional = None optimize_device_cache: Optional = False )
参数
- model (
PreTrainedModel
) — 要优化的模型,可以是 ‘AutoModelForCausalLM’ 或 ‘AutoModelForSeq2SeqLM’。有关更多详细信息,请查看PreTrainedModel
的文档。 - args (
transformers.TrainingArguments
) — 用于训练的参数。 - tokenizer (
PreTrainedTokenizerBase
) — 用于对数据进行编码的令牌化器。有关更多详细信息,请查看transformers.PreTrainedTokenizer
和transformers.PreTrainedTokenizerFast
的文档。 - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用于训练的优化器和调度器。 - data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], 可选) — 用于训练并传递给数据加载器的 数据整理器。
- eval_dataset (
datasets.Dataset
) — 用于评估的数据集。 - max_length (
int
, 默认值为None
) — 输入的最大长度。 - compute_metrics (
Callable[[EvalPrediction], Dict]
, optional) — 用于计算指标的函数。 必须接收一个EvalPrediction
并返回一个字典,其中包含指标名称到值的映射。 - optimize_device_cache (
bool
, optional, defaults toFalse
) — 优化 CUDA 缓存,以实现稍微更节省内存的训练。
IterativeSFTTrainer 可用于使用在优化之间需要一些步骤的方法微调模型。
step
< source > ( input_ids: Optional = None attention_mask: Optional = None labels: Optional = None texts: Optional = None texts_labels: Optional = None ) → dict[str, Any]
参数
- input_ids (List
torch.LongTensor
) — 包含 input_ids 的张量列表(如果未提供,将使用文本) - attention_mask (List
torch.LongTensor
, , optional) — 包含 attention_mask 的张量列表 - labels (List
torch.FloatTensor
, optional) — 包含标签的张量列表(如果设置为 None,将默认为 input_ids) - texts (List
str
, optional) — 包含文本输入的字符串列表(如果未提供,将直接使用 input_ids) - texts_labels (List
str
, optional) — 包含文本标签的字符串列表(如果设置为 None,将默认为文本)
返回
dict[str, Any]
训练统计数据的摘要
根据 input_ids、attention_mask 和 labels 列表,或根据文本和文本标签列表,运行优化步骤。