Transformers 文档

微调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

微调

微调通过较小的专业数据集调整预训练模型以适应特定任务。与从头开始训练模型相比,这种方法所需的数据和计算量要少得多,因此对于许多用户来说,它是一种更易于访问的选项。

Transformers 提供了 Trainer API,它提供了一套全面的训练功能,用于微调 Hub 上的任何模型。

在我们的“资源”部分的“任务秘籍”中了解如何微调其他任务的模型!

本指南将向您展示如何使用 Trainer 微调模型以对 Yelp 评论进行分类。

使用您的用户令牌登录您的 Hugging Face 帐户,以确保您可以访问受控模型并在 Hub 上共享您的模型。

from huggingface_hub import login

login()

首先加载 Yelp Reviews 数据集并对其进行预处理(标记化、填充和截断)以进行训练。使用 map 一步预处理整个数据集。

from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

def tokenize(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

dataset = dataset.map(tokenize, batched=True)

在完整数据集的较小子集上进行微调,以减少所需时间。与在完整数据集上进行微调相比,结果不会那么好,但在提交完整数据集训练之前,先确保一切正常工作很有用。

small_train = dataset["train"].shuffle(seed=42).select(range(1000))
small_eval = dataset["test"].shuffle(seed=42).select(range(1000))

Trainer

Trainer 是一个经过优化的 Transformers 模型训练循环,无需手动编写训练代码即可轻松开始训练。在 TrainingArguments 中选择各种训练功能,例如梯度累积、混合精度以及报告和记录训练指标的选项。

加载模型并提供预期的标签数量(您可以在 Yelp Review 数据集卡 上找到此信息)。

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']"
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."

以上消息提醒,模型预训练的头部被丢弃,并替换为随机初始化的分类头部。随机初始化的头部需要针对您的特定任务进行微调,才能输出有意义的预测。

模型加载完成后,在 TrainingArguments 中设置训练超参数。超参数是控制训练过程的变量——例如学习率、批次大小、时期数——这反过来会影响模型性能。选择正确的超参数非常重要,您应该进行试验以找到适合您任务的最佳配置。

对于本指南,您可以使用默认超参数,它们提供了一个良好的基线。本指南中唯一要配置的设置是保存检查点的位置、如何在训练期间评估模型性能以及将模型推送到 Hub。

Trainer 需要一个函数来计算和报告您的指标。对于分类任务,您将使用 evaluate.loadEvaluate 库中加载 accuracy 函数。在 compute 中收集预测和标签以计算准确率。

import numpy as np
import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # convert the logits to their predicted class
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

设置 TrainingArguments,指定保存模型的位置以及何时在训练期间计算准确率。下面的示例将其设置为 "epoch",这意味着在每个 epoch 结束时报告准确率。添加 push_to_hub=True 以在训练后将模型上传到 Hub。

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="yelp_review_classifier",
    eval_strategy="epoch",
    push_to_hub=True,
)

创建一个 Trainer 实例,并向其传递模型、训练参数、训练和测试数据集以及评估函数。调用 train() 开始训练。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,
)
trainer.train()

最后,使用 push_to_hub() 将您的模型和分词器上传到 Hub。

trainer.push_to_hub()

TensorFlow

Trainer 与 Transformers TensorFlow 模型不兼容。相反,由于这些模型是作为标准 tf.keras.Model 实现的,因此可以使用 Keras 对它们进行微调。

from transformers import TFAutoModelForSequenceClassification
from datasets import load_dataset
from transformers import AutoTokenizer

model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

def tokenize(examples):
    return tokenizer(examples["text"])

dataset = dataset.map(tokenize)

有两种方法可以将数据集转换为 tf.data.Dataset

  • prepare_tf_dataset() 是创建 tf.data.Dataset 的推荐方法,因为您可以检查模型以确定使用哪些列作为输入以及丢弃哪些列。这使您可以创建更简单、性能更高的数据集。
  • to_tf_datasetDatasets 库中更底层的函数,它通过指定要使用的列和标签列,让您更精细地控制数据集的创建方式。

将分词器添加到 prepare_tf_dataset() 以填充每个批次,您还可以选择随机打乱数据集。对于更复杂的预处理,可以将预处理函数传递给 collate_fn 参数。

tf_dataset = model.prepare_tf_dataset(
    dataset["train"], batch_size=16, shuffle=True, tokenizer=tokenizer
)

最后,编译拟合模型以开始训练。

没必要向 compile 传递损失参数,因为 Transformers 会自动选择适合任务和架构的损失。但是,如果您愿意,可以随时指定损失参数。

from tensorflow.keras.optimizers import Adam

model.compile(optimizer=Adam(3e-5))
model.fit(tf_dataset)

资源

有关各种任务的更详细训练脚本,请参阅 Transformers 示例。您还可以查看笔记本以获取交互式示例。

< > 在 GitHub 上更新