Llama-3.1-Storm-8B:通过自我筛选+模型合并改进的 SLM

作者:Ashvini Kumar Jindal、Pawan Kumar Rajpoot、Ankur Parikh、Akshita Sukhlecha
动机
在语言模型微调中,数据质量至关重要,特别是对于参数量不超过 8B 的小型语言模型(SLM)。我们对这一概念的探索始于 NeurIPS LLM 效率挑战赛 2023,参赛者需要在 24 小时内在一台普通 GPU 上微调一个开源 LLM。
我们的方法,凭借数据筛选赢得了第一名 🏅。 我们从约 500 万个开源示例中,通过“自我筛选”——利用模型识别有价值的训练示例,筛选出约 20 万个高质量样本。这种方法被证明非常有效,在有限资源下取得了显著的改进。详情请参阅我们的论文:Birbal:使用精选数据集微调的高效 7B 指令模型
在此成功的基础上,我们改进了技术,专注于训练数据的自我筛选,以高效地增强 SLM。本文介绍了我们的最新工作,该工作结合了两种自我筛选方法、有针对性的监督微调(SFT)和模型合并,在各种基准测试中显著优于 Llama-3.1-8B-Instruct 和 Hermes-3-Llama-3.1-8B。
概述
我们展示了 Llama-3.1-Storm-8B 模型,该模型在各种基准测试中显著优于 Meta AI 的 Llama-3.1-8B-Instruct 和 Hermes-3-Llama-3.1-8B 模型,如下一节的性能比较图所示。我们的方法包括三个关键步骤:
- 自我筛选:我们应用了两种自我筛选方法,从约 280 万个开源示例中选出了约 100 万个高质量示例。我们的筛选标准侧重于教育价值和难度级别,使用相同的 SLM 进行标注,而不是使用更大的模型(例如 70B、405B)。
- 有针对性的微调:我们对 Llama-3.1-8B-Instruct 模型进行了基于 Spectrum 的有针对性微调。Spectrum 方法通过根据信噪比 (SNR) 有选择地定位层模块来加速训练,并冻结其余模块。在我们的工作中,50% 的层被冻结。
- 模型合并:我们使用 SLERP 方法将我们微调后的模型与 Llama-Spark 模型合并。合并方法产生一个混合模型,其特性从两个父模型平滑插值而来,确保所得模型捕获其两个父模型的精髓。Llama-3.1-Storm-8B 在 10 个不同的基准测试中改进了 Llama-3.1-8B-Instruct。这些基准测试涵盖指令遵循、知识驱动的问答、推理、真实答案生成和函数调用等领域。
🏆 隆重推出 Llama-3.1-Storm-8B
Llama-3.1-Storm-8B 基于 Llama-3.1-8B-Instruct,旨在增强 8B 参数模型的对话和函数调用能力。
如上图左侧子图所示,Llama-3.1-Storm-8B 模型在各项基准测试中均优于 Meta-Llama-3.1-8B-Instruct,包括:指令遵循 (IFEval)、知识驱动问答基准 (GPQA、MMLU-Pro)、推理 (ARC-C、MuSR、BBH)、减少幻觉 (TruthfulQA) 和函数调用 (BFCL)。对于计算资源有限的 AI 开发者和爱好者来说,这项改进尤其重要。
我们还使用最近发布的基于 Llama-3.1-8B-Instruct 的模型 Hermes-3-Llama-3.1-8B 对我们的模型进行了基准测试。如上图右侧子图所示,Llama-3.1-Storm-8B 在 9 个基准测试中,有 7 个均优于 Hermes-3-Llama-3.1-8B,Hermes-3-Llama-3.1-8B 在 MuSR 基准测试中优于 Llama-3.1-Storm-8B,而这两个模型在 BBH 基准测试中表现相当。
Llama-3.1-Storm-8B 模型优势
Llama-3.1-Storm-8B 是一款强大的通用模型,适用于各种应用。我们邀请 AI 社区探索 Llama-3.1-Storm-8B,并期待它在各种项目和应用中发挥作用。
模型优势 | 相关基准 |
🎯 改进的指令遵循 | IFEval 严格 (+3.93%) |
🌐 增强的知识驱动问答 | GPQA (+7.21%)、MMLU-Pro (+0.55%)、AGIEval (+3.77%) |
🧠 更好的推理能力 | ARC-C (+3.92%)、MuSR (+2.77%)、BBH (+1.67%)、AGIEval (+3.77%) |
🤖 卓越的代理能力 | BFCL:整体准确率 (+7.92%),BFCL:AST 摘要 (+12.32%) |
🚫 减少幻觉 | TruthfulQA (+9%) |
注:所有改进均相对于 Meta-Llama-3.1-8B-Instruct 的绝对增益。
Llama-3.1-Storm-8B 模型
BF16
: Llama-3.1-Storm-8B- ⚡
FP8
: Llama-3.1-Storm-8B-FP8-Dynamic - ⚡
GGUF
: Llama-3.1-Storm-8B-GGUF - 🚀 Ollama:
ollama run ajindal/llama3.1-storm:8b
💻 如何使用模型
🚀 启动 Llama-3.1-Storm-8B Colab Notebook
Hugging Face transformers
库默认以 bfloat16
格式加载模型。这也是 Llama-3.1-Storm-8B 检查点所使用的类型,因此建议使用此方式运行以确保最佳结果。
安装
pip install --upgrade "transformers>=4.43.2" torch==2.3.1 accelerate vllm==0.5.3.post1
开发者可以使用 Transformers 和 vLLM 等流行库轻松将 Llama-3.1-Storm-8B 集成到他们的项目中。以下部分通过简单的实际示例说明了用法。
对话用例
与 🤗 Transformers 一起使用
使用 transformers.pipeline()
API
import transformers
import torch
model_id = "akjindal53244/Llama-3.1-Storm-8B"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
outputs = pipeline(messages, max_new_tokens=128, do_sample=True, temperature=0.01, top_k=100, top_p=0.95)
print(outputs[0]["generated_text"][-1]) # Expected Output: {'role': 'assistant', 'content': '2 + 2 = 4'}
使用 model.generate()
API
pip install flash_attn==2.6.3
import torch
from transformers import AutoTokenizer, LlamaForCausalLM
# Apply Llama3.1 chat-template
def format_prompt(user_query):
template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
return template.format(user_query)
model_id = 'akjindal53244/Llama-3.1-Storm-8B'
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
load_in_8bit=False,
load_in_4bit=False,
use_flash_attention_2=True
)
# Build final input prompt after applying chat-template
prompt = format_prompt("What is 2+2?")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
generated_ids = model.generate(input_ids, max_new_tokens=128, temperature=0.01, do_sample=True, eos_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(response) # Expected Output: '2 + 2 = 4'
与 vLLM 一起使用
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
model_id = "akjindal53244/Llama-3.1-Storm-8B" # FP8 model: "akjindal53244/Llama-3.1-Storm-8B-FP8-Dynamic"
num_gpus = 1
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = LLM(model=model_id, tensor_parallel_size=num_gpus)
sampling_params = SamplingParams(max_tokens=128, temperature=0.01, top_k=100, top_p=0.95)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize = False)
print(llm.generate([prompt], sampling_params)[0].outputs[0].text.strip()) # Expected Output: 2 + 2 = 4
与 LitGPT 一起使用
pip install 'litgpt[all]'
litgpt download akjindal53244/Llama-3.1-Storm-8B --model_name meta-llama/Meta-Llama-3.1-8B
from litgpt import LLM
llm = LLM.load(model="akjindal53244/Llama-3.1-Storm-8B")
llm.generate("What do Llamas eat?")
函数调用用例
与 Meta-Llama-3.1-8B-Instruct 相比,Llama-3.1-Storm-8B 具有令人印象深刻的函数调用能力,这在 BFCL 基准测试中得到了证明。
函数调用提示格式
Llama-3.1-Storm-8B 经过特定系统提示训练以进行函数调用
You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.
Here are the available functions:
<tools>LIST_OF_TOOLS</tools>
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags in the format:
<tool_call>{"tool_name": <function-name>, "tool_arguments": <args-dict>}</tool_call>
以上系统提示应与传递 LIST_OF_TOOLS
作为输入一起使用。
与 vLLM 一起使用
import json
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
model_id = "akjindal53244/Llama-3.1-Storm-8B" # FP8 model: "akjindal53244/Llama-3.1-Storm-8B-FP8-Dynamic"
num_gpus = 1
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = LLM(model=model_id, tensor_parallel_size=num_gpus)
sampling_params = SamplingParams(max_tokens=128, temperature=0.01, top_k=100, top_p=0.95)
def create_system_prompt(tools_list):
system_prompt_format = """You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.
Here are the available functions:
<tools>{}</tools>
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags in the format:
<tool_call>{"tool_name": <function-name>, "tool_arguments": <args-dict>}</tool_call>"""
# Convert the tools list to a string representation
tools_str = json.dumps(tools_list, ensure_ascii=False)
# Format the system prompt with the tools list
system_prompt = system_prompt_format.format(tools_str)
return system_prompt
# Example tools list
tools_list = [
{
"name": "peers",
"description": "Retrieves a list of company peers given a stock symbol.",
"parameters": {
"symbol": {
"description": "The stock symbol for the company.",
"type": "str",
"default": ""
}
}
},
{
"name": "web_chain_details",
"description": "python",
"parameters": {
"chain_slug": {
"description": "The slug identifier for the blockchain (e.g., 'ethereum' for Ethereum mainnet).",
"type": "str",
"default": "ethereum"
}
}
}
]
# Create the system prompt with the tools list
system_prompt = create_system_prompt(tools_list)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "I need to understand the details of the Ethereum blockchain for my cryptocurrency project. Can you fetch the details for 'ethereum'?"}
]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize = False)
print(llm.generate([prompt], sampling_params)[0].outputs[0].text.strip()) # Expected Output: <tool_call>{'tool_name': 'web_chain_details', 'tool_arguments': {'chain_slug': 'ethereum'}}</tool_call>
与 Ollama 一起使用
import ollama
tools = [{
'type': 'function',
'function': {
'name': 'get_current_weather',
'description': 'Get the current weather for a city',
'parameters': {
'type': 'object',
'properties': {
'city': {
'type': 'string',
'description': 'The name of the city',
},
},
'required': ['city'],
},
},
},
{
'type': 'function',
'function': {
'name': 'get_places_to_vist',
'description': 'Get places to visit in a city',
'parameters': {
'type': 'object',
'properties': {
'city': {
'type': 'string',
'description': 'The name of the city',
},
},
'required': ['city'],
},
},
},
]
response = ollama.chat(
model='ajindal/llama3.1-storm:8b',
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'What is the weather in Toronto and San Francisco?'}
],
tools=tools
)
print(response['message']) # Expected Response: {'role': 'assistant', 'content': "<tool_call>{'tool_name': 'get_current_weather', 'tool_arguments': {'city': 'Toronto'}}</tool_call>"}
Llama-3.1-Storm-8B 背后原理
本节详细介绍了我们用于创建 Llama-3.1-Storm-8B 的三步配方
自我筛选
- 来源数据集:我们选择了 5 个开源数据集(The-Tome、agent-data、Magpie-Llama-3.1-Pro-300K-Filtered、openhermes_200k_unfiltered、Llama-3-Magpie-PO-100K-SML)。这些数据集总共包含约 280 万个示例。
- 数据筛选涉及为每个示例分配值,然后根据分配的值做出选择决策。通常,LLM 或机器学习模型用于分配这些值。有许多方法可以使用 LLM 为示例分配值。评估示例的两个最流行的值是教育价值和难度级别。教育价值决定了示例(指令 + 响应)的价值或信息量,难度级别决定了示例(指令 + 响应)的难度。教育价值范围从 1 到 5,其中 5 表示信息量最大,1 表示信息量最少。有 3 个难度级别——简单、中等和困难。由于我们的目标是在自我筛选框架下改进 SLM,我们专注于使用相同的模型——Llama-3.1-8B-Instruct,而不是使用更大的 LLM,例如 Llama-3.1-70B-Instruct、Llama-3.1-405B-Instruct 等。
- 自我筛选步骤
- 步骤 1:基于教育价值的筛选
- 我们使用 Llama-3.1-8B-Instruct 进行零样本推理,为约 280 万个示例分配教育价值(1-5)。接下来,我们选择教育价值 >=3 的示例,并删除其余示例。我们遵循了 FineWeb-Edu 数据集 的方法。这将总示例数从约 280 万减少到约 130 万。
- 步骤 2:基于难度级别的筛选
- 我们使用 Llama-3.1-8B-Instruct 对步骤 1 中的约 130 万个示例进行零样本推理,以分配难度级别(简单、中等和困难)。经过初步实验,我们选择了中等和困难级别的示例,并移除了其余示例。该策略类似于 Llama-3.1 技术报告 中描述的数据修剪方法。其中约 65 万个示例为中等难度,约 32.5 万个示例为困难难度。
- 步骤 1:基于教育价值的筛选
- 我们最终筛选出的数据集包含约 97.5 万个示例。我们将其分为约 96 万个训练示例和约 1.5 万个验证示例。
有针对性的监督指令微调
- 我们基于自我筛选的模型在 Llama-3.1-8B-Instruct 模型上,使用约 96 万个示例进行了 4 个 epoch 的微调。
- 我们采用了 Spectrum,一种有针对性的微调方法,以缩短训练时间、降低内存消耗并减少灾难性遗忘的风险。Spectrum 通过根据信噪比 (SNR) 选择性地训练特定层来优化 LLM 的训练过程。Spectrum 的核心概念简单而高效。它不是在训练期间更新模型的每一层,而是识别并优先处理对性能改进贡献最大的层(高 SNR),而低 SNR 的层则保持冻结。
- 在我们基于 Spectrum 的完全微调过程中,50% 的层被冻结。
模型合并
- 模型合并 的效果令人惊艳,已在 开放 LLM 排行榜 上产生了许多最先进的模型。受此启发,我们决定将我们基于自我筛选微调的模型与 Llama-Spark 模型合并,后者是 Llama-3.1-8B-Instruct 的衍生模型。
- 我们使用 SLERP 方法合并上述两个模型。SLERP 合并方法生成的混合模型,其特性平滑地插值自两个父模型,确保所得模型捕获了其两个父模型的精髓。
- 在我们的基准测试中,我们的自筛选 SFT 模型平均表现优于 Llama-Spark 模型。然而,合并后的模型比这两个模型中的任何一个都表现更好。
自我筛选和模型合并的影响
如上图所示,基于自我筛选的 SFT 方法在 10 个基准测试中,有 7 个表现优于 Llama-3.1-8B-Instruct,这强调了筛选高质量示例的关键作用。此外,这些结果表明,选择合适的模型进行合并可以进一步提高在评估基准上的性能。
展望未来
我们希望通过我们的自我筛选和模型合并方法改进其他 SLM,例如 Gemma-2、Phi-3 和 Qwen2。我们正在探索各种模型合并技术及其对模型能力的影响。我们的目标是继续为 AI 社区提供有价值的工具,特别是那些计算资源有限的社区。我们将在近期发布提示和精选数据集。
对齐说明
尽管 Llama-3.1-Storm-8B 未经过明确的模型对齐过程,但它可能仍然保留了从 Meta-Llama-3.1-8B-Instruct 模型继承的一些对齐属性。
致谢
我们感谢 Sebastian Raschka、Mark Saroufim、Lewis Tunstall、Maxime Labonne、Prateek Yadav 和 Dipanjan Sarkar 提供的宝贵反馈。我们衷心感谢 Lambda Labs 为这项工作提供计算资源赞助。
引用我们的工作
@misc {ashvini_kumar_jindal_2024,
author = { {Ashvini Kumar Jindal, Pawan Kumar Rajpoot, Ankur Parikh, Akshita Sukhlecha} },
title = { Llama-3.1-Storm-8B },
year = 2024,
url = { https://huggingface.co/akjindal53244/Llama-3.1-Storm-8B },
doi = { 10.57967/hf/2902 },
publisher = { Hugging Face }
}
支持我们的工作
我们的团队由三名成员组成,他们分布在三个不同的时区,我们赢得了 NeurIPS LLM 效率挑战赛 2023 和金融与阿拉伯语 LLM 领域的其他四项比赛。我们还发布了 SOTA 数学推理模型。
Llama-3.1-Storm-8B 是我们迄今为止对开源社区最有价值的贡献。我们致力于开发高效的通用 LLM。我们正在寻找计算资源和创新合作者来推动这项倡议。
附录
本节详细介绍了我们的评估设置,包括重现所有模型结果的分步说明。
评估框架
我们使用了 lm-eval-harness 工具包,这是一个广泛用于 LLM 评估的开源项目。此选择有助于在不同模型和研究工作中实现一致的比较。
所有模型在每个基准测试中均使用相同的代码库和脚本进行评估,从而消除了因实现差异而可能产生的潜在差异。我们使用了 HF Open LLM Leaderboard Branch 进行大部分评估,确保了版本的一致性。我们提供了每个基准测试中使用的确切脚本,允许任何人重现结果。
以下是每个基准测试使用的脚本:
# IFEval
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --device cuda:0 --tasks leaderboard_ifeval --batch_size 32 --apply_chat_template --fewshot_as_multiturn
# BBH (Big-Bench Hard)
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --device cuda:0 --tasks leaderboard_bbh --batch_size 32 --apply_chat_template --fewshot_as_multiturn --num_fewshot 3
# GPQA
accelerate launch -m lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_gpqa --batch_size 4 --apply_chat_template --fewshot_as_multiturn
# MMLU-Pro
accelerate launch -m lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_mmlu_pro --batch_size 32 --apply_chat_template --fewshot_as_multiturn --num_fewshot 5
# Math Level-5
accelerate launch -m lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_math_hard --batch_size 32 --apply_chat_template --fewshot_as_multiturn --num_fewshot 4
# MuSR
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_musr --device cuda:0 --batch_size 32 --apply_chat_template --fewshot_as_multiturn
# ARC-C
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks arc_challenge --device cuda:0 --batch_size 32 --num_fewshot 0 --apply_chat_template
# TruthfulQA
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks truthfulqa_mc2 --device cuda:0 --batch_size 128 --apply_chat_template --fewshot_as_multiturn
对于 AGIEval,我们使用 lm-eval-harness 的主分支,因为 agieval_nous
任务在 HF 排行榜分支中不可用。
# AGIEval
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks agieval_nous --device cuda:0 --batch_size 32 --apply_chat_template
注意:给定基准的所有模型都使用相同的脚本和硬件配置进行评估。这种方法消除了批处理大小、GPU 数量或其他硬件相关因素差异可能产生的潜在副作用。
BFCL
我们对 Llama-3.1-Storm-8B 评估使用了以下提示:
You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.
Here are the available functions:
<tools>{}</tools>
Follow the below guidelines:
1. If any tool needed to answer the query is not available, you must return an empty list "[]" as response.
2. Else if query does not provide any must-have argument of a required tool, you must return an empty list "[]" as response.
3. Else, for each function call you must return a json object in response with function name and arguments within <tool_call></tool_call> XML tags in the format:
<tool_call>{"tool_name": <function-name>, "tool_arguments": <args-dict>}</tool_call>
我们对 Llama-3.1-Storm-8B 的实验揭示了一个有趣的能力:Llama-3.1-Storm-8B 能够准确处理所需工具或参数缺失的情况,尽管它并未专门针对此类场景进行训练。通过简单地向 BFCL 提示添加两条直接的指令,我们利用模型增强的指令遵循能力来解决这些边缘情况。这表明 Llama-3.1-Storm-8B 仅通过提示工程就可以处理许多用例,这归因于其强大的指令遵循能力。