TRL 文档

文本环境

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

文本环境

文本环境为语言智能体提供了一个学习场所。它允许语言模型使用工具来完成任务,例如使用 Python 解释器来回答数学问题,或使用搜索索引来回答琐事问题。访问工具使语言模型能够解决对于模型本身来说非常困难的任务,但对于合适的工具来说可能微不足道。一个很好的例子是大数的算术,一旦您可以使用计算器,它就变成了一个简单的复制粘贴任务。

让我们深入了解文本环境的工作原理,并从工具开始!

工具

文本环境的核心构建模块之一是模型可以用来解决任务的工具。通常,工具可以是任何以字符串作为输入并返回字符串的 Python 函数。TextEnvironment 为工具提供了两个选项:要么使用 transformers.Tool 中预定义的工具,要么使用具有 __call__ 方法的您自己的函数或类。让我们看看这两种方法!

transformers.Tool

文本环境完全支持 transformers.Tool 类的工具。在该框架中构建工具的优势在于它们可以轻松共享

from transformers import load_tool

# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")

# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")

# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")

这些工具可以从 Hub 或本地文件夹加载。使用工具就像使用文本查询调用它们一样简单

calc_tool("1/2")
>>> "0.5"

请注意,输入和返回值都是字符串,以便于与语言模型一起使用。

自定义工具

以下是一个将两个整数相加的工具示例

def add(text):
    int_1, int_2 = text.split("+")
    result = int(int_1) + int(int_2)
    return str(result)

print(add("1+1"))
>>> "2"

我们查看了诸如计算器之类的基本示例,但该原理也适用于更复杂的工具,例如 Web 搜索工具,您可以在其中输入查询并获得搜索结果作为回报。现在让我们看看模型如何使用调用语法来使用工具。

调用语法

为了让模型能够以统一的方式调用工具,我们创建了一个简单的语法,如下所示

"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"

涉及一些特殊的 token,让我们分解一下:首先,模型可以通过发出 <request> token 来表示它想要使用工具。之后,我们想知道要调用的工具的名称,这可以通过用 <> 括号括起工具名称来完成。一旦我们知道要调用哪个工具,就会紧跟工具查询,该查询以自由文本形式存在。 <call> token 表示查询的结束并停止模型生成。此时,将解析模型输出,并将查询发送到工具。环境将工具响应附加到字符串,后跟 <response> token 以显示工具输出的结束。

让我们看一下计算器的具体示例,并假设它的名称是 Calculator(稍后会详细介绍如何推断工具的名称)

"<request><Calculator>1/2<call>0.5<response>"

最后,当模型生成 <submit> 时,episode 结束并停止生成,这标志着交互已完成。

现在让我们看一下如何创建一个新的文本环境!

创建 TextEnvironment

prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""

def reward_fn(result, answer):
    """Simplified reward function returning 1 if result matches answer and 0 otherwise."""
    result_parsed = result.split("=")[1].split("<")[0]
    return int(result_parsed==answer)

text_env = TextEnvironemnt(
    model=model, 
    tokenizer=tokenizer,
    tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
    reward_fn=exact_match_reward,
    prompt=prompt, 
    max_turns=1
    max_tool_response=100
    generation_kwargs={"do_sample": "true"}
)

让我们分解一下设置

参数 描述
model 与环境交互并生成请求的语言模型。
tokenizer 语言模型的 tokenizer,用于处理字符串的 token 化。
tools 工具的 listdict。 如果是前者,则工具的名称从类名推断,否则它是字典的键。
reward_fn 一个以字符串作为输入并返回的函数。可以有传递给 .run() 的额外参数,例如 ground truth。
prompt 预先添加到每个任务的提示。通常是一些示例,用于演示模型如何在 few-shot 方式中使用工具。
max_turns 模型和工具之间交互的最大次数,超过此次数 episode 结束。
max_tool_response 工具响应将被截断为此数字,以避免模型上下文耗尽。
max_length 一个 episode 中允许的最大 token 数。
generation_kwargs 语言模型使用的生成设置。

您可以根据您的需求自定义环境,并添加自定义工具和设置。让我们看看如何使用环境让模型与可用工具进行交互!

运行 Episode

要通过文本环境运行一组查询,只需使用 run 方法。

queries = ["What is 1/2?"]
answers = ["0.5"]

queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)

这将为每个查询执行模型/工具反馈循环,直到不再调用任何工具、达到最大轮数或 episode 中的最大 token 数超出为止。传递给 run 的额外 kwargs(例如上面的 answers=answers)将传递给奖励函数。

run 返回五个对象

  • queries:token 化的查询列表
  • responses:在环境中生成的所有 token,包括模型和工具 token
  • masks:指示哪些 token 由模型生成,哪些 token 由工具生成的掩码
  • rewards:每个查询/响应的奖励列表
  • historiesTextHistory 对象列表,这些对象是有用的对象,包含上述所有内容以及文本等效项

掩码对于训练至关重要,因为我们不想优化模型未生成的 token,即工具生成的 token。

接下来,我们将使用生成的响应训练 PPO 步骤!

训练

TextEnvironment 的 episode 上进行训练非常简单,只需将除 TextHistory 对象之外的所有返回变量转发到 step 方法即可

train_stats = ppo_trainer.step(queries, responses, rewards, masks)

TextHistory

TextHistory 对象存储模型和文本环境之间的交互。它存储每个 turn 中生成的 token 和文本及其在每个 turn 中的来源(模型或系统)以及奖励。让我们浏览一下类属性和方法。

属性

下表总结了 TextEnvironment 类的可用属性

属性 描述
text 在文本环境中生成的文本的完整字符串,包含模型和系统生成的文本。
text_spans 包含每个模型或系统生成的文本段的跨度的元组列表。
system_spans 指示该段是模型生成还是系统生成的布尔值列表。
tokens 在文本环境中生成的所有 token,包含模型和系统生成的 token。
token_spans text_spans 类似,token_spans 指示模型和系统生成的 token 的边界。
token_masks token 掩码可用于通过掩盖系统生成的 token 来忽略它们。
completed 指示与环境的交互是否已完成。
truncated 指示与环境的交互是否由于达到最大长度而完成。

使用这些属性,您可以重建模型与 TextEnvironment 的每次交互。 TextHistory 还允许您可视化文本历史记录。让我们看一下!

可视化

当模型在 TextEnvironment 内部交互时,可视化和区分文本输出的哪些部分由模型生成,哪些部分来自系统和工具可能很有用。为此,有两个方法 TextHistory.show_text()TextHistory.show_tokens()。它们分别打印文本和 token,并使用 rich 高亮显示各个段(请确保在使用这些方法之前安装它)。

您可以看到提示以灰色突出显示,而系统段(如查询和工具响应)以绿色突出显示。模型生成的所有段均以蓝色突出显示,除了纯文本输出外,奖励还以李子色文本显示。这是 show_text 的示例

有时,当显示解码后的文本时,可能会隐藏与 token 化相关的棘手问题。因此,TextHistory 还提供了一个选项,可以使用 show_tokens 直接在 token 上显示相同的高亮显示

请注意,您可以通过传递 show_legend=True 来打开颜色图例。

API 文档

class trl.TextEnvironment

< >

( model = None tokenizer = None tools = None reward_fn = None prompt = None max_turns = 4 max_tool_response = 100 max_length = None generation_kwargs = None )

TextEnvironment 允许 LLM 使用工具与环境进行交互。

compute_reward

< >

( histories **reward_kwargs )

计算一系列历史记录的奖励。

generate

< >

( histories )

为一系列历史记录生成回复。

parse_tool_call

< >

( text )

解析请求字符串。 预期格式:<request><tool_name>query<call>

run

< >

( queries **rewards_kwargs )

参数

  • queries (list[str]) — 要在环境中运行模型的查询列表。

在一系列查询上运行环境。

step

< >

( history )

参数

  • history (TextHistory) — 要向前推进的历史记录。

将环境向前推进一步。

task_end_check

< >

( history model_turn = True )

检查当前生成序列是否已完成。

tasks_end_check

< >

( histories model_turn = True )

检查当前生成序列是否已完成。

class trl.TextHistory

< >

( text tokens system = True )

TextHistory 类跟踪语言模型和环境之间交互的历史记录。

append_segment

< >

( text tokens system = True )

参数

  • text (str) — 新片段的文本。
  • tokens (torch.LongTensor) — 新片段的 tokens。
  • system (bool, optional) — 新片段是系统片段还是用户片段。

向历史记录中追加新片段。

complete

< >

( truncated = False )

将历史记录标记为已完成。

show_colour_legend

< >

( )

打印颜色图例。

show_text

< >

( show_legend = False )

打印文本历史记录。

show_tokens

< >

( tokenizer show_legend = False )

打印历史记录 tokens。

split_query_response_tokens

< >

( )

将 tokens 分割成查询和回复 tokens。

< > Update on GitHub