TRL 文档

脚本实用工具

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

脚本实用工具

ScriptArguments

class trl.ScriptArguments

< >

( dataset_name: str dataset_config: typing.Optional[str] = None dataset_train_split: str = 'train' dataset_test_split: str = 'test' gradient_checkpointing_use_reentrant: bool = False ignore_bias_buffers: bool = False )

参数

  • dataset_name (str) — 数据集名称。
  • dataset_config (strNone, 可选, 默认为 None) — 数据集配置名称。 对应于 load_dataset 函数的 name 参数。
  • dataset_train_split (str, 可选, 默认为 "train") — 用于训练的数据集拆分。
  • dataset_test_split (str, 可选, 默认为 "test") — 用于评估的数据集拆分。
  • gradient_checkpointing_use_reentrant (bool, 可选, 默认为 False) — 是否对梯度检查点应用 use_reentrant
  • ignore_bias_buffers (bool, 可选, 默认为 False) — 分布式训练的调试参数。 修复 LM 偏差/掩码缓冲区的 DDP 问题 - 无效的标量类型、原地操作。 参见 https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992

所有脚本通用的参数。

TrlParser

class trl.TrlParser

< >

( dataclass_types: typing.Union[transformers.hf_argparser.DataClassType, typing.Iterable[transformers.hf_argparser.DataClassType], NoneType] = None **kwargs )

参数

  • dataclass_types (Union[DataClassType, Iterable[DataClassType]]None, 可选, 默认为 None) — 用于参数解析的数据类类型。
  • **kwargs — 传递给 transformers.HfArgumentParser 构造函数的附加关键字参数。

transformers.HfArgumentParser 的子类,旨在解析带有数据类支持配置的命令行参数,同时还支持配置文件加载和环境变量管理。

示例

# config.yaml
env:
    VAR1: value1
arg1: 23
# main.py
import os
from dataclasses import dataclass
from trl import TrlParser

@dataclass
class MyArguments:
    arg1: int
    arg2: str = "alpha"

parser = TrlParser(dataclass_types=[MyArguments])
training_args = parser.parse_args_and_config()

print(training_args, os.environ.get("VAR1"))
$ python main.py --config config.yaml
(MyArguments(arg1=23, arg2='alpha'),) value1

$ python main.py --arg1 5 --arg2 beta
(MyArguments(arg1=5, arg2='beta'),) None

parse_args_and_config

< >

( args: typing.Optional[typing.Iterable[str]] = None return_remaining_strings: bool = False )

将命令行参数和配置文件解析为指定数据类类型的实例。

此方法包装了 transformers.HfArgumentParser.parse_args_into_dataclasses,并且还解析了使用 --config 标志指定的配置文件。 配置文件(YAML 格式)提供参数值,这些值将替换数据类中的默认值。 命令行参数可以覆盖配置文件设置的值。 该方法还会设置配置文件 env 字段中指定的任何环境变量。

parse_args_into_dataclasses

< >

( args = None return_remaining_strings = False look_for_args_file = True args_filename = None args_file_flag = None ) 由以下各项组成的元组

参数

  • args — 要解析的字符串列表。 默认值取自 sys.argv。 (与 argparse.ArgumentParser 相同)
  • return_remaining_strings — 如果为 true,则同时返回剩余参数字符串的列表。
  • look_for_args_file — 如果为 true,将查找与此进程的入口点脚本具有相同基本名称的“.args”文件,并将其潜在内容附加到命令行参数中。
  • args_filename — 如果不是 None,将使用此文件而不是前一个参数中指定的“.args”文件。
  • args_file_flag — 如果不是 None,将查找命令行参数中使用此标志指定的文件。该标志可以多次指定,优先级由顺序决定(最后一个获胜)。

返回

由以下内容组成的元组

  • dataclass 实例,其顺序与它们传递给 initializer.abspath 的顺序相同
  • 如果适用,则为在初始化后添加到解析器的更多(非 dataclass 支持的)参数的附加命名空间。
  • 剩余参数字符串的潜在列表。(与 argparse.ArgumentParser.parse_known_args 相同)

将命令行参数解析为指定 dataclass 类型的实例。

这依赖于 argparse 的 ArgumentParser.parse_known_args。请参阅以下文档:docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args

set_defaults_with_config

< >

( **kwargs )

使用通过关键字参数提供的那些值覆盖解析器的默认值。

任何具有更新默认值的参数,如果之前是必需的,也会被标记为非必需。

返回解析器未使用的字符串列表。

< > 更新 在 GitHub 上