Optimum 文档

训练器

您正在查看的是需要从源码安装。如果您想进行常规 pip 安装,请查看最新稳定版本 (v1.27.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

训练器

ORTTrainer

class optimum.onnxruntime.ORTTrainer

< >

( 模型: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None 参数: ORTTrainingArguments = None 数据收集器: typing.Optional[transformers.data.data_collator.DataCollator] = None 训练数据集: typing.Optional[torch.utils.data.dataset.Dataset] = None 评估数据集: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None 分词器: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None 模型初始化: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None 计算指标: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None 回调: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None 优化器: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) 用于指标的预处理logits: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

参数

  • 模型PreTrainedModeltorch.nn.Module可选)— 用于训练、评估或预测的模型。如果未提供,则必须传入一个 model_init

    ORTTrainer 经过优化,可与 transformers 库提供的 PreTrainedModel 配合使用。只要您自己的模型(定义为 torch.nn.Module)与 🤗 Transformers 模型的工作方式相同,您仍然可以使用它们进行 ONNX Runtime 后端训练和 PyTorch 后端推理。

  • 参数ORTTrainingArguments可选)— 用于训练的调整参数。如果未提供,将默认为 ORTTrainingArguments 的一个基本实例,其中 output_dir 设置为当前目录中名为 tmp_trainer 的目录。
  • 数据收集器DataCollator可选)— 用于从 train_dataseteval_dataset 元素列表中形成批次的功能。如果未提供 tokenizer,将默认为 default_data_collator,否则将默认为 DataCollatorWithPadding 的实例。
  • 训练数据集torch.utils.data.Datasettorch.utils.data.IterableDataset可选)— 用于训练的数据集。如果是 Dataset,则会自动删除 model.forward() 方法不接受的列。请注意,如果它是一个带有随机化的 torch.utils.data.IterableDataset,并且您以分布式方式进行训练,那么您的可迭代数据集应该要么使用一个内部属性 generator(它是一个 torch.Generator),用于所有进程上必须相同的随机化,要么有一个 set_epoch() 方法,该方法在内部设置所使用的 RNG 的种子。
  • 评估数据集(Union[torch.utils.data.Dataset, Dict[str, torch.utils.data.Dataset]],可选)— 用于评估的数据集。如果是 Dataset,则会自动删除 model.forward() 方法不接受的列。如果是一个字典,则会在每个数据集上进行评估,并将字典键作为度量名称的前缀。
  • 分词器PreTrainedTokenizerBase可选)— 用于预处理数据用的分词器。如果提供,它将在批处理输入时自动将输入填充到最大长度,并与模型一起保存,以便更容易地重新运行中断的训练或重用微调后的模型。
  • 模型初始化Callable[[], PreTrainedModel]可选)— 实例化要使用的模型的功能。如果提供,每次调用 ORTTrainer.train 都将从该功能给出的新模型实例开始。该功能可以不带参数,或带有一个参数(包含 optuna/Ray Tune/SigOpt 试验对象),以便能够根据超参数(如层数、内部层大小、dropout 概率等)选择不同的架构。
  • 计算指标Callable[[EvalPrediction], Dict]可选)— 将用于在评估时计算指标的功能。必须接受一个 EvalPrediction 并返回一个从字符串到指标值的字典。
  • 回调(List of TrainerCallback可选)— 用于自定义训练循环的回调列表。这些回调将被添加到此处详述的默认回调列表中。如果您想删除其中一个默认回调,请使用 ORTTrainer.remove_callback 方法。
  • 优化器Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]可选)— 包含要使用的优化器和调度器的元组。将默认为您的模型上的 AdamW 实例和由 args 控制的 get_linear_schedule_with_warmup 给出的调度器。
  • 用于指标的预处理logitsCallable[[torch.Tensor, torch.Tensor], torch.Tensor]可选)— 一个在每个评估步骤缓存 logits 之前立即预处理 logits 的函数。必须接受两个张量,即 logits 和标签,并返回经过处理的 logits。此函数所做的修改将反映在 compute_metrics 收到的预测中。请注意,如果数据集中没有标签,则标签(第二个参数)将为 None

ORTTrainer 是一个简单但功能完备的 ONNX Runtime 训练和评估循环,针对 🤗 Transformers 进行了优化。

重要属性

  • 模型 — 始终指向核心模型。如果使用 transformers 模型,它将是 PreTrainedModel 的子类。
  • 模型包装器 — 如果一个或多个其他模块包装了原始模型,则始终指向最外部的模型。这是应该用于前向传播的模型。例如,在 DeepSpeed 下,内部模型首先被 ORTModule 包装,然后被 DeepSpeed 包装,然后再次被 torch.nn.DistributedDataParallel 包装。如果内部模型尚未包装,则 self.model_wrappedself.model 相同。
  • is_model_parallel — 模型是否已切换到模型并行模式(与数据并行不同,这意味着一些模型层分布在不同的 GPU 上)。
  • place_model_on_device — 是否自动将模型放置在设备上 - 如果使用模型并行或 DeepSpeed,或者如果默认的 ORTTrainingArguments.place_model_on_device 被覆盖为返回 False,则此项将设置为 False
  • is_in_train — 模型当前是否正在运行 train(例如,在 train 中调用 evaluate 时)

创建优化器

< >

( )

设置优化器。

我们提供了一个效果良好的合理默认值。如果您想使用其他优化器,可以通过 optimizers 在 ORTTrainer 的初始化中传递一个元组,或者在子类中重写此方法。

获取 ort_optimizer_cls_and_kwargs

< >

( 参数: ORTTrainingArguments )

参数

  • 参数ORTTrainingArguments)— 训练会话的训练参数。

根据 ORTTrainingArguments 返回在 ONNX Runtime 中实现的优化器类和优化器参数。

训练

< >

( 从检查点恢复: typing.Union[bool, str, NoneType] = None 试用: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None 忽略评估键: typing.Optional[typing.List[str]] = None **kwargs )

参数

  • 从检查点恢复strbool可选)— 如果是 str,则为上一个 ORTTrainer 实例保存的检查点的本地路径。如果是 bool 且等于 True,则加载上一个 ORTTrainer 实例保存在 args.output_dir 中的最后一个检查点。如果存在,训练将从此处加载的模型/优化器/调度器状态恢复。
  • 试用optuna.TrialDict[str, Any]可选)— 用于超参数搜索的试用运行或超参数字典。
  • 忽略评估键List[str]可选)— 模型输出(如果是字典)中应在训练期间收集评估预测时忽略的键列表。
  • kwargsDict[str, Any]可选)— 用于隐藏已弃用参数的附加关键字参数。

使用 ONNX Runtime 加速器进行训练的主要入口点。

ORTSeq2SeqTrainer

class optimum.onnxruntime.ORTSeq2SeqTrainer

< >

( 模型: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None 参数: ORTTrainingArguments = None 数据收集器: typing.Optional[transformers.data.data_collator.DataCollator] = None 训练数据集: typing.Optional[torch.utils.data.dataset.Dataset] = None 评估数据集: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None 分词器: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None 模型初始化: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None 计算指标: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None 回调: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None 优化器: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) 用于指标的预处理logits: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

评估

< >

( 评估数据集: typing.Optional[torch.utils.data.dataset.Dataset] = None 忽略键: typing.Optional[typing.List[str]] = None 指标键前缀: str = 'eval' **gen_kwargs )

参数

  • 评估数据集Dataset可选)— 如果您希望覆盖 self.eval_dataset,请传递一个数据集。如果是 Dataset,则会自动删除 model.forward() 方法不接受的列。它必须实现 __len__ 方法。
  • 忽略键List[str]可选)— 模型输出(如果是字典)中应在收集预测时忽略的键列表。
  • 指标键前缀str可选,默认为 "eval")— 用于作为指标键前缀的可选前缀。例如,如果前缀为 "eval"(默认),则指标“bleu”将命名为“eval_bleu”。
  • 最大长度int可选)— 使用生成方法预测时要使用的最大目标长度。
  • 光束数量int可选)— 使用生成方法预测时将使用的光束搜索的光束数量。1 表示不进行光束搜索。
  • gen_kwargs — 其他特定于 generate 的关键字参数。

运行评估并返回指标。

调用脚本将负责提供一个计算指标的方法,因为它们是依赖于任务的(将其传递给初始化 compute_metrics 参数)。

您还可以通过子类化并覆盖此方法来注入自定义行为。

预测

< >

( 测试数据集: Dataset 忽略键: typing.Optional[typing.List[str]] = None 指标键前缀: str = 'test' **gen_kwargs )

参数

  • 测试数据集Dataset)— 用于运行预测的数据集。如果是 Dataset,则会自动删除 model.forward() 方法不接受的列。必须实现 __len__ 方法。
  • 忽略键List[str]可选)— 模型输出(如果是字典)中应在收集预测时忽略的键列表。
  • 指标键前缀str可选,默认为 "eval")— 用于作为指标键前缀的可选前缀。例如,如果前缀为 "eval"(默认),则指标“bleu”将命名为“eval_bleu”。
  • 最大长度int可选)— 使用生成方法预测时要使用的最大目标长度。
  • 光束数量int可选)— 使用生成方法预测时将使用的光束搜索的光束数量。1 表示不进行光束搜索。
  • gen_kwargs — 其他特定于 generate 的关键字参数。

运行预测并返回预测和潜在指标。

根据数据集和您的用例,您的测试数据集可能包含标签。在这种情况下,此方法也将返回指标,就像在 evaluate() 中一样。

如果您的预测或标签具有不同的序列长度(例如,因为您在令牌分类任务中进行动态填充),则预测将被填充(右侧)以允许连接成一个数组。填充索引为 -100。

返回:NamedTuple 具有以下键的命名元组

  • 预测(np.ndarray):在 test_dataset 上的预测。
  • 标签ID(np.ndarray可选):标签(如果数据集中包含)。
  • 指标(Dict[str, float]可选):潜在的指标字典(如果数据集中包含标签)。

ORTTrainingArguments

class optimum.onnxruntime.ORTTrainingArguments

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' )

参数

  • optim (strtraining_args.ORTOptimizerNamestransformers.training_args.OptimizerNames, 可选, 默认为 "adamw_hf") — 要使用的优化器,包括 Transformers 中的优化器:adamw_hf、adamw_torch、adamw_apex_fused 或 adafactor。以及 ONNX Runtime 实现的优化器:adamw_ort_fused。

ORTSeq2SeqTrainingArguments

class optimum.onnxruntime.ORTSeq2SeqTrainingArguments

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' sortish_sampler: bool = False predict_with_generate: bool = False generation_max_length: typing.Optional[int] = None generation_num_beams: typing.Optional[int] = None generation_config: typing.Union[str, pathlib.Path, transformers.generation.configuration_utils.GenerationConfig, NoneType] = None )

参数

  • optim (strtraining_args.ORTOptimizerNamestransformers.training_args.OptimizerNames, 可选, 默认为 "adamw_hf") — 要使用的优化器,包括 Transformers 中的优化器:adamw_hf、adamw_torch、adamw_apex_fused 或 adafactor。以及 ONNX Runtime 实现的优化器:adamw_ort_fused。
< > 在 GitHub 上更新