使用 Flax 和 Optax 轻松实现 JAX 训练循环
在我之前的博客文章中,我以一种非典型的方式讨论了 JAX——一个用于高性能数值计算和机器学习的框架。**我没有创建一个训练循环**,只展示了几个看起来模糊地像机器学习的模式。如果你还没有读过那篇博客文章,可以在这里阅读。
这种方法是故意的,因为我觉得 JAX——尽管是为机器学习研究设计的——其通用性远超于此。使用它的步骤是:定义你想要发生的事情,用`jax.jit`包裹它,让 JAX 将你的函数追踪成一个中间图表示,然后将其传递给 XLA 进行编译和优化。结果是一个单一的、高度优化的二进制 blob,随时准备接收你的数据。这种方法自然适用于许多机器学习应用,以及其他科学计算任务。因此,只针对机器学习没有意义。而且,这方面已经有广泛的覆盖——我想对 JAX 的入门教程采取不同的视角。
在上一篇文章中,我提到完全有可能在纯 JAX 中开发一个完整的机器学习训练循环——包括模型、优化器等所有内容。这不言而喻,因为 JAX 是通用目的的。这是一个很好的练习,但不是我喜欢采用的策略。在这篇博客文章中,我想介绍两个建立在 JAX 之上的高级库,它们在编写机器学习应用程序时为我们完成了大量的繁重工作。这些库就是 **Flax** 和 **Optax**。
库总结如下
- **JAX**——提供一个**高级神经网络 API**,让开发者可以像在 PyTorch 中一样,以组件的方式而不是以需要参数作为输入的 JAX 函数来思考模型。
- **Optax** —— 一个包含大量模型训练实用工具的库,例如**优化器、损失函数、学习率调度器**等等!非常地开箱即用。
在这篇文章的末尾,我们将实现并训练一个非常简单的**类别条件图像生成模型**,称为**变分自编码器**(VAE),用于生成 MNIST 数字。
使用 Flax 的神经网络 API
纯 JAX 中训练循环的高级结构大致如下所示
dataset = ... # initialise training dataset that we can iterate over
params = ... # initialise trainable parameters of our model
epochs = ...
def model_forward(params, batch):
... # perform a forward pass of our model on `batch` using `params`
return outputs
def loss_fn(params, batch):
model_output = model_forward(params, batch)
loss = ... # compute a loss based on `batch` and `model_output`
return loss
@jax.jit
def train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
grads = ... # transform `grads` (clipping, multiply by learning rate, etc.)
params = ... # update `params` using `grads` (such as via SGD)
return params, loss
for _ in range(epochs):
for batch in dataset:
params, loss = train_step(params, batch)
... # report metrics like loss, accuracy, and the like.
我们以函数式方式定义模型:一个以模型参数和批次作为输入,并返回模型输出的函数。类似地,我们定义损失函数,它也以参数和批次作为输入,但返回损失。
我们的最终函数是训练步骤本身,我们将其封装在 `jax.jit` 中——这为 XLA 提供了最大的上下文来编译和优化训练步骤。它首先使用函数转换 `jax.value_and_grad` 计算损失函数的梯度,然后操作返回的梯度(可能通过学习率进行缩放),并更新参数。我们返回新的参数,并在下一次调用 `train_step` 时使用它们。这在循环中进行调用,在每个训练步骤之前从数据集中获取新的批次。
大多数机器学习程序都遵循上述模式。但在像 PyTorch 这样的框架中,我们将模型前向传播和模型参数的管理封装到一个表示模型的状态对象中——从而简化了训练循环。如果我们可以在无状态的 JAX 中模仿这种行为,允许开发者以基于类的方式思考模型,那就太好了。这正是 Flax 的神经网络 API —— `flax.linen` —— 旨在实现的目标。
以纯粹无状态的函数式方式编写模型是否优于有状态的类式方式,这不是本篇博客文章的主题。两者都有其优点。**无论如何,在执行过程中,无论我们是否使用 Flax,最终结果都是相同的。我们得到一个无状态的、高度优化的二进制大对象,我们可以将数据抛给它。**毕竟,一切都是 JAX。
在 Flax 中定义模块主要有两种方式:一种是 PyTorch 风格,另一种是紧凑表示。
import flax.linen as nn
from typing import Callable
class Model(nn.Module):
dim: int
activation_fn: Callable = nn.relu
def setup(self):
self.layer = nn.Dense(self.dim)
def __call__(self, x):
x = self.layer(x)
return self.activation_fn(x)
class ModelCompact(nn.Module):
dim: int
activation_fn: Callable = nn.relu
@nn.compact
def __call__(self, x):
x = nn.Dense(self.dim)(x)
return self.activation_fn(x)
如果我们的初始化逻辑很复杂,前者可能更合适。反之,如果模块相对简单,我们可以利用 `nn.compact` 表示,仅通过前向传播就能自动定义模块。
与其他框架一样,我们可以将模块嵌套在彼此内部以实现复杂的模型行为。正如我们已经看到的,`flax.linen` 提供了一些预构建的模块,如 `nn.Dense`(与 PyTorch 的 `nn.Linear` 相同)。我不会一一列举所有模块,但卷积、嵌入等常用模块都已具备。
如果您正在将 PyTorch 模型移植到 Flax,请记住默认的权重初始化可能不同。例如,在 PyTorch 中,默认的偏置初始化是 LeCun normal,但在 Flax 中,它被初始化为零。
然而,目前我们无法调用此模型,即使我们初始化了类本身。根本没有参数可用。此外,模块永远不是参数的容器。**Flax 模块的实例只是一个空壳,它松散地将操作与参数和稍后作为输入传递的输入关联起来。**
为了理解我的意思,让我们为模型初始化一些参数
key = jax.random.PRNGKey(0xffff)
key, model_key = jax.random.split(key)
model = Model(dim=4)
params = model.init(model_key, jnp.zeros((1, 8)))
params
===
Out:
FrozenDict({
params: {
layer: {
kernel: Array([[-0.05412389, -0.28172645, -0.07438638, 0.5238516 ],
[-0.13562573, -0.17592733, 0.45305118, -0.0650041 ],
[ 0.25177842, 0.13981569, -0.41496065, -0.15681015],
[ 0.13783392, -0.6254694 , -0.09966562, -0.04283331],
[ 0.48194656, 0.07596914, 0.0429794 , -0.2127948 ],
[-0.6694777 , 0.15849823, -0.4057232 , 0.26767966],
[ 0.22948688, 0.00706845, 0.0145666 , -0.1280596 ],
[ 0.62309605, 0.12575962, -0.05112049, -0.316764 ]], dtype=float32),
bias: Array([0., 0., 0., 0.], dtype=float32),
},
},
})
在上面的单元格中,我们首先初始化了模型类,它返回了一个 `Model` 实例,并将其赋值给变量 `model`。正如我所说,它不包含任何参数,它只是一个我们传入参数和输入的空壳。我们可以通过打印 `model` 变量本身来看到这一点
model
===
Out: Model(
# attributes
dim = 4
activation_fn = relu
)
我们也可以直接调用模块本身,即使我们已经定义了 `__call__` 方法,但它会失败: ```python model(jnp.zeros((1, 8)))
输出:/usr/local/lib/python3.10/dist-packages/flax/linen/module.py 中的 getattr(self, name) 935 msg += (f' 如果在 ".setup()" 中定义了 "{name}",请记住这些字段 936 只能在 "init" 或 "apply" 内部访问。') --> 937 raise AttributeError(msg) 938 939 def dir(self) -> List[str]
AttributeError: "Model" 对象没有属性 "layer"。如果 "layer" 在 '.setup()' 中定义,请记住这些字段只能在 'init' 或 'apply' 内部访问。
To initialise the parameters, we passed a PRNG key and some dummy inputs to the
model's `init` function of the same shape and dtype as the inputs we will use
later. In this simple case, we just pass `x` as in the original module's
`__call__` definition, but could be multiple arrays, PyTrees, or PRNG keys. We
need the input shapes and dtypes in order to determine the shape and dtype of
the model parameters.
From the `model.init` call, we get a nested `FrozenDict` holding our model's
parameters. If you have seen PyTorch state dictionaries, the format of the
parameters is similar: nested dictionaries with meaningful named keys, with
parameter arrays as values. If you've read my previous blog post or read about
JAX before, you will know that this structure is a PyTree. Not only does Flax
help developers loosely associate parameters and operations, **it also helps
initialise model parameters based on the model definition**.
With the parameters, we can call the model using `model.apply` – providing the
parameters and inputs:
```python
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, (1, 8))
y = model.apply(params, x)
y
===
Out: Array([[0.9296505 , 0.25998798, 0.01101626, 0. ]], dtype=float32)
`model.init` 返回的 PyTree 没有特别之处——它只是一个存储模型参数的常规 PyTree。`params` 可以与任何包含 `model` 期望参数的 PyTree 互换: ```python zero_params = jax.tree_map(jnp.zeros_like, params) # 生成一个与 `params` 结构相同,所有值都设置为 0 的 PyTree。 print(zero_params) model.apply(zero_params, x)
输出:FrozenDict({ params: { layer: { bias: Array([0., 0., 0., 0.], dtype=float32), kernel: Array([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32), }, }, })
数组([[0., 0., 0., 0.]], dtype=float32)
Forcing model calls to require explicitly passing parameters keeps it
stateless and returning parameters like any other PyTree, makes Flax
interoperable with JAX functions – as well as other libraries built on JAX.
**Essentially, by using Flax we aren't forced to use any other specific
frameworks and have access to all regular JAX features.**
If you are used to frameworks like PyTorch, calling models like this feels
unnatural at first. However, I personally quite like it this way – it feels
rather elegant to pass different parameters to the model to get different
behaviour rather than "load" the weights. A bit subjective and fuzzy, I know,
but I like it.
> To summarise the difference, if we aim to implement $f_\theta(x)$, a PyTorch
module is basically $f_\theta$ (which we can call on $x$). A Flax module is
simply $f$, which needs to be provided parameters $\theta$ before it can be
called on $x$ – or alternatively, we call $f$ on $(\theta, x)$.
All in all, the point of Flax is to **provide a familiar stateful API for
development** whilst **preserving JAX statelessness during runtime**. We can
build our neural network modules in terms of classes and objects, but **the
final result is a stateless function `model.apply` that takes in our inputs and
a PyTree of parameters.**
This is identical behaviour to what we began with (recall our `model_forward`
function at the start of this section), just now tied up nicely together.
Therefore, our function containing `model.apply` that takes as input our
PyTree, can be safely jit-compiled. The result is the same, a heavily-optimised
binary blob we bombard with data. Nothing changes during runtime, it just makes
development easier for those who prefer reasoning about neural networks in a
class-based way whilst remaining interoperable with, and keeping the
performance of JAX.
There's a lot more to Flax than this, especially outside the `flax.linen`
neural network API. For now though, we will move on to developing a full
training loop using Flax and **Optax**. We will swing back around to some extra
Flax points later, but I feel some concepts are hard to explain without first
showing a training loop.
## A full training loop with Optax and Flax
We've shown how to reduce the complexity of writing model code and parameter
initialisation. We can push this further by relying on Optax to handle the
gradient manipulation and parameter updates in `train_step`. For simple
optimisers, these steps can be quite simple. However, for more complex
optimisers or gradient transformation behaviour, it can get quite complex to
implement in JAX alone. Optax packages this complex behaviour into a simple
API.
```python
import optax
optimiser = optax.sgd(learning_rate=1e-3)
optimiser
===
Out: GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fa7185503a0>, update=<function chain.<locals>.update_fn at 0x7fa718550550>)
并不漂亮,但我们可以看到优化器只是一个**梯度变换**——实际上,Optax 中的所有优化器都实现为梯度变换。梯度变换被定义为一对函数 `init` 和 `update`,它们都是纯函数。与 Flax 模型一样,Optax 优化器内部不保留状态,在使用前必须进行初始化,并且开发人员必须将任何状态传递给 `update`: ```python optimiser_state = optimiser.init(params) optimiser_state
输出:(EmptyState(), EmptyState())
Of course, as SGD is a stateless optimiser, the initialisation call simply
returns an empty state. It must return this to maintain the API of a gradient
transformation. Let's try with a more complex optimiser like Adam:
```python
optimiser = optax.adam(learning_rate=1e-3)
optimiser_state = optimiser.init(params)
optimiser_state
===
Out: (ScaleByAdamState(count=Array(0, dtype=int32), mu=FrozenDict({
params: {
layer: {
bias: Array([0., 0., 0., 0.], dtype=float32),
kernel: Array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
},
},
}), nu=FrozenDict({
params: {
layer: {
bias: Array([0., 0., 0., 0.], dtype=float32),
kernel: Array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
},
},
})),
EmptyState())
这里,我们可以看到 Adam 优化器的一阶和二阶统计量,以及一个存储优化器更新次数的计数。与 SGD 一样,在调用 `update` 时需要将此状态传递给它。
像 Flax 参数一样,优化器状态也只是一个 PyTree。任何具有兼容结构的 PyTree 都可以使用。同样,这也可以实现与 JAX 和 `jax.jit` 以及其他基于 JAX 的库的互操作性。
具体来说,**Optax 梯度变换只是一个包含纯函数 `init` 和 `update` 的命名元组**。`init` 是一个纯函数,它接收一个要变换的梯度示例实例,并返回优化器初始状态。对于 `optax.sgd`,无论提供的示例是什么,它都会返回一个空状态。对于 `optax.adam`,我们得到一个更复杂的状态,其中包含与所提供示例具有相同 PyTree 结构的一阶和二阶统计量。
`update` 接收一个更新的 PyTree,其结构与提供给 `init` 的示例实例相同。此外,它还接收 `init` 返回的优化器状态,以及可选地接收模型本身的参数,某些优化器可能需要这些参数。此函数将返回变换后的梯度(**这可能是另一组梯度,或实际的参数更新**)和新的优化器状态。
这在文档此处有很好的解释。
提供一些虚拟数据,我们得到以下结果
import optax
params = jnp.array([0.0, 1.0, 2.0]) # some dummy parameters
optimiser = optax.adam(learning_rate=0.01)
opt_state = optimiser.init(params)
grads = jnp.array([4.0, 0.6, -3])# some dummy gradients
updates, opt_state = optimiser.update(grads, opt_state, params)
updates
===
Out: Array([-0.00999993, -0.00999993, 0.00999993], dtype=float32)
Optax 提供了一个辅助函数来将更新应用于我们的参数: ```python new_params = optax.apply_updates(params, updates) new_params
输出:Array([-0.00999993, 0.99000007, 2.01 ], dtype=float32)
It is important to emphasise that Optax optimisers are gradient transformations,
**but gradient transformations are not just optimisers.** We'll see more of that
later after we finish the training loop.
On that note, let's begin with said training loop. Recall that our goal is to
train a class-conditioned, variational autoencoder (VAE) on the MNIST dataset.
> I chose this example as it is slightly more interesting than the typical
classification example found in most tutorials.
Not strictly related to JAX, Flax, or Optax, but it is worth describing what a
VAE is. First, an autoencoder model is one that maps some input $x$ in our data
space to a **latent vector** $z$ in the **latent space** (a space with smaller
dimensionality than the data space) and back to the data space. It is trained to
minimise the reconstruction loss between the input and the output, essentially
learning the identity function through an **information bottleneck**.
The portion of the network that maps from the data space to the latent space is
called the **encoder** and the portion that maps from the latent space to the
data space is called the **decoder**. Applying the encoder is somewhat
analogous to lossy compression. Likewise, applying the decoder is akin to
lossy decompression.
What makes a VAE different to an autoencoder is that the encoder does not
output the latent vector directly. Instead, **it outputs the mean and
log-variance of a Gaussian distribution, which we then sample from in order
to obtain our latent vector**. We apply an extra loss term to make these mean and
log-variance outputs roughly follow the standard normal distribution.
> Interestingly, defining the encoder this way means for every given input $x$
we have many possible latent vectors which are sampled stochastically. Our
encoder is almost mapping to a sphere of possible latents centred at the mean
vector with radius scaling with log-variance.
The decoder is the same as before. However, now we can sample **a latent from
the normal distribution and pass it to the decoder in order to generate samples
like those in the dataset**! Adding the variational component turns our
autoencoder compression model into a VAE generative model.

> Abstract diagram of a VAE, pilfered from [this AWS blog](https://aws.amazon.com/blogs/machine-learning/deploying-variational-autoencoders-for-anomaly-detection-with-tensorflow-serving-on-amazon-sagemaker/)
Our goal is to implement the model code for the VAE as well as the training loop
with both the reconstruction and variational loss terms. Then, we can sample new
digits that look like those in the MNIST dataset! Additionally, we will provide
an extra input to the model – the class index – so we can control which number
we want to generate.
Let's begin by defining our configuration. For this educational example, we will
just define some constants in a cell:
```python
batch_size = 16
latent_dim = 32
kl_weight = 0.5
num_classes = 10
seed = 0xffff
以及一些导入和 PRNG 初始化
import jax # install correct wheel for accelerator you want to use
import flax
import optax
import orbax
import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from jax.typing import ArrayLike
from typing import Tuple, Callable
from math import sqrt
import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
key = jax.random.PRNGKey(seed)
我们顺便获取 MNIST 数据集
train_dataset = MNIST('data', train = True, transform=T.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
JAX、Flax 和 Optax 不具备数据加载实用工具,所以我在这里直接使用了 PyTorch 实现的 MNIST 数据集,它完全可用。
现在来看我们的第一个真正的 Flax 模型。我们首先定义一个子模块 `FeedForward`,它实现了一堆带中间非线性层的线性层。
class FeedForward(nn.Module):
dimensions: Tuple[int] = (256, 128, 64)
activation_fn: Callable = nn.relu
drop_last_activation: bool = False
@nn.compact
def __call__(self, x: ArrayLike) -> ArrayLike:
for i, d in enumerate(self.dimensions):
x = nn.Dense(d)(x)
if i != len(self.dimensions) - 1 or not self.drop_last_activation:
x = self.activation_fn(x)
return x
key, model_key = jax.random.split(key)
model = FeedForward(dimensions = (4, 2, 1), drop_last_activation = True)
print(model)
params = model.init(model_key, jnp.zeros((1, 8)))
print(params)
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, (1, 8))
y = model.apply(params, x)
y
===
Out:
FeedForward(
# attributes
dimensions = (4, 2, 1)
activation_fn = relu
drop_last_activation = True
)
FrozenDict({
params: {
Dense_0: {
kernel: Array([[ 0.0840368 , -0.18825287, 0.49946404, -0.4610112 ],
[ 0.4370267 , 0.21035315, -0.19604324, 0.39427406],
[ 0.00632685, -0.02732705, 0.16799504, -0.44181877],
[ 0.26044282, 0.42476758, -0.14758752, -0.29886967],
[-0.57811564, -0.18126923, -0.19411889, -0.10860331],
[-0.20605426, -0.16065307, -0.3016759 , 0.44704655],
[ 0.35531637, -0.14256613, 0.13841921, 0.11269159],
[-0.430825 , -0.0171169 , -0.52949774, 0.4862139 ]], dtype=float32),
bias: Array([0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: Array([[ 0.03389561, -0.00805947],
[ 0.47362345, 0.37944487],
[ 0.41766328, -0.15580587],
[ 0.5538078 , 0.18003668]], dtype=float32),
bias: Array([0., 0.], dtype=float32),
},
Dense_2: {
kernel: Array([[ 1.175035 ],
[-1.1607001]], dtype=float32),
bias: Array([0.], dtype=float32),
},
},
})
Array([[0.5336972]], dtype=float32)
这里我们使用了 `nn.compact` 装饰器,因为逻辑相对简单。我们遍历元组 `self.dimensions`,并通过 `nn.Dense` 模块传递当前的激活,然后应用 `self.activation_fn`。对于 `FeedForward` 中的最后一个线性层,此激活函数可以选择性地省略。这是因为 `nn.relu` 只输出非负值,而有时我们需要非负输出!
使用 `FeedForward`,我们可以定义完整的 VAE 模型
class VAE(nn.Module):
encoder_dimensions: Tuple[int] = (256, 128, 64)
decoder_dimensions: Tuple[int] = (128, 256, 784)
latent_dim: int = 4
activation_fn: Callable = nn.relu
def setup(self):
self.encoder = FeedForward(self.encoder_dimensions, self.activation_fn)
self.pre_latent_proj = nn.Dense(self.latent_dim * 2)
self.post_latent_proj = nn.Dense(self.encoder_dimensions[-1])
self.class_proj = nn.Dense(self.encoder_dimensions[-1])
self.decoder = FeedForward(self.decoder_dimensions, self.activation_fn, drop_last_activation=False)
def reparam(self, mean: ArrayLike, logvar: ArrayLike, key: jax.random.PRNGKey) -> ArrayLike:
std = jnp.exp(logvar * 0.5)
eps = jax.random.normal(key, mean.shape)
return eps * std + mean
def encode(self, x: ArrayLike):
x = self.encoder(x)
mean, logvar = jnp.split(self.pre_latent_proj(x), 2, axis=-1)
return mean, logvar
def decode(self, x: ArrayLike, c: ArrayLike):
x = self.post_latent_proj(x)
x = x + self.class_proj(c)
x = self.decoder(x)
return x
def __call__(
self, x: ArrayLike, c: ArrayLike, key: jax.random.PRNGKey) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
mean, logvar = self.encode(x)
z = self.reparam(mean, logvar, key)
y = self.decode(z, c)
return y, mean, logvar
key = jax.random.PRNGKey(0x1234)
key, model_key = jax.random.split(key)
model = VAE(latent_dim=4)
print(model)
key, call_key = jax.random.split(key)
params = model.init(model_key, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), call_key)
recon, mean, logvar = model.apply(params, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), call_key)
recon.shape, mean.shape, logvar.shape
===
Out:
ClassVAE(
# attributes
encoder_dimensions = (256, 128, 64)
decoder_dimensions = (128, 256, 784)
latent_dim = 4
activation_fn = relu
)
((16, 784), (16, 4), (16, 4))
以上单元格内容很多。了解该模型的工作原理细节对于理解后续的训练循环并不重要,因为我们可以将该模型视为一个黑盒子。只需替换您选择的模型即可。话虽如此,我将简要解释每个函数:
- `setup`:创建网络的子模块,即两个 `FeedForward` 堆栈和两个 `nn.Linear` 层,用于向/从潜在空间进行投影。此外,它还初始化第三个 `nn.Linear` 层,将我们的类条件向量投影到与最后一个编码器层相同的维度。
- `reparam`:直接从随机高斯分布采样潜在变量是不可微分的,因此我们采用**重参数化技巧**。这涉及采样一个随机向量,按标准差缩放,然后加到均值上。由于它涉及随机数组生成,除了均值和对数方差之外,我们还将其作为输入。
- `encode`:将编码器和投影应用于输入到潜在空间。请注意,投影的输出实际上是潜在空间大小的两倍,因为我们将其一分为二以获得均值和对数方差。
- `decode`:从潜在空间向 `x` 应用投影,然后将 `class_proj` 在条件向量上的输出添加进去。这就是我们将类别信息注入模型的方式。最后,它通过解码器堆栈传递结果。
- `__call__`:这仅仅是完整的模型前向传播:`encode` 然后 `reparam` 然后 `decode`。这在训练期间使用。
上述例子还表明,除了 `setup` 和 `__call__` 之外,我们还可以向 Flax 模块添加其他函数。这对于更复杂的行为或我们只想执行模型部分功能(稍后会详细介绍)时非常有用。
现在我们有了模型、优化器和数据集。下一步是编写实现训练步骤的函数,然后对其进行 jit 编译。
def create_train_step(key, model, optimiser):
params = model.init(key, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), jax.random.PRNGKey(0)) # dummy key just as example input
opt_state = optimiser.init(params)
def loss_fn(params, x, c, key):
reduce_dims = list(range(1, len(x.shape)))
c = jax.nn.one_hot(c, num_classes) # one hot encode the class index
recon, mean, logvar = model.apply(params, x, c, key)
mse_loss = optax.l2_loss(recon, x).sum(axis=reduce_dims).mean()
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean ** 2 - jnp.exp(logvar), axis=reduce_dims)) # KL loss term to keep encoder output close to standard normal distribution.
loss = mse_loss + kl_weight * kl_loss
return loss, (mse_loss, kl_loss)
@jax.jit
def train_step(params, opt_state, x, c, key):
losses, grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, c, key)
loss, (mse_loss, kl_loss) = losses
updates, opt_state = optimiser.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss, mse_loss, kl_loss
return train_step, params, opt_state
这里,我没有直接定义训练步骤,而是定义了一个函数,该函数根据目标模型和优化器返回训练步骤函数,并返回新初始化的参数和优化器状态。
让我们把它全部拆解
- 首先,它使用一个示例输入初始化我们的模型。在这种情况下,这是一个 784 维数组,其中包含(展平的)MNIST 数字和一个随机、随机的密钥。
- 还使用我们刚刚初始化的参数初始化了优化器状态。
- 现在,它定义了损失函数。这仅仅是一个 `model.apply` 调用,它返回模型的输入重建,以及预测的均值和对数方差。然后我们计算均方误差损失和 KL 散度,最后计算加权和以得到最终损失。KL 损失项的作用是使编码器输出接近标准正态分布。
- 接下来是实际的训练步骤定义。它首先使用我们的老朋友 `jax.value_and_grad` 对 `loss_fn` 进行转换,该函数将返回损失和梯度。我们必须设置 `has_aux=True`,因为我们返回所有单个损失项以用于日志记录。我们将梯度、优化器状态和参数提供给 `optimiser.update`,它返回转换后的梯度和新的优化器状态。然后将转换后的梯度应用于参数。最后,我们返回新的参数、优化器状态和损失项——然后将整个过程封装在 `jax.jit` 中。呼……
生成训练步骤的函数只是我非常喜欢的一种模式,没有什么可以阻止你直接编写训练步骤。
让我们调用 `create_train_step`
key, model_key = jax.random.split(key)
model = VAE(latent_dim=latent_dim)
optimiser = optax.adamw(learning_rate=1e-4)
train_step, params, opt_state = create_train_step(model_key, model, optimiser)
当我们调用上述函数时,我们将得到一个准备好编译的 `train_step`,它能以极快的速度接受我们的参数、优化器状态和数据。与 jit 编译函数一样,第一次使用给定输入形状的调用会很慢,但在后续调用中会很快,因为我们跳过了编译和优化过程。
现在我们可以编写训练循环并训练模型了!
freq = 100
for epoch in range(10):
total_loss, total_mse, total_kl = 0.0, 0.0, 0.0
for i, (batch, c) in enumerate(train_loader):
key, subkey = jax.random.split(key)
batch = batch.numpy().reshape(batch_size, 784)
c = c.numpy()
params, opt_state, loss, mse_loss, kl_loss = train_step(params, opt_state, batch, c, subkey)
total_loss += loss
total_mse += mse_loss
total_kl += kl_loss
if i > 0 and not i % freq:
print(f"epoch {epoch} | step {i} | loss: {total_loss / freq} ~ mse: {total_mse / freq}. kl: {total_kl / freq}")
total_loss = 0.
total_mse, total_kl = 0.0, 0.0
===
Out:
epoch 0 | step 100 | loss: 49.439998626708984 ~ mse: 49.060447692871094. kl: 0.7591156363487244
epoch 0 | step 200 | loss: 37.1823616027832 ~ mse: 36.82903289794922. kl: 0.7066375613212585
epoch 0 | step 300 | loss: 33.82365036010742 ~ mse: 33.49456024169922. kl: 0.6581906080245972
epoch 0 | step 400 | loss: 31.904821395874023 ~ mse: 31.570871353149414. kl: 0.6679074764251709
epoch 0 | step 500 | loss: 31.095705032348633 ~ mse: 30.763246536254883. kl: 0.6649144887924194
epoch 0 | step 600 | loss: 29.771989822387695 ~ mse: 29.42426872253418. kl: 0.6954278349876404
...
epoch 9 | step 3100 | loss: 14.035745620727539 ~ mse: 10.833460807800293. kl: 6.404574871063232
epoch 9 | step 3200 | loss: 14.31241226196289 ~ mse: 11.043667793273926. kl: 6.53748893737793
epoch 9 | step 3300 | loss: 14.26440143585205 ~ mse: 11.01070785522461. kl: 6.5073771476745605
epoch 9 | step 3400 | loss: 13.96005630493164 ~ mse: 10.816412925720215. kl: 6.28728723526001
epoch 9 | step 3500 | loss: 14.166285514831543 ~ mse: 10.919700622558594. kl: 6.493169784545898
epoch 9 | step 3600 | loss: 13.819541931152344 ~ mse: 10.632755279541016. kl: 6.373570919036865
epoch 9 | step 3700 | loss: 14.452215194702148 ~ mse: 11.186063766479492. kl: 6.532294750213623
现在我们有了 `train_step` 函数,训练循环本身只是重复地获取数据,调用我们超快的 `train_step` 函数,并记录结果以便我们跟踪训练。我们可以看到损失正在下降,这意味着我们的模型正在训练!
请注意,KL 损失项在训练期间**增加**。只要它不会太高,这是可以的,否则从模型中采样将变得不可能。调整超参数 `kl_weight` 非常重要。太低会导致完美的重建但没有采样能力——太高则会导致输出变得模糊。
让我们从模型中进行采样,这样我们就可以看到它确实产生了一些合理的样本
def build_sample_fn(model, params):
@jax.jit
def sample_fn(z: jnp.array, c: jnp.array) -> jnp.array:
return model.apply(params, z, c, method=model.decode)
return sample_fn
sample_fn = build_sample_fn(model, params)
num_samples = 100
h, w = 10
key, z_key = jax.random.split(key)
z = jax.random.normal(z_key, (num_samples, latent_dim))
c = np.repeat(np.arange(h)[:, np.newaxis], w, axis=-1).flatten()
c = jax.nn.one_hot(c, num_classes)
sample = sample_fn(z, c)
z.shape, c.shape, sample.shape
===
Out: ((100, 32), (100, 10), (100, 784))
上面的单元格生成了 100 个样本——来自 10 个类别中的每个类别的 10 个示例。我们将采样函数 jit 编译,以防以后我们想再次采样。我们只调用 `model.decode` 方法,而不是整个模型,因为我们只需要解码我们随机采样的潜在变量。这通过在 `model.apply` 调用中指定 `method=model.decode` 来实现。
让我们使用 matplotlib 可视化结果
import matplotlib.pyplot as plt
import math
from numpy import einsum
sample = einsum('ikjl', np.asarray(sample).reshape(h, w, 28, 28)).reshape(28*h, 28*w)
plt.imshow(sample, cmap='gray')
plt.show()
看来我们的模型确实经过了训练,可以进行采样!此外,模型能够使用类别条件信号,以便我们可以控制生成哪些数字。因此,我们成功地使用 Flax 和 Optax 构建了一个完整的训练循环!
Flax 和 Optax 的额外小技巧
我想通过强调一些可能在您自己的应用程序中很有用的有趣和有用的功能来结束这篇博客文章。我不会深入探讨它们的任何细节,而只是简单地总结并为您指明正确的方向。
您可能已经注意到,当我们将参数、优化器状态和许多其他指标添加到 `train_step` 的返回调用中时,处理所有状态会变得有点麻烦。如果我们以后需要更复杂的状态,情况可能会变得更糟。一个解决方案是返回一个 `namedtuple`,这样我们至少可以将状态打包在一起。然而,Flax 提供了自己的解决方案,`flax.training.train_state.TrainState`,它有一些额外的函数,可以更轻松地更新组合状态(模型和优化器状态)。
最简单的方法是简单地使用我们之前的 `train_step`,并使用 `TrainState` 进行重构。
from flax.training.train_state import TrainState
def create_train_step(key, model, optimiser):
params = model.init(key, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), jax.random.PRNGKey(0))
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimiser)
def loss_fn(state, x, c, key):
reduce_dims = list(range(1, len(x.shape)))
c = jax.nn.one_hot(c, num_classes)
recon, mean, logvar = state.apply_fn(state.params, x, c, key)
mse_loss = optax.l2_loss(recon, x).sum(axis=reduce_dims).mean()
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean ** 2 - jnp.exp(logvar), axis=reduce_dims))
loss = mse_loss + kl_weight * kl_loss
return loss, (mse_loss, kl_loss)
@jax.jit
def train_step(state, x, c, key):
losses, grads = jax.value_and_grad(loss_fn, has_aux=True)(state, x, c, key)
loss, (mse_loss, kl_loss) = losses
state = state.apply_gradients(grads=grads)
return state, loss, mse_loss, kl_loss
return train_step, state
我们像以前一样在 `create_train_step` 中初始化参数。然而,下一步是使用 `TrainState.create` 创建状态,并传入模型前向调用、已初始化的参数以及我们想要使用的优化器。在内部,`TrainState.create` 将为我们初始化并存储优化器状态。
在 `loss_fn` 中,我们不再调用 `model.apply`,而是使用 `state.apply_fn`。两种方法是等效的,只是有时我们可能不在 `model` 的作用域内,因此无法访问 `model.apply`。
`train_step` 本身的变化最大。不再是先调用 `optimiser.update`,再调用 `optax.apply_updates`,我们直接调用 `state.apply_gradients`,它在内部更新优化器状态和参数。然后它返回新状态,我们将其返回并传递给 `train_step` 的下一次调用——就像我们处理 `params` 和 `opt_state` 一样。
可以通过继承 `TrainState` 来向其添加额外属性,例如添加属性来存储最新的损失。
总之,`TrainState` 使在训练循环中传递状态变得更容易,并抽象了优化器和参数更新。
Flax 的另一个有用功能是能够将参数**绑定**到模型,从而产生一个可以直接调用的交互式实例,就像一个带有内部状态的 PyTorch 模型一样。然而,这种状态是静态的,只有当我们再次绑定它时才能改变,这使得它无法用于训练。但是,它对于交互式调试或推理非常方便。
API 非常简单
key, model_key = jax.random.split(key)
model = nn.Dense(2)
params = model.init(model_key, jnp.zeros(8))
bound_model = model.bind(params)
bound_model(jnp.ones(8))
===
Out: Array([ 0.45935923, -0.691003 ], dtype=float32)
我们可以通过调用 `model.unbind` 来获取未绑定的模型及其参数: ```python bound_model.unbind()
输出:(Dense( # 属性 features = 2 use_bias = True dtype = None param_dtype = float32 precision = None kernel_init = init bias_init = zeros dot_general = dot_general ), FrozenDict({ params: { kernel: Array([[-0.11450272, -0.2808447 ], [-0.45104247, -0.3774913 ], [ 0.07462895, 0.3622056 ], [ 0.59189916, -0.34050766], [-0.10401642, -0.36226135], [ 0.157985 , 0.00198693], [-0.00792678, -0.1142673 ], [ 0.31233454, 0.4201768 ]], dtype=float32), bias: Array([0., 0.], dtype=float32), }, }))
I said I wouldn't enumerate layers in Flax as I don't see much value in doing
so, but I will highlight two particularly interesting ones. First is
`nn.Dropout` which is numerically the same as its PyTorch counterpart, but like
anything random in JAX, requires a PRNG key as input.
The dropout layer takes its random key by internally calling
`self.make_rng('dropout')`, which pulls and splits from a PRNG stream named
`'dropout'`. This means when we call `model.apply` we will need to define the
starting key for this PRNG stream. This can be done by passing a dictionary
mapping stream names to PRNG keys, to the `rngs` argument in `model.apply`:
```python
key, x_key = jax.random.split(key)
key, drop_key = jax.random.split(key)
x = jax.random.normal(x_key, (3,3))
model = nn.Dropout(0.5, deterministic=False)
y = model.apply({}, x, rngs={'dropout': drop_key}) # there is no state, just pass empty dictionary :)
x, y
===
Out: (Array([[ 1.7353934, -1.741734 , -1.3312583],
[-1.615281 , -0.6381292, 1.3057163],
[ 1.2640097, -1.986926 , 1.7818599]], dtype=float32),
Array([[ 3.4707868, 0. , -2.6625166],
[ 0. , 0. , 2.6114326],
[ 0. , -3.973852 , 0. ]], dtype=float32))
`model.init` 也接受一个 PRNG 键字典。如果您像我们目前所做的那样传入单个键,它将启动一个名为 `'params'` 的流。这相当于传入 `{'params': rng}`。
子模块可以访问流,因此无论 `nn.Dropout` 在模型中的哪个位置,它都可以调用 `self.make_rng('dropout')`。我们可以通过在 `model.apply` 调用中指定 PRNG 流来定义自己的 PRNG 流。在我们的 VAE 示例中,我们可以放弃手动传入键,而是使用 `self.make_rng('noise')` 或类似方法获取用于随机采样的键,然后在 `model.apply` 的 `rngs` 中传入一个起始键。对于具有大量随机性的模型,这样做可能值得。
第二个有用的内置模块是 `nn.Sequential`,它也类似于 PyTorch 的对应模块。它简单地将许多模块链接在一起,使得一个模块的输出流入下一个模块的输入。如果我们想快速定义大量堆叠的层,这很有用。
现在来谈谈 Optax 的一些小技巧!首先,Optax 带有一堆学习率调度器。除了在创建优化器时将浮点值传递给 `learning_rate`,我们还可以传递一个调度器。在应用更新时,Optax 会自动选择正确的学习率。让我们定义一个简单的线性调度器。
start_lr, end_lr = 1e-3, 1e-5
steps = 10_000
lr_scheduler = optax.linear_schedule(
init_value=start_lr,
end_value=end_lr,
transition_steps=steps,
)
optimiser = optax.adam(learning_rate=lr_scheduler)
你可以使用 `optax.join_schedules` 将调度器连接在一起,以获得更复杂的行为,例如学习率预热后接衰减。
warmup_start_lr, warmup_steps = 1e-6, 1000
start_lr, end_lr, steps = 1e-2, 1e-5, 10_000
lr_scheduler = optax.join_schedules(
[
optax.linear_schedule(
warmup_start_lr,
start_lr,
warmup_steps,
),
optax.linear_schedule(
start_lr,
end_lr,
steps - warmup_steps,
),
],
[warmup_steps],
)
optimiser = optax.adam(lr_scheduler)
`optax.join_schedules` 的最后一个参数应该是一个整数序列,定义不同调度器之间的步长边界。在本例中,我们在 `warmup_steps` 步后从预热切换到衰减。
Optax 在其 `opt_state` 中跟踪优化器步数,因此我们不需要自己跟踪。它将使用此计数自动选择正确的学习率。
与连接调度器类似,Optax 支持链式连接优化器。更具体地说,是梯度变换的链式连接。
optimiser = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(1e-2),
)
调用 `optimiser.update` 时,梯度将首先被裁剪,然后进行常规的 Adam 更新。像这样将变换链式连接起来是一种非常优雅的 API,可以实现复杂的行为。例如,在 PyTorch 中添加更新的指数移动平均(EMA)是非平凡的,而在 Optax 中,它就像将 `optax.ema` 添加到我们的 `optax.chain` 调用中一样简单。
optimiser = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(1e-2),
optax.ema(decay=0.999)
)
在这种情况下,`optax.ema` 是对最终更新的变换,而不是对未处理梯度的变换。
梯度累积在 Optax 中实现为优化器包装器,而不是梯度变换。
grad_accum = 4
optimiser = optax.MultiSteps(optax.adam(1e-2), grad_accum)
返回的优化器在 `optimiser.update` 调用期间收集更新,直到发生 `grad_accum` 步。在中间步中,返回的更新将是与 `params` 形状相同的零 PyTree,导致没有更新。每 `grad_accum` 步,将返回累积的更新。
`grad_accum` 也可以是一个函数,这为我们提供了一种通过调整参数更新之间的步数来在训练期间改变批量大小的方法。
如果我们只想训练某些参数呢?例如,在微调预训练模型时。如今,这是一种非常常见的做法,即获取预训练的大型语言模型并将其调整以适应特定的下游任务。
让我们从 Huggingface Hub 获取一个预训练的 BERT 模型: ```python from transformers import FlaxBertForSequenceClassification model = FlaxBertForSequenceClassification.from_pretrained('bert-base-uncased') model.params.keys()
输出:dict_keys(['bert', 'classifier'])
> Huggingface provides Flax versions of *most* of their models. The API to use
them is a bit different, calling `model(**inputs, params=params)` rather than
`model.apply`. Providing no parameters will use the pretrained weights stored
in `model.params` which is useful for inference-only tasks, but for training we
need to pass the current parameters to the call.
We can see there are two top-level keys in the parameter PyTree: `bert` and
`classifier`. Suppose we only want to finetune the classifier head and leave the
BERT backbone alone, we can achieve this using `optax.multi_transform`:
```python
optimiser = optax.multi_transform({'train': optax.adam(1e-3), 'freeze': optax.set_to_zero()}, {'bert': 'freeze', 'classifier': 'train'})
opt_state = optimiser.init(model.params)
grads = jax.tree_map(jnp.ones_like, model.params)
updates, opt_state = optimiser.update(grads, opt_state, model.params)
`optax.multi_transform` 接收两个输入,第一个是从标签到梯度变换的映射。第二个是一个 PyTree,其结构或前缀与更新相同(在上面的例子中我们使用前缀方法),映射到标签。与给定更新的标签匹配的变换将被应用。这允许参数的分区和对不同部分应用不同的更新。
第二个参数也可以是一个函数,该函数给定更新 PyTree,返回一个将更新(或其前缀)映射到标签的 PyTree。
这可用于其他情况,例如为不同的层使用不同的优化器(例如,为某些层禁用权重衰减),但在我们的情况下,我们仅将 `optax.adam` 用于可训练参数,并使用无状态变换 `optax.set_to_zero` 将其他区域的梯度置零。
在 jit 编译函数中,由于优化过程发现它们将始终为零,因此对它们应用了 `optax.set_to_zero` 的梯度将不会被计算。因此,我们从仅微调部分层中获得了预期的内存节省!
让我们打印更新,以便我们看到 BERT 主干中确实没有更新,而分类器头中有更新: ```python updates['classifier'], updates['bert']['embeddings']['token_type_embeddings']
输出:{'bias': Array([-0.00100002, -0.00100002], dtype=float32), 'kernel': Array([[-0.00100002, -0.00100002], [-0.00100002, -0.00100002], [-0.00100002, -0.00100002], ..., [-0.00100002, -0.00100002], [-0.00100002, -0.00100002], [-0.00100002, -0.00100002]], dtype=float32)} {'embedding': Array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}
We can verify that all updates are zero using `jax.tree_util.tree_reduce`:
```python
jax.tree_util.tree_reduce(lambda c, p: c and (jnp.count_nonzero(p) == 0), updates['bert'], True)
===
Out: Array(True, dtype=bool)
尽管 JAX 生态系统相对年轻,但 Flax 和 Optax 都功能丰富。我建议您打开 Flax 或 Optax API 参考,然后搜索您习惯在其他框架中使用的层、优化器、损失函数和功能。
我想谈的最后一件事涉及一个完全不同的基于 JAX 的库。**Orbax** 提供了 PyTree 检查点实用程序,用于保存和恢复任意 PyTree。我不会深入探讨细节,但这里将展示基本用法。没有什么比花费数小时训练却发现忘记添加检查点代码更糟糕的了!
以下是保存 BERT 分类器参数的基本用法
import orbax
import orbax.checkpoint
from flax.training import orbax_utils
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(model.params['classifier'])
orbax_checkpointer.save('classifier.ckpt', model.params['classifier'], save_args=save_args)
!ls
===
Out: classifier.ckpt
我们可以通过执行以下命令恢复: ```python orbax_checkpointer.restore('classifier.ckpt')
输出:{'bias': array([0., 0.], dtype=float32), 'kernel': array([[-0.06871808, -0.06338844], [-0.03397266, 0.00899913], [-0.00669084, -0.06431466], ..., [-0.02699363, -0.03812294], [-0.00148801, 0.01149782], [-0.01051403, -0.00801195]], dtype=float32)}
Which returns the raw PyTree. If you are using a custom dataclass with objects
that can't be serialised (such as a Flax train state where `apply_fn` and `tx`
can't be serialised) you can pass an example PyTree to `item` in the `restore`
call, to let Orbax know the structure you want.
Manually saving checkpoints like this is a bit old-fashioned. Orbax has a bunch
of automatic versioning and scheduling features built in, such as automatic
deleting of old checkpoints, tracking the best metric, and more. To use these
features, wrap the `orbax_checkpointer` in
`orbax.checkpoint.CheckpointManager`:
```python
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
'managed-checkpoint', orbax_checkpointer, options)
for step in range(10):
checkpoint_manager.save(step, model.params['classifier'], save_kwargs={'save_args': save_args})
!ls -l managed-checkpoint/*
===
Out:
managed-checkpoint/6:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
managed-checkpoint/7:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
managed-checkpoint/8:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
managed-checkpoint/9:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
由于我们设置了 `max_to_keep=4`,因此只保留了最后四个检查点。
我们可以查看哪些步骤有检查点: ```python checkpoint_manager.all_steps()
输出:[6, 7, 8, 9]
As well as view if there is a checkpoint for a specific step:
```python
checkpoint_manager.should_save(6)
===
Out: False
以及最新保存的步骤是: ```python checkpoint_manager.latest_step()
输出:9
We can restore using the checkpoint manager. Rather than provide a path to the
`restore` function, we provide the step we want to restore:
```python
step = checkpoint_manager.latest_step()
checkpoint_manager.restore(step)
===
Out: {'bias': array([0., 0.], dtype=float32),
'kernel': array([[-0.06871808, -0.06338844],
[-0.03397266, 0.00899913],
[-0.00669084, -0.06431466],
...,
[-0.02699363, -0.03812294],
[-0.00148801, 0.01149782],
[-0.01051403, -0.00801195]], dtype=float32)}
对于特别大的检查点,Orbax 支持异步检查点,它将检查点移动到后台线程。您可以通过将 `orbax.checkpoint.AsyncCheckpointer` 包装在我们之前创建的 `orbax.checkpoint.PyTreeCheckpointer` 外部来实现这一点。
您可能会在网上看到 Flax 检查点实用程序的引用。然而,这些实用程序正在被弃用,建议改用 Orbax。
Orbax 的文档有点简略,但它有很多选项可供选择。值得阅读 `CheckpointManagerOptions` 类这里,并查看可用功能。
结论
在这篇博客文章中,我介绍了两个基于 JAX 的库:Flax 和 Optax。这更多是一篇关于如何使用这些库轻松实现 JAX 训练循环的实用指南,而不是像我之前关于 JAX 的博客文章那样进行意识形态讨论。
本文总结如下
- Flax 提供了一个神经网络 API,允许开发者以基于类的方式构建神经网络模块。与其他框架不同,这些模块内部不包含状态,本质上是松散地将函数与参数和输入关联起来的空壳,并提供简单的参数初始化方法。
- Optax 提供了一套庞大的优化器,用于更新我们的参数。这些优化器,就像 Flax 模块一样,不包含状态,必须手动传入状态。所有优化器都只是梯度变换:一对纯函数 `init` 和 `update`。Optax 还提供了其他梯度变换和包装器,以实现更复杂的行为,例如梯度裁剪和参数冻结。
- 这两个库都只操作并返回 PyTree,并且可以轻松地与基础 JAX 互操作——尤其与 `jax.jit`。这也使它们能够与基于 JAX 的其他库互操作。例如,通过选择 Flax,我们不会被限制使用 Optax,反之亦然。
这两个库的细节远不止本文所述,但我希望这是一个良好的开端,能帮助您在 JAX 中创建自己的训练循环。现在一个很好的练习是使用本文中的训练循环和模型代码,并将其调整以适应您自己的任务,例如另一个生成模型。
如果您喜欢这篇文章,请考虑在Twitter上关注我,或访问我的网站,阅读更多关于机器学习和其他主题的随笔。感谢您读到这里,希望您觉得有用!
致谢和额外资源
一些很好的额外资源
Flax 的一些替代品
我不知道 Optax 有相对成熟的替代品。如果你知道一些,请告诉我!
发现这篇博文有问题?请通过电子邮件或 Twitter 告诉我!