Transformers 文档

超参数搜索

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

超参数搜索

超参数搜索旨在发现能够产生最佳模型性能的超参数集。Trainer 通过 hyperparameter_search() 支持多种超参数搜索后端,包括 OptunaSigOptWeights & BiasesRay Tune,以优化单个目标或多个目标。

本指南将介绍如何为每个后端设置超参数搜索。

[!WARNING][SigOpt](https://github.com/sigopt/sigopt-server) 已处于公共归档模式,不再积极维护。请尝试使用 Optuna、Weights & Biases 或 Ray Tune 代替。

pip install optuna/sigopt/wandb/ray[tune]

要使用 hyperparameter_search(),您需要创建一个 model_init 函数。此函数包含基本的模型信息(参数和配置),因为在每次搜索试验运行中都需要重新初始化。

model_init 函数与 optimizers 参数不兼容。请继承 Trainer 并重写 create_optimizer_and_scheduler() 方法来创建自定义优化器和调度器。

以下是一个 model_init 函数的示例。

def model_init(trial):
    return AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=True if model_args.use_auth_token else None,
    )

model_init 连同训练所需的其他所有内容传递给 Trainer。然后,您可以调用 hyperparameter_search() 来开始搜索。

hyperparameter_search() 接受一个 direction 参数,用于指定是最小化、最大化还是同时最小化和最大化多个目标。您还需要设置正在使用的 后端,一个包含要优化的超参数的 对象,要运行的 试验次数,以及一个用于返回目标值的 compute_objective

如果未定义 compute_objective,则会调用默认的 compute_objective,它是评估指标(如 F1)的总和。

from transformers import Trainer

trainer = Trainer(
    model=None,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
    processing_class=tokenizer,
    model_init=model_init,
    data_collator=data_collator,
)
trainer.hyperparameter_search(...)

以下示例演示了如何使用不同的后端对学习率和训练批次大小执行超参数搜索。

Optuna
Ray Tune
SigOpt
Weights & Biases

Optuna 优化类别、整数和浮点数。

def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64, 128]),
    }

best_trials = trainer.hyperparameter_search(
    direction=["minimize", "maximize"],
    backend="optuna",
    hp_space=optuna_hp_space,
    n_trials=20,
    compute_objective=compute_objective,
)

分布式数据并行

Trainer 仅支持 Optuna 和 SigOpt 后端的分布式数据并行 (DDP) 超参数搜索。只有 rank-zero 进程用于生成搜索试验,结果参数会传递给其他 ranks。

< > 在 GitHub 上更新