Transformers 文档

Keras 回调函数

Hugging Face's logo
加入Hugging Face社区

并获得增强的文档体验

开始使用

Keras 回调函数

使用 Keras 训练 Transformers 模型时,有一些库特定的回调函数可用于自动化常见任务。

KerasMetricCallback

transformers.KerasMetricCallback

< >

( metric_fn: Callable eval_dataset: Union output_cols: Optional = None label_cols: Optional = None batch_size: Optional = None predict_with_generate: bool = False use_xla_generation: bool = False generate_kwargs: Optional = None )

参数

PushToHubCallback

transformers.PushToHubCallback

< >

( output_dir: Union save_strategy: Union = 'epoch' save_steps: Optional = None tokenizer: Optional = None hub_model_id: Optional = None hub_token: Optional = None checkpoint: bool = False **model_card_args )

参数

  • output_dir (str) — 模型预测和检查点将被写入并与 Hub 上的存储库同步的输出目录。
  • save_strategy (strIntervalStrategy, 可选,默认为 "epoch") — 训练期间采用的检查点保存策略。可能的值为:

    • "no":在训练结束时保存。
    • "epoch":在每个 epoch 结束时保存。
    • "steps":每隔 save_steps 保存。
  • save_steps (int, 可选) — 使用“steps” save_strategy 时,保存之间的步数。
  • tokenizer (PreTrainedTokenizerBase, 可选) — 模型使用的分词器。如果提供,将与权重一起上传到存储库。
  • hub_model_id (str, 可选) — 用于与本地 output_dir 保持同步的存储库的名称。它可以是一个简单的模型 ID,在这种情况下,模型将推送到您的命名空间。否则,它应该是整个存储库名称,例如 "user_name/model",这允许您使用 "organization_name/model" 推送到您是成员的组织。

    默认为 output_dir 的名称。

  • hub_token (str, 可选) — 用于将模型推送到 Hub 的令牌。默认为 huggingface-cli login 获取的缓存文件夹中的令牌。
  • checkpoint (bool, 可选,默认为 False) — 是否保存完整的训练检查点(包括 epoch 和优化器状态)以允许恢复训练。仅当 save_strategy"epoch" 时可用。

此回调函数会定期保存模型并将其推送到 Hub。默认情况下,它会在每个 epoch 推送一次,但可以通过 save_strategy 参数更改。推送的模型可以像 Hub 上的任何其他模型一样访问,例如使用 from_pretrained 方法。

from transformers.keras_callbacks import PushToHubCallback

push_to_hub_callback = PushToHubCallback(
    output_dir="./model_save",
    tokenizer=tokenizer,
    hub_model_id="gpt5-7xlarge",
)

model.fit(train_dataset, callbacks=[push_to_hub_callback])
< > GitHub 更新