Optimum 文档
训练器
并获得增强的文档体验
开始使用
训练器
ORTTrainer
class optimum.onnxruntime.ORTTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None args: ORTTrainingArguments = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
参数
- model (PreTrainedModel 或
torch.nn.Module
, 可选) — 要训练、评估或用于预测的模型。如果未提供,则必须传递model_init
。ORTTrainer
经过优化,可以与 transformers 库提供的 PreTrainedModel 一起使用。您仍然可以使用您自己定义的torch.nn.Module
模型,通过 ONNX Runtime 后端进行训练,并通过 PyTorch 后端进行推理,只要它们的工作方式与 🤗 Transformers 模型相同即可。 - args (
ORTTrainingArguments
, 可选) — 用于调整训练的参数。如果未提供,将默认为ORTTrainingArguments
的基本实例,并将output_dir
设置为当前目录中名为 tmp_trainer 的目录。 - data_collator (
DataCollator
, 可选) — 用于从train_dataset
或eval_dataset
的元素列表中形成批次的函数。如果未提供tokenizer
,将默认为 default_data_collator,否则为 DataCollatorWithPadding 的实例。 - train_dataset (
torch.utils.data.Dataset
或torch.utils.data.IterableDataset
, 可选) — 用于训练的数据集。如果它是 Dataset,则会自动删除model.forward()
方法不接受的列。请注意,如果它是具有某些随机化的torch.utils.data.IterableDataset
,并且您以分布式方式进行训练,则您的可迭代数据集应使用内部属性generator
,该属性是一个torch.Generator
,用于在所有进程上必须相同的随机化(并且 ORTTrainer 将在每个 epoch 手动设置此generator
的种子),或者具有一个set_epoch()
方法,该方法在内部设置使用的 RNG 的种子。 - eval_dataset (Union[
torch.utils.data.Dataset
, Dict[str,torch.utils.data.Dataset
]), 可选) — 用于评估的数据集。如果它是 Dataset,则会自动删除model.forward()
方法不接受的列。如果它是一个字典,它将在每个数据集上进行评估,并将字典键添加到指标名称的前面。 - tokenizer (PreTrainedTokenizerBase, 可选) — 用于预处理数据的分词器。如果提供,将用于在批处理输入时自动将输入填充到最大长度,并且它将与模型一起保存,以便更轻松地重新运行中断的训练或重用微调模型。
- model_init (
Callable[[], PreTrainedModel]
, 可选) — 一个实例化要使用的模型的功能。如果提供,每次调用ORTTrainer.train
都将从该函数给出的模型的新实例开始。该函数可能没有参数,或者只有一个包含 optuna/Ray Tune/SigOpt 试验对象的参数,以便能够根据超参数(例如层数、内层大小、dropout 概率等)选择不同的架构。 - compute_metrics (
Callable[[EvalPrediction], Dict]
, 可选) — 将用于计算评估指标的函数。必须接受EvalPrediction
并返回从字符串到指标值的字典。 - callbacks (TrainerCallback 列表, 可选) — 用于自定义训练循环的回调列表。将这些添加到此处详述的默认回调列表中。如果要删除使用的默认回调之一,请使用
ORTTrainer.remove_callback
方法。 - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
, 可选) — 包含要使用的优化器和调度器的元组。将默认为模型上的AdamW
实例和由args
控制的get_linear_schedule_with_warmup
给出的调度器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
, 可选) — 一个在每次评估步骤中预处理 logits 的函数,就在缓存它们之前。必须接受两个张量,logits 和 labels,并返回处理后的 logits,如预期的那样。此函数所做的修改将反映在compute_metrics
收到的预测中。请注意,如果数据集没有标签(第二个参数),则 labels(第二个参数)将为None
。
ORTTrainer 是一个简单但功能完整的 ONNX Runtime 训练和评估循环,针对 🤗 Transformers 进行了优化。
重要属性
- model — 始终指向核心模型。如果使用 transformers 模型,它将是 PreTrainedModel 子类。
- model_wrapped — 在存在一个或多个其他模块包装原始模型的情况下,始终指向最外部的模型。这是应该用于前向传递的模型。例如,在
DeepSpeed
下,内部模型首先包装在ORTModule
中,然后包装在DeepSpeed
中,然后再包装在torch.nn.DistributedDataParallel
中。如果内部模型未被包装,则self.model_wrapped
与self.model
相同。 - is_model_parallel — 模型是否已切换到模型并行模式(不同于数据并行,这意味着某些模型层在不同的 GPU 上拆分)。
- place_model_on_device — 是否自动将模型放置在设备上 - 如果使用模型并行或 deepspeed,或者如果默认的
ORTTrainingArguments.place_model_on_device
被覆盖以返回False
,则将其设置为False
。 - is_in_train — 模型当前是否正在运行
train
(例如,当在train
中调用evaluate
时)
设置优化器。
我们提供了一个合理的默认值,效果很好。如果您想使用其他值,可以通过 optimizers
在 ORTTrainer 的 init 中传递一个元组,或者在子类中子类化并覆盖此方法。
get_ort_optimizer_cls_and_kwargs
< source >( args: ORTTrainingArguments )
根据 ORTTrainingArguments
返回 ONNX Runtime 中实现的优化器类和优化器参数。
train
< source >( resume_from_checkpoint: typing.Union[bool, str, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )
参数
- resume_from_checkpoint (
str
或bool
, 可选) — 如果是str
,则为本地路径,指向先前ORTTrainer
实例保存的已保存检查点。如果为bool
且等于True
,则加载先前ORTTrainer
实例保存在 args.output_dir 中的最后一个检查点。如果存在,训练将从此处加载的模型/优化器/调度器状态恢复。 - trial (
optuna.Trial
或Dict[str, Any]
, 可选) — 试运行 (trial run) 或用于超参数搜索的超参数字典。 - ignore_keys_for_eval (
List[str]
, 可选) — 模型输出(如果输出是字典)中在训练期间评估指标时不应考虑的键列表。 - kwargs (
Dict[str, Any]
, 可选) — 用于隐藏已弃用参数的附加关键字参数。
使用 ONNX Runtime 加速器进行训练的主要入口点。
ORTSeq2SeqTrainer
class optimum.onnxruntime.ORTSeq2SeqTrainer
< 源代码 >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None args: ORTTrainingArguments = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
evaluate
< 源代码 >( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' **gen_kwargs )
参数
- eval_dataset (
Dataset
, 可选) — 如果您希望覆盖self.eval_dataset
,请传递数据集。 如果它是 Dataset,则会自动删除model.forward()
方法不接受的列。 它必须实现__len__
方法。 - ignore_keys (
List[str]
, 可选) — 模型输出(如果输出是字典)中在收集预测结果时应忽略的键列表。 - metric_key_prefix (
str
, 可选, 默认为"eval"
) — 用作指标键前缀的可选前缀。 例如,如果前缀为"eval"
(默认值),则指标 “bleu” 将被命名为 “eval_bleu”。 - max_length (
int
, 可选) — 使用generate
方法进行预测时的最大目标长度。 - num_beams (
int
, 可选) — 使用generate
方法进行预测时将使用的束搜索 (beam search) 的束数量。 1 表示不使用束搜索。 - gen_kwargs — 附加的
generate
特定关键字参数。
运行评估并返回指标。
调用脚本将负责提供计算指标的方法,因为它们是任务相关的(将其传递给 init compute_metrics
参数)。
您还可以子类化并覆盖此方法以注入自定义行为。
predict
< 源代码 >( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' **gen_kwargs )
参数
- test_dataset (
Dataset
) — 用于运行预测的数据集。 如果它是 Dataset,则会自动删除model.forward()
方法不接受的列。 必须实现__len__
方法。 - ignore_keys (
List[str]
, 可选) — 模型输出(如果输出是字典)中在收集预测结果时应忽略的键列表。 - metric_key_prefix (
str
, 可选, 默认为"eval"
) — 用作指标键前缀的可选前缀。 例如,如果前缀为"eval"
(默认值),则指标 “bleu” 将被命名为 “eval_bleu”。 - max_length (
int
, 可选) — 使用generate
方法进行预测时的最大目标长度。 - num_beams (
int
, 可选) — 使用generate
方法进行预测时将使用的束搜索 (beam search) 的束数量。 1 表示不使用束搜索。 - gen_kwargs — 附加的
generate
特定关键字参数。
运行预测并返回预测结果和潜在的指标。
根据数据集和您的使用情况,您的测试数据集可能包含标签。 在这种情况下,此方法还将像 evaluate()
中一样返回指标。
如果您的预测结果或标签具有不同的序列长度(例如,因为您在标记分类任务中执行动态填充),则预测结果将被填充(在右侧)以允许连接成一个数组。 填充索引为 -100。
返回:NamedTuple 一个具名元组,包含以下键
- predictions (
np.ndarray
):test_dataset
上的预测结果。 - label_ids (
np.ndarray
, 可选): 标签(如果数据集包含标签)。 - metrics (
Dict[str, float]
, 可选): 潜在的指标字典(如果数据集包含标签)。
ORTTrainingArguments
class optimum.onnxruntime.ORTTrainingArguments
< 源代码 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: typing.List[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = None push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, typing.List[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' )
ORTSeq2SeqTrainingArguments
class optimum.onnxruntime.ORTSeq2SeqTrainingArguments
< 源代码 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: typing.List[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = None push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, typing.List[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' sortish_sampler: bool = False predict_with_generate: bool = False generation_max_length: typing.Optional[int] = None generation_num_beams: typing.Optional[int] = None generation_config: typing.Union[str, pathlib.Path, transformers.generation.configuration_utils.GenerationConfig, NoneType] = None )