TRL 文档
脚本实用工具
并获得增强的文档体验
开始使用
脚本实用工具
ScriptArguments
class trl.ScriptArguments
< source >( dataset_name: str dataset_config: typing.Optional[str] = None dataset_train_split: str = 'train' dataset_test_split: str = 'test' gradient_checkpointing_use_reentrant: bool = False ignore_bias_buffers: bool = False )
参数
- dataset_name (
str
) — 数据集名称。 - dataset_config (
str
或None
, 可选, 默认为None
) — 数据集配置名称。 对应于 load_dataset 函数的name
参数。 - dataset_train_split (
str
, 可选, 默认为"train"
) — 用于训练的数据集拆分。 - dataset_test_split (
str
, 可选, 默认为"test"
) — 用于评估的数据集拆分。 - gradient_checkpointing_use_reentrant (
bool
, 可选, 默认为False
) — 是否对梯度检查点应用use_reentrant
。 - ignore_bias_buffers (
bool
, 可选, 默认为False
) — 分布式训练的调试参数。 修复 LM 偏差/掩码缓冲区的 DDP 问题 - 无效的标量类型、原地操作。 参见 https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992。
所有脚本通用的参数。
TrlParser
class trl.TrlParser
< source >( dataclass_types: typing.Union[transformers.hf_argparser.DataClassType, typing.Iterable[transformers.hf_argparser.DataClassType], NoneType] = None **kwargs )
transformers.HfArgumentParser
的子类,旨在解析带有数据类支持配置的命令行参数,同时还支持配置文件加载和环境变量管理。
# main.py
import os
from dataclasses import dataclass
from trl import TrlParser
@dataclass
class MyArguments:
arg1: int
arg2: str = "alpha"
parser = TrlParser(dataclass_types=[MyArguments])
training_args = parser.parse_args_and_config()
print(training_args, os.environ.get("VAR1"))
$ python main.py --config config.yaml
(MyArguments(arg1=23, arg2='alpha'),) value1
$ python main.py --arg1 5 --arg2 beta
(MyArguments(arg1=5, arg2='beta'),) None
parse_args_and_config
< source >( args: typing.Optional[typing.Iterable[str]] = None return_remaining_strings: bool = False )
将命令行参数和配置文件解析为指定数据类类型的实例。
此方法包装了 transformers.HfArgumentParser.parse_args_into_dataclasses
,并且还解析了使用 --config
标志指定的配置文件。 配置文件(YAML 格式)提供参数值,这些值将替换数据类中的默认值。 命令行参数可以覆盖配置文件设置的值。 该方法还会设置配置文件 env
字段中指定的任何环境变量。
parse_args_into_dataclasses
< source >( args = None return_remaining_strings = False look_for_args_file = True args_filename = None args_file_flag = None ) → 由以下各项组成的元组
参数
- args — 要解析的字符串列表。 默认值取自 sys.argv。 (与 argparse.ArgumentParser 相同)
- return_remaining_strings — 如果为 true,则同时返回剩余参数字符串的列表。
- look_for_args_file — 如果为 true,将查找与此进程的入口点脚本具有相同基本名称的“.args”文件,并将其潜在内容附加到命令行参数中。
- args_filename — 如果不是 None,将使用此文件而不是前一个参数中指定的“.args”文件。
- args_file_flag — 如果不是 None,将查找命令行参数中使用此标志指定的文件。该标志可以多次指定,优先级由顺序决定(最后一个获胜)。
返回
由以下内容组成的元组
- dataclass 实例,其顺序与它们传递给 initializer.abspath 的顺序相同
- 如果适用,则为在初始化后添加到解析器的更多(非 dataclass 支持的)参数的附加命名空间。
- 剩余参数字符串的潜在列表。(与 argparse.ArgumentParser.parse_known_args 相同)
将命令行参数解析为指定 dataclass 类型的实例。
这依赖于 argparse 的 ArgumentParser.parse_known_args
。请参阅以下文档:docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
使用通过关键字参数提供的那些值覆盖解析器的默认值。
任何具有更新默认值的参数,如果之前是必需的,也会被标记为非必需。
返回解析器未使用的字符串列表。