学习 JAX — 一个用于高性能机器学习的框架
最近,我参加了 Huggingface x Google Cloud 社区冲刺,尽管它被命名为 ControlNet 冲刺,但其范围非常广泛:涉及扩散模型,使用 JAX,并免费使用 Google Cloud 提供的 TPU。在相对短的时间内,诞生了许多很酷的项目。
我们的项目雄心勃勃:将我的硕士论文工作——结合 逐步展开去噪自编码器(与离散扩散模型大致相关)与 VQ-GAN,全部移植到 JAX,然后添加对文本条件生成支持。有了这个新模型,我们将从头开始训练一个新的文本到图像模型,就像 Dalle-mini 那样。
有趣的是,Dalle-mini 正是诞生于之前的 Huggingface 社区冲刺。这些社区活动可以产生许多出色的项目!
我原始论文中的无条件结果。
不幸的是,我们未能实现最终目标,因为代码中某个地方存在一个细微的错误。除了漂亮的颜色,我们从模型中没有得到太多成果。我真希望能展示一些有趣的输出,但尽管团队尽了最大努力,我们还是时间不够了。你可以在这里找到我们项目的 JAX 代码。尽管结果令人失望,但我仍然很高兴我参与了,因为我学到了很多东西。
我们社区冲刺模型的“非常棒”的样本。
为了社区冲刺做准备,我通过遵循 Aleksa Gordic 的一个出色教程系列,开始了 JAX 的学习。Aleksa 在视频开头就提到他自己也正在学习 JAX。毫无疑问,他现在已经更好了,但我仍然深受这种态度的启发:在自己学习的同时分享和教学。因此,我决定在冲刺结束后,我会秉承这种精神,分享我在学习 JAX 两个月后的所知。所以,我们现在在这里了。
虽然可以单独使用 JAX 实现所有功能——包括手动实现优化器和模型——但这并不是我喜欢的方法。在冲刺期间,我们大量使用了基于 JAX 构建的库,例如 Flax 和 Optax。自己尝试做所有事情当然很有价值,但如果你只是想开始,同样值得直接使用更高级别的框架,而且我觉得许多现有教程已经涵盖了从头开始的工作。
话虽如此,在这篇特定的博客中,我将只涵盖 JAX 本身——将使用高级库创建完整训练循环留到以后的文章。我最初尝试在一个单元中涵盖所有内容,但篇幅过长,难以处理。即使现在,只涵盖 JAX,这篇帖子也已经很长。我将其称之为 JAX 的深入探讨与入门介绍之间。我跳过了框架的某些部分,同时深入探讨了我认为理解起来很重要的概念。
要理解这篇文章,您应该对 Python 和数组操作库(如 NumPy)有一定的经验——机器学习经验会有帮助,但不是必需的。我希望它能成为进入 JAX 生态系统的一个良好切入点,并为那些经验更丰富的人提供一些独特的视角。
如果您对从头开始实现所有功能感到好奇,我建议您查看前面提到的 Aleksa Gordic 的教程,或者这里的官方教程。
闲话少说……
基本用法“几乎”和 NumPy 一样
JAX 是由 Google 开发并开源的高性能机器学习研究和数值计算框架。有人说这个名字来源于其三个核心组件的结合:即时(Just-in-time)编译、自动微分(Autodiff)和 XLA。JAX 的原始论文中提到它代表“Just After eXecution”(执行之后)。当我分享这个小知识时,似乎没有人对此感到困扰。
JAX 的一大吸引力在于它与 NumPy 共享相似的 API,但可以在 GPU 和 TPU 等快速加速器上执行,同时代码与加速器无关。熟悉的 API 也有助于工程师快速掌握 JAX——或者至少让他们入门。此外,与其他“可用于机器学习”的框架(如 PyTorch 和 Tensorflow)相比,它对多设备并行提供了非常好的内置支持。
尽管它的确旨在支持机器学习研究,但在我看来,它对机器学习的偏向较弱,更容易应用于其他领域。这有点类似于 NumPy,它是一个通用的数组操作库,因为它足够通用,可以做任何事情。然而,我认为您使用 JAX 的方式与 NumPy 非常不同,尽管最初看起来相似。
具体来说,如果说 NumPy 是关于逐个操作地操作数组,那么 JAX 则是关于定义操作和输入的计算图,并让编译器对其进行优化。换句话说,就是定义你想要发生的事情,然后让 JAX 承担起使其快速运行的繁重工作。在 NumPy 中,开发者需要通过调用快速且高度优化的函数并尽可能避免慢速的 Python 环境来优化一切。这种额外的负担确实赋予了比严格的 JAX 更大的灵活性。然而,在许多机器学习应用中,我们并不需要这种灵活性。
够了,意识形态的咆哮,让我们看看友好的 JAX Numpy API,首先初始化几个数组。
import jax
import jax.numpy as jnp
import numpy as np
L = [0, 1, 2, 3]
x_np = np.array(L, dtype=np.int32)
x_jnp = jnp.array(L, dtype=jnp.int32)
x_np, x_jnp
===
Out: (Array([0, 1, 2, 3], dtype=int32), array([0, 1, 2, 3], dtype=int32))
请注意,您在较旧的教程中可能会看到
import jax.numpy as np
这行代码。这已不再是惯例,之前这样做的建议将成为人类历史上的一个污点。
是不是惊人的相似?jax.numpy
接口与 numpy
接口非常相似,这意味着我们几乎可以用 jax.numpy
中类似的功能来完成 numpy
中的所有操作。
x1 = x_jnp*2
x2 = x_jnp+1
x3 = x1 + x2
x1, x2, x3
===
Out: (Array([0, 2, 4, 6], dtype=int32),
Array([1, 2, 3, 4], dtype=int32),
Array([ 1, 4, 7, 10], dtype=int32))
jnp.dot(x1, x2), jnp.outer(x1, x2)
===
Out: (Array(40, dtype=int32),
Array([[ 0, 0, 0, 0],
[ 2, 4, 6, 8],
[ 4, 8, 12, 16],
[ 6, 12, 18, 24]], dtype=int32))
如果您之前使用过 NumPy,所有这些都应该很熟悉。我不会通过枚举函数来让您感到厌烦——那是文档的用途。
第一个有趣的区别是 JAX 处理随机性的方式。在 NumPy 中,要从均匀分布生成随机数组,我们只需执行:```python random_np = np.random.random((5,)) random_np
输出:array([0.58337985, 0.87832186, 0.08315021, 0.16689551, 0.50940328])
In JAX it works differently. A key concept in JAX is that functions in it are
**pure**. This means that given the same input they will always return the same
output, and do not modify any global state from within the function. Using
random number generation that modifies some global psuedorandom number generator
(PRNG) clearly violates both principles. Therefore, we have to handle randomness
in a stateless way by manually passing around the PRNG key and splitting it to
create new random seeds. This has the added benefit of making randomness in code
more reproducible – ignoring accelerator-side stochasticity – as in JAX we are
forced to handle fixed seeds by default. Let's see what that looks like:
```python
seed = 0x123456789 # some integer seed. In hexadecimal just to be ✨✨
key = jax.random.PRNGKey(seed) # create the initial key
key, subkey = jax.random.split(key) # split the key
random_jnp = jax.random.uniform(subkey, (5,)) # use `subkey` to generate, `key` can be split into more subkeys later.
random_jnp
===
Out: Array([0.2918682 , 0.90834665, 0.13555491, 0.08107758, 0.9746183 ], dtype=float32)
如果希望每个随机操作都产生不同的输出,请务必不要重复使用相同的键:```python jax.random.normal(key, (2,)), jax.random.normal(key, (2,))
输出:(Array([-0.67039955, 0.02259737], dtype=float32), Array([-0.67039955, 0.02259737], dtype=float32))
You may be pleased to know that if we want to generate `N` random arrays, we don't need to call `jax.random.split` in a loop `N` times. Pass the number of keys you want to the function:
```python
key, *subkeys = jax.random.split(key, 5)
[jax.random.normal(s, (2,2)) for s in subkeys]
===
Out: [Array([[ 1.0308125 , -0.07533383],
[-0.36027843, -1.270425 ]], dtype=float32),
Array([[ 0.34779412, -0.11094793],
[ 1.0509511 , 0.52164143]], dtype=float32),
Array([[ 1.5565109 , -0.9507161 ],
[ 1.4706124 , 0.25808835]], dtype=float32),
Array([[-0.5725152 , -1.1480215 ],
[-0.6206856 , -0.12488112]], dtype=float32)]
另一个小区别是 JAX 不允许原地操作:```python x1[0] = 5
输出:TypeError 回溯(最近一次调用在最后)
在 <cell line: 1>() 中 ----> 1 x1[0] = 5
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in _unimplemented_setitem(self, i, x) 261 “或另一个 .at[] 方法:” 262 “https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html”) --> 263 raise TypeError(msg.format(type(self))) 264 265 def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array
TypeError: ‘<class 'jaxlib.xla_extension.ArrayImpl'>’ 对象不支持项赋值。JAX 数组是不可变的。代替 x[idx] = y
,使用 x = x.at[idx].set(y)
或其他 .at[] 方法:https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html 正如错误消息所说,JAX 数组是**不可变的**,因此同样的问题也适用于其他原地操作,例如 +=
、*=
等。同样,正如错误消息所说,我们可以使用 JAX 数组上的 at
属性来执行功能上纯净的等效操作。这会返回新数组,但将其设置为旧变量在数值上等同于真正的原地版本。```python x1_p999 = x1.at[0].add(999) x1, x1_p999
输出:(Array([0, 2, 4, 6], dtype=int32), Array([999, 2, 4, 6], dtype=int32))
> Applying `x1 += 5` and similar *does* work, but under the Python hood this is
just `x1 = x1 + 5` anyway. It just creates a new array and hence is still
immutable.
JAX functions also only accept NumPy or JAX array inputs. This is in contrast
to NumPy that will happily accept Python lists. JAX chooses to throw an error to
avoid silent degradation in performance.
One final difference is that out of bounds indexing does not raise an error. This is because raising an error from code running on an accelerator is difficult and our goal with "accelerated NumPy" is to use accelerators. This is similar to how invalid floating point arithmetic results in NaN values, rather than simply erroring.
When indexing to retrieve a value out of bounds, JAX will instead just clamp the
index to the bounds of the array:
```python
x1[0], x1[-1], x1[10000]
===
Out: (Array(0, dtype=int32), Array(6, dtype=int32), Array(6, dtype=int32))
当索引更新超出边界的值(例如使用 .at
属性)时,更新将被简单忽略:```python x1 = x1.at[10000].set(999) x1
输出:Array([0, 2, 4, 6], dtype=int32)
All somewhat interesting, but so far there isn't a great deal of pull towards
JAX over NumPy. It gets more concerning when we start timing the functions:
```python
x1_np, x2_np = np.asarray(x1), np.asarray(x2)
%timeit x1_np @ x2_np
%timeit (x1 @ x2).block_until_ready()
===
Out: 1.17 µs ± 6.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
7.27 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
上述乘法的 JAX 版本大约慢 6-7 倍,这是为什么?
基准测试需要
block_until_ready
函数。通常,JAX 在操作完成之前不会将控制权返回给 Python。它异步地调度到加速器。因此,返回时间可能快于实际计算时间,导致基准测试不准确。
这又回到了 NumPy 旨在以逐操作(或即时)方式进行数组操作,而 JAX 旨在定义图并让编译器为您优化的观点。通过像 NumPy 一样即时执行 JAX 函数,我们没有留下优化空间,并且由于 JAX 调度操作的额外开销,我们得到了一个较慢的函数。直截了当地说,如果您这样使用 JAX,那您就做错了。
那么,我们如何让 JAX 运行得快呢?通过利用 XLA 的强大功能。
进入 jax.jit
前面函数如此慢的原因是 JAX 每次只向加速器分派一个操作。JAX 的预期用法是使用 XLA 将多个操作——理想情况下是几乎所有操作——一起编译。为了指示要一起编译的区域,我们可以将要编译的函数传递给 jax.jit
函数或使用 @jax.jit
装饰器。该函数不会立即编译,而是在首次调用时编译——因此得名“即时编译”。
在首次调用期间,输入数组的形状将用于追踪计算图,使用 Python 解释器逐步执行函数并逐个执行操作,并在图中记录发生的情况。此中间表示可以提供给 XLA,随后进行编译、优化和缓存。如果使用相同的输入数组形状和数据类型调用相同的函数,将检索此缓存,从而跳过追踪和编译过程,直接调用高度优化的预编译二进制大对象。
让我们看看它的实际应用。
def fn(W, b, x):
return x @ W + b
key, w_key, b_key, x_key = jax.random.split(key, 4)
W = jax.random.normal(w_key, (4, 2)),
b = jax.random.uniform(b_key, (2,))
x = jax.random.normal(x_key, (4,))
print("`fn` time")
%timeit fn(W, b, x).block_until_ready()
print("`jax.jit(fn)` first call time")
jit_fn = jax.jit(fn)
%time jit_fn(W, b, x).block_until_ready()
print("`jit_fn` time")
%timeit jit_fn(W, b, x).block_until_ready()
===
Out:
`fn` time
26.1 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
`jit_fn` first call (warmup) time
CPU times: user 35.8 ms, sys: 38 µs, total: 35.9 ms
Wall time: 36.3 ms
`jit_fn` time
7.62 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
正如预期,第一次调用将比后续调用花费更长的时间。因此,务必将第一次调用排除在任何基准测试之外。我们还发现,即使对于这个简单的示例,编译后的函数也比原始函数执行得快得多。
我们可以通过在函数上调用 jax.make_jaxpr
来将追踪到的图视为 jaxpr
:```python jax.make_jaxpr(fn)(params, x)
输出:{ lambda ; a:f32[4,2] b:f32[2] c:f32[4]. let d:f32[2] = dot_general[dimension_numbers=(([0], [0]), ([], []))] c a e:f32[2] = add d b in (e,) }
And also the compiled version of the function, albeit hard to read:
```python
print(jax.jit(fn).lower(params, x).compile().as_text())
===
HloModule jit_fn, entry_computation_layout={(f32[4,2]{1,0},f32[2]{0},f32[4]{0})->f32[2]{0}}, allow_spmd_sharding_propagation_to_output={true}
%fused_computation (param_0.1: f32[2], param_1.1: f32[4], param_2: f32[4,2]) -> f32[2] {
%param_1.1 = f32[4]{0} parameter(1)
%param_2 = f32[4,2]{1,0} parameter(2)
%dot.0 = f32[2]{0} dot(f32[4]{0} %param_1.1, f32[4,2]{1,0} %param_2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, metadata={op_name="jit(fn)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="<ipython-input-4-04cd19da0726>" source_line=2}
%param_0.1 = f32[2]{0} parameter(0)
ROOT %add.0 = f32[2]{0} add(f32[2]{0} %dot.0, f32[2]{0} %param_0.1), metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-4-04cd19da0726>" source_line=2}
}
ENTRY %main.6 (Arg_0.1: f32[4,2], Arg_1.2: f32[2], Arg_2.3: f32[4]) -> f32[2] {
%Arg_1.2 = f32[2]{0} parameter(1), sharding={replicated}
%Arg_2.3 = f32[4]{0} parameter(2), sharding={replicated}
%Arg_0.1 = f32[4,2]{1,0} parameter(0), sharding={replicated}
ROOT %fusion = f32[2]{0} fusion(f32[2]{0} %Arg_1.2, f32[4]{0} %Arg_2.3, f32[4,2]{1,0} %Arg_0.1), kind=kOutput, calls=%fused_computation, metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-4-04cd19da0726>" source_line=2}
}
下面是一个更明确、更愚蠢的例子。
def stupid_fn(x):
y = jnp.copy(x)
for _ in range(1000):
x = x * x
return y
print("`stupid_fn` time")
%time stupid_fn(x).block_until_ready()
print("`jit_stupid_fn` first call")
jit_stupid_fn = jax.jit(stupid_fn)
%time jit_stupid_fn(x).block_until_ready()
print("`jit_stupid_fn` time")
%timeit jit_stupid_fn(x).block_until_ready()
===
Out:
`stupid_fn` time
CPU times: user 58.6 ms, sys: 1.06 ms, total: 59.7 ms
Wall time: 81.9 ms
`jit_stupid_fn` first call
CPU times: user 666 ms, sys: 13.9 ms, total: 680 ms
Wall time: 800 ms
`jit_stupid_fn` time
8.72 µs ± 735 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
在函数中,它将输入 x
复制到变量 y
,然后将输入自身乘以 1000 次。最后,它只是返回 y
,使这些乘法完全没有意义。在非 jit 版本中,程序会愉快且无意义地执行乘法。无知是福。
首次调用 jit 函数时,JAX 仍会逐步执行所有乘法以追踪计算图。然而,后续调用所使用的编译版本将飞速运行,因为 XLA 发现这些乘法对于获取最终输出而言并非必需,并将其优化掉。我们可以通过打印 jaxpr
来实际查看这一点:```python jax.make_jaxpr(stupid_fn)(x)
输出:{ lambda ; a:f32[4]. let b:f32[4] = copy a c:f32[4] = mul a a d:f32[4] = mul c c e:f32[4] = mul d d f:f32[4] = mul e e ... bmh:f32[4] = mul bmg bmg bmi:f32[4] = mul bmh bmh bmj:f32[4] = mul bmi bmi bmk:f32[4] = mul bmj bmj bml:f32[4] = mul bmk bmk bmm:f32[4] = mul bml bml _:f32[4] = mul bmm bmm in (b,) } 这显示了所有 1000 次乘法(相信我!)。与编译版本进行比较:
python print(jax.jit(stupid_fn).lower(x).compile().as_text())
输出:HloModule jit_stupid_fn, entry_computation_layout={(f32[4]{0})->f32[4]{0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.2 (Arg_0.1: f32[4]) -> f32[4] { %Arg_0.1 = f32[4]{0} parameter(0), sharding={replicated} ROOT %copy = f32[4]{0} copy(f32[4]{0} %Arg_0.1) }
Which contains only a single copy operation. Experiment with the above code
blocks yourself by changing the number of iterations in the loop. You will find
that the time to execute the original function will increase with number of
iterations, along with the time to trace the graph on first call to the jit
function. However, the time to execute the compiled version on subsequent calls
will not increase in a meaningful way.
The above is a contrived example, but demonstrates a critical point: **we can
let XLA do a lot of the heavy lifting for us optimisation-wise.** This is
different to other frameworks that execute eagerly, where it would
happily execute extremely pointless code. This isn't a fault of the
framework as eager execution has many benefits, but demonstrates the
point that compiling our functions using XLA can help optimise our code in ways
we didn't know about, or could reasonably anticipate.
What exact optimisations XLA applies is a topic outside the scope of this blog.
One quick example is that the earlier statement about JAX arrays not allowing
in-place operations results in no potential performance loss. This is because
XLA can identify cases where it can replace operations with in-place
equivalents. So basically don't sweat it if you were worried earlier about not
being able to do operations in-place!
Secondly, in order to let XLA be the best it can be, **`jax.jit` should be used
in the widest possible context**. For example, (again contrived) if we had only
jit compiled the multiplications in `stupid_fn`, XLA would be unaware that the
outermost loop was unnecessary and could not optimise it out – it is simply
outside the region to be compiled. A concrete machine learning example would be
wrapping the entire training step – forward, backwards and optimiser step – in
`jax.jit` for maximum effect.
Most machine learning applications can be expressed in this way: one monolithic
compiled function that we throw data and model parameters at. It just might take
some massaging. In the original JAX paper, they say "The design of JAX is
informed by the observation that ML workloads are typically dominated by **PSC
(pure-and-statically-composed) subroutines**" which lends itself well to this
compilation process. Even functions that seemingly cannot have static input
shapes can be converted into a static form, for example padding sequences in
language modeling tasks or rewriting our functions in clever ways.
Although eager mode execution is useful for development work, once
development is done there is less benefit to eager execution over heavily
optimised binary blobs, hungry for our data. However, said compilation and optimisations rely on following the rules of JAX.
## JIT needs static shapes
The biggest blocker to jit compiling functions is that **all arrays need to have
static shapes**. That is to say, given the **shapes** and shapes alone of the
function inputs, it should be possible to determine the shape of all other
variables in the traced graph at compile time.
Take for example the following function, that given an integer `length` returns
an array filled with the value `val`:
```python
def create_filled(val, length):
return jnp.full((length,), val)
print(create_filled(1.0, 5))
print(create_filled(2, 2))
jit_create_filled = jax.jit(create_filled)
jit_create_filled(2, 5)
===
Out: [1. 1. 1. 1. 1.]
[2 2]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-13-0ecd13642388> in <cell line: 8>()
6
7 jit_create_filled = jax.jit(create_filled)
----> 8 jit_create_filled(2, 5)
[... skipping hidden 12 frame]
3 frames
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in canonicalize_shape(shape, context)
2037 except TypeError:
2038 pass
-> 2039 raise _invalid_shape_error(shape, context)
2040
2041 def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJAXprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function create_filled at <ipython-input-13-0ecd13642388>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
在即时执行中,函数返回我们预期的结果。然而,当追踪函数的 jit 版本时,会出现错误。这是因为在追踪时,jnp.full
函数会接收一个追踪器数组,该数组只包含有关形状和 dtype 的信息——而不包含用于确定形状的值。因此,无法追踪输出数组,因为在编译时形状是未知的。
我们可以通过使用 jax.jit
的一个参数 static_argnums
来解决这个问题。这指定了哪些参数不进行追踪,简单地在编译时将其视为常规 Python 值。在 jaxpr
图中,我们 Python 级别函数的 length
参数本质上变成了图中的一个常量。
jit_create_filled = jax.jit(create_filled, static_argnums=(1,))
print(jit_create_filled(2, 5))
print(jit_create_filled(1., 10))
print(jax.make_jaxpr(create_filled, static_argnums=(1,))(2, 5))
print(jax.make_jaxpr(create_filled, static_argnums=(1,))(1.6, 10))
===
Out: [2 2 2 2 2]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
{ lambda ; a:i32[]. let
b:i32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] a
in (b,) }
{ lambda ; a:f32[]. let
b:f32[10] = broadcast_in_dim[broadcast_dimensions=() shape=(10,)] a
in (b,) }
由于形状现在是图中的常量,每次向函数传递不同的 length
时,它都会重新编译。因此,这种方法只有在 length
的可能值非常有限的情况下才真正有效,否则我们将不断编译不同的图。
毫无疑问,尽管 Python 级别的函数是相同的,但针对不同静态输入调用的底层二进制文件是完全不同的。我们基本上将缓存从匹配函数和输入形状,转变为匹配函数、输入形状以及静态参数的值。
现在换一个例子:我们定义一个函数,它接受一个输入数组 x
和一个与 x
形状相同的布尔掩码 mask
,并返回一个新数组,其中被掩码的位置设置为一个较大的负数。
def mask_tensor(x, mask):
x = x.at[mask].set(-100.)
return x
key, x_key, mask_key = jax.random.split(key, 3)
x = jax.random.normal(x_key, (4,4))
mask = jax.random.uniform(mask_key, (4,4)) < 0.5
print("calling eager function")
print(mask_tensor(x, mask))
print("calling compiled function")
jit_mask_tensor = jax.jit(mask_tensor)
jit_mask_tensor(x, mask)
===
Out: calling eager function
[[-3.8728207e-01 -1.3147168e+00 -2.2046556e+00 4.1792620e-02]
[-1.0000000e+02 -1.0000000e+02 -8.2206033e-02 -1.0000000e+02]
[ 2.1814612e-01 9.6735013e-01 1.3497342e+00 -1.0000000e+02]
[-8.7061942e-01 -1.0000000e+02 -1.0000000e+02 -1.0000000e+02]]
calling compiled function
---------------------------------------------------------------------------
NonConcreteBooleanIndexError Traceback (most recent call last)
<ipython-input-23-2daf7923c05b> in <cell line: 14>()
12 print("calling compiled function")
13 jit_mask_tensor = jax.jit(mask_tensor)
---> 14 jit_mask_tensor(x, mask)
[... skipping hidden 12 frame]
5 frames
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _expand_bool_indices(idx, shape)
4297 if not type(abstract_i) is ConcreteArray:
4298 # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
-> 4299 raise errors.NonConcreteBooleanIndexError(abstract_i)
4300 elif _ndim(i) == 0:
4301 raise TypeError("JAX arrays do not support boolean scalar indices")
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4,4])
以即时模式执行该函数按预期工作。然而,中间变量的形状仅凭输入形状的知识是无法知道的,因为它取决于 mask
中为 True
的元素数量。因此,我们无法编译该函数,因为并非所有形状都是静态的。
此外,我们不能使用 static_argnum
,因为 mask
本身不可哈希,因此不能用于匹配对缓存二进制文件的调用。此外,即使我们可以,mask
的可能值也太高了。为了处理所有可能性,我们需要编译 2**16
或 65,536 个图。
然而,通常我们可以重写函数以执行相同的操作,并且在所有步骤中都具有已知的形状。
def mask_tensor(x, mask):
x = ~mask * x - mask*100.
return x
print("calling eager function")
print(mask_tensor(x, mask))
print("calling compiled function")
jit_mask_tensor = jax.jit(mask_tensor)
print(jit_mask_tensor(x, mask))
===
calling eager function
[[ 1.012518 -100. -0.8887863 -100. ]
[-100. -100. -100. 1.5008001 ]
[-100. -0.6636745 0.57624763 -0.94975847]
[ 1.1513114 -100. 0.88873196 -100. ]]
calling compiled function
[[ 1.012518 -100. -0.8887863 -100. ]
[-100. -100. -100. 1.5008001 ]
[-100. -0.6636745 0.57624763 -0.94975847]
[ 1.1513114 -100. 0.88873196 -100. ]]
所有中间形状都将在编译时已知。具体来说,当 mask
为 True 时,我们将 x
乘以零;当 mask
为 False 时,我们将 x
乘以一。然后,我们添加一个新数组,其中当 mask
为 False 时为零,当 mask
为 True 时为 -100
。此时,我们有两个具有具体形状的数组。将它们相加会得到正确的结果,该结果同样是具体的。
限制可能的输入形状数量
一个相关但可以“有点”jit 编译的情况是,形状可以在编译时确定,但输入形状变化很大。由于我们通过查看被调用的函数和输入的形状来检索缓存的编译函数,这将导致大量的编译。这很有道理,因为图本身是针对特定静态形状进行优化的,但这会导致无声的减速。
import random
def cube(x):
return x*x*x
def random_shape_test(fn):
length = random.randint(1, 1000)
return fn(jnp.empty((length,)))
print("random length eager time:")
%timeit -n1000 random_shape_test(cube).block_until_ready()
jit_cube = jax.jit(cube)
jit_cube(x1)
print("fixed length compiled time:")
%timeit -n1000 jit_cube(x1).block_until_ready()
print("random length compiled time:")
%timeit -n1000 random_shape_test(jit_cube).block_until_ready()
===
Out:
random length eager time:
The slowest run took 43.13 times longer than the fastest. This could mean that an intermediate result is being cached.
6.12 ms ± 8.37 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fixed length compiled time:
7.31 µs ± 241 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
random length compiled time:
The slowest run took 53.37 times longer than the fastest. This could mean that an intermediate result is being cached.
4.55 ms ± 6.11 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
因此,我们应该尽力限制 jit 函数作为输入接受的形状数量。常见的例子包括将序列填充到单个长度,或在数据加载器上设置 drop_last=True
以避免批处理中示例数量不同。
函数纯度和副作用
JAX 的转换和编译旨在仅作用于纯 Python 函数。粗略地说,一个功能纯粹的函数是指给定相同的输入,它将始终产生相同的输出,并且没有任何可观察的副作用。
例如,请看这个例子,其中 fn
的输出不仅依赖于 x
,还依赖于 shift
,我们在函数调用之间更改了 shift
。
shift = -1.0
x1 = jnp.array([0, 1, 2])
x2 = jnp.array([0, -1, 0])
x3 = jnp.array([0, 1, 2, 3])
def fn(x):
return x + shift
print(fn(x1))
shift = 1.0
print(fn(x2))
print(fn(x3))
shift = -1.0
jit_fn = jax.jit(fn)
print(jit_fn(x1))
shift = 1.0
print(jit_fn(x2))
print(jit_fn(x3))
===
Out:
[-1. 0. 1.]
[1. 0. 1.]
[1. 2. 3. 4.]
[-1. 0. 1.]
[-1. -2. -1.]
[1. 2. 3. 4.]
即时模式调用(前三个)代表我们的真实值,后三个是使用相同输入和全局 shift 值
的 jit 函数的输出。在 jit 函数中,给定形状的第一次调用(当我们追踪时)将使用正确的当前全局 shift 值。这是因为追踪利用了 Python 解释器,因此可以看到正确的全局值。
如果我们再次调用,JAX 找到了缓存的函数,它不会查看新的全局 shift,而是直接执行编译后的代码,其中旧值已作为常量烘焙到图中。然而,如果再次触发追踪(例如使用不同的输入形状),则将使用正确的 shift
。
这就是“JAX 转换和编译**旨在**仅作用于纯函数”的含义。它们仍然可以应用于非纯函数,但是当跳过追踪并直接使用编译函数时,函数的行为将与 Python 解释器有所不同。另一个例子是使用 print
函数的函数。
def fn(x):
print("called identity function")
return x
jit_fn = jax.jit(fn)
print("called `jit_fn(0.5)`")
_ = jit_fn(0.5)
print("called `jit_fn(1.0)`")
_ = jit_fn(1.0)
print("called `jit_fn([-1, 1])`")
_ = jit_fn(jnp.array([-1, 1]))
===
Out:
called `jit_fn(0.5)`
called identity function
called `jit_fn(1.0)`
called `jit_fn([-1, 1])`
called identity function
同样,**只要触发追踪,行为就与 Python 相同**,但只要使用缓存的函数,行为就会有所不同。这同样是不纯的,因为 print
是一个副作用。
如果我们使用的全局变量也是一个 JAX 数组呢?
b = jnp.array([1,2,3])
def fn(x):
return x + b
jit_fn = jax.jit(fn)
x = jnp.array([1,2,3])
print(jit_fn(x))
b = jnp.array([0,0,0])
print(jit_fn(x))
===
[2 4 6]
[2 4 6]
同样,由于 x
的输入形状没有改变,将使用编译版本,因此函数中 b
的值不会更新。然而,b
实际上是图中的一个变量,与我们前面修改 shift
的例子不同,在那个例子中 shift
是图中的一个常量。JAX 通过将 b
添加为追踪图中的**隐式参数**来维护编译函数中的函数纯度。因此,图是函数纯粹的,但是 b
对我们来说本质上是一个常量,因为我们无法在 Python 级别修改这个隐式参数而不重新编译。
一般来说,**最终编译后的函数**是纯粹的。然而,我们创建的 Python 级别函数不一定是纯粹的。尽管如此,jax.jit
仍然可以应用,但需要注意。我将把这些注意事项总结如下:
- 不操作 JAX 数组的代码不会被追踪,只在追踪期间调用(当 Python 解释器逐步执行函数时,像其他 Python 代码一样评估代码)。这包括
print
语句和设置 Python 级别变量,以及 Python 级别的条件和循环。 - 操纵 JAX 数组但 JAX 数组不是 Python 函数(也许它相对于函数是全局的)的参数的代码将被 jit 编译,但图中的那些变量将取它们在**编译时**所具有的任何值,并成为追踪图的隐式参数。
我感觉这两种不纯净的情况仍然有价值。例如,第一种情况在调试形状问题(例如在追踪过程中调试形状不匹配)或使用某些全局配置对象禁用函数的部分功能时很有用。
config = dict(relu=False)
@jax.jit
def fn(W, x):
y = x @ W
if config['relu']:
y = jax.nn.relu(y)
return y
W, x = jnp.ones((2,2)), jnp.ones((2,))
jax.make_jaxpr(fn)(W, x)
===
Out:
{ lambda ; a:f32[2,2] b:f32[2]. let
c:f32[2] = pjit[
jaxpr={ lambda ; d:f32[2,2] e:f32[2]. let
f:f32[2] = dot_general[dimension_numbers=(([0], [0]), ([], []))] e d
in (f,) }
name=fn
] a b
in (c,) }
您可以在 jaxpr
中看到,只有 dot_general
出现在图中。relu
函数没有被追踪,因为 Python 解释器没有执行 if
语句的主体,因此没有将其添加到图中。重要的是要强调**只编译了一个条件分支**:最终图中没有分支。
可以说,如果您希望在程序的单次执行中同时使用两种选项,那么使用
static_argnums
是有道理的。但是,如果您的config
对象不会改变,我认为上述模式是没问题的!
可以在编译函数中添加条件语句。然而,Python 级别的条件语句只在追踪时使用。被遍历的分支将在追踪图中展开。必须使用特殊函数(稍后显示)才能在最终编译函数中添加条件语句。
第二点很有用,如果有一些我们知道不会改变的对象,例如一个我们只想快速运行推理的预训练机器学习模型。
bert = ... # some pretrained JAX BERT model that we can call
@jax.jit
def fn(x):
return bert(x)
上述方法是可行的,但对 bert
的更改不会反映在编译函数中,直到 x
的形状发生变化。我们甚至可以在第一次调用后将 bert
设置为 None
,fn
仍然可以工作,前提是我们使用相同的输入形状。
总的来说,我感觉 JAX 中对函数纯度的强调有点言过其实。在我(可能被误导的)看来,最好是简单地理解追踪时和编译时行为之间的差异,以及它们何时会触发。Python 具有令人难以置信的表达力,利用这一点是 JAX 强大之处的一部分,因此不必要地限制它将是一种耻辱。
编译函数中的条件和循环
我希望现在你对追踪时和编译时行为之间的区别有了一点直觉。如果没有,这里有一个总结:
- 当 jit 编译函数遇到一组尚未遇到的输入形状和静态参数值时,就会发生追踪。在这种情况下,JAX 依赖 Python 解释器逐步执行函数。所有正常的 Python 规则在这种情况下都适用。追踪的图将**包含在此特定追踪实例中遇到的可追踪操作**。
- 当 jit 编译函数被调用,并且输入形状和静态参数值的集合与缓存中的某个匹配时,就会调用编译版本。在这种情况下,**行为仅仅是调用编译函数,别无其他**。
这种行为很强大,因为它允许我们用表达性强的 Python 定义我们想要发生的事情,并依靠快速、优化的代码来实际执行。然而,它确实带来了一些问题:
- 每种输入形状和静态值的组合只能追踪一个条件路径。
- 由于追踪是逐操作进行的,循环将被简单地展开,而不是在最终编译的函数中形成循环。
有时这些特性很有吸引力。第一个可以用来简单地禁用我们不关心的分支——几乎就像 C 语言中的编译时标志。第二个对于少量循环迭代非常有用,其中可以优化跨迭代的依赖关系。然而,有时这会适得其反。
我们已经看到了一个例子,回顾 stupid_fn
:
def stupid_fn(x):
y = jnp.copy(x)
for _ in range(1000):
x = x * x
return y
jax.make_jaxpr(stupid_fn)(jnp.array([1.1, -1.1]))
===
Out:
Out: { lambda ; a:f32[4]. let
b:f32[4] = copy a
c:f32[4] = mul a a
d:f32[4] = mul c c
e:f32[4] = mul d d
f:f32[4] = mul e e
... <truncated>
bmh:f32[4] = mul bmg bmg
bmi:f32[4] = mul bmh bmh
bmj:f32[4] = mul bmi bmi
bmk:f32[4] = mul bmj bmj
bml:f32[4] = mul bmk bmk
bmm:f32[4] = mul bml bml
_:f32[4] = mul bmm bmm
in (b,) }
输出长得惊人。在追踪期间,整个循环都会被展开。这不仅看起来很烦人,而且还会使图的优化花费很长时间,导致函数的第一次调用完成时间很长。**JAX 并不知道我们处于 for 循环上下文中**,它只是简单地接收操作并将其添加到图中。
幸运的是,JAX 在其 jax.lax
子模块中暴露了控制流原语。
def less_stupid_fn(x):
y = jnp.copy(x)
x = jax.lax.fori_loop(start=0, stop=1000, body_fun=lambda i, x: x * x, init_val=x)
return y
jax.make_jaxpr(less_stupid_fn)(jnp.array([1.1, -1.1]))
===
Out:
{ lambda ; a:f32[2]. let
b:f32[2] = copy a
_:i32[] _:f32[2] = scan[
jaxpr={ lambda ; c:i32[] d:f32[2]. let
e:i32[] = add c 1
f:f32[2] = mul d d
in (e, f) }
length=1000
linear=(False, False)
num_carry=2
num_consts=0
reverse=False
unroll=1
] 0 a
in (b,) }
在上面的例子中,我们将 Python for 循环转换为 jax.lax.fori_loop
。它接受 for 循环范围的(整数)开始和结束,以及要在循环体中执行的函数和起始输入值作为参数。body_fun
的返回值必须与 init_val
具有相同的类型和形状,并且在所有迭代中都保持相同的类型和形状。此外,body_fun
的输入还接受当前的循环索引。
查看 jaxpr
,我们可以看到巨大的操作展开已被更紧凑的版本取代,使用了 scan
原语。这本质上是固定次数执行 body_fun
,并在迭代之间传递状态。scan
编译 body_fun
(就像 jax.jit
那样),因此需要固定的输入和输出形状。
如果循环次数不是静态的,那么我们将看到一个 while 循环原语!没有 for 循环原语,它只是通过
scan
或while
实现的。
让我们编译我们不那么愚蠢的函数 less_stupid_fn
,看看我们是否能得到相同的代码。即使使用我们花哨的原始函数,XLA 也应该以同样的方式优化该函数。```python print(jax.jit(less_stupid_fn).lower(x).compile().as_text())
输出:HloModule jit_less_stupid_fn, entry_computation_layout={(f32[2]{0})->f32[2]{0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.2 (Arg_0.1: f32[2]) -> f32[2] { %Arg_0.1 = f32[2]{0} parameter(0), sharding={replicated} ROOT %copy = f32[2]{0} copy(f32[2]{0} %Arg_0.1) }
And indeed, we get a single copy operation again.
A similar function exists for while loops named `jax.lax.while_loop`. An equivalent to `less_stupid_fn` would be:
```python
def less_stupid_fn(x):
y = jnp.copy(x)
x = jax.lax.while_loop(
cond_fun=lambda ix: ix[0] < 1000,
body_fun=lambda ix: (ix[0]+1, ix[1]*ix[1]),
init_val=(0, x)
)
return y
jax.make_jaxpr(less_stupid_fn)(jnp.array([1.1, -1.1]))
===
Out:
{ lambda ; a:f32[2]. let
b:f32[2] = copy a
_:i32[] _:f32[2] = while[
body_jaxpr={ lambda ; c:i32[] d:f32[2]. let
e:i32[] = add c 1
f:f32[2] = mul d d
in (e, f) }
body_nconsts=0
cond_jaxpr={ lambda ; g:i32[] h:f32[2]. let i:bool[] = lt g 1000 in (i,) }
cond_nconsts=0
] 0 a
in (b,) }
其中,只要 cond_fun
返回 True
,body_fun
就会继续执行,并在迭代之间传递状态,从状态 init_val
开始。
这些循环不像 Python 级别的等价物那样漂亮,但它们能完成任务。请记住,这些循环原语无法进行跨迭代优化,因为 body_fun
会作为自己的单元进行编译。相同的规则适用于 jax.jit
:使 body_fun
尽可能大,以便为 XLA 提供最大的上下文。
如果循环迭代次数**小且固定**,则可能值得改用 Python 循环。例如,您可以使用 fori_loop
来包装整个扩散模型进行推理,但使用常规循环训练未展开的模型仅需两个固定步骤。
对于编译函数中的条件语句,JAX 提供了许多选项。我不会在这里一一列举,JAX 文档中这里有一个很好的总结。与常规 if 语句行为最接近的函数是 jax.lax.cond
。
@jax.jit
def cond_fn(x):
pred = jnp.abs(x.max() - x.min()) <= 1.0
return jax.lax.cond(pred, lambda x: x, lambda x: x / 2, x)
print(cond_fn(jnp.array([0.1, 0.2])))
print(cond_fn(jnp.array([-0.5, 0.5])))
print(cond_fn(jnp.array([1.0, -1.0])))
===
Out: [0.1 0.2]
[-0.5 0.5]
[ 0.5 -0.5]
jax.lax.cond
接受一个布尔值、两个函数以及函数的运算数。如果 pred
为 True
,则第一个函数将使用 operands
执行;如果 pred
为 False
,则第二个函数执行。在上面的函数中,我们检查 x
的最小值和最大值之间的绝对差。如果它们小于或等于 1.0
,则数组不变返回;否则,数组减半。
我们可以打印 jaxpr
并看到两个分支都得到了追踪:```python jax.make_jaxpr(cond_fn)(jnp.array([1.0, -1.0]))
输出:{ lambda ; a:f32[2]. let b:f32[2] = pjit[ jaxpr={ lambda ; c:f32[2]. let d:f32[] = reduce_max[axes=(0,)] c e:f32[] = reduce_min[axes=(0,)] c f:f32[] = sub d e g:f32[] = abs f h:bool[] = le g 1.0 i:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h j:f32[2] = cond[ branches=( { lambda ; k:f32[2]. let l:f32[2] = div k 2.0 in (l,) } { lambda ; m:f32[2]. let in (m,) } ) linear=(False,) ] i c in (j,) } name=cond_fn ] a in (b,) }
The equivalent for `n` branches (rather than just the implied two with
`jax.lax.cond`) is `jax.lax.switch`. With this, we can implement a highly
performant `is_even` function!
```python
@jax.jit
def is_even_fast(x):
return jax.lax.switch(x, [
lambda: True,
lambda: False,
lambda: True,
lambda: False,
lambda: True,
lambda: False,
lambda: True,
lambda: False,
lambda: True,
... <truncated>
lambda: False
])
is_even_fast(123512)
===
Out: Array(True, dtype=bool)
不要看上面函数的
jaxpr
。
简要介绍 PyTrees
您可能已经注意到,到目前为止,我们使用 jax.jit
编译的所有函数都只接受像单个数组或值这样的扁平结构作为输入。如果以后我们想将 JAX 用于大规模机器学习问题,这就会带来问题。难道我们要一个一个地写 GPT-3 的所有参数数组吗?
实际上,我们可以使用任意的**PyTrees**作为 jit 编译函数的输入、中间值和输出。
PyTree 的正式定义是“由容器类 Python 对象构建的树状结构。如果类在 PyTree 注册表中,则被视为容器类。” 默认情况下,PyTree 注册表包含 list
、tuple
和 dict
类。此外,注册表中不包含的任何对象都被视为叶子(即:单个元素或单个数组)。PyTree 可以包含其他 PyTree,形成嵌套结构和叶子。
可以向 PyTree 注册表注册自己的自定义类,但这超出了本博客的范围。
当调用 jit 函数时,JAX 会检查是否存在具有**相同 PyTree 结构、叶子形状和静态参数值**的现有缓存编译函数。如果所有这些都匹配,则将重用已编译函数。就像尽可能保持参数形状相同以便使用缓存函数一样,您应该力求保持 PyTree 结构相同。
让我们举一个具体的例子,实现一个简单的多层感知器的前向传播。首先,我们将构建一个字典列表。列表中每个字典代表一个层,字典存储该层的权重和偏置。
dims = [784, 64, 10]
key, *subkeys = jax.random.split(key, len(dims))
params = [
{
'W': jax.random.normal(w_key, (out_dim, in_dim)),
'b': jnp.zeros((out_dim,))
}
for w_key, in_dim, out_dim in zip(subkeys, dims[:-1], dims[1:])
]
jax.tree_util.tree_structure(params), jax.tree_util.tree_map(lambda l: str(l.shape), params)
===
Out:
(PyTreeDef([{'W': *, 'b': *}, {'W': *, 'b': *}]),
[{'W': '(64, 784)', 'b': '(64,)'}, {'W': '(10, 64)', 'b': '(10,)'}])
变量 params
符合 PyTree 的定义。单元格的输出是 PyTree 的结构和另一个显示 params
叶子形状的 PyTree。让我们将前向传播定义为一个函数,它以 PyTree params
和数组 x
作为输入,并用 jax.jit
装饰它。
@jax.jit
def feed_forward(params, x):
for p in params:
x = jax.nn.tanh(p['W'] @ x + p['b'])
return x
key, x_key = jax.random.split(key)
feed_forward(params, jax.random.normal(x_key, (dims[0],)))
===
Out:
Array([-1. , -0.93132854, -1. , -0.99993926, 0.9998755 ,
-0.9970358 , -0.8498685 , 1. , -0.9999984 , 1. ], dtype=float32)
如果你曾打印过 PyTorch 模型
model.state_dict()
,你应该能看到我们如何仅使用嵌套字典就能实现类似的效果。我在上面的例子中只是用了一个列表来演示我们如何嵌套任意组合的容器,只要它们在 PyTree 注册表中即可。
在最简单的情况下,PyTrees 仅仅是很好的容器,帮助我们打包函数的输入。它们可以变得更加复杂,但我还没有深入研究这个话题。我想是时候另起炉灶了。
函数转换
一篇关于 JAX 的博客文章不可能不提到函数转换。在 JAX 的 Github 仓库中,你首先会看到“更深入地看,你会发现 JAX 实际上是一个可扩展的组合函数转换系统”。我自己已经开始尝试这个系统,但还没有深入到可以深入撰写文章的程度,尽管我怀疑它需要一篇完全独立的文章才能充分展现其价值。
为了让您尝尝甜头,请看这个仓库,它允许您向任意 JAX 函数添加 LoRA!
函数变换只是一个以另一个函数为输入,并返回另一个函数的函数。嗯,函数变换就是变换函数。
JAX 带有许多必须提及的内置函数转换。您已经见过 jax.jit
。另外两个是 jax.grad
和 jax.value_and_grad
转换,它们构成了 JAX 的自动微分组件。自动微分是训练机器学习模型的重要组成部分。
简而言之,jax.grad
接收一个函数 f
,并返回另一个计算 f
导数的函数。jax.value_and_grad
返回一个函数,该函数又返回一个元组 (value, grad)
,其中 value
是 f(x)
的输出,grad
是 jax.grad(f)(x)
的输出。
def fn(x):
return 2*x # derivative is 2 everywhere
print(fn(5.))
print(jax.grad(fn)(5.))
print(jax.value_and_grad(fn)(5.))
===
Out:
10.0
2.0
(Array(10., dtype=float32, weak_type=True), Array(2., dtype=float32, weak_type=True))
默认情况下,自动微分函数将对第一个函数参数求梯度,因此新函数 jax.grad(f)
的输出将与 f
的第一个参数具有**相同的形状和结构**。
def dummy_loss_fn(params, x):
y = feed_forward(params, x)
return y.sum()
grad_loss_fn = jax.grad(dummy_loss_fn)
grads = grad_loss_fn(params, jnp.zeros(dims[0]))
jax.tree_util.tree_structure(grads)
===
Out: PyTreeDef([{'W': *, 'b': *}, {'W': *, 'b': *}])
上面是一个虚拟示例,我们将模型的前向传播和“损失”计算打包到一个函数中。然后我们调用
jax.grad
来获取相对于模型参数的梯度。这是 JAX 训练循环中的常见模式,通常随后计算参数更新并计算新参数。在后续关于 Flax 的文章中,您会经常看到这种模式。
我们可以通过指定 argnums
参数为我们想要求导的参数索引来改变选择第一个参数的默认行为。我们甚至可以通过传递一个整数序列来指定多个参数。
我们甚至可以将 grad
应用于一个已经计算了一阶导数的函数,从而获得一个计算二阶导数的函数。
def fn(x):
return 2 * x**3
x = 1.0
grad_fn = jax.grad(fn)
grad_grad_fn = jax.grad(grad_fn)
print(f"d0x: {fn(x)}, d1x: {grad_fn(x)}, d2x: {grad_grad_fn(x)}")
===
Out: d0x: 2.0, d1x: 6.0, d2x: 12.0
上述行为在 PyTorch 或 Tensorflow 等其他机器学习框架中很难实现。但在 JAX 中,由于其对函数转换的强调,实现起来非常简单。
有时,我们希望计算一个函数(它也输出辅助数据)的梯度。一个常见的例子是损失函数,它也输出其他指标,如准确率。我们希望将这些辅助数据排除在梯度计算之外,这可以通过向 grad
传递 has_aux=True
来实现。在以下示例中,我们这样做是为了同时返回我们伪造的“损失”和 feed_forward
本身的输出,同时还计算相对于 params
的梯度!这涉及很多操作!
def dummy_loss_fn(params, x):
y = feed_forward(params, x)
return y.sum(), y
grad_loss_fn = jax.value_and_grad(dummy_loss_fn, has_aux=True)
values, grads = grad_loss_fn(params, jnp.zeros(dims[0]))
values, jax.tree_util.tree_structure(grads)
===
Out:
((Array(0., dtype=float32),
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
PyTreeDef([{'W': *, 'b': *}, {'W': *, 'b': *}]))
正如我前面提到的,**JAX 转换是可组合的,可以组合在一起以生成复杂的行为**。我们已经看到了通过两次应用 jax.grad
来获得二阶导数的一个例子。另一个例子是组合 jax.jit
和 jax.grad
来生成一个 jit 编译的自动微分函数!
冒着变成“自动微分”部分而不是函数转换部分的风险,我应该提到其他转换。一个特别著名的是 jax.vmap
,它只是将单个输入函数转换为可以接受批次输入的函数。
我个人觉得这没什么用,因为我太习惯于编写批处理代码了。但您的体验可能有所不同。
一个更强大的转换是 jax.pmap
,它将一个函数转换为可以在**多个加速器**上并行化的函数,通常以单程序、多数据(数据并行)的方式进行。使用 JAX 的一个主要吸引力在于其使用 pmap
和其他“p
”函数实现的内置和简易并行支持。然而,这是一个独立的主题,所以我将把它留到未来的博客中探讨。
总结
在这篇长篇帖子中,我介绍了 JAX 并深入探讨了其中的一些关键概念,也分享了一些非常有主见的看法。我还没有展示 JAX 中完整的机器学习训练循环,但我会在后续的文章中使用 Flax 和 Optax 等高级库来涵盖这一点。
如果要总结这篇文章的要点,它们将是:
jax.jit
功能非常强大,应尽可能在最广泛的上下文中加以利用。- 注意理解追踪时和编译时行为之间的差异。
- 大多数机器学习代码都可以用静态方式重写,并且应尽可能这样做,以充分利用 XLA。
JAX 的内容远不止于此,但我认为这些要点对于奠定基础理解很有帮助,可以在此基础上有效地进行后续学习。
可以说,开始编写训练循环并不需要如此长的 JAX 介绍,这也不是我最初的意图。然而,在写作过程中,我发现深入探究 JAX 及其行为的基础非常有趣,我希望这次探索对其他开始学习 JAX 的人,甚至对那些经验更丰富的人也有用。如果它不合您的口味,我保证未来的文章会更实用。
如果您喜欢这篇文章,请考虑在Twitter上关注我,或访问我的网站以获取更多关于机器学习和其他话题的胡言乱语。感谢您阅读到这里,希望您觉得有用!
致谢和更多资源
一些不错的额外资源
感谢 Kamil Hepak 对这篇博客文章进行的语言审阅!