学习工具(实验性 🧪)
最近,使用工具的大型语言模型 (LLM) 成为一个热门话题,出现了许多优秀的成果,例如 ToolFormer 和 ToolBench。在 TRL 中,我们提供了一个简单的示例,说明如何使用强化学习来训练 LLM 使用工具。
以下是 trl 代码库 中脚本的概述
文件 | 描述 |
---|---|
calculator.py | 使用强化学习训练 LLM 使用计算器的脚本。 |
triviaqa.py | 训练 LLM 使用维基工具回答问题的脚本。 |
python_interpreter.py | 训练 LLM 使用 Python 解释器解决数学难题的脚本。 |
请注意,以上脚本严重依赖于 TextEnvironment
API,该 API 仍处于积极开发阶段。API 未来可能会发生变化。有关相关文档,请参阅 TextEnvironment
。
学习使用计算器
大致思路如下
加载一个工具,例如 ybelkada/simple-calculator,它可以解析类似
"14 + 34"
的文本计算并返回计算结果。from transformers import AutoTokenizer, load_tool tool = load_tool("ybelkada/simple-calculator") tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
定义一个奖励函数,如果工具返回正确的答案,则返回正奖励。在脚本中,我们创建了一个虚拟奖励函数,例如
reward_fn = lambda x: 1
,但稍后我们会直接覆盖奖励。创建有关如何使用工具的提示
# system prompt prompt = """\ What is 13.1-3? <request><SimpleCalculatorTool>13.1-3<call>10.1<response> Result=10.1<submit> What is 4*3? <request><SimpleCalculatorTool>4*3<call>12<response> Result=12<submit> What is 12.1+1? <request><SimpleCalculatorTool>12.1+1<call>13.1<response> Result=13.1<submit> What is 12.1-20? <request><SimpleCalculatorTool>12.1-20<call>-7.9<response> Result=-7.9<submit>"""
使用模型创建
trl.TextEnvironment
env = TextEnvironment( model, tokenizer, {"SimpleCalculatorTool": tool_fn}, reward_fn, prompt, generation_kwargs=generation_kwargs, )
然后生成一些数据,例如
tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]
,并使用queries, responses, masks, rewards, histories = env.run(tasks)
运行环境。环境将在提示中查找<call>
标记,并将工具输出附加到响应中;它还将返回与响应关联的掩码。您可以进一步使用histories
来可视化模型和工具之间的交互;histories[0].show_text()
将显示带颜色编码的工具输出的文本,而histories[0].show_tokens(tokenizer)
将显示可视化标记。最后,我们可以使用
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
训练模型。训练器将使用掩码在计算损失时忽略工具输出,请确保将该参数传递给step
。
实验结果
我们使用上述脚本训练了一个模型,使用了 10 个随机种子。您可以使用以下命令复现运行结果。如果您没有访问 slurm 集群的权限,可以随意删除--slurm-*
参数。
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/research_projects/tools/calculator.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
然后,我们可以使用openrlbenchmark
,它会生成以下图表。
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
如我们所见,虽然 1-2 个实验由于某种原因崩溃了,但大多数运行在计算器任务中都获得了接近完美的熟练度。
(早期实验 🧪):学习使用维基工具进行问答
在ToolFormer论文中,它展示了一个有趣的用例,该用例利用维基百科搜索工具来帮助回答问题。在本节中,我们尝试进行类似的实验,但使用强化学习来训练模型在TriviaQA数据集上使用维基工具。
请注意,许多设置都不同,因此结果不可直接比较。
构建搜索索引
由于ToolFormer没有开源,因此我们需要首先复制搜索索引。在他们的论文中提到,作者使用了一个基于 BM25 的检索器构建了搜索索引,该检索器对来自KILT的维基百科转储进行索引。
幸运的是,pyserini
已经实现了 BM25 检索器,并为 KILT 维基百科转储提供了一个预构建的索引。我们可以使用以下代码搜索索引。
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
然后,我们基本上将此代码段部署为一个 Hugging Face 空间这里,以便我们稍后将其用作transformers.Tool
。
实验设置
我们使用以下设置
- 使用
bigcode/starcoderbase
模型作为基础模型。 - 使用
pyserini-wikipedia-kilt-doc
空间作为维基工具,并且仅使用搜索结果的第一段,允许TextEnvironment
从工具获取最多max_tool_reponse=400
个响应标记。 - 测试响应是否包含答案字符串,如果是,则奖励 1,否则奖励 0。
- 请注意,这是一个简化的评估标准。在ToolFormer中,作者检查响应的前 20 个词是否包含正确答案。
- 使用了以下演示了维基工具用法的提示。
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
结果与讨论
我们的实验表明,智能体可以学习使用维基工具来回答问题。学习曲线大多会上涨,但其中一个实验确实崩溃了。
您可以点击此处查看 Wandb 报告以进行进一步检查。
请注意,训练模型的正确率处于较低水平,这可能是由于以下原因造成的
- 搜索错误:当给出问题“布鲁斯·威利斯的真名是什么?”时,如果模型搜索“布鲁斯·威利斯”,我们的维基工具会返回“帕特里克·波瓦耶(1948 年 2 月 18 日出生)是一位法国演员。他尤其以他的声音而闻名:自 1988 年以来,他一直是布鲁斯·威利斯的法语配音演员。”但正确的搜索应该是“沃尔特·布鲁斯·威利斯(1955 年 3 月 19 日出生)是一位美国前演员。他凭借在喜剧剧情系列《月光光》中(1985-1989 年)的主角而声名鹊起,并出演了 100 多部电影,在《虎胆龙威》系列电影(1988-2013 年)中饰演约翰·麦克莱恩以及其他角色后,获得了动作英雄的认可。[1][2]”
响应过长:维基工具默认情况下有时会输出非常长的序列。例如,当维基工具搜索“布朗法案”时。
我们的维基工具会返回“拉尔夫·M·布朗法案位于加利福尼亚州政府法典第 54950 条“等”中,是加利福尼亚州立法机构的一项法案,由议员拉尔夫·M·布朗起草并于 1953 年通过,该法案保证公众有权参加和参与地方立法机构的会议。”
ToolFormer的维基工具返回“拉尔夫·M·布朗法案是加利福尼亚州立法机构的一项法案,该法案保证公众有权参加和参与地方立法机构的会议。”更简洁。
(早期实验 🧪):使用 Python 解释器解决数学难题
在本节中,我们尝试训练模型使用 Python 解释器来解决数学难题。大致的想法是向智能体提供如下提示
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
训练实验可以在这里找到:https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
< > GitHub 更新