nanoJAXGPT:JAX/Equinox 的教学介绍

社区文章 发布于 2024 年 10 月 23 日

简介

自推出以来,JAX 在机器学习 (ML) 社区中人气显著上升。简单的网络搜索即可发现其庞大的社区支持、各种衍生项目以及围绕 JAX 构建的众多 Python 库。这引出了一个不可避免的问题:什么是 JAX,以及我为什么应该关注它?

嗯,根据官方文档,JAX 是一个用于加速器导向的数组计算的 Python 库……

等一下,我们先暂停一下!如果你真的想了解官方文档中概述的 JAX 简介,你会直接去那里,而不是在这里阅读这篇博客文章。话虽如此,虽然有很多资源可以帮助你使用 JAX 启动你的机器学习项目,但本文不仅仅是赞美 JAX 作为机器学习框架,也不是向初学者介绍如何使用它进行机器学习。我们将卷起袖子,亲自动手,拿一个知名的仓库(Andrej Karpathy 的 nanoGPT),并使用 JAXEquinox 从头到尾重写它。

嗯…Equinox?

是的,如果你还没听说过,Equinox 是一个围绕 JAX 构建的库,旨在尽可能顺利地构建神经网络 (NN)。它的独特之处在于其熟悉的 PyTorch 式语法,使得从 PyTorch 背景转过来的人能够轻松过渡。但不要被它的简单性所迷惑。在底层,Equinox 正在勤奋地将你的模型注册为 JAX PyTree,这是 JAX 中一个强大的数据结构,允许进行复杂的转换和计算。

为了将其置于上下文中,我们将通过一个实际示例来说明这个过程。这是一个代码片段,演示了如何使用 Equinox 定义一个线性层

# Code extracted from https://docs.kidger.site/equinox/all-of-equinox/

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

现在,Equinox 提供了各种预构建的神经网络层,包括我们上面定义的 Linear 层,可以用于构建复杂的架构。Equinox 作为使用 JAX 训练深度学习模型的库的一个显著优势是它能够将任意 Python 对象,更具体地说是激活函数,合并到 PyTree 定义中。它还提供了额外功能,以促进 JAXjax.jitjax.grad 装饰器的使用,因为它们要求所有输入都是数组的 PyTree,通过分别实现过滤转换作为 equinox.filter_jitequinox.filter_grad 装饰器。你可以在这里找到有关 Equinox 中过滤的更多信息。

先决条件

本博客的以下部分假设读者对 JAX 有基本了解。下面,我们整理了一个全面但不穷尽的资源列表,以帮助您入门。

清晰说明

  • 在 PyTorch 中,常规做法是在模块中定义一个 forward 方法,用于在训练阶段的前向传播中执行操作。这种方法也可以在 equinox 模块中采用。但是,通常也会在类的 __call__ 定义中定义前向传播的计算。这提供了一种定义模型前向传播的简便方法,但需要注意的是,任何方法都可以使用,并且没有特殊处理的方法。因此,在接下来的部分中,当我们提到前向传播时,建议读者将注意力集中在相应模块的 __call__ 定义,或开发人员选择用于此目的的任何其他方法。

nanoGPT

nanoGPT 是一个简单快速的仓库,用于训练或微调中等规模的 GPT(生成式预训练 Transformer)。我们将使用 JAX/Equinox 重写这个深度学习仓库。该仓库的内容如图 1 所示,我们重点关注 model.pytrain.py

Description of the image
图 1:nanoGPT 的项目结构


model.py

此文件中概述的模型从 GPT-2 架构中汲取灵感,融入了各种模块以模拟相似的结构。它的设计旨在易于访问和理解,即使对于该领域的新手也是如此。让我们首先在下面概述此模型定义中最重要的模块。

class CausalSelfAttention(nn.Module):
  ...

class MLP(nn.Module):
  ...

class Block(nn.Module):
  ...

class GPT(nn.Module):
  ...

train.py

model.py 文件中定义了模型架构后,此文件中包含一个用于使用 PyTorch 训练模型的训练脚本。您可以在上面链接的原始仓库中查看此文件的内容。由于 JAX 中的训练范式与 PyTorch 中的训练范式大相径庭,因此我们在此不提取和概述此文件的结构。

重写 model.py

nanoGPT 中引入 SwiGLU

在重写 nanoGPT 的过程中,我们力求在最终输出中引入一个独特的元素。为此,我们用 SwiGLU 激活函数取代了 MLP 模块中标准的 GELU 激活函数。SwiGLUGLU 激活函数的一种变体,其显著特点是能够根据特定的训练任务动态调整非线性。对于有兴趣深入了解 SwiGLU 的读者,可以在此处找到更多信息。

SwiGLU 激活函数的数学表达式如下:SwiGLU(x,W,V,b,c,β)=Swishβ(xW+b)(xV+c)SwiGLU(x, W, V, b, c, \beta) = Swish_{\beta}(xW + b) \otimes (xV + c)

其中 W,V,b,cW, V, b, c 都是神经网络中的可训练参数,我们可以按照下面的代码块实现它。让我们一步步分解这段代码

  • 我们首先创建一个 eqx.Module 类的子类,因为此激活函数具有可训练参数,因此我们需要在 PyTree 定义中注册它。
  • 我们使用三个参数 dim_indim_outkey 定义了 __init__ 方法。前两个必须在初始化此模块时定义,我们将根据输入和输出参数的数量推断出适当的值。
  • __call__ 方法实现了 SwiGLU 激活函数的定义。我们将 Swish 激活函数应用于输入的一次变换,并与输入的另一次变换进行逐元素乘法。
import equinox as eqx
import jax
import jax.nn as nn
import jax.numpy as jnp

class SwiGLU(eqx.Module):
    """
    Implementation of the SwiGLU activation function in the paper by Noam Shazeer at Google

    References:
        GLU Variants Improve Transformer paper  : https://arxiv.org/abs/2002.05202
        Aziz et al. Paper Summaries             : https://azizbelaweid.substack.com/p/what-is-swiglu-how-to-implement-it
    """

    W: jax.Array
    V: jax.Array
    b: jax.Array
    c: jax.Array

    def __init__(self, dim_in, dim_out, key):
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.W = jax.random.normal(k1, (dim_in, dim_out))
        self.V = jax.random.normal(k2, (dim_in, dim_out))
        self.b = jax.random.normal(k3, (dim_out,))
        self.c = jax.random.normal(k4, (dim_out,))

    def __call__(self, x):
        return jax.nn.swish(jnp.dot(x, self.W) + self.b) * (jnp.dot(x, self.V) + self.c)
Icon

在大多数即将推出的模块中,您可能会注意到有一个 config 参数。我们将从以下 GPTConfig 定义初始化的 dataclass 对象作为此参数的参数传递。它包含模型架构的预定义配置。

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

MLP 模块

import torch.nn as nn

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

鉴于我们从头开始使用 equinox 构建模块的经验,转换上述 MLP 层应该相对简单。我们概述了此转换的步骤如下

  1. 首先,将此类别从 torch.nn 更改为 equinox 模块。

    class MLP(eqx.Module):
    
  2. 接下来,我们重写 __init__ 方法以在 JAX 中初始化 MLP 层。我们已将 PyTorchnn.Linearnn.Dropout 层替换为它们的 Equinox 等效项,保持参数一致以保留原始行为。我们在 Equinox 版本中初始化 SwiGLU 模块,仔细选择 dim_indim_out 参数以匹配前一个 Linear 层的输出维度和后一个 Linear 层的输入维度,两者均为 4 * config.n_embd

    class MLP(eqx.Module):
        c_fc    : eqx.nn.Linear
        swiglu  : SwiGLU
        c_proj  : eqx.nn.Linear
        dropout : eqx.nn.Dropout
    
        def __init__(self, config, key):
            lkey1, lkey2, skey = jax.random.split(key, 3)
    
            self.c_fc     = eqx.nn.Linear(config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=lkey1)
            self.swiglu   = SwiGLU(4 * config.n_embd, 4 * config.n_embd, skey)
            self.c_proj   = eqx.nn.Linear(4 * config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2)
            self.dropout  = eqx.nn.Dropout(config.dropout)
    
  3. 最后,我们已将前向传播中的激活函数 self.gelu(x) 替换为 self.swiglu(x)。正如您可能已经观察到的,我们在前向传播的某些步骤中使用了转换函数 jax.vmap。当我们以逐层方式剖析整个架构,解释每个模块接收的输入的维度以及在这种情况下 vmap 的必要性时,将进一步阐述这一点。

    然而,目前,让我们继续重写模型中剩余的模块。

    class MLP(eqx.Module):
        c_fc: eqx.nn.Linear
        swiglu: SwiGLU
        c_proj: eqx.nn.Linear
        dropout: eqx.nn.Dropout
    
        def __init__(self, config, key):
            lkey1, lkey2, skey = jax.random.split(key, 3)
    
            self.c_fc = eqx.nn.Linear(config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=lkey1)
            self.swiglu = SwiGLU(4 * config.n_embd, 4 * config.n_embd, skey)
            self.c_proj = eqx.nn.Linear(4 * config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2)
            self.dropout = eqx.nn.Dropout(config.dropout)
    
        def __call__(self, x):
            x = jax.vmap(self.c_fc)(x)
            x = jax.vmap(self.swiglu)(x)
            x = jax.vmap(self.c_proj)(x)
            x = self.dropout(x)
            return x
    

CausalSelfAttention 模块

接下来,模块转换过程应该看起来相当简单,因为它反映了前面 MLP 模块中采取的步骤。但是,我们将重点指出即将到来的模块定义中应用的独特更改。

PyTorch 版本:

# Code extracted from https://github.com/karpathy/nanoGPT/blob/master/model.py

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

Equinox 版本:

class CausalSelfAttention(eqx.Module):
    c_attn: eqx.nn.Linear
    c_proj: eqx.nn.Linear
    attn_dropout: eqx.nn.Dropout
    resid_dropout: eqx.nn.Dropout
    bias: jax.Array = eqx.field(static=True)

    _config: GPTConfig = eqx.field(static=True)

    def __init__(self, config, key):
        assert config.n_embd % config.n_head == 0

        # PRNGKey
        lkey1, lkey2 = jax.random.split(key, 2)

        # key, query, value projections for all heads, but in a batch
        self.c_attn = eqx.nn.Linear(config.n_embd, 3 * config.n_embd, use_bias=config.bias, key=lkey1)
        # output projection
        self.c_proj = eqx.nn.Linear(config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2)
        # regularization
        self.attn_dropout = eqx.nn.Dropout(config.dropout)
        self.resid_dropout = eqx.nn.Dropout(config.dropout)
        self._config = config
        # causal mask to ensure that attention is only applied to the left in the input sequence
        # Has been made a buffer by using lax.stop_gradient whenever it is used.
        # Immutability calls for reshape, plus there is no view for jnp (or numpy) arrays.
        self.bias = jnp.tril(jnp.ones((config.block_size, config.block_size))).reshape(1, 1, config.block_size,
                                                                                       config.block_size)

    def __call__(self, x):
        T, C = jnp.shape(x)  # sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = jnp.split(jax.vmap(self.c_attn)(x), 3, axis=1)
        # Immutability calls for reshape, plus there is no view for jnp (or numpy) arrays.
        k = jnp.swapaxes(k.reshape(T, self._config.n_head, C // self._config.n_head), 0, 1)  # (nh, T, hs)
        q = jnp.swapaxes(q.reshape(T, self._config.n_head, C // self._config.n_head), 0, 1)  # (nh, T, hs)
        v = jnp.swapaxes(v.reshape(T, self._config.n_head, C // self._config.n_head), 0, 1)  # (nh, T, hs)

        # manual implementation of attention
        att = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(jnp.shape(k)[-1])
        # Note: Added the stop_gradient just to be safe, I see no update rule acting on the bias inside the
        # forward pass.
        att = jnp.where(lax.stop_gradient(self.bias[:, :, :T, :T]) == 0, float('-inf'), att)
        att = jax.nn.softmax(att, axis=-1)
        att = self.attn_dropout(att)
        y = jnp.matmul(att, v)  # (nh, T, T) x (nh, T, hs) -> (nh, T, hs)
        # Reshaping with Immutability creates a new copy
        y = jnp.swapaxes(y, 1, 2).reshape(T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(jax.vmap(self.c_proj)(y))
        return y
  • 我们已经重写了这个注意力模块的架构,使其在 __init__ 方法中看起来几乎相同,除了最后几行。
  • 在此模块中,以及随后的几个模块中,我们将 config 参数注册为类字段。这是一个特殊情况,我们正在注册一个不构成 NN 架构中层的字段。在这种情况下,必须使用 eqx.field(static=True) 将其设置为 Equinox 静态字段。
  • 在前向传播中,您会注意到我们将 B, T, C = x.size() 更改为 B, T, C = jnp.size(x)。这是一个重要的区别,它突出了 JAX 的函数式编程风格。在 PyTorch 中,像 x 这样的张量是具有可调用方法的对象,因此您可以直接在 x 上调用 size 方法。但在 JAX 中,数组作为参数传递给 jax.numpy 中的函数。当我们浏览代码时,请留意这种将数组传递给 JAX 函数的函数式模式。
Warning

需要注意的是,虽然 JAX 植根于函数式编程范式,并且通常需要将 JAX 数组作为参数传递给函数,而不是在数组对象上调用方法,但它确实为我们提供了某些作为数组方法的功能。一个典型的例子是 jax.numpy.transpose 函数,除了其在函数式编程中的传统用法外,也可以作为 JAX 数组的方法来调用。

  • 因此,关于 numpy 数组(以及扩展到 jax.numpy 数组)的重点是:它们不附带 view 方法。为了将我们的数组转换为下一步转换所需的形状,我们决定使用方便的 jnp.reshape 函数。
  • 在我们的实现中,我们跳过了闪速注意力部分,直接手动实现了注意力机制。您可能会注意到我们的方法与原始方法之间存在一些相似之处,除了我们正在使用 JAX 的函数式 API 之外。
    • 一个关键区别在于我们使用 jnp.matmul 函数执行矩阵乘法,取代了 @ 运算符。
    • 另一个需要注意的是,jnp.transpose 的工作方式与 torch.transpose 略有不同。在 JAX 中,您会希望使用 jnp.swapaxes 函数来实现与 PyTorch 相同的结果。

块模块

让我们仔细看看块模块,它是 Transformer 架构的关键组成部分。您会发现它使用了我们之前定义的大部分模块。需要注意的是,在原始的 PyTorch 版本中,nanoGPT 的作者为 LayerNorm 层中的 bias 参数传入了一个参数。如果您是 PyTorch 老手(或者只是查阅了文档),您可能会发现内置的 LayerNorm 模块实际上没有这个参数!作者从头开始实现了他们自己的自定义 LayerNorm,以支持这个可选的偏置功能。然而,在我们使用 Equinox 库重写时,内置的 LayerNorm 模块默认方便地包含了 bias 参数,因此我们可以直接使用它,而无需自定义实现。

PyTorch 版本:

# Code extracted from https://github.com/karpathy/nanoGPT/blob/master/model.py

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

Equinox 版本:

class Block(eqx.Module):
    ln_1: eqx.nn.LayerNorm
    attn: CausalSelfAttention
    ln_2: eqx.nn.LayerNorm
    mlp: MLP

    def __init__(self, config, key):
        ckey, mkey = jax.random.split(key, 2)

        self.ln_1 = eqx.nn.LayerNorm(config.n_embd, use_bias=config.bias)
        self.attn = CausalSelfAttention(config, ckey)
        self.ln_2 = eqx.nn.LayerNorm(config.n_embd, use_bias=config.bias)
        self.mlp = MLP(config, mkey)

    def __call__(self, x):
        x = x + self.attn(jax.vmap(self.ln_1)(x))
        x = x + self.mlp(jax.vmap(self.ln_2)(x))
        return x

GPT 模块

我们现在已经到达模型结构的最顶层。这个模块的原始版本有很多方法,不仅仅是构造函数(__init__)和 __call__ 方法。但是,我们为了简化并专注于我们决定在代码中实现的 JAXEquinox 部分,删减了大部分这些方法。

PyTorch 版本:

# Code extracted from https://github.com/karpathy/nanoGPT/blob/master/model.py

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

       return idx

Equinox 版本:

原始代码库将 Transformer 层定义为模块字典(来自 PyTorchModuleDict)。然而,由于在 Equinox 中,我们必须在构造函数之前将类的层定义为字段,因此无法按照原始结构组织代码。因此,为了简洁起见,我们将 Transformer 层提取到它自己的模块中,并将其命名为 TransformerLayer

TransformerLayer 模块
class TransformerLayer(eqx.Module):
    _config: GPTConfig = eqx.field(static=True)

    wte: eqx.nn.Embedding
    wpe: eqx.nn.Embedding
    drop: eqx.nn.Dropout
    h: list
    ln_f: eqx.nn.LayerNorm

    def __init__(self, config, key):
        ekey, pkey, hkey, fkey = jax.random.split(key, 4)

        assert config.vocab_size is not None
        assert config.block_size is not None
        self._config = config

        self.wte = eqx.nn.Embedding(config.vocab_size, config.n_embd, key=ekey)
        self.wpe = eqx.nn.Embedding(config.block_size, config.n_embd, key=pkey)
        self.drop = eqx.nn.Dropout(config.dropout)
        self.h = [Block(config, hkey) for _ in range(config.n_layer)]
        self.ln_f = eqx.nn.LayerNorm(config.n_embd, use_bias=config.bias)

    def __call__(self, idx):
        t, = idx.shape
        assert t <= self._config.block_size, f"Cannot forward sequence of length {t}, block size is only {self._config.block_size}"
        pos = jnp.arange(0, t, dtype=jnp.int64)

        tok_emb = jax.vmap(self.wte)(idx)  # token embeddings of shape (t, n_embd)
        pos_emb = jax.vmap(self.wpe)(pos)  # position embeddings of shape (t, n_embd)
        x = self.drop(tok_emb + pos_emb)
        for block in self.h:
            x = block(x)
        x = jax.vmap(self.ln_f)(x)

        return x

我们希望读者注意这样一个事实:在前向传播的第一行,我们只能从输入中解包 token 维度的长度。这与 PyTorch 实现形成对比,在 PyTorch 实现中,批处理维度也已获取。这里的区别在于,我们不会处理一批输入,而是处理包含一系列 token 的单个输入。别担心!!! 当我们构建训练循环时,将在批处理维度上应用向量化映射,这将在以后变得清晰。

由于 Transformer 层位于单独的模块中,GPT 模块变得尽可能简单。我们将在下面向您展示 GPT 模块最精简的版本。

GPT 模块
class GPT(eqx.Module):
    _config: GPTConfig = eqx.field(static=True)

    transformer: TransformerLayer
    lm_head: eqx.nn.Linear

    def __init__(self, config, key):
        tkey, lmhkey = jax.random.split(key, 2)

        assert config.vocab_size is not None
        assert config.block_size is not None
        self._config = config

        self.transformer = TransformerLayer(config, tkey)

        self.lm_head = eqx.nn.Linear(config.n_embd, config.vocab_size, use_bias=False, key=lmhkey)

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(self, eqx.is_array)))
        if non_embedding:
            n_params -= sum(self.transformer.wpe.weight.shape)
        return n_params
    
    ## CODE STRIPPED FOR DEMONSTRATION
    
    def __call__(self, idx, train_mode=False):
        x = self.transformer(idx)

        if train_mode:
            logits = jax.vmap(self.lm_head)(x)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = jax.vmap(self.lm_head)(x[[-1], :])  # note: using list [-1] to preserve the time dim

        return logits

在我们的 GPT 模块的前向传播中,您可能会发现我们没有像 PyTorch 实现那样设计该方法以采用可选的 target 参数。在我们的版本中,我们在训练循环中计算损失。稍后会详细介绍。但是,在这种情况下,我们接受一个参数来确定调用前向传播的模式:训练模式或推理模式。因此,我们可以在推理时实现适当的逻辑,如原始仓库中所示。

现在,我们有必要向读者展示我们如何实现原始 GPT 模块中的其余逻辑。我们分情况处理此任务,为每个方法划分了不同的部分。对于每个方法,我们在此处也采用自底向上的方法,通过展示所有依赖项的实现并逐步向上。

我们首先在项目中定义一个 helper 包,以添加一些功能组件,帮助我们更快地实现 GPT 模块中的某些逻辑,更重要的是:抽象逻辑使其更接近 PyTorch。我们在 helper 模块中定义了两个单独的模块,如下所示

.
└── helpers/
    ├── eqx.py
    └── init.py
init.py
def normal_(array: jax.Array, mean: float, std: float, key: jax.random.PRNGKey = jax.random.PRNGKey(0)) -> None:
    new_array = jax.random.normal(key, array.shape) * std + mean
    return new_array


def zeros_(array: jax.Array) -> None:
    new_array = jax.numpy.zeros(array.shape)
    return new_array

虽然第二种方法本身就解释了其意图,但我们解释了第一种函数的意图。它旨在初始化一个输入 JAX 数组,使其具有给定标准差和均值的正态分布。这将在初始化 GPT 模块时派上用场。

eqx.py
def named_parameters(model: eqx.Module):
    out = []

    for path, p in jax.tree_util.tree_flatten_with_path(eqx.filter(model, eqx.is_array))[0]:
        pn = ''

        for index in range(len(path)):
            if isinstance(path[index], str):  # Check if path[index] is a string
                pn += '.' + path[index]
            else:
                pn += str(path[index])

        out.append((pn[1:], p))
    
    return out


def find_sub_tree(model: eqx.Module, sub_tree_name: str, filter_fn: Callable = None):
    out = []
    for path, p in jax.tree_util.tree_flatten_with_path(model, is_leaf=filter_fn)[0]:
        pn = ''
    
        for index in range(len(path)):
            if isinstance(path[index], jax._src.tree_util.DictKey):
                pn += '.' + path[index].key
            else:
                pn += str(path[index])
    
        if filter_fn:
            if filter_fn(p) and pn.endswith(sub_tree_name):
                out.append(p)
        elif pn.endswith(sub_tree_name):
            out.append(p)
    
    return out

在此模块中,第一个函数旨在复制 torch.Module 类中可用作方法的功能(请阅读此处)。它将任何 Equinox 模块作为参数,并返回一个元组列表,每个元组包含一个表示模型中参数路径的字符串和参数本身。

第二个函数可用于查找其全名以给定字符串结尾的参数。我们将在接下来的几节中看到这些函数如何派上用场。

回到 GPT 模块,重点关注 _init_weights 方法,您可能会注意到在 PyTorch 版本中,此方法用作 LinearEmbedding 层的自定义初始化器。如果您仔细查看构造函数,您还会发现此方法应用于模型后,还有另一段自定义初始化器逻辑。这段逻辑专门用于残差投影权重(c_proj.weight)。在我们的实现中,我们将所有这些初始化器逻辑组合到一个函数中,如下所示。

_init_weights GPT 方法
@staticmethod
def _init_weights(model: eqx.Module, config: GPTConfig, key: jax.random.PRNGKey):
    def init_layer(model, is_layer: Callable, mean: float, std: float):
        get_weights = lambda m: [x.weight
                                  for x in jax.tree_util.tree_leaves(m, is_leaf=is_layer)
                                  if is_layer(x)]
        weights = get_weights(model)

        new_weights = [init.normal_(weight, mean=mean, std=std, key=subkey)
                        for weight, subkey in zip(weights, jax.random.split(key, len(weights)))]

        return eqx.tree_at(get_weights, model, new_weights)

    def init_linear(model):
        is_linear = lambda x: isinstance(x, eqx.nn.Linear)

        model = init_layer(model, is_linear, mean=0.0, std=0.2)

        get_biases = lambda m: [x.bias
                                for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                                if is_linear(x) and x.bias is not None]
        biases = get_biases(model)

        new_biases = [init.zeros_(bias) for bias in biases]

        return eqx.tree_at(get_biases, model, new_biases)

    def init_embedding(model):
        is_embedding = lambda x: isinstance(x, eqx.nn.Embedding)

        return init_layer(model, is_embedding, mean=0.0, std=0.2)

    def init_c_proj_weights_with_normal(model):
        get_c_proj_weights = lambda m: eqx_helper.find_sub_tree(m, "c_proj.weight")

        old_weights = get_c_proj_weights(model)
        new_weights = [init.normal_(weight, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer), key=subkey)
                        for weight, subkey in zip(old_weights, jax.random.split(key, len(old_weights)))]

        return eqx.tree_at(get_c_proj_weights, model, new_weights)

    initialized_model = init_linear(model)
    initialized_model = init_embedding(initialized_model)
    # apply special scaled init to the residual projections, per GPT-2 paper
    initialized_model = init_c_proj_weights_with_normal(initialized_model)

    return initialized_model

我知道!你可能会好奇几行 PyTorch 代码是如何变成这样的。我向你保证,当我们把代码分解成更小的块来解释时,这会听起来很简单。但是,在继续之前,我们提醒读者 JAX 数组的不可变性。因此,对模型的任何更新都不能在其中完成,而应该作为新的 PyTree 返回。

init_layer 此函数被编写为一个抽象,允许初始化通过 is_layer 可调用函数过滤的任何层。它将使用由指定均值和标准差定义的正态分布中采样的值初始化输入模型中与过滤器匹配的层。

Icon

此代码只不过是 Equinox 文档中关于 自定义参数初始化 的代码的简单抽象(请阅读此处)。鼓励读者参考我们也在先决条件部分列出的此文档。

init_linear 在这里,我们简单地调用 init_layer 函数,并使用过滤器来识别模型中的 Linear 层,然后对返回的模型进行额外的初始化,将 Linear 层的偏差置为零。

init_embeddinginit_linear 函数非常相似。

init_c_proj_weights_with_normal 实现其名称所示的功能。c_proj.weights 用自定义正态分布初始化。

我们调用这些定义的函数并返回新的更新模型。然而,您可能已经注意到,尽管我们已经在 GPT 模块中定义了 _init_weights 方法,但它并未在构造函数中调用,因此在以传统方式创建实例时,它不会对模型进行必要的更新。为了实现这一点,我们创建了一个额外的静态方法,用于创建具有这些更新权重的 GPT 实例。

@staticmethod
def create_instance(config, key):
    key1, key2 = jax.random.split(key, 2)

    inst = GPT(config, key1)
    new_inst = GPT._init_weights(inst, config, key2)

    return new_inst
Warning

我们避免使用 _init_weight 来创建更新的实例,而是简单地替换 self 对象。相反,我们返回一个包含更新权重的新实例。

要创建一个新的 GPT 实例,我们只需调用 GPT.create_instance,而不是简单地调用 GPT。至此,model.py 文件中的最后一个方法也已实现。现在,我们将转向 train.py 文件,在那里我们将展示如何使用此模型从头开始预训练语言模型。

但首先,让我们尝试在下一节中理解 JAX 中的向量化映射如何工作。这个概念对于读者理解即将到来的章节中训练循环是如何构建的至关重要。


理解向量化映射 (vmap) 流程

在本博客的这一部分,我们打算分解输入数据的流向,以理解 vmap 在每个模块中如何从上到下工作。我们将使用一个松散引用的数学符号来简化事物。

输入到模型的将是批次(ℬ)的 token(𝒯),表示将用于预训练模型的文本。

Icon

此预训练数据可以是您选择的数据集,您可以按照 data 文件夹中的 prepare.py 脚本将它们构造成适合我们训练范式的数据。

因此输入将是一个形状为以下内容的 jnp 数组:

ℬ × 𝒯

由于我们将在训练脚本中将此输入传递给模型,我们将对第 0 维使用 vmap 转换。

jax.vmap(model, in_axes=(0, None))(x, True)

在上面的代码片段中,请记住我们必须为传递给 vmap 函数的每个参数定义批处理维度。因此,对于参数 x,我们分别指定第 0 维,对于第二个参数 True,我们指定 None 作为批处理维度。

现在,从高层次来看,GPT 模块的前向方法只接收一个 token 流(𝒯),并且批处理作为一系列独立函数并行执行。

然后,我们通过 self.transformer(idx) 将此 𝒯 传递给 Transformer。

Transformer 中的前两个 Embedding 层将接收一个标量值并将其转换为给定大小的嵌入向量。然而,我们正在尝试嵌入一个 token 流 𝒯,以获取与我们初始输入对应的嵌入 token 列表。因此,我们需要在第 0 维上对 idx 进行批处理,以便 Embedding 层将使用 𝒯 中的单个标量值进行调用。生成的数组将是 𝒯 × ℰ 大小,其中 ℰ 是嵌入维度的数量

位置嵌入也是如此。结果数组将通过 Block 模块。

Block 的前向传播中,需要对每个 token 的嵌入向量进行层归一化。也就是说,在这种情况下,token 维度充当批处理。我们对第 0 维应用 vmap。返回的数组与输入相同。

读者现在应该具备足够的经验来剖析 vmap 过程。因此,我们将其余的 vmap 探索留给读者作为练习。


重写 train.py

现在我们已经完成了模型的构建,可以开始编写训练脚本了。我们将重点关注导致训练过程的主要代码段,其余部分将不言自明。

get_batch

此函数将使用从数据文件夹中执行相关数据集脚本后获得的训练/验证集的预处理 bin 文件。在我们的实验中,我们为 tinystories 数据集执行了 prepare.py 文件。

在下面的函数中,我们随机检索指定大小的批量数据,其格式适合于预训练大型语言模型。

Icon

请注意,在此训练练习中,原始仓库打算使用 600,000 个批次来训练模型,这与通常的 epoch 约定不同。

def get_batch(split: str):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'validation.bin'), dtype=np.uint16, mode='r')

    ix = np.random.randint(len(data) - block_size, size=(batch_size,))
    x = jnp.stack([jnp.array(data[i:i + block_size], dtype=jnp.int64) for i in ix])
    y = jnp.stack([jnp.array(data[i + 1:i + 1 + block_size], dtype=jnp.int64) for i in ix])

    return x, y

convert_model_to_dtype

此函数用于将我们的模型(一个 PyTree)转换为指定的数据类型。请注意,我们使用的是全局定义的数据类型,并且也简单地覆盖了全局模型。我们在模型初始化为三种起始状态之一(从头开始、恢复或从 gpt-2)后调用此函数。

def convert_model_to_dtype():
    global model
    def convert_pytree_to_dtype(pytree, dtype):
        def _convert(leaf):
            if eqx.is_array(leaf):
                return leaf.astype(dtype)
            else:
                return leaf
    
        return jax.tree_util.tree_map(_convert, pytree)
    
    
    if dtype == 'bfloat16':
        model = convert_pytree_to_dtype(model, jnp.bfloat16)
    elif dtype == 'float16':
        model = convert_pytree_to_dtype(model, jnp.float16)
    elif dtype == 'float32':
        model = convert_pytree_to_dtype(model, jnp.float32)

lr_scheduler

我们定义了一个简单的余弦衰减学习率调度器,如下所示。decay_steps 的定义是为了当训练脚本旨在恢复训练过程时,调度器能了解剩余的步数以进行学习率衰减。

Warning

这种恢复调度器的方法在深度学习实践中并非最理想或最标准。然而,我们之所以采用如此基本粗略的逻辑,是因为我们在保存优化器状态(即学习率调度器)时遇到了一个未解决的错误。如果有一位好奇的读者能提供保存和恢复 Equinox 模型优化器状态的解决方案,我们将不胜感激。

lr_scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=learning_rate,
    warmup_steps=warmup_iters if init_from == 'scratch' else 0,
    decay_steps=lr_decay_iters - iter_num,
    end_value=min_lr,
)

optimizer

我们使用 optax 定义了一个简单的 AdamW 优化器。我们还使用了 optax 包装器 inject_hyperparameters,以便我们能够访问根据调度器更新的当前学习率。

optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)

compute_loss

如果您还记得,在定义 GPT 模块的前向传播时,我们提到将在训练循环中计算损失。此损失计算被定义为一个函数,如所示。此函数使用 eqx.filter_jit 转换进行 JIT 编译,因为我们正在将 Equinox 模型传递给它。

@eqx.filter_jit
def compute_loss(model, x, y):
    logits = jax.vmap(model, in_axes=(0, None))(x, True)

    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, # B, T, C
        labels=y, # B, T
    )

    return jnp.mean(loss)

make_step

这是训练循环中每次迭代调用的顶层函数。此函数执行一系列关键步骤来训练模型。我们将尝试逐行分解它。

@eqx.filter_jit
def make_step(
        model,
        optimizer_state,
        x,
        y
):
    loss, grads = eqx.filter_value_and_grad(compute_loss)(model, x, y)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
    model = eqx.apply_updates(model, updates)
    return model, optimizer_state, loss

第 1 行

我们之前编写的 compute_loss 函数首先通过 filter_value_and_grad 函数进行转换,该函数将为我们计算损失和梯度。在这里,我们方便地在一行中执行了前向和后向传播。

Icon

eqx.filter_value_and_grad 函数是 Equinoxjax.value_and_grad 转换的实现,以考虑模型中存在的非 JAX 数组。

第 2 行

利用计算出的梯度,我们用当前的优化器状态计算模型所需的更新。

第 3 行

计算出的更新现在应用于模型。这是实际采取的步骤,旨在减少受参数影响的模型损失。

第 4 行

更新后的模型、优化器状态以及步进前的损失被返回,以便从训练循环中访问。

estimate_loss

此函数用于根据训练设置确定的固定间隔计算训练和评估损失,并在训练循环中执行。

def estimate_loss(model):
    out = {}
    model = eqx.nn.inference_mode(model)
    for split in ['train', 'val']:
        losses = jnp.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            loss = compute_loss(model, jax.lax.stop_gradient(x), y)
            losses = losses.at[k].set(loss.item())
        out[split] = jnp.mean(losses)
    return out

训练循环

现在我们向您展示代码中实现的最简洁的训练循环版本。在初始化优化器状态后,我们每迭代一步。该循环也适用于恢复阶段。您可以查看我们项目中使用的日志步骤以获取更多视角。

optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))

for local_iter_num in range(iter_num, max_iters):
    x, y = get_batch("train")
    
    # do a training step
    model, optimizer_state, loss = make_step(model, optimizer_state, x, y)

保存模型

我们使用以下逻辑来保存模型参数和训练配置。我们再次鼓励读者参考我们的存储库,以获取此逻辑的完整实现。

checkpoint_params = {
    "model_args": gptconf,
    "iter_num": local_iter_num,
    "val_loss": losses["val"],
    "learning_rate": lr,
    "config": config,
}

checkpoint_file = os.path.join(out_dir, 'model.eqx')
checkpoint_params_file = os.path.join(out_dir, 'params.pkl')

eqx.tree_serialise_leaves(checkpoint_file, model)

with open(checkpoint_params_file, "wb") as f:
    cloudpickle.dump(checkpoint_params, f)

结论

如果您已经读到这里,恭喜您对探索 JAXEquinox 的执着!在这篇博文中,我们采取了一种独特的方法来学习这些强大的框架,通过逐步重写著名的 nanoGPT 存储库。

在此过程中,我们遇到并克服了 JAX 的不可变特性和 PyTree 定义所特有的几个挑战。从重新构想模型架构到调整训练循环,每一步都帮助我们学习了如何有效利用 JAXEquinox 来完成复杂的深度学习任务。我们看到了如何:

  1. 实现自定义初始化。
  2. 将模型参数作为 PyTree 处理。
  3. 使用 Equinox 的过滤转换,例如 equinox.filter_jitequinox.filter_grad,来处理模型中的非数组对象。

我们探索了 JAX 的转换,特别是 vmap,以创建高效、并行化的代码,用于处理模型各层中的批量输入。Equinox 能够与 JAX 无缝集成,同时提供类似 PyTorch 的熟悉界面来构建神经网络,这被证明是无价的。值得注意的是,Equinox 的过滤转换对于将 JAX 强大的 JIT 编译和自动微分应用于我们的模型至关重要,正如我们在 compute_lossmake_step 函数中看到的那样。

这次重写不仅是一个学习练习,还展示了 JAXEquinox 在处理复杂深度学习模型方面的灵活性和强大功能。通过这个例子,我们希望您对这些框架有了更深入的理解,并对将它们应用于自己的项目更有信心。

总结时,请记住这仅仅是个开始。机器学习领域不断发展,而 JAXEquinox 等框架只是永无止境旅程中的一个停靠站。我们鼓励您继续探索、实验,并突破这些工具及其他工具所能达到的极限。

对于那些有兴趣深入研究的人,此项目的整个代码库都是开源的,可以在 https://github.com/surgeglobal/nanoJAXGPT 找到。我们希望此资源能成为您探索 JAXEquinox 的跳板。愿您的机器学习之旅充满激动人心的发现和开创性的创新!

repo-icon

surgeglobal/nanoJAXGPT

由 Surge Global 创建 • 2024 年 6 月 6 日更新

致谢

  • 感谢 Andrej Karpathy 精心设计的 nanoGPT 存储库,它帮助我们理解了 GPT 架构并贡献了其项目的 JAX/Equinox 版本。
  • 我们还要感谢 Anh Tong,他的 EquinoxnanoGPT 是我们独特重写的灵感来源。我们也推荐参考他的 nanoGPT 版本:https://github.com/anh-tong/nanoGPT-equinox
  • 感谢 JAX 团队提供的出色框架。
  • 感谢 Equinox 团队,让 JAX 用起来感觉像 PyTorch。
  • 感谢 Modal 团队为使无服务器 GPU 使用变得可访问且价格合理所做的努力。最重要的是,为您的每个工作区提供 30 美元的免费积分。
  • 本博客文章由 Icons8 提供的免费图标支持。

社区

注册登录 发表评论