TRL 文档

数据实用工具

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

数据实用工具

is_conversational

trl.is_conversational

< >

( example: dict ) bool

参数

  • example (dict[str, Any]) — 数据集的单个数据条目。示例可以根据数据集类型具有不同的键。

返回值

bool

如果数据采用对话格式,则返回 True,否则返回 False

检查示例是否为对话格式。

示例

>>> example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
>>> is_conversational(example)
True
>>> example = {"prompt": "The sky is"})
>>> is_conversational(example)
False

apply_chat_template

trl.apply_chat_template

< >

( example: dict tokenizer: PreTrainedTokenizerBase tools: typing.Optional[list[typing.Union[dict, typing.Callable]]] = None )

将聊天模板应用于对话示例以及 tools 中函数列表的架构。

有关更多详细信息,请参阅 maybe_apply_chat_template()

maybe_apply_chat_template

trl.maybe_apply_chat_template

< >

( example: dict tokenizer: PreTrainedTokenizerBase tools: typing.Optional[list[typing.Union[dict, typing.Callable]]] = None ) dict[str, str]

参数

  • example (dict[str, list[dict[str, str]]) — 表示对话数据集的单个数据条目的字典。每个数据条目可以根据数据集类型具有不同的键。支持的数据集类型有:

    • 语言建模数据集:"messages"
    • 仅提示数据集:"prompt"
    • 提示-补全数据集:"prompt""completion"
    • 偏好数据集:"prompt""chosen""rejected"
    • 带有隐式提示的偏好数据集:"chosen""rejected"
    • 不成对的偏好数据集:"prompt""completion""label"

    对于键 "messages""prompt""chosen""rejected""completion",这些值是消息列表,其中每条消息都是一个字典,键为 "role""content"

  • tokenizer (PreTrainedTokenizerBase) — 用于应用聊天模板的分词器。
  • tools (list[Union[dict, Callable]]None可选,默认为 None) — 模型可以访问的工具(可调用函数)列表。如果模板不支持函数调用,则此参数无效

返回值

dict[str, str]

应用聊天模板的格式化示例。

如果示例是对话格式,则对其应用聊天模板。

注释

  • 此函数不更改键,但对于语言建模数据集,"messages" 将替换为 "text"

  • 在仅提示数据的情况下,如果最后一个角色是 "user",则将生成提示添加到提示中。否则,如果最后一个角色是 "assistant",则继续最后一条消息。

示例

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
>>> example = {
...     "prompt": [{"role": "user", "content": "What color is the sky?"}],
...     "completion": [{"role": "assistant", "content": "It is blue."}]
... }
>>> apply_chat_template(example, tokenizer)
{'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}

maybe_convert_to_chatml

trl.maybe_convert_to_chatml

< >

( example: dict ) dict[str, list]

参数

  • example (dict[str, list]) — 包含消息列表的单个数据条目。

返回值

dict[str, list]

重新格式化为 ChatML 样式的示例。

将具有字段 fromvalue 的对话数据集转换为 ChatML 格式。

此函数修改对话数据以与 OpenAI 的 ChatML 格式对齐

  • 在消息字典中,将键 "from" 替换为 "role"
  • 在消息字典中,将键 "value" 替换为 "content"
  • "conversations" 重命名为 "messages",以与 ChatML 保持一致。

示例

>>> from trl import maybe_convert_to_chatml
>>> example = {
...     "conversations": [
...         {"from": "user", "value": "What color is the sky?"},
...         {"from": "assistant", "value": "It is blue."}
...     ]
... }
>>> maybe_convert_to_chatml(example)
{'messages': [{'role': 'user', 'content': 'What color is the sky?'},
              {'role': 'assistant', 'content': 'It is blue.'}]}

extract_prompt

trl.extract_prompt

< >

( example: dict )

从偏好数据示例中提取共享提示,其中提示隐式包含在选定和拒绝的补全中。

有关更多详细信息,请参阅 maybe_extract_prompt()

maybe_extract_prompt

trl.maybe_extract_prompt

< >

( example: dict ) dict[str, list]

参数

  • example (dict[str, list]) — 表示偏好数据集中的单个数据条目的字典。它必须包含键 "chosen""rejected",其中每个值可以是对话式或标准式 (str)。

返回值

dict[str, list]

包含以下内容的字典

  • "prompt": “chosen” 和 “rejected” 完成之间最长的公共前缀。
  • "chosen": “chosen” 完成的剩余部分,已移除 prompt。
  • "rejected": “rejected” 完成的剩余部分,已移除 prompt。

从偏好数据示例中提取共享提示,其中提示隐式包含在选定和拒绝的补全中。

如果示例已包含 "prompt" 键,则函数按原样返回示例。否则,该函数会识别 “chosen” 和 “rejected” 完成之间最长的公共对话轮次序列(前缀),并将其提取为 prompt。然后,它会从各自的 “chosen” 和 “rejected” 完成中移除此 prompt。

示例

>>> example = {
...     "chosen": [
...         {"role": "user", "content": "What color is the sky?"},
...         {"role": "assistant", "content": "It is blue."}
...     ],
...     "rejected": [
...         {"role": "user", "content": "What color is the sky?"},
...         {"role": "assistant", "content": "It is green."}
...     ]
... }
>>> extract_prompt(example)
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}

或者,使用 datasets.Datasetmap 方法

>>> from trl import extract_prompt
>>> from datasets import Dataset
>>> dataset_dict = {
...     "chosen": [
...         [
...             {"role": "user", "content": "What color is the sky?"},
...             {"role": "assistant", "content": "It is blue."},
...         ],
...         [
...             {"role": "user", "content": "Where is the sun?"},
...             {"role": "assistant", "content": "In the sky."},
...         ],
...     ],
...     "rejected": [
...         [
...             {"role": "user", "content": "What color is the sky?"},
...             {"role": "assistant", "content": "It is green."},
...         ],
...         [
...             {"role": "user", "content": "Where is the sun?"},
...             {"role": "assistant", "content": "In the sea."},
...         ],
...     ],
... }
>>> dataset = Dataset.from_dict(dataset_dict)
>>> dataset = dataset.map(extract_prompt)
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}

unpair_preference_dataset

trl.unpair_preference_dataset

< >

( dataset: ~DatasetType num_proc: typing.Optional[int] = None desc: typing.Optional[str] = None ) Dataset

参数

  • dataset (DatasetDatasetDict) — 要解对的偏好数据集。该数据集必须包含列 "chosen""rejected" 以及可选的 "prompt"
  • num_proc (intNone, *可选*, 默认为 None) — 用于处理数据集的进程数。
  • desc (strNone, *可选*, 默认为 None) — 在映射示例时,与进度条一起显示的有意义的描述。

返回值

Dataset

解对的偏好数据集。

解对偏好数据集。

示例

>>> from datasets import Dataset
>>> dataset_dict = {
...     "prompt": ["The sky is", "The sun is"]
...     "chosen": [" blue.", "in the sky."],
...     "rejected": [" green.", " in the sea."]
... }
>>> dataset = Dataset.from_dict(dataset_dict)
>>> dataset = unpair_preference_dataset(dataset)
>>> dataset
Dataset({
    features: ['prompt', 'completion', 'label'],
    num_rows: 4
})
>>> dataset[0]
{'prompt': 'The sky is', 'completion': ' blue.', 'label': True}

maybe_unpair_preference_dataset

trl.maybe_unpair_preference_dataset

< >

( dataset: ~DatasetType num_proc: typing.Optional[int] = None desc: typing.Optional[str] = None ) Dataset or DatasetDict

参数

  • dataset (DatasetDatasetDict) — 要解对的偏好数据集。该数据集必须包含列 "chosen""rejected" 以及可选的 "prompt"
  • num_proc (intNone, *可选*, 默认为 None) — 用于处理数据集的进程数。
  • desc (strNone, *可选*, 默认为 None) — 在映射示例时,与进度条一起显示的有意义的描述。

返回值

Dataset 或 DatasetDict

如果偏好数据集已配对,则返回解对后的数据集;否则返回原始数据集。

如果偏好数据集已配对,则解对该数据集。

示例

>>> from datasets import Dataset
>>> dataset_dict = {
...     "prompt": ["The sky is", "The sun is"]
...     "chosen": [" blue.", "in the sky."],
...     "rejected": [" green.", " in the sea."]
... }
>>> dataset = Dataset.from_dict(dataset_dict)
>>> dataset = unpair_preference_dataset(dataset)
>>> dataset
Dataset({
    features: ['prompt', 'completion', 'label'],
    num_rows: 4
})
>>> dataset[0]
{'prompt': 'The sky is', 'completion': ' blue.', 'label': True}

pack_examples

trl.pack_examples

< >

( examples: dict seq_length: int ) dict[str, list[list]]

参数

  • examples (dict[str, list[list]]) — 示例字典,键为字符串,值为列表的列表。
  • seq_length (int) — 最大序列长度。

返回值

dict[str, list[list]]

示例字典,键为字符串,值为列表的列表。

将示例打包成大小为 seq_length 的块。

示例

>>> from trl import pack_examples
>>> examples = {
...     "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
...     "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
... }
>>> pack_examples(examples, seq_length=5)
{'input_ids': [[1, 2, 3, 4, 5], [6, 7, 8]], 'attention_mask': [[0, 1, 1, 0, 0], [1, 1, 1]]}
>>> pack_examples(examples, seq_length=2)
{'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]}

pack_dataset

trl.pack_dataset

< >

( dataset: ~DatasetType seq_length: int map_kwargs: typing.Optional[dict[str, typing.Any]] = None ) Dataset or DatasetDict

参数

  • dataset (DatasetDatasetDict) — 要打包的数据集
  • seq_length (int) — 目标打包序列长度。
  • map_kwargs (dictNone, *可选*, 默认为 None) — 在打包示例时,传递给数据集 map 方法的附加关键字参数。

返回值

Dataset 或 DatasetDict

包含打包序列的数据集。当序列被组合时,示例的数量可能会减少。

将数据集中的序列打包成大小为 seq_length 的块。

示例

>>> from datasets import Dataset
>>> examples = {
...     "input_ids": [[1, 2], [3, 4], [5, 6], [7]],
...     "attention_mask": [[1, 1], [0, 1], [1, 1], [1]],
... }
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4)
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 4], [5, 6, 7]],
 'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]}

truncate_dataset

trl.truncate_dataset

< >

( dataset: ~DatasetType max_length: int map_kwargs: typing.Optional[dict[str, typing.Any]] = None ) Dataset or DatasetDict

参数

  • dataset (DatasetDatasetDict) — 要截断的数据集。
  • seq_length (int) — 要截断到的最大序列长度。
  • map_kwargs (dictNone, 可选, 默认为 None) — 传递给数据集的 map 方法以截断示例的其他关键字参数。

返回值

Dataset 或 DatasetDict

包含截断序列的数据集。

将数据集中的序列截断为指定的 max_length

示例

>>> from datasets import Dataset
>>> examples = {
...     "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
...     "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
... }
>>> dataset = Dataset.from_dict(examples)
>>> truncated_dataset = truncate_dataset(dataset, max_length=2)
>>> truncated_dataset[:]
{'input_ids': [[1, 2], [4, 5], [8]],
 'attention_mask': [[0, 1], [0, 0], [1]]}
< > 在 GitHub 上更新