欢迎 Gemma 2 - Google 全新的开放式 LLM

发布于 2024 年 6 月 27 日
在 GitHub 上更新

Google 发布了 Gemma 2,这是其最先进的开放式 LLM 系列的最新成员,我们很高兴能与 Google 合作,确保其在 Hugging Face 生态系统中实现最佳集成。您可以在 Hub 上找到这 4 个开放权重模型(2 个基础模型和 2 个微调模型)。已发布的功能和集成包括:

目录

什么是 Gemma 2?

Gemma 2 是 Google 最新迭代的开放式 LLM。它有两种大小,90 亿和 270 亿参数,并提供基础(预训练)和指令微调版本。Gemma 基于 Google Deepmind Gemini,上下文长度为 8K token。

Gemma 2 模型的训练数据量是其第一代模型的两倍,其中 27B 版本总计 13 万亿 token,9B 版本总计 8 万亿 token,这些数据主要来源于网络数据(主要是英文)、代码和数学。我们不清楚训练混合的具体细节,只能猜测更大、更仔细的数据整理是性能改进的重要因素。

Gemma 2 采用与第一代相同的许可,这是一个宽松的许可,允许重新分发、微调、商业用途和衍生作品。

Gemma 2 的技术进步

Gemma 2 与第一代有许多相似之处。它具有 8192 个 token 的上下文长度,并使用旋转位置嵌入(RoPE)。与原始 Gemma 相比,Gemma 2 有四项主要改进:

  • 滑动窗口注意力:交错使用滑动窗口注意力和全二次注意力,以实现高质量生成。
  • Logit 软限制:通过将 logit 缩放到固定范围来防止其过度增长,从而改善训练。
  • 知识蒸馏:利用更大的教师模型来训练更小的模型(针对 9B 模型)。
  • 模型合并:将两个或更多 LLM 合并成一个新模型

Gemma 2 在 Google Cloud TPU (27B 使用 v5p9B 使用 TPU v4) 上训练,使用 JAXML Pathways。Gemma 2 Instruct 已针对对话应用进行优化,并使用监督微调 (SFT)、从大型模型进行蒸馏、使用更注重对话能力的奖励模型进行人工反馈强化学习 (RLHF) 以及使用 WARP 进行模型合并以提高整体性能,在合成和人工生成的提示-响应对的混合数据上进行了训练。

与预训练混合类似,关于微调数据集或与 SFT 和 RLHF 相关的超参数的详细信息尚未公开。

滑动窗口注意力

滑动窗口注意力是一种减少 transformer 模型注意力计算内存和时间需求的方法,已在 Mistral 等模型中使用。Gemma 2 的新颖之处在于,滑动窗口应用于每隔一层(局部 - 4096 个 token),而中间层仍使用全二次全局注意力(8192 个 token)。我们认为这是一种在长上下文情况下提高质量的方法(一半的层仍然关注所有 token),同时部分受益于滑动注意力的优势。

软限制和注意力实现

软限制是一种防止 logit 过度增长而不截断它们的技术。它的工作原理是将 logit 除以一个最大值阈值 (soft_cap),然后通过一个 tanh 层(确保它们在 (-1, 1) 范围内),最后再次乘以该阈值。这保证了最终值将在 (-soft_cap, +soft_cap) 区间内,而不会损失太多信息,但能稳定训练。

总而言之,logit 的计算公式为:logits ← soft_cap * tanh(logits/soft_cap)

Gemma 2 对最后一层和每个注意力层都采用了软限制。注意力 logit 限制在 50.0,最终 logit 限制在 30.0。

在发布时,软限制与 Flash Attention/SDPA 不兼容,但它们仍可在推理中使用以实现最大效率。Gemma 2 团队观察到在推理期间移除软限制时差异非常小。

注意:为了稳定的微调运行,您仍然需要启用软限制,因此,我们建议使用 eager 注意力而不是 SDPA 进行微调。

知识蒸馏

知识蒸馏是一种流行的技术,用于训练一个较小的*学生*模型来模仿一个较大但性能更好的*教师*模型的行为。它的工作原理是,通过使用来自教师模型(例如 GPT-4、Claude 或 Gemini)的 token 概率分布来增强 LLM 的下一个 token 预测任务,这为学生模型提供了更丰富的学习信号。

根据 Gemma 2 技术报告,知识蒸馏用于预训练 9B 模型,而 27B 模型则从头开始预训练。

对于后训练,Gemma 2 团队从一个教师模型(报告中未指定,但 presumably Gemini Ultra)生成了一组多样化的完成,然后用 SFT 在这些合成数据上训练学生模型。这是许多开放模型的基础,例如 ZephyrOpenHermes,它们完全基于大型 LLM 的合成数据进行训练。

尽管这种方法有效,但它也有缺点,因为学生和教师模型之间的容量不匹配可能导致*训练-推理不匹配*,即学生模型在推理过程中生成的文本与训练期间看到的文本处于分布之外。

为了解决这个问题,Gemma 2 团队使用了“按策略蒸馏”,其中学生模型从 SFT 提示生成完成。然后,这些完成用于计算教师和学生模型的 logit 之间的 KL 散度。通过在整个训练过程中最小化 KL 散度,学生模型学会准确地模拟教师模型的行为,同时最大限度地减少训练-推理不匹配。

这种方法非常有趣,正如我们在社区中看到的那样,在线 DPO 等按策略方法会产生更强大的模型,而按策略蒸馏的一个优点是,您只需要教师模型的 logit,因此您不需要依赖奖励模型或 LLM-as-a-judge 来改进模型。看到这种方法在未来几个月内是否会在微调者中变得更受欢迎,这将令人兴奋!

模型合并

模型合并是一种将两个或多个 LLM 合并成一个新模型的技术。它相对较新且仍在实验中,无需加速器即可使用。Mergekit 是一个流行的开源工具包,用于合并 LLM。它实现了线性、SLERP、TIES、DARE 和其他合并技术。

根据技术报告,Gemma 2 使用了 Warp,这是一种新的合并技术,它分三个不同阶段合并模型

  1. 指数移动平均 (EMA):在强化学习 (RL) 微调过程中应用。
  2. 球面线性插值 (SLERP):在多个策略的 RL 微调后应用。

    65
  3. 向初始化线性插值 (LITI):此阶段在 SLERP 阶段之后应用。

Gemma 2 评估

Gemma 模型表现如何?以下是根据技术报告和开放 LLM 排行榜新版本与其他开放模型的性能比较。

技术报告结果

Gemma 2 的技术报告比较了不同开放式 LLM 在之前的开放式 LLM 排行榜基准测试中的性能。

Llama 3 (70B) Qwen 1.5 (32B) Gemma 2 (27B)
MMLU 79.2 74.3 75.2
GSM8K 76.9 61.1 75.1
ARC-c 68.8 63.6 71.4
HellaSwag 88.0 85.0 86.4
Winogrande 85.3 81.5 83.7

该报告还比较了小型语言模型的性能。

基准测试 Mistral (7B) Llama 3 (8B) Gemma (8B) Gemma 2 (9B)
MMLU 62.5 66.6 64.4 71.3
GSM8K 34.5 45.7 50.9 62.3
ARC-C 60.5 59.2 61.1 68.4
HellaSwag 83.0 82.0 82.3 81.9
Winogrande 78.5 78.5 79.0 80.6

开放 LLM 排行榜结果

注意:我们目前正在新的开放 LLM 排行榜基准上单独评估 Google Gemma 2,并将在今天晚些时候更新此部分。

如何提示 Gemma 2

基础模型没有提示格式。像其他基础模型一样,它们可以用于继续输入序列,生成可信的延续,或者用于零样本/少样本推理。Instruct 版本具有非常简单的对话结构

<start_of_turn>user
knock knock<end_of_turn>
<start_of_turn>model
who is there<end_of_turn>
<start_of_turn>user
LaMDA<end_of_turn>
<start_of_turn>model
LaMDA who?<end_of_turn><eos>

必须精确复现这种格式才能有效使用。我们稍后将展示如何使用 transformers 中提供的聊天模板轻松复现 instruct 提示。

演示

您可以在 Hugging Chat 上与 Gemma 27B Instruct 模型聊天!请点击此处链接:https://huggingface.co/chat/models/google/gemma-2-27b-it

使用 Hugging Face Transformers

通过 Transformers release 4.42,您可以使用 Gemma 并利用 Hugging Face 生态系统中的所有工具。要将 Gemma 模型与 transformers 一起使用,请确保使用最新的 transformers 版本

pip install "transformers>=4.42.3" --upgrade

以下代码片段展示了如何使用 gemma-2-9b-ittransformers。它需要大约 18 GB 的内存,这适用于许多消费级 GPU。相同的代码片段也适用于 gemma-2-27b-it,后者需要 56GB 的内存,使其成为生产用例中一个非常有趣的模型。通过以 8 位或 4 位模式加载,可以进一步减少内存消耗。

from transformers import pipeline
import torch

pipe = pipeline(
    "text-generation",
    model="google/gemma-2-9b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

messages = [
    {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
]
outputs = pipe(
    messages,
    max_new_tokens=256,
    do_sample=False,
)
assistant_response = outputs[0]["generated_text"][-1]["content"]
print(assistant_response)

哈喽,各位伙伴!我是一艘谦逊的文字之船,航行在数字海洋中。他们叫我 Gemma,是 Google DeepMind 的杰出作品。我通过大量文本训练而成,学会了像一个真正的海盗一样说话和写作。

问我问题吧,我会尽力回答,是的!🦜📚

我们使用 bfloat16,因为那是指令微调模型的参考精度。在 float16 模式下运行可能在您的硬件上更快,并且在 9B 模型上结果应该相似。但请注意,27B 指令微调模型在使用 float16 时会产生不稳定的输出:您必须为该模型权重使用 bfloat16。

您还可以自动量化模型,以 8 位甚至 4 位模式加载。加载大型 27B 版本的 4 位模型大约需要 18 GB 内存才能运行,使其与许多消费级显卡和 Google Colab 中的 GPU 兼容。这就是您以 4 位模式加载生成管道的方式

pipeline = pipeline(
    "text-generation",
    model=model,
    model_kwargs={
        "torch_dtype": torch.bfloat16,
        "quantization_config": {"load_in_4bit": True}
    },
)

有关如何将模型与 transformers 一起使用的更多详细信息,请查看模型卡

与 Google Cloud 集成

注意:我们目前正在努力将新的容器添加到 GKE 和 Vertex AI 中,以高效运行 Google Gemma 2。容器可用后,我们将立即更新此部分。

使用 🤗 TRL 进行微调

训练 LLM 在技术和计算上都具有挑战性。在本节中,我们将介绍 Hugging Face 生态系统中可用于在消费级 GPU 上高效训练 Gemma 的工具

下面是针对 OpenAssistant 聊天数据集微调 Gemma 的示例命令。我们使用 4 位量化和 QLoRA 来节省内存,以针对所有注意力块的线性层。请注意,与密集 Transformer 不同,不应针对 MLP 层,因为它们是稀疏的并且与 PEFT 相互作用不良。

首先,安装 🤗 TRL 的夜间版本并克隆仓库以访问训练脚本

pip install "transformers>=4.42.3" --upgrade
pip install --upgrade bitsandbytes
pip install --ugprade peft
pip install git+https://github.com/huggingface/trl
git clone https://github.com/huggingface/trl
cd trl

然后你可以运行脚本

# peft tuning; single GPU; https://wandb.ai/costa-huang/huggingface/runs/l1l53cst
python \
    examples/scripts/sft.py \
    --model_name google/gemma-2-27b \
    --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
    --dataset_text_field="text" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --learning_rate 2e-4 \
    --report_to wandb \
    --bf16 \
    --max_seq_length 1024 \
    --lora_r 16 --lora_alpha 32 \
    --lora_target_modules q_proj k_proj v_proj o_proj \
    --load_in_4bit \
    --use_peft \
    --attn_implementation eager \
    --logging_steps=10 \
    --gradient_checkpointing \
    --output_dir models/gemma2

alt_text

如果您有更多 GPU 可用,可以使用 DeepSpeed 和 ZeRO Stage 3 进行训练

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft.py \
    --model_name google/gemma-2-27b \
    --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
    --dataset_text_field="text" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --learning_rate 2e-5 \
    --report_to wandb \
    --bf16 \
    --max_seq_length 1024 \
    --attn_implementation eager \
    --logging_steps=10 \
    --gradient_checkpointing \
    --output_dir models/gemma2

alt_text

与推理端点集成

您可以使用文本生成推理作为后端,将 Gemma 2 部署到 Hugging Face 的 推理端点文本生成推理是 Hugging Face 开发的生产就绪推理容器,可轻松部署大型语言模型。它具有连续批处理、token 流式传输、多 GPU 快速推理的张量并行化以及生产就绪日志记录和跟踪等功能。

要部署 Gemma 2 模型,请转到模型页面并单击部署 -> 推理端点小部件。推理端点支持 OpenAI 兼容的消息 API,它允许您通过简单地更改 URL 从另一个封闭模型切换到开放模型。

from openai import OpenAI

# initialize the client but point it to TGI
client = OpenAI(
    base_url="<ENDPOINT_URL>" + "/v1/",  # replace with your endpoint url
    api_key="<HF_API_TOKEN>",  # replace with your token
)
chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {"role": "user", "content": "Why is open-source software important?"},
    ],
    stream=True,
    max_tokens=500
)

# iterate and print stream
for message in chat_completion:
    print(message.choices[0].delta.content, end="")

额外资源

致谢

如果没有许多社区成员的贡献,包括对 LLM 评估做出贡献的 ClémentineNathan;对文本生成推理提供支持的 Nicolas;将 Gemma 2 集成到 transformers 中的 ArthurSanchitJoaoLysandre;以及使 Gemma 2 在 Hugging Chat 中可用的 NathanVictor,发布这些模型并在生态系统中提供支持和评估是不可能的。

感谢 Google 团队发布 Gemma 2 并将其提供给开源 AI 社区!

社区

注册登录以发表评论