训练器实用工具
此页面列出了 Trainer 使用的所有实用函数。
如果您正在研究库中 Trainer 的代码,那么其中大多数函数才会派上用场。
实用工具
类 transformers.EvalPrediction
< 源代码 >( predictions: Union label_ids: Union inputs: Union = None )
评估输出(始终包含标签),用于计算指标。
类 transformers.IntervalStrategy
< 源代码 >( value names = None module = None qualname = None type = None start = 1 )
枚举。
用于在分布式训练期间实现可重复行为的辅助函数。参见
transformers.set_seed
< 源代码 >( seed: int deterministic: bool = False )
用于设置 random
、numpy
、torch
和/或 tf
(如果已安装)中的种子以实现可重复行为的辅助函数。
装饰器,用于使分布式训练中的所有进程等待每个 local_master 执行某些操作。
回调函数内部
类 transformers.trainer_callback.CallbackHandler
< 来源 >( callbacks model tokenizer optimizer lr_scheduler )
内部类,按顺序调用回调函数列表。
分布式评估
类 transformers.trainer_pt_utils.DistributedTensorGatherer
< 来源 >( world_size num_samples make_multiple_of = None padding_index = -100 )
一个负责通过块在 CPU 上正确收集张量(或张量的嵌套列表/元组)的类。
如果我们的数据集有 16 个样本,批大小为 2,在 3 个进程上,并且我们在每一步收集然后传输到 CPU,我们的采样器将生成以下索引
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]
以获得大小为 3 的倍数(以便每个进程获得相同的数据集长度)。然后进程 0、1 和 2 将负责对以下样本进行预测
- P0:
[0, 1, 2, 3, 4, 5]
- P1:
[6, 7, 8, 9, 10, 11]
- P2:
[12, 13, 14, 15, 0, 1]
每个进程处理的第一批将是
- P0:
[0, 1]
- P1:
[6, 7]
- P2:
[12, 13]
因此,如果我们在第一批结束时收集,我们将获得一个对应于以下索引的张量(张量的嵌套列表/元组)
[0, 1, 6, 7, 12, 13]
如果我们直接连接我们的结果而不采取任何预防措施,用户将在预测循环结束时按此顺序获得索引的预测
[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]
出于某种原因,这不会让他们满意。这个类就是为了解决这个问题。
将 arrays
添加到内部存储,将在传递的第一个数组处将存储初始化为完整大小,以便如果我们注定要遇到 OOM,它会在开始时发生。
返回正确收集的数组并截断为样本数量(因为采样器添加了一些额外的内容以使每个进程获得相同长度的数据集)。
训练器参数解析器
argparse.ArgumentParser
的这个子类使用数据类上的类型提示来生成参数。
该类旨在与原生 argparse 良好地配合使用。特别是,您可以在初始化后向解析器添加更多(非数据类支持的)参数,并且您将在解析后获得输出作为额外的命名空间。可选:要创建子参数组,请使用数据类中的 _argument_group_name
属性。
parse_args_into_dataclasses
< 源代码 >( args = None return_remaining_strings = False look_for_args_file = True args_filename = None args_file_flag = None ) → 由以下内容组成的元组
返回值
由以下内容组成的元组
- 数据类实例,其顺序与它们传递给初始化程序的顺序相同。abspath
- 如果适用,则为初始化后添加到解析器的更多(非数据类支持的)参数提供额外的命名空间。
- 剩余参数字符串的潜在列表。(与 argparse.ArgumentParser.parse_known_args 相同)
将命令行参数解析为指定数据类类型的实例。
这依赖于 argparse 的 ArgumentParser.parse_known_args
。请参阅以下文档:docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
parse_dict
< 源代码 >( args: Dict allow_extra_keys: bool = False ) → 由以下内容组成的元组
另一种辅助方法,它根本不使用 argparse
,而是使用字典并填充数据类类型。
parse_json_file
< 来源 >( json_file: Union allow_extra_keys: bool = False ) → 由以下部分组成的元组
替代辅助方法,完全不使用 argparse
,而是加载 json 文件并填充数据类类型。
parse_yaml_file
< 来源 >( yaml_file: Union allow_extra_keys: bool = False ) → 由以下部分组成的元组
替代辅助方法,完全不使用 argparse
,而是加载 yaml 文件并填充数据类类型。
调试实用程序
类 transformers.debug_utils.DebugUnderflowOverflow
< 源代码 >( model max_frames_to_save = 21 trace_batch_nums = [] abort_after_batch_num = None )
此调试类有助于检测和理解模型何时开始变得非常大或非常小,更重要的是检测权重和激活元素中的 nan
或 inf
。
有 2 种工作模式
- 下溢/溢出检测(默认)
- 特定批次绝对最小值/最大值跟踪,无需检测
模式 1:下溢/溢出检测
然后正常运行训练,如果在至少一个权重、输入或输出元素中检测到 nan
或 inf
,此模块将抛出异常并打印导致此事件的 max_frames_to_save
帧,每个帧报告
- 完全限定的模块名称加上运行其
forward
的类名 - 每个模块权重以及输入和输出的所有元素的绝对最小值和最大值
例如,以下是 google/mt5-small
在 fp16 中运行的检测报告的标题和最后几帧
混合精度
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
[...]
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
您可以在这里看到,T5DenseGatedGeluDense.forward
导致输出激活的绝对最大值约为 62.7K,非常接近 fp16 的上限 64K。在下一帧中,我们有 Dropout
,它在将某些元素归零后重新规范化权重,这将绝对最大值推到超过 64K,导致溢出。
如您所见,当数字对于 fp16 数字开始变得非常大时,我们需要查看的是之前的帧。
跟踪是在前向钩子中完成的,该钩子在 forward
完成后立即调用。
默认情况下,会打印最后 21 帧。您可以更改默认值以适应您的需求。例如
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
为了验证您是否正确设置了此调试功能,并且您打算在可能需要数小时才能完成的训练中使用它,请先按照下一节中的说明,对其中一批或几批启用正常的跟踪来运行它。
模式 2. 特定批次的绝对最小值/最大值跟踪,不进行检测
第二种工作模式是按批次跟踪,并关闭下溢/溢出检测功能。
假设您想查看给定批次的每次 forward
调调用的所有成分的绝对最小值和最大值,并且仅对批次 1 和 3 执行此操作。然后您将此类实例化如下
现在,将使用与上述相同的格式跟踪完整的批次 1 和 3。批次从 0 开始索引。
如果您知道程序在某个批次号之后开始出现异常,这将很有帮助,因此您可以快速转发到该区域。
提前停止
您还可以使用以下命令指定停止训练后的批次号:
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
此功能主要在跟踪模式下有用,但您可以在任何模式下使用它。
性能:
由于此模块会在每次前向传递时测量模型每个权重的绝对 min
/`max
,因此会降低训练速度。因此,请记住在满足调试需求后将其关闭。