Transformers 文档

Trainer 实用工具

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Trainer 实用工具

本页列出了 Trainer 使用的所有实用函数。

其中大部分仅在您研究库中 Trainer 的代码时才有意义。

实用工具

class transformers.EvalPrediction

< >

( predictions: numpy.ndarray | tuple[numpy.ndarray] label_ids: numpy.ndarray | tuple[numpy.ndarray] inputs: numpy.ndarray | tuple[numpy.ndarray] | None = None losses: numpy.ndarray | tuple[numpy.ndarray] | None = None )

参数

  • predictions (np.ndarray) — 模型的预测结果。
  • label_ids (np.ndarray) — 待匹配的目标(标签)。
  • inputs (np.ndarray, 可选) — 传递给模型的基础输入数据。
  • losses (np.ndarray, 可选) — 评估过程中计算的损失值。

评估输出(始终包含标签),用于计算指标。

class transformers.IntervalStrategy

< >

( value names = None module = None qualname = None type = None start = 1 )

一个枚举。

transformers.enable_full_determinism

< >

( seed: int warn_only: bool = False )

用于在分布式训练期间实现可复现行为的辅助函数。关于 pytorch,请参阅 https://pytorch.ac.cn/docs/stable/notes/randomness.html

transformers.set_seed

< >

( seed: int deterministic: bool = False )

参数

  • seed (int) — 要设置的种子。
  • deterministic (bool, 可选, 默认为 False) — 是否在可用时使用确定性算法。可能会减慢训练速度。

实现可复现行为的辅助函数,用于在 randomnumpytorch(如果已安装)中设置种子。

transformers.torch_distributed_zero_first

< >

( local_rank: int )

参数

  • local_rank (int) — 本地进程的排名 (rank)。

装饰器,使分布式训练中的所有进程等待每个 local_master 完成某些操作。

回调内部机制

class transformers.trainer_callback.CallbackHandler

< >

( callbacks model processing_class optimizer lr_scheduler )

按顺序调用回调列表的内部类。

Trainer 参数解析器

class transformers.HfArgumentParser

< >

( dataclass_types: typing.Union[transformers.hf_argparser.DataClassType, collections.abc.Iterable[transformers.hf_argparser.DataClassType], NoneType] = None **kwargs )

参数

  • dataclass_types (DataClassTypeIterable[DataClassType], 可选) — 数据类类型,或数据类类型列表,我们将使用解析后的参数“填充”其实例。
  • kwargs (dict[str, Any], 可选) — 以常规方式传递给 argparse.ArgumentParser()

argparse.ArgumentParser 的这个子类使用数据类上的类型提示来生成参数。

该类旨在与原生 argparse 良好配合。特别是,您可以在初始化后向解析器添加更多(非数据类支持的)参数,并在解析后获得一个额外的命名空间输出。可选:要创建子参数组,请在数据类中使用 _argument_group_name 属性。

parse_args_into_dataclasses

< >

( args = None return_remaining_strings = False look_for_args_file = True args_filename = None args_file_flag = None ) 组成的元组

参数

  • args — 要解析的字符串列表。默认取自 sys.argv。(与 argparse.ArgumentParser 相同)
  • return_remaining_strings — 如果为 true,则同时返回剩余参数字符串的列表。
  • look_for_args_file — 如果为 true,将寻找一个与该进程入口脚本同名但后缀为“.args”的文件,并将其潜在内容附加到命令行参数中。
  • args_filename — 如果不为 None,将使用此文件而不是前一个参数中指定的“.args”文件。
  • args_file_flag — 如果不为 None,将查找命令行参数中以此标志指定的文件。该标志可以指定多次,优先级由顺序决定(最后一个获胜)。

返回

组成的元组:

  • 数据类实例,顺序与传递给初始化的顺序相同。
  • 如果适用,一个额外的命名空间,用于初始化后添加到解析器的更多(非数据类支持的)参数。
  • 剩余参数字符串的潜在列表。(与 argparse.ArgumentParser.parse_known_args 相同)

将命令行参数解析为指定数据类类型的实例。

这依赖于 argparse 的 ArgumentParser.parse_known_args。请参阅文档:docs.python.org/3/library/argparse.html#argparse.ArgumentParser.parse_args

parse_dict

< >

( args: dict allow_extra_keys: bool = False ) 组成的元组

参数

  • args (dict) — 包含配置值的字典。
  • allow_extra_keys (bool, 可选, 默认为 False) — 默认为 False。如果为 False,当字典中包含未解析的键时,将引发异常。

返回

组成的元组:

  • 数据类实例,顺序与传递给初始化的顺序相同。

另一种辅助方法,完全不使用 argparse,而是使用字典填充数据类类型。

parse_json_file

< >

( json_file: str | os.PathLike allow_extra_keys: bool = False ) 组成的元组

参数

  • json_file (stros.PathLike) — 要解析的 json 文件的文件名。
  • allow_extra_keys (bool, 可选, 默认为 False) — 默认为 False。如果为 False,当 json 文件中包含未解析的键时,将引发异常。

返回

组成的元组:

  • 数据类实例,顺序与传递给初始化的顺序相同。

另一种辅助方法,完全不使用 argparse,而是加载 json 文件并填充数据类类型。

parse_yaml_file

< >

( yaml_file: str | os.PathLike allow_extra_keys: bool = False ) 组成的元组

参数

  • yaml_file (stros.PathLike) — 要解析的 yaml 文件的文件名。
  • allow_extra_keys (bool, 可选, 默认为 False) — 默认为 False。如果为 False,当 yaml 文件中包含未解析的键时,将引发异常。

返回

组成的元组:

  • 数据类实例,顺序与传递给初始化的顺序相同。

另一种辅助方法,完全不使用 argparse,而是加载 yaml 文件并填充数据类类型。

调试实用工具

class transformers.debug_utils.DebugUnderflowOverflow

< >

( model max_frames_to_save = 21 trace_batch_nums = [] abort_after_batch_num = None )

参数

  • model (nn.Module) — 要调试的模型。
  • max_frames_to_save (int, 可选, 默认为 21) — 向前记录多少帧。
  • trace_batch_nums(list[int], 可选, 默认为 []) — 追踪哪些批次编号(这会关闭异常检测)。
  • abort_after_batch_num (int, 可选) — 是否在特定批次编号完成后中止。

此调试类有助于检测和理解模型何时开始出现极大或极小的数值,更重要的是检测 naninf 权重和激活元素。

共有 2 种工作模式:

  1. 下溢/上溢检测(默认)
  2. 特定批次的绝对最小值/最大值追踪(不带检测)

模式 1:下溢/上溢检测

要激活下溢/上溢检测,请使用模型初始化该对象,

debug_overflow = DebugUnderflowOverflow(model)

然后正常运行训练。如果在权重、输入或输出元素中至少有一个检测到 naninf,该模块将抛出异常并打印导致该事件的 max_frames_to_save 帧信息,每帧报告:

  1. 运行 forward 的模块全名加上类名
  2. 每个模块权重、输入和输出的所有元素的绝对最小值和最大值

例如,以下是在 fp16 混合精度下运行 google/mt5-small 时检测报告的页眉和最后几帧:

混合精度

Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min  abs max  metadata
[...]
                  encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
                  encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
                  encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00      inf output

您可以在此看到,T5DenseGatedGeluDense.forward 产生的输出激活值绝对最大值约为 62.7K,非常接近 fp16 的上限 64K。在下一帧中,我们有 Dropout 层,它在将某些元素置零后会对权重进行重新归一化,这将绝对最大值推高到了 64K 以上,从而导致了上溢(overflow)。

如您所见,当数值开始进入 fp16 的极高范围时,我们需要研究的是之前的那些帧。

追踪是通过 forward hook 完成的,该钩子在 forward 完成后立即被调用。

默认情况下会打印最后 21 帧。您可以根据需要调整默认值。例如:

debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)

为了验证您是否正确设置了此调试功能,并打算在可能需要数小时才能完成的训练中使用它,请先按照下一节所述,对前几个批次启用正常追踪运行一下。

模式 2:特定批次的绝对最小值/最大值追踪(不带检测)

第二种工作模式是按批次追踪,此时下溢/上溢检测功能是关闭的。

假设您想观察给定批次中每次 forward 调用所有成分的绝对最小值和最大值,

且仅针对第 1 批和第 3 批执行此操作。那么您可以这样实例化此类:

debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])

现在,第 1 批和第 3 批的完整过程将使用上述格式进行追踪。批次索引从 0 开始。

如果您知道程序在特定批次编号后开始出现异常,这将非常有用,您可以直接快进到该区域。

提前停止

您还可以指定在哪个批次编号后停止训练,使用:

debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)

此功能主要在追踪模式下有用,但也可用于任何模式。

性能:

由于此模块在每次 forward 时都会测量模型每个权重的绝对 min/max,因此会减慢训练速度。因此,请记得在满足调试需求后将其关闭。

在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.