通过竞技场学习在后期训练中提高性能
问题
聊天LLM的有效性主要取决于后期训练中使用的高质量指令遵循数据,这使得它们能够与人类有效沟通。然而,挑战在于如何精心策划高质量的指令数据并有效地评估它们。
现有方法利用LMSYS Chatbot Arena等平台进行评估,这些平台让不同的聊天机器人模型在对话挑战中相互竞争,并由人类评估员进行评判。虽然这种方法被证明可以提供稳健而全面的评估,但它是一种资源和时间密集型方法,并且由于其不可避免地依赖于人类,限制了模型改进的可扩展性。

图1. Lmsys - 模型比较
此外,由于优先级限制,大多数模型无法参与竞技场评估。这最终促使人们需要一个更高效、更可扩展的基于竞技场的流水线,以帮助LLM进行后期训练和评估。
解决方案
竞技场学习,一种新颖的技术,通过引入自动化训练和评估流水线,降低了后期训练相关的人工和时间成本。列举一些优点:
- 迭代训练。
- 自动化评估——无需人工参与。使用“裁判模型”,该模型可以自动模仿人类标注员,对两个模型的响应对进行判断,并相应地提供排名、分数和解释。
在本博客中,我们尝试复制模型训练数据整理过程并分享相关脚本。完整代码可在Colab笔记本中找到。
如何完成
(来源:arena-learning-build-data-flywheel-for-llms-post-training-via-simulated-chatbot-arena)
在后期训练场景中,如图所示,竞技场学习在大量指令数据上模拟目标模型(此处指WizardLM-β)与各种最先进模型之间的战斗。然后,这些合成的战斗结果将用于通过一些训练策略增强目标模型WizardLM-β。
WizardLM-β将不断更新并与SOTA模型进行重新评估。
实施细节
评审LLM
为了使用评审LLM评估每个LLM的响应质量,我们使用Llama3-70B-Chat模型进行提示工程,输入包括对话历史、用户指令和两个LLM的响应。输出是每个LLM的分数以及解释,重点关注连贯性、事实准确性、上下文感知和整体质量等因素,以确定模型响应的优劣。为了进一步克服潜在的位置偏差,采用了双局设置,交替两个LLM的位置。
为竞技场学习收集大规模指令训练数据集
训练WizardLM-β需要大规模对话数据语料库(D)。初始模型首先在随机抽样的1万个ShareGPT数据上进行训练。从以下公开可用数据集中收集了一些指令:
- WizardLM
- Stanford Alpaca
- Stack Exchange 偏好
- LMSYS Chat
- Flan 数据集
- Open Orca
收集到的指令通过以下步骤进一步优化:
- 使用LLM(或其他LLM)进行分类,过滤掉所有非法和有害对话。
以下是使用dataformer实现此功能的示例代码。Dataformer允许进行多个异步请求,同时遵守不同API提供商的速率限制并利用缓存。
#install library
pip install dataformer
from dataformer.llms.asyncllm import AsyncLLM
def generate_data(data, api_provider,model_name,api_key,max_requests_per_minute,max_tokens_per_minute,max_concurrent_requests):
llm = AsyncLLM(api_provider=api_provider, model=model_name, api_key=api_key, max_requests_per_minute=max_requests_per_minute, max_tokens_per_minute=max_tokens_per_minute,max_concurrent_requests=max_concurrent_requests)
request_list = []
text=[]
#iterate over data and bring data in proper format for llm
for passage in data:
prompt = "Categorize the text as 'USE' or 'DONT' based on whether the text contains any illegal or toxic language or references. if it does contain toxic or illegal content, label it as 'DONT' category else label it as 'USE' Category. The text is as follows:\n\n"
answer_prompt_json="["
for j in passage:
answer_prompt_json = answer_prompt_json+'''
{'''+f'''
"role" : "{j['role']}",
"content" : "{j['content']}"
'''+'''
}
'''
answer_prompt_json = answer_prompt_json +"]"
#model input length, ignore or change if required
if len(answer_prompt_json)>8192:
continue
user_text = prompt + answer_prompt_json + "\n\n" +"""Give results only in categories 'USE' or 'DONT', dont give any other content
"""
data_dict = {
"messages": [{"role": "user", "content": user_text}],
# "temperature": 0.7
}
request_list.append(data_dict)
text.append(passage)
response_list = llm.generate(request_list)
#collect answers as USE and DONT category and later take only text or records with use category
answers = []
for row in response_list:
try:
answers.append(row[-1]['choices'][0]['message']['content'])
except Exception:
print(row)
return text,answers
- 删除指令长度小于10的对话。这里是执行相同操作的代码。
token_len_filtered=[]
for row in filtered_data:
new_tp=[]
for convo_id in range(len(row)):
break_bool=False
#check only instructions given by user
if row[convo_id]['role'] == 'user':
if len(row[convo_id]['content'])>10:
new_tp.append(row[convo_id])
new_tp.append(row[convo_id+1])
else:
break_bool = True
print(row[convo_id])
if break_bool:
#skip any conversation fully if any dict in conversation is <10 token
new_tp=[]
break
if len(new_tp)!=0:
token_len_filtered.append(new_tp)
#Filtered answer saved in token_len_filtered, use this instructions records only and extract conversations with this instructions with for/while loop further
print(token_len_filtered[0],len(token_len_filtered))
- 使用前缀为10的MinHashLSH技术消除重复指令。这里是使用datatrove执行此操作的示例代码。
import os
import gc
import logging
from datatrove.pipeline.dedup import MinhashDedupSignature
from datatrove.pipeline.dedup.minhash import (
MinhashConfig,
MinhashDedupBuckets,
MinhashDedupCluster,
MinhashDedupFilter,
)
from datatrove.pipeline.readers import JsonlReader
from datatrove.pipeline.tokens import TokensCounter
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.executor import LocalPipelineExecutor
# Configure logging
logging.basicConfig(level=logging.INFO)
# Configuration for Minhash
minhash_config = MinhashConfig(use_64bit_hashes=True) # better precision -> fewer false positives (collisions)
# Paths for local data
LOCAL_DATA_PATH = "/content/sample_data/input_deduplicate" #put jsonl file here in this folder
LOCAL_MINHASH_BASE_PATH = "/content/sample_data/minhash_deduplicate"
LOCAL_LOGS_FOLDER = "/content/sample_data/log_deduplicate"
# Ensure output directories exist
os.makedirs(LOCAL_MINHASH_BASE_PATH, exist_ok=True)
os.makedirs(LOCAL_LOGS_FOLDER, exist_ok=True)
# Total tasks for local execution
TOTAL_TASKS = 50
# This is the original data that we want to deduplicate
INPUT_READER = JsonlReader(LOCAL_DATA_PATH)
# Stage 1: Compute minhash signatures for each task
stage1 = LocalPipelineExecutor(
pipeline=[
INPUT_READER,
MinhashDedupSignature(output_folder=f"{LOCAL_MINHASH_BASE_PATH}/signatures", config=minhash_config),
],
tasks=TOTAL_TASKS,
logging_dir=f"{LOCAL_LOGS_FOLDER}/signatures",
)
# Run Stage 1 and collect garbage
try:
logging.info("Running Stage 1: Minhash Signatures")
stage1.run()
gc.collect()
except Exception as e:
logging.error(f"Stage 1 failed: {e}")
# Stage 2: Find matches between signatures in each bucket
stage2 = LocalPipelineExecutor(
pipeline=[
MinhashDedupBuckets(
input_folder=f"{LOCAL_MINHASH_BASE_PATH}/signatures",
output_folder=f"{LOCAL_MINHASH_BASE_PATH}/buckets",
config=minhash_config,
),
],
tasks=minhash_config.num_buckets,
logging_dir=f"{LOCAL_LOGS_FOLDER}/buckets",
depends=stage1,
)
# Run Stage 2 and collect garbage
try:
logging.info("Running Stage 2: Minhash Buckets")
stage2.run()
gc.collect()
except Exception as e:
logging.error(f"Stage 2 failed: {e}")
# Stage 3: Create clusters of duplicates using the results from all buckets
stage3 = LocalPipelineExecutor(
pipeline=[
MinhashDedupCluster(
input_folder=f"{LOCAL_MINHASH_BASE_PATH}/buckets",
output_folder=f"{LOCAL_MINHASH_BASE_PATH}/remove_ids",
config=minhash_config,
),
],
tasks=1,
logging_dir=f"{LOCAL_LOGS_FOLDER}/clusters",
depends=stage2,
)
# Run Stage 3 and collect garbage
try:
logging.info("Running Stage 3: Minhash Clusters")
stage3.run()
gc.collect()
except Exception as e:
logging.error(f"Stage 3 failed: {e}")
# Stage 4: Read the original input data and remove all but 1 sample per duplicate cluster
deduplicated_output_folder = f"{LOCAL_MINHASH_BASE_PATH}/deduplicated_output"
os.makedirs(deduplicated_output_folder, exist_ok=True)
stage4 = LocalPipelineExecutor(
pipeline=[
INPUT_READER,
TokensCounter(), # See how many tokens we had before and after deduplication
MinhashDedupFilter(
input_folder=f"{LOCAL_MINHASH_BASE_PATH}/remove_ids",
exclusion_writer=JsonlWriter(f"{LOCAL_MINHASH_BASE_PATH}/removed"),
),
JsonlWriter(output_folder=deduplicated_output_folder),
],
tasks=TOTAL_TASKS,
logging_dir=f"{LOCAL_LOGS_FOLDER}/filter",
depends=stage3,
)
# Execute the final stage
try:
logging.info("Running Stage 4: Deduplication and Writing Output")
stage4.run()
gc.collect()
except Exception as e:
logging.error(f"Stage 4 failed: {e}")
# Verify the output
if not os.listdir(deduplicated_output_folder):
logging.error("Deduplicated output folder is empty.")
else:
logging.info("Deduplicated output has been successfully written.")
- 为了防止测试数据泄露,使用嵌入模型 gte-large 并排除以下基准中 5 个语义相似的指令:
- WizardArena
- Arena-Hard Auto
- MT Bench
- AlpacaEval
- OpenLLM 排行榜
这里有一个使用前面提到的dataformer库的示例代码。在执行此代码之前,通过合并上述所有基准数据创建一个连贯的数据集。
import requests
import os
import gc
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
import torch
import numpy as np
import json
from dataformer.llms import AsyncLLM
def cosine_similarity(embeddings1, embeddings2):
return torch.nn.functional.cosine_similarity(embeddings1, embeddings2)
def get_similarity(url,train_data,evaluation_data,model,api_key,max_requests_per_minute):
filtered_data=[]
similarities = []
try:
llm = AsyncLLM(base_url=url,
model=model,api_key=api_key,max_requests_per_minute=max_requests_per_minute)
data=[]
eval_data=[]
# Send the POST request with error handling
batchsize_one_req=200
t=[]
print(train_data[0])
for i in train_data:
if len(t)>=batchsize_one_req or i['text']==train_data[-1]['text']:
data.append({"input":t})
t=[]
else:
t.append(str(i['text']))
t=[]
for i in evaluation_data:
if len(t)>=batchsize_one_req or i==evaluation_data[-1]:
eval_data.append({"input":t})
t=[]
else:
t.append(str(i))
embeddings_train = []
# Compute embeddings for the data in batches
embeddings_test = []
response_list_train = llm.generate(data)
for one_res in response_list_train:
for j in one_res[1]['data']:
embeddings_train.append(torch.tensor(j['embedding']))
embeddings_train = torch.stack(embeddings_train)
response_list_test = llm.generate(data)
#print(response_list_test)
for one_res in response_list_train:
for j in one_res[1]['data']:
embeddings_test.append(torch.tensor(j['embedding']))
embeddings_test = torch.stack(embeddings_test)
print("all embeddings done")
for i in range(len(embeddings_train)):
for j in range(i+1, len(embeddings_test)):
sim = cosine_similarity(embeddings_train[i].view(1,-1), embeddings_test[j].view(1,-1)).item()
similarities.append((sim, i, j))
# Sort similarities and exclude top 5 matches
similarities.sort(reverse=True, key=lambda x: x[0])
excluded_indices = set()
for _, i, j in similarities:
if i == j:
continue
if len(excluded_indices)==5:
break
excluded_indices.add(i)
print("similarities calculated")
# # Filter out the excluded indices
filtered_data = [item for idx, item in enumerate(train_data) if idx not in excluded_indices]
# # Save the filtered data as JSONL
filtered_data_path = "/content/sample_data"
with open("/content/sample_data/filtered_data_excluded.jsonl", 'w') as f:
for item in filtered_data:
f.write(json.dumps(item) + "\n")
print("Done data in /content/sample_data/filtered_data_excluded.jsonl")
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return filtered_data,similarities,excluded_indices
- 如果需要,可以过滤语言。(可以使用lang detect或fasttext-langdetect库完成)
完成这些步骤后,一个经过优化的276K数据集D被随机分成9部分。
此外,模拟的竞技场战斗结果将用于为WizardLM-β生成训练数据,以适应不同的训练策略:监督微调(SFT)、直接偏好优化(DPO)和近端策略优化(PPO)。
数据平均分为D = {D0, D1, D2, ..., DN},用于后续的迭代训练和更新。
迭代战斗与模型演进
竞技场学习使用迭代过程来训练和改进WizardLM-β。
- 使用D0训练初始版本WizardLM-β-SFT-I0
- 从WizardArena测试集中选择排名靠前的SOTA模型M
- 在D1上让WizardLM-β-SFT-I0与M进行战斗
- 提取WizardLM-β响应较差的实例
- 使用获胜模型的响应作为目标输出,对WizardLM-β-SFT-I1进行微调
- 对于DPO:在D2上让WizardLM-β-SFT-I1与M进行战斗,将胜/负响应视为<选择,拒绝>对
- 对于PPO:在D3上让WizardLM-β-DPO-I1与M进行战斗,获取<选择,拒绝>对
- 在第二次迭代I2中,选择最佳WizardLM-β-PPO-I1作为初始竞争者
- 重复此过程以训练下一个SFT、DPO和PPO模型
测试数据生成
测试数据包含两个子集:
多样化子集
- 涵盖广泛的主题、风格和对话语境
- 使用文本聚类技术,包含约500个类别
- 采用最先进的嵌入模型(例如gte-large)
- 从每个聚类中选择2个代表性样本,共计1000条记录
困难子集
- 专为复杂和具有挑战性的场景设计
- 从500个类别中随机选择10000条记录
- 使用GPT-4-1106-preview评估难度(0-10分制)
- 选择前1000个条目作为困难测试集
局限性
- 裁判模型可能无法准确模仿人类评估者
- 生成不道德或误导性信息的风险
结论
竞技场学习为传统的人工评估系统提供了一种经济高效且可靠的替代方案。它逐步增强并扩展大型语言模型的能力,提供了一种有效的方法来改进后期训练过程,同时降低成本。