TRL文档

文本环境

Hugging Face's logo
加入Hugging Face社区

并获取增强的文档体验

开始使用

文本环境

文本环境为语言代理提供了一个学习平台。它允许语言模型使用工具来完成任务,例如使用Python解释器来回答数学问题或使用搜索索引来回答趣味问题。访问工具使得语言模型能够解决对于模型自身来说可能非常困难但对适当的工具来说却可能是简单的事务。一个好的例子是在你有了计算器后,大数算术变成了简单的复制粘贴任务。

让我们深入了解文本环境的工作原理,并从工具开始。

工具

文本环境的基石之一是模型可以用来解决任务的工具。一般来说,工具可以是任何以字符串为输入并返回字符串的 Python 函数。《TextEnvironment》提供了两种工具选项:要么使用来自 transformers.Tool 的预设工具,要么定义一个具有 __call__ 方法的自定义函数或类。让我们看看两种方法!

transformers.Tool

文本环境完全支持 transformers.Tool 类的工具。在这个框架中构建工具的优势在于它们可以很容易地共享。

from transformers import load_tool

# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")

# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")

# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")

这些工具可以来自 hub 或本地文件夹。使用工具就像用文本查询调用它们一样简单。

calc_tool("1/2")
>>> "0.5"

请注意,输入和返回值都是字符串,以便容易与语言模型一起使用。

自定义工具

以下是一个添加两个整数的工具示例。

def add(text):
    int_1, int_2 = text.split("+")
    result = int(int_1) + int(int_2)
    return str(result)

print(add("1+1"))
>>> "2"

我们查看了一些基本示例,如计算器,但这个原则也适用于更复杂的工具,如网页搜索工具,让您输入查询并返回搜索结果。现在让我们看看模型如何使用调用语法来使用工具。

调用语法

为了使模型调用工具的方式标准化,我们创建了一个简单的语法,如下所示

"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"

涉及一些特殊令牌,因此让我们来分解它:首先,模型可以通过发出<request>令牌来表明它想要使用一个工具。之后,我们想要知道要调用哪个工具的名称,这可以通过将工具名称括在<>括号中来完成。一旦我们知道了要调用的工具,随后是工具查询,这是自由文本形式的。<call>令牌表示查询的结束并停止模型生成。在此阶段,模型的输出被解析并将查询发送给工具。环境将工具的响应附加到字符串后面,然后跟随着<response>令牌以显示工具输出的结束。

让我们看看计算器的具体示例,假设它的名称是Calculator(稍后我们将讨论如何推断工具的名称)

"<request><Calculator>1/2<call>0.5<response>"

最后,当模型生成<submit>时,会结束这一幕并停止生成,这个令牌标记了交互为完成。

现在让我们看看我们如何创建一个新的文本环境!

创建文本环境

prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""

def reward_fn(result, answer):
    """Simplified reward function returning 1 if result matches answer and 0 otherwise."""
    result_parsed = result.split("=")[1].split("<")[0]
    return int(result_parsed==answer)

text_env = TextEnvironemnt(
    model=model, 
    tokenizer=tokenizer,
    tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
    reward_fn=exact_match_reward,
    prompt=prompt, 
    max_turns=1
    max_tool_response=100
    generation_kwargs={"do_sample": "true"}
)

让我们来分解一下设置

参数 描述
model 与环境和生成请求进行交互的语言模型。
tokenizer 处理字符串分词的语言模型分词器。
tools 工具的listdict。如果前面不是,则工具的名称从类名推断出来,否则是从字典的键推断出来。
reward_fn 接受字符串作为输入并返回的函数。可以具有传递给.run()的额外参数,如基准真理。
prompt 每个任务前面都附加的提示。通常是一系列的例子,以演示模型如何通过少量示例使用工具。
max_turns 在剧情结束后,模型和工具之间交互的最大次数。
max_tool_response 工具响应被截断到这个数字以避免超出模型上下文。
max_length 一个剧情中允许的最大令牌数。
generation_kwargs 语言模型使用的生成设置。

您可以根据需要自定义环境并添加自定义工具和设置。让我们看看您如何使用环境让模型与可用的工具进行交互!

运行一集

要在一个文本环境中运行查询集,可以简单地使用 run 方法。

queries = ["What is 1/2?"]
answers = ["0.5"]

queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)

这将对每个查询执行模型/工具反馈循环,直到不再调用任何工具、达到最大回合数或角色片断中的最大令牌数超过。传递给 run 的额外 kwargs(例如上述 answers=answers)将进一步传递给奖励函数。

run 会返回五个对象:

  • queries:分词后的查询列表
  • responses:在环境中生成的所有令牌,包括模型和工具令牌
  • masks:标识哪些令牌由模型生成,哪些令牌由工具生成的掩码
  • rewards:每个查询/响应的奖励列表
  • histories:包含 TextHistory 对象的列表,这些对象包含以上所有内容,还有文本等价物

掩码对于训练至关重要,因为我们不希望优化模型未生成的令牌——即由工具产生的令牌。

接下来,我们将使用生成的响应进行 PPO 步骤的训练!

训练

TextEnvironment 的角色上训练很直接,只需将返回的所有变量(除 TextHistory 对象外)正向传递给 step 方法。

train_stats = ppo_trainer.step(queries, responses, rewards, masks)

文本历史记录

TextHistory 对象存储了模型和文本环境之间的交互。它存储每个回合中生成的令牌和文本及其来源(模型或系统)以及奖励。让我们来看看类的属性和方法。

属性

以下表格总结了TextEnvironment类的可用属性

属性 描述
text text环境中生成的文本完整字符串,包括模型和系统生成的文本。
text_spans 一个由每个模型或系统生成的文本片段的跨度组成的元组列表。
system_spans 一个表示该片段是否为模型或系统生成的布尔值列表。
tokens 在文本环境中生成的所有标记,包括模型和系统生成的标记。
token_spans text_spans类似,token_spans表示模型和系统生成的标记的边界。
token_masks 可以使用标记掩码忽略系统生成的标记。
completed 表示与环境的交互是否已完成。
truncated 表示交互是否已完成,因为达到了最大长度。

通过这些属性,您可以重建模型与TextEnvironment的每次交互。此外,TextHistory还允许您可视化文本历史。让我们看看!

可视化

当模型在TextEnvironment内部交互时,可视化并区分文本输出的哪些部分是由模型生成的以及哪些部分来自系统和技术是有用的。为此,有两个方法:TextHistory.show_text()TextHistory.show_tokens()。它们分别打印文本和标记,并使用rich突出显示不同的片段(在使用这些方法之前,请确保已安装它)。

您可以看到,提示以灰色突出显示,而系统生成的片段,如查询和工具响应,以绿色突出显示。模型生成的所有片段以蓝色突出显示,除纯文本输出外,奖励还以洋红色文本形式显示。以下为show_text的示例。

有时可能会有一些复杂的分词相关问题,在显示解码后的文本时这些问题是隐藏的。因此,TextHistory还提供选项使用show_tokens直接在标记上显示相同的突出显示。

请注意,可以通过传递show_legend=True来打开颜色图例。

trl.TextEnvironment

TextEnvironment 允许使用工具与环境进行聊天。

compute_reward

< >

( histories **reward_kwargs )

计算一系列历史事件的奖励。

生成

< >

( histories )

为一系列历史事件生成响应。

解析工具调用

< >

( text )

解析请求字符串。期望格式:<request><tool_name>query<call>

执行

< >

( queries **rewards_kwargs )

参数

  • queries (list[str]) — 在环境中运行模型的查询列表。

在查询列表上运行环境。

步骤

< >

( history )

参数

  • 历史 (TextHistory) - 向前跳跃的历史记录。

向前推进环境一步。

task_end_check

< >

( history model_turn = True )

检查当前生成序列是否结束。

tasks_end_check

< >

( histories model_turn = True )

检查当前生成序列是否结束。

class trl.TextHistory

< >

( text tokens system = True )

TextHistory类跟踪文本模型与环境交互的历史。

append_segment

< >

( text tokens system = True )

向历史中添加一个新的段落。

参数: text (str): 新段落的文本。tokens (torch.LongTensor): 新段落的标记。system (bool, 可选): 新段落是系统段落还是用户段落。

complete

< >

( truncated = False )

将历史标记为完成。

显示颜色图例

< >

( )

打印颜色图例。

show_text

< >

( show_legend = False )

打印文本历史。

show_tokens

< >

( tokenizer show_legend = False )

打印历史令牌。

split_query_response_tokens

< >

( )

将令牌分为查询和响应令牌。

< > 在GitHub上更新