TRL 文档
学习工具 (实验性 🧪)
并获得增强的文档体验
开始入门
学习工具 (实验性 🧪)
最近,将大型语言模型 (LLM) 与工具结合使用已成为一个热门话题,涌现出诸如 ToolFormer 和 ToolBench 等出色作品。在 TRL 中,我们提供了一个简单的示例,说明如何通过强化学习教 LLM 使用工具。
以下是 trl 仓库中脚本的概述
文件 | 描述 |
---|---|
calculator.py | 使用强化学习训练 LLM 使用计算器的脚本。 |
triviaqa.py | 训练 LLM 使用维基工具回答问题的脚本。 |
python_interpreter.py | 训练 LLM 使用 python 解释器解决数学难题的脚本。 |
请注意,以上脚本严重依赖 TextEnvironment
API,该 API 仍在积极开发中。该 API 将来可能会发生变化。有关相关文档,请参阅 TextEnvironment
。
学习使用计算器
大致思路如下
加载一个工具,例如 ybelkada/simple-calculator,它可以解析诸如
"14 + 34"
之类的文本计算并返回计算出的数字from transformers import AutoTokenizer, load_tool tool = load_tool("ybelkada/simple-calculator") tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
定义一个奖励函数,如果工具返回正确的答案,则返回正奖励。在脚本中,我们创建了一个虚拟奖励函数,例如
reward_fn = lambda x: 1
,但我们稍后直接覆盖奖励。创建关于如何使用工具的提示
# system prompt prompt = """\ What is 13.1-3? <request><SimpleCalculatorTool>13.1-3<call>10.1<response> Result=10.1<submit> What is 4*3? <request><SimpleCalculatorTool>4*3<call>12<response> Result=12<submit> What is 12.1+1? <request><SimpleCalculatorTool>12.1+1<call>13.1<response> Result=13.1<submit> What is 12.1-20? <request><SimpleCalculatorTool>12.1-20<call>-7.9<response> Result=-7.9<submit>"""
使用模型创建
trl.TextEnvironment
env = TextEnvironment( model, tokenizer, {"SimpleCalculatorTool": tool_fn}, reward_fn, prompt, generation_kwargs=generation_kwargs, )
然后生成一些数据,例如
tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]
并使用queries, responses, masks, rewards, histories = env.run(tasks)
运行环境。环境将在提示中查找<call>
标记,并将工具输出附加到响应;它还将返回与响应关联的掩码。您可以进一步使用histories
可视化模型和工具之间的交互;histories[0].show_text()
将显示带有颜色编码工具输出的文本,而histories[0].show_tokens(tokenizer)
将显示可视化标记。最后,我们可以使用
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
训练模型。训练器将使用掩码在计算损失时忽略工具输出,请确保将该参数传递给step
。
实验结果
我们使用上述脚本训练了一个模型,使用了 10 个随机种子。您可以使用以下命令重现运行。如果您无权访问 slurm 集群,请随意删除 --slurm-*
参数。
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/research_projects/tools/calculator.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
然后我们可以使用 openrlbenchmark
,它生成以下图表。
# pip install openrlbenchmark==0.2.1a5
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
正如我们所见,虽然 1-2 个实验由于某种原因崩溃了,但大多数运行在计算器任务中获得了接近完美的熟练度。
(早期实验 🧪):学习使用维基工具进行问题解答
在 ToolFormer 论文中,它展示了一个有趣的用例,即利用维基百科搜索工具来帮助回答问题。在本节中,我们尝试进行类似的实验,但使用 RL 来教模型在 TriviaQA 数据集上使用维基工具。
请注意,许多设置都不同,因此结果不具有直接可比性。
构建搜索索引
由于 ToolFormer 没有开源,我们需要首先复制搜索索引。他们的论文中提到,作者使用 BM25 检索器构建了搜索索引,该检索器索引了来自 KILT 的维基百科转储。
幸运的是,pyserini
已经实现了 BM25 检索器,并为 KILT 维基百科转储提供了预构建索引。我们可以使用以下代码搜索索引。
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
然后,我们基本上将此代码片段部署为 Hugging Face Space 此处,以便稍后我们可以将该 Space 用作 transformers.Tool
。
实验设置
我们使用以下设置
- 使用
bigcode/starcoderbase
模型作为基础模型 - 使用
pyserini-wikipedia-kilt-doc
space 作为维基工具,并且仅使用搜索结果的第一段,允许TextEnvironment
从工具获得最多max_tool_reponse=400
个响应标记。 - 测试响应是否包含答案字符串,如果包含,则给予 1 的奖励,否则给予 0 的奖励。
- 请注意,这是一个简化的评估标准。在 ToolFormer 中,作者检查响应的前 20 个单词是否包含正确答案。
- 使用了以下提示,演示了维基工具的用法。
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
结果和讨论
我们的实验表明,智能体可以学习使用维基工具来回答问题。学习曲线大多会上涨,但其中一个实验确实崩溃了。
Wandb 报告位于 此处,以供进一步检查。
请注意,训练模型的正确率偏低,这可能是由于以下原因:
- 不正确的搜索: 当给出问题
"What is Bruce Willis' real first name?"
时,如果模型搜索Bruce Willis
,我们的维基工具返回“Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.但正确的搜索应该是
Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]”
不必要的长响应:维基工具默认情况下有时会输出非常长的序列。例如,当维基工具搜索 “Brown Act” 时
我们的维基工具返回 “The Ralph M. Brown Act, located at California Government Code 54950 “et seq.”, is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public’s right to attend and participate in meetings of local legislative bodies.”
ToolFormer 的维基工具返回 “The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public’s right to attend and participate in meetings of local legislative bodies.”,这更加简洁。
(早期实验 🧪):使用 python 解释器解决数学难题
在本节中,我们尝试教模型使用 python 解释器来解决数学难题。大致思路是给智能体一个如下所示的提示
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>8<response>
Result = 8 <submit>
Q: """
训练实验可以在 https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y 中找到