TRL 文档
文本环境
并获得增强的文档体验
开始使用
文本环境
文本环境为语言智能体提供了一个学习场所。它允许语言模型使用工具来完成任务,例如使用 Python 解释器来回答数学问题,或使用搜索索引来回答琐事问题。访问工具使语言模型能够解决对于模型本身来说非常困难的任务,但对于合适的工具来说可能微不足道。一个很好的例子是大数的算术,一旦您可以使用计算器,它就变成了一个简单的复制粘贴任务。

让我们深入了解文本环境的工作原理,并从工具开始!
工具
文本环境的核心构建模块之一是模型可以用来解决任务的工具。通常,工具可以是任何以字符串作为输入并返回字符串的 Python 函数。TextEnvironment
为工具提供了两个选项:要么使用 transformers.Tool
中预定义的工具,要么使用具有 __call__
方法的您自己的函数或类。让我们看看这两种方法!
transformers.Tool
文本环境完全支持 transformers.Tool
类的工具。在该框架中构建工具的优势在于它们可以轻松共享
from transformers import load_tool
# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
这些工具可以从 Hub 或本地文件夹加载。使用工具就像使用文本查询调用它们一样简单
calc_tool("1/2")
>>> "0.5"
请注意,输入和返回值都是字符串,以便于与语言模型一起使用。
自定义工具
以下是一个将两个整数相加的工具示例
def add(text):
int_1, int_2 = text.split("+")
result = int(int_1) + int(int_2)
return str(result)
print(add("1+1"))
>>> "2"
我们查看了诸如计算器之类的基本示例,但该原理也适用于更复杂的工具,例如 Web 搜索工具,您可以在其中输入查询并获得搜索结果作为回报。现在让我们看看模型如何使用调用语法来使用工具。
调用语法
为了让模型能够以统一的方式调用工具,我们创建了一个简单的语法,如下所示
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
涉及一些特殊的 token,让我们分解一下:首先,模型可以通过发出 <request>
token 来表示它想要使用工具。之后,我们想知道要调用的工具的名称,这可以通过用 <>
括号括起工具名称来完成。一旦我们知道要调用哪个工具,就会紧跟工具查询,该查询以自由文本形式存在。 <call>
token 表示查询的结束并停止模型生成。此时,将解析模型输出,并将查询发送到工具。环境将工具响应附加到字符串,后跟 <response>
token 以显示工具输出的结束。
让我们看一下计算器的具体示例,并假设它的名称是 Calculator
(稍后会详细介绍如何推断工具的名称)
"<request><Calculator>1/2<call>0.5<response>"
最后,当模型生成 <submit>
时,episode 结束并停止生成,这标志着交互已完成。
现在让我们看一下如何创建一个新的文本环境!
创建 TextEnvironment
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""
def reward_fn(result, answer):
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
result_parsed = result.split("=")[1].split("<")[0]
return int(result_parsed==answer)
text_env = TextEnvironemnt(
model=model,
tokenizer=tokenizer,
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
reward_fn=exact_match_reward,
prompt=prompt,
max_turns=1
max_tool_response=100
generation_kwargs={"do_sample": "true"}
)
让我们分解一下设置
参数 | 描述 |
---|---|
model | 与环境交互并生成请求的语言模型。 |
tokenizer | 语言模型的 tokenizer,用于处理字符串的 token 化。 |
tools | 工具的 list 或 dict 。 如果是前者,则工具的名称从类名推断,否则它是字典的键。 |
reward_fn | 一个以字符串作为输入并返回的函数。可以有传递给 .run() 的额外参数,例如 ground truth。 |
prompt | 预先添加到每个任务的提示。通常是一些示例,用于演示模型如何在 few-shot 方式中使用工具。 |
max_turns | 模型和工具之间交互的最大次数,超过此次数 episode 结束。 |
max_tool_response | 工具响应将被截断为此数字,以避免模型上下文耗尽。 |
max_length | 一个 episode 中允许的最大 token 数。 |
generation_kwargs | 语言模型使用的生成设置。 |
您可以根据您的需求自定义环境,并添加自定义工具和设置。让我们看看如何使用环境让模型与可用工具进行交互!
运行 Episode
要通过文本环境运行一组查询,只需使用 run
方法。
queries = ["What is 1/2?"]
answers = ["0.5"]
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
这将为每个查询执行模型/工具反馈循环,直到不再调用任何工具、达到最大轮数或 episode 中的最大 token 数超出为止。传递给 run
的额外 kwargs
(例如上面的 answers=answers
)将传递给奖励函数。
run
返回五个对象
queries
:token 化的查询列表responses
:在环境中生成的所有 token,包括模型和工具 tokenmasks
:指示哪些 token 由模型生成,哪些 token 由工具生成的掩码rewards
:每个查询/响应的奖励列表histories
:TextHistory
对象列表,这些对象是有用的对象,包含上述所有内容以及文本等效项
掩码对于训练至关重要,因为我们不想优化模型未生成的 token,即工具生成的 token。
接下来,我们将使用生成的响应训练 PPO 步骤!
训练
在 TextEnvironment
的 episode 上进行训练非常简单,只需将除 TextHistory
对象之外的所有返回变量转发到 step
方法即可
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
TextHistory
TextHistory
对象存储模型和文本环境之间的交互。它存储每个 turn 中生成的 token 和文本及其在每个 turn 中的来源(模型或系统)以及奖励。让我们浏览一下类属性和方法。
属性
下表总结了 TextEnvironment
类的可用属性
属性 | 描述 |
---|---|
text | 在文本环境中生成的文本的完整字符串,包含模型和系统生成的文本。 |
text_spans | 包含每个模型或系统生成的文本段的跨度的元组列表。 |
system_spans | 指示该段是模型生成还是系统生成的布尔值列表。 |
tokens | 在文本环境中生成的所有 token,包含模型和系统生成的 token。 |
token_spans | 与 text_spans 类似,token_spans 指示模型和系统生成的 token 的边界。 |
token_masks | token 掩码可用于通过掩盖系统生成的 token 来忽略它们。 |
completed | 指示与环境的交互是否已完成。 |
truncated | 指示与环境的交互是否由于达到最大长度而完成。 |
使用这些属性,您可以重建模型与 TextEnvironment
的每次交互。 TextHistory
还允许您可视化文本历史记录。让我们看一下!
可视化
当模型在 TextEnvironment
内部交互时,可视化和区分文本输出的哪些部分由模型生成,哪些部分来自系统和工具可能很有用。为此,有两个方法 TextHistory.show_text() 和 TextHistory.show_tokens()。它们分别打印文本和 token,并使用 rich
库 高亮显示各个段(请确保在使用这些方法之前安装它)。
您可以看到提示以灰色突出显示,而系统段(如查询和工具响应)以绿色突出显示。模型生成的所有段均以蓝色突出显示,除了纯文本输出外,奖励还以李子色文本显示。这是 show_text
的示例

有时,当显示解码后的文本时,可能会隐藏与 token 化相关的棘手问题。因此,TextHistory
还提供了一个选项,可以使用 show_tokens
直接在 token 上显示相同的高亮显示

请注意,您可以通过传递 show_legend=True
来打开颜色图例。
API 文档
class trl.TextEnvironment
< source >( model = None tokenizer = None tools = None reward_fn = None prompt = None max_turns = 4 max_tool_response = 100 max_length = None generation_kwargs = None )
TextEnvironment 允许 LLM 使用工具与环境进行交互。
计算一系列历史记录的奖励。
为一系列历史记录生成回复。
解析请求字符串。 预期格式:<request><tool_name>query<call>
在一系列查询上运行环境。
将环境向前推进一步。
检查当前生成序列是否已完成。
检查当前生成序列是否已完成。
TextHistory
类跟踪语言模型和环境之间交互的历史记录。
append_segment
< source >( text tokens system = True )
向历史记录中追加新片段。
将历史记录标记为已完成。
打印颜色图例。
打印文本历史记录。
打印历史记录 tokens。
将 tokens 分割成查询和回复 tokens。