Transformers 文档

Trainer 的实用工具

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

Trainer 的实用工具

此页面列出了 Trainer 使用的所有实用函数。

如果您正在研究库中 Trainer 的代码,那么其中大多数函数才有用。

实用工具

class transformers.EvalPrediction

< >

( predictions: typing.Union[numpy.ndarray, tuple[numpy.ndarray]] label_ids: typing.Union[numpy.ndarray, tuple[numpy.ndarray]] inputs: typing.Union[numpy.ndarray, tuple[numpy.ndarray], NoneType] = None losses: typing.Union[numpy.ndarray, tuple[numpy.ndarray], NoneType] = None )

参数

  • predictions (np.ndarray) — 模型的预测结果。
  • label_ids (np.ndarray) — 要匹配的目标。
  • inputs (np.ndarray, 可选) — 传递给模型的输入数据。
  • losses (np.ndarray, 可选) — 评估期间计算的损失值。

评估输出(始终包含标签),用于计算指标。

class transformers.IntervalStrategy

< >

( value names = None module = None qualname = None type = None start = 1 )

一个枚举。

transformers.enable_full_determinism

< >

( seed: int warn_only: bool = False )

分布式训练期间实现可重现行为的辅助函数。请参阅

transformers.set_seed

< >

( seed: int deterministic: bool = False )

参数

  • seed (int) — 要设置的种子值。
  • deterministic (bool, 可选, 默认为 False) — 是否在可用时使用确定性算法。可能会减慢训练速度。

用于实现可重现行为的辅助函数,可在 randomnumpytorch 和/或 tf 中设置种子(如果已安装)。

transformers.torch_distributed_zero_first

< >

( local_rank: int )

参数

  • local_rank (int) — 本地进程的 rank。

装饰器,使分布式训练中的所有进程等待每个 local_master 完成某些操作。

Callbacks 内部机制

class transformers.trainer_callback.CallbackHandler

< >

( callbacks model processing_class optimizer lr_scheduler )

仅按顺序调用回调列表的内部类。

分布式评估

class transformers.trainer_pt_utils.DistributedTensorGatherer

< >

( world_size num_samples make_multiple_of = None padding_index = -100 )

参数

  • world_size (int) — 分布式训练中使用的进程数。
  • num_samples (int) — 数据集中的样本数。
  • make_multiple_of (int, 可选) — 如果传递此参数,则此类假定传递给每个进程的数据集是通过添加样本来使其成为此参数的倍数。
  • padding_index (int, 可选, 默认为 -100) — 如果数组的序列长度不完全相同,则使用的填充索引。

一个负责按块在 CPU 上正确收集张量(或张量的嵌套列表/元组)的类。

如果我们的数据集有 16 个样本,在 3 个进程上批次大小为 2,并且我们在每一步都收集然后在 CPU 上传输,那么我们的采样器将生成以下索引

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]

以获得大小为 3 的倍数(以便每个进程获得相同的数据集长度)。然后,进程 0、1 和 2 将负责预测以下样本

  • P0: [0, 1, 2, 3, 4, 5]
  • P1: [6, 7, 8, 9, 10, 11]
  • P2: [12, 13, 14, 15, 0, 1]

每个进程上处理的第一个批次将是

  • P0: [0, 1]
  • P1: [6, 7]
  • P2: [12, 13]

因此,如果我们在第一个批次结束时收集,我们将获得一个张量(张量的嵌套列表/元组),对应于以下索引

[0, 1, 6, 7, 12, 13]

如果我们在不采取任何预防措施的情况下直接连接我们的结果,那么用户将在预测循环结束时按此顺序获得索引的预测

[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]

由于某些原因,这不会让他们满意。此类旨在解决该问题。

add_arrays

< >

( arrays )

arrays 添加到内部存储。将在传递的第一个数组处初始化存储到完整大小,以便如果我们注定要发生 OOM,它会在开始时发生。

finalize

< >

( )

返回正确收集的数组并截断到样本数量(因为采样器添加了一些额外的数据以使每个进程的数据集长度相同)。

Trainer 参数解析器

class transformers.HfArgumentParser

< >

( dataclass_types: typing.Union[transformers.hf_argparser.DataClassType, collections.abc.Iterable[transformers.hf_argparser.DataClassType], NoneType] = None **kwargs )

参数

  • dataclass_types (DataClassTypeIterable[DataClassType], 可选) — Dataclass 类型,或我们将用解析后的参数“填充”实例的数据类类型列表。
  • kwargs (Dict[str, Any], 可选) — 以常规方式传递给 argparse.ArgumentParser()

argparse.ArgumentParser 的这个子类使用数据类上的类型提示来生成参数。

该类旨在与原生 argparse 良好配合。 特别是,您可以在初始化后向解析器添加更多(非数据类支持的)参数,并且在解析后您将获得作为附加命名空间的输出。 可选:要创建子参数组,请在数据类中使用 _argument_group_name 属性。

parse_args_into_dataclasses

< >

( args = None return_remaining_strings = False look_for_args_file = True args_filename = None args_file_flag = None ) 由...组成的元组

参数

  • args — 要解析的字符串列表。 默认值取自 sys.argv。(与 argparse.ArgumentParser 相同)
  • return_remaining_strings — 如果为 true,则还返回剩余参数字符串的列表。
  • look_for_args_file — 如果为 true,将查找与此进程的入口点脚本具有相同基本名称的“.args”文件,并将其潜在内容附加到命令行参数。
  • args_filename — 如果不为 None,将使用此文件而不是上一个参数中指定的“.args”文件。
  • args_file_flag — 如果不为 None,将在命令行参数中查找使用此标志指定的文件。 该标志可以多次指定,优先级由顺序决定(最后一个优先)。

返回

由...组成的元组

  • 数据类实例,其顺序与它们传递给初始化程序的顺序相同。abspath
  • 如果适用,则为在初始化后添加到解析器的更多(非数据类支持的)参数的附加命名空间。
  • 剩余参数字符串的潜在列表。(与 argparse.ArgumentParser.parse_known_args 相同)

将命令行参数解析为指定的数据类类型的实例。

这依赖于 argparse 的 ArgumentParser.parse_known_args。 有关文档,请参见:docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args

parse_dict

< >

( args: dict allow_extra_keys: bool = False ) 由...组成的元组

参数

  • args (dict) — 包含配置值的字典
  • allow_extra_keys (bool, 可选, 默认为 False) — 默认为 False。 如果为 False,如果字典包含未解析的键,则会引发异常。

返回

由...组成的元组

  • 数据类实例,其顺序与它们传递给初始化程序的顺序相同。

不使用 argparse 的替代辅助方法,而是使用字典并填充数据类类型。

parse_json_file

< >

( json_file: typing.Union[str, os.PathLike] allow_extra_keys: bool = False ) 由...组成的元组

参数

  • json_file (stros.PathLike) — 要解析的 json 文件的文件名
  • allow_extra_keys (bool, 可选, 默认为 False) — 默认为 False。 如果为 False,如果 json 文件包含未解析的键,则会引发异常。

返回

由...组成的元组

  • 数据类实例,其顺序与它们传递给初始化程序的顺序相同。

不使用 argparse 的替代辅助方法,而是加载 json 文件并填充数据类类型。

parse_yaml_file

< >

( yaml_file: typing.Union[str, os.PathLike] allow_extra_keys: bool = False ) 由...组成的元组

参数

  • yaml_file (stros.PathLike) — 要解析的 yaml 文件的文件名
  • allow_extra_keys (bool, 可选, 默认为 False) — 默认为 False。 如果为 False,如果 json 文件包含未解析的键,则会引发异常。

返回

由...组成的元组

  • 数据类实例,其顺序与它们传递给初始化程序的顺序相同。

不使用 argparse 的替代辅助方法,而是加载 yaml 文件并填充数据类类型。

调试工具

class transformers.debug_utils.DebugUnderflowOverflow

< >

( model max_frames_to_save = 21 trace_batch_nums = [] abort_after_batch_num = None )

参数

  • model (nn.Module) — 要调试的模型。
  • max_frames_to_save (int, 可选, 默认为 21) — 要记录回溯多少帧
  • trace_batch_nums(List[int], 可选, 默认为 []) — 要追踪的批次号(关闭检测)
  • abort_after_batch_num (`int“, 可选) — 在完成某个批次号后是否中止

这个调试类帮助检测和理解模型何时开始变得非常大或非常小,更重要的是 naninf 权重和激活元素。

有两种工作模式

  1. 下溢/溢出检测(默认)
  2. 特定批次的绝对最小值/最大值追踪,不进行检测

模式 1:下溢/溢出检测

要激活下溢/溢出检测,请使用模型初始化对象

debug_overflow = DebugUnderflowOverflow(model)

然后正常运行训练,如果在至少一个权重、输入或输出元素中检测到 naninf,此模块将抛出异常,并将打印导致此事件的 max_frames_to_save 帧,每帧报告

  1. 完全限定的模块名称加上运行了 forward 的类名
  2. 每个模块的权重以及输入和输出的所有元素的绝对最小值和最大值

例如,以下是使用 fp16 运行的 google/mt5-small 的检测报告中的标题和最后几帧

混合精度

Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min  abs max  metadata
[...]
                  encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
                  encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
                  encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00      inf output

您可以在这里看到,T5DenseGatedGeluDense.forward 导致输出激活,其绝对最大值约为 62.7K,非常接近 fp16 的 64K 上限。在下一帧中,我们有 Dropout,它重新归一化权重,在将某些元素归零后,这会将绝对最大值推高到超过 64K,然后我们得到溢出。

正如您所看到的,当数字开始变得对于 fp16 数字来说非常大时,我们需要查看的是之前的帧。

跟踪是在前向钩子中完成的,该钩子在 forward 完成后立即调用。

默认情况下,打印最后 21 帧。您可以更改默认值以适应您的需求。例如

debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)

为了验证您已正确设置此调试功能,并且打算在可能需要数小时才能完成的训练中使用它,请首先为几个批次启用正常跟踪运行它,如下一节所述。

模式 2. 特定批次的绝对最小值/最大值追踪,不进行检测

第二种工作模式是按批次追踪,并关闭下溢/溢出检测功能。

假设您想观察每个 forward 调用的所有成分的绝对最小值和最大值,对于

给定批次,并且仅对批次 1 和 3 执行此操作。然后您可以按如下方式实例化此类

debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])

现在将使用与上述相同的格式追踪完整的批次 1 和 3。批次从 0 开始索引。

如果您知道程序在某个批次号之后开始出现异常,这将很有帮助,因此您可以快速前进到该区域。

提前停止

您还可以指定在哪个批次号之后停止训练,使用

debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)

此功能主要在追踪模式下有用,但您可以在任何模式下使用它。

性能:

由于此模块测量模型每次前向传播时每个权重的绝对 min/`max,因此会降低训练速度。因此,请记住在满足调试需求后将其关闭。

< > 在 GitHub 上更新