MicroJAX

这是一个关于如何构建一个微型 Jax/mlx 风格的转换引擎的微型开发博客,因为我发现网上没有简单易懂的解释函数转换引擎的资料。就像 Karpathy-senpai 的 micrograd 是 PyTorch 的一个简单版本一样,microjax 是 Jax 的一个更简单版本。
Github: microjax - 如果你觉得有用,请点个⭐。
Notebook: 分步学习
如果你有任何问题/更正,请在 Twitter 上私信我 @shxf0072。
🤗 博客是 pythonstuff 的镜像
警告:此项目旨在通过递增的复杂性来帮助理解,而非追求绝对正确性。
现在有 Autodidax,但它相当难,使用了大量术语和 lambda 函数。 我因此有了灵感 :3
很久以前,谷歌沉迷于 TensorFlow。谷歌有定制芯片 TPU 和 XLA,一个非常智能的线性代数编译器,让它飞速运行,但 TensorFlow 用起来很糟糕。其他新兴框架如 PyTorch 具有动态特性,但表现不佳。Jax 是谷歌的一个实验项目,它将 Python 跟踪转换为 XLA,然后可以用 MLIR 编译并在 GPU 和 TPU 等加速器上运行。我们可以将 Jax 分为两部分:函数转换和 XLA 编译器。函数转换使我们能够计算梯度、Hessian 并定义类似 vmap 的转换,而 XLA IR 转换则带来了速度。在这里,我们只介绍一个简单的函数转换引擎,因此 Jax 的 X 部分不在讨论范围内,但嘿,MLX 有 X,所以我不在乎,我称它为 microjax。尽管 Jax 最近变得更主流了,但 Jax 的故事可以追溯到 autograd。
Autograd 有一个更具创新性的目标。你需要为机器学习或科学计算计算梯度,而 Python 是科学计算的语言。那么,如何使 Python 这样的高级解释型语言可微分呢?
当然是编写一个解释器!
基础解释器
首先,我们将从原语开始。这些是 Jax 中的基本操作,所有其他操作都将基于它们。顺便说一句,这里只涉及标量值,因此更容易理解(虽然技术上支持 np 数组,但为了简化,我不会定义数组操作)。
from __future__ import annotations
import math
from contextlib import contextmanager
from typing import Any, Callable
class OPS:
ADD = "add"
MUL = "mul"
NEG = "neg"
RECIP = "recip"
EXP = "exp"
SIN = "sin"
# ik i can make this enum, i choose not too :P
对于大多数情况,你只需要这些操作,如果你喜欢,也可以定义自己的操作。
让我们从基础解释器开始,这就像一个抽象类。
class Interpreter:
def __init__(self, level: int = 0, *args, **kwargs):
self.level = level
def process_primitive(self, prim, boxes, params):
"in this function, either you process primitives or you unbox and send to lower level interpreter"
raise NotImplementedError
在 JAX 中,这称为 Trace。Trace 跟踪 tracers,它们只是封装的值。我们把值和一些额外的信息(例如是否需要计算梯度,以及形状和要使用的解释器)一起放在盒子里。对于每种独特的转换类型,我们都会有一个该类型的盒子和一个该类型的解释器。
现在我们将有多个解释器。为了知道封装的值应该在哪个上下文中评估,我们需要跟踪解释器。我们将使用最常见的数据结构:栈来实现这一点。
STACK: list[Interpreter] = []
def push_interpreter(interpreter: Interpreter):
STACK.append(interpreter)
return STACK
def pop_interpreter() -> Interpreter:
return STACK.pop()
@contextmanager
def interpreter_context(interpreter_type: Interpreter):
stack_item = interpreter_type(level=len(STACK))
push_interpreter(stack_item)
try:
yield stack_item
finally:
pop_interpreter()
现在我们来定义 Box,Box(tracer)是实际流经你定义的函数的值。我们需要重写一些双下划线方法,使其与 Python 兼容。
class Box:
_interpreter: Interpreter
def aval(self):
raise NotImplementedError
def __add__(self, other):
return add(self, other)
def __radd__(self, other):
return add(other, self)
def __mul__(self, other):
return mul(self, other)
def __rmul__(self, other):
return mul(other, self)
def __neg__(self):
return neg(self)
def __sub__(self, other):
return add(self, neg(other))
def __rsub__(self, other):
return add(other, neg(self))
def __truediv__(self, other):
return mul(self, recip(other))
def __rtruediv__(self, other):
return mul(other, recip(self))
def __iadd__(self, other):
return add(self, other)
def __imul__(self, other):
return mul(self, other)
def __isub__(self, other):
return add(self, neg(other))
def __itruediv__(self, other):
return mul(self, recip(other))
# dont worry about this undefined functions, we add them later
我们抽象类快完成了,只剩几个辅助函数
当函数接收到多个封装值时,我们需要为它们找到最高级别的解释器。
def find_top_interpreter(args):
"""
find the top level interpreter for the given arguments
"""
interpreters = []
for item in args:
if isinstance(item, Box):
interpreters.append(item._interpreter)
if interpreters:
return max(interpreters, key=lambda x: x.level)
# if no interpreters are found, return the default EvalInterpreter
return STACK[0]
如果一个值在第 2 层,另一个在第 3 层,我们需要用这个函数将它们提升到第 3 层。
def full_raise(interpreter: Interpreter | Any, out) -> Box | JVPBox:
"""
if interpreter need values boxed
if out is not boxed, box it (using interpreter.pure)
ie. raise out to the box level
"""
if not isinstance(out, Box):
return interpreter.pure(out)
return out
每个封装值都将分配一个解释器。每个解释器都有一个表示其在栈中位置的级别。`find_top_interpreter` 函数将找到所有解释器中级别最高的那个。`full_raise` 将把一个值提升到栈中当前解释器的级别。`bind_single` 只是一个小的包装器,用于处理 `bind` 返回的元组。
def bind(prim, *args, **params):
interpreter = find_top_interpreter(args)
# this will raise the boxes to the top level
# eg converts primitive values to Boxes if interpreter is not the top level
boxes = [full_raise(interpreter, arg) for arg in args]
outs = interpreter.process_primitive(prim, boxes, params)
return [out for out in outs]
def bind_single(prim, *args, **params):
(out,) = bind(prim, *args, **params)
return out
“bind”是重要的函数,它将调用解释器。
原语
这些是构建块,所有其他函数都将在此基础上构建。我喜欢称它们为模拟函数,因为它们不真正计算任何东西;它们更像是箱子到解释器的路由器。
def add(*args):
return bind_single(OPS.ADD, *args)
def mul(*args):
return bind_single(OPS.MUL, *args)
def neg(x):
return bind_single(OPS.NEG, x)
def recip(x):
return bind_single(OPS.RECIP, x)
def exp(x):
return bind_single(OPS.EXP, x)
def sin(x):
return bind_single(OPS.SIN, x)
原语就像模拟函数。当你调用 `mul(Box1(3), Box1(2))` 时,它会为 `Box1(3)` 和 `Box1(2)` 找到解释器,然后找到其中级别最高的解释器。它会解封装这些值,并告诉该解释器处理这些原语。对于每种操作类型,都有一个原语操作函数。
复合函数建立在原语之上。只要你可以用原语表达你的函数,你就可以使用任意复杂的函数。
def cos(x):
return sin(x + math.pi / 2)
def sigmoid(x):
return 1 / (1 + exp(-x))
def tanh(x):
return 2 * sigmoid(2 * x) - 1
def silu(x):
return x * sigmoid(x)
Eval 解释器
即使我们有很好的抽象层,最终总需要有人来运行 `add` 或 `mul` 函数。这将由 eval 解释器完成。我们首先定义评估规则,然后定义 eval 解释器。
class EvalRules:
def __init__(self):
self.rules = {
OPS.ADD: self.add,
OPS.MUL: self.mul,
OPS.NEG: self.neg,
OPS.RECIP: self.recip,
OPS.EXP: self.exp,
OPS.SIN: self.sin,
}
def __getitem__(self, op):
return self.rules[op]
def add(self, primals, *args):
x, y = primals
return [x + y]
def mul(self, primals, *args):
x, y = primals
return [x * y]
def neg(self, primals, *args):
(x,) = primals
return [-x]
def recip(self, primals, *args):
(x,) = primals
return [1 / x]
def exp(self, primals, *args):
(x,) = primals
return [math.exp(x)]
def sin(self, primals, *args):
(x,) = primals
return [math.sin(x)]
我们不期望任何封装值传递给 eval 解释器,因此我们可以直接在值上调用函数。这很简单:接收参数并返回结果。
class EvalInterpreter(Interpreter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.rules = EvalRules()
def pure(self, val):
return val
def process_primitive(self, prim, values, params):
return self.rules[prim](values, *params)
基本解释器是 `EvalInterpreter`,现在我们运行基本程序,由于评估是原语的末尾,我们将在栈的底部推入 eval 解释器。
push_interpreter(EvalInterpreter())
现在我们可以启动基本程序了。
def func(x):
return sin(x)*exp(x) + tanh(x)
x = 3.14
func(x)
1.033056645880499
你可能会觉得这样 ^,但这些抽象将用于构建更复杂的 AD。
自动微分
假设我们有函数
如果 x 是一个向量,那么它的梯度可以通过雅可比矩阵计算。
如果你不知道什么是雅可比矩阵,它只是偏导数组成的矩阵,如果需要复习,请观看此视频 链接
现在我们可以用原语来定义我们的函数了。
同样地,我们可以将雅可比矩阵定义为每个函数导数的乘积,或者数学上我们可以通过链式法则定义大雅可比矩阵。
因此,如果你将雅可比矩阵乘以向量,你将得到梯度。
你可能从线性代数课上知道,如果乘法是结合律的,那么你可以从任何一边乘以雅可比链。
这个等式成立,因为矩阵乘法是结合律的。
雅可比-向量积 (JVP)
雅可比-向量积,或前向自动微分,计算雅可比矩阵和向量的乘积。
其中 $J$ 是雅可比矩阵, $v$ 是一个向量。
在我们的链式法则示例中
JVP 将导数从输入到输出向前传播通过计算图。
我们可以按照函数评估的相同方向计算梯度 A->B->C->D,dA->dB->dC->dD。
向量-雅可比积 (VJP)
向量-雅可比积,或反向自动微分,计算向量和雅可比矩阵的乘积。
其中 $v^T$ 是向量 $v$ 的转置, $J$ 是雅可比矩阵。
在我们的链式法则示例中
VJP 将导数从输出到输入反向传播通过计算图。这是训练神经网络中常用的反向传播算法的基础。
正向自动微分 (JVP)
前向自动微分真的很容易。我们将把值及其导数(初始化为 1)装箱。当我们向前计算函数时,我们也会计算其导数。
class JVPBox(Box):
def __init__(self, interpretor: Interpreter, primal, tangent) -> None:
super().__init__()
self._interpreter = interpretor
self.primal = primal
self.tangent = tangent
def __repr__(self):
return f"JVPBox (primal={self.primal}, tangent={self.tangent})"
值称为原值,其导数称为切线。
当您有一个函数 f(x) = sin(x) 时,它的导数表示该点处切线的斜率。由于每个点的斜率都在变化,这些变化由梯度函数定义。因此,尽管我们使用 cos(x) 得到 sin(x) 在原点 x 处的导数,但我们称之为切线。它应该表示该点处切线的斜率。
我们将为每个原始操作定义规则。
注意,现在要定义这个规则,你只能使用你定义的原始函数,所以整个操作需要是封闭的,即只能使用上面定义的 cos,不能使用 math.cos(x)。
class JVPRules:
def __init__(self):
self.rules = {
OPS.ADD: self.add,
OPS.MUL: self.mul,
OPS.NEG: self.neg,
OPS.RECIP: self.recip,
OPS.EXP: self.exp,
OPS.SIN: self.sin,
}
# dont forget to return tuple(primals),tuple(tangents)
def __getitem__(self, op):
return self.rules[op]
@staticmethod
def add(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return (x + y,), (x_dot + y_dot,)
@staticmethod
def mul(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return (x * y,), (x_dot * y + x * y_dot,)
@staticmethod
def neg(primals, tangents):
(x,), (x_dot,) = primals, tangents
return (-x,), (-x_dot,)
@staticmethod
def recip(primals, tangents):
(x,), (x_dot,) = primals, tangents
y = 1 / x
return (y,), (-y * y * x_dot,)
@staticmethod
def exp(primals, tangents):
(x,), (x_dot,) = primals, tangents
y = exp(x)
return (y,), (y * x_dot,)
@staticmethod
def sin(primals, tangents):
(x,), (x_dot,) = primals, tangents
return (sin(x),), (cos(x) * x_dot,)
JVP 解释器
现在我们来构建第一个真正的解释器。我们将值及其切线装箱。首先,我们解开值,并处理这些原值和切线。然后,我们将结果装箱。
class JVPInterpreter(Interpreter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.rules = JVPRules()
def pure(self, val):
return JVPBox(self, val, 0.0)
def process_primitive(self, prim, boxes, params):
# unbox values
primals = [box.primal for box in boxes]
tangents = [box.tangent for box in boxes]
# process primitive
op = self.rules[prim]
primals_out, tangents_out = op(primals, tangents, **params)
# box values
result = []
for p, t in zip(primals_out, tangents_out):
result.append(JVPBox(self, p, t))
return result
JVP 简单来说就是一个函数,它接受一个函数、其输入和其切线,并返回该函数的输出和其切线。
def jvp_simple(func, primals, tangents):
with interpreter_context(JVPInterpreter) as iptr:
box_in = [JVPBox(iptr, x, t) for x, t in zip(primals, tangents)]
out = func(*box_in)
box_out = full_raise(iptr, out)
primal_out, tangent_out = box_out.primal, box_out.tangent
return primal_out, tangent_out
def func(x):
return sin(x) + exp(x)
x,xdot = 3.14,1
y,y_dot = jvp_simple(func, (x,), (xdot,))
print(y,y_dot)
23.10545951163867 22.103868126994644
现在我们来定义一个包装函数,它将接收一个函数,修改输入以执行 JVP 并返回梯度。
def deriv(function):
def jvp_forward(input_value):
_, gradient = jvp_simple(function, (input_value,), (1,))
return gradient
return jvp_forward
def func(x):
return sin(x)
x = 3.14
print(func(x))
deriv_func = deriv(func)
print(deriv_func(x))
0.0015926529164868282
-0.9999987317275395
但是如果我们对导数函数求导会发生什么呢?
deriv_func = deriv(func)
deriv_func2 = deriv(deriv_func)
print(deriv_func2(x))
-0.0015926529164865067
我们得到二阶导数或 Hessian。
这里发生的是我们正在双重封装值,`BOX2( BOX1(原始值, 导数1), 导数2 )`,但是由于函数在组合下是封闭的,我们可以任意多次组合导数,
例如,我们来看一个函数
import matplotlib.pyplot as plt
import random
plt.style.use("dark_background")
plt.figure(figsize=(10,5))
def forward(func, vec):
return [func(x) for x in vec]
x = [i / 10 for i in range(-50, 50)]
def func(x):
return x*sin(x)
primals = forward(func, x)
# first derivative
f_prime = forward(deriv(func), x)
f_prime2 = forward(deriv(deriv(func)), x)
f_prime3 = forward(deriv(deriv(deriv(func))), x)
f_prime4 = forward(deriv(deriv(deriv(deriv(func)))), x)
plt.plot(x, primals, label='f(x) = x*sin(x)')
plt.plot(x, f_prime, label="f'(x)")
plt.plot(x, f_prime2, label="f''(x)")
plt.plot(x, f_prime3, label="f'''(x)")
plt.plot(x, f_prime4, label="f''''(x)")
plt.grid(True, linestyle="--", alpha=0.15)
plt.box(False)
plt.legend(loc='upper right')
plt.show()
反向自动微分 (VJP)
就像我们在 JVP 中通过用切线封装值来添加额外信息一样,在 VJP 中,我们将添加额外的信息节点。这将创建一个可以反向遍历的图。因此称为反向自动微分。
网上关于反向自动微分的解释很多,推特上充斥着“我在 X 中实现了自动微分”的帖子。所以我不会详细解释。我认为最好的解释是 Karpathy(senpai)的视频 链接。看完这个你就会明白是怎么回事了。
定义节点,它将保存反向传播函数及其父节点。
class Node:
def __init__(self, vjp: Callable, parents: list[Node]) -> None:
self.vjp = vjp
self.parents = parents
@property
def is_leaf(self):
return len(self.parents) == 0
def get_leaf_nodes() -> Node:
return Node(None, [])
反向传播规则
需要注意的是,即使在反向传播中,我们也只调用我们定义的原始函数。与 PyTorch 不同,在 PyTorch 中你可以在反向传播中做任何你想要的事情。在 Jax 中,你的反向传播需要是封闭的,即只能通过组合现有操作来实现。
因此,Torch 相对容易扩展(FAFO),而 Jax 则不然。
class VJPRules:
def __init__(self):
self.rules = {
OPS.ADD: self.add,
OPS.MUL: self.mul,
OPS.NEG: self.neg,
OPS.RECIP: self.recip,
OPS.EXP: self.exp,
OPS.SIN: self.sin,
}
"""
Jax define one of vjp or jvp rules
it derives one from the other
but this is much more simple to understand
"""
def __getitem__(self, op):
return self.rules[op]
def add(self, primals):
x, y = primals
def vjp_add(grad):
return grad, grad
return (x + y,), vjp_add
def mul(self, primals):
x, y = primals
def vjp_mul(grad):
return grad * y, grad * x
return (x * y,), vjp_mul
def tanh(self, primals):
(x,) = primals
y = tanh(x)
def vjp_tanh(grad):
return ((1 - y * y) * grad,)
return (y,), vjp_tanh
def neg(self, primals):
(x,) = primals
def vjp_neg(grad):
return (-grad,)
return (-x,), vjp_neg
def recip(self, primals):
(x,) = primals
y = 1 / x
def vjp_recip(grad):
return (-y * y * grad,)
return (y,), vjp_recip
def exp(self, primals):
(x,) = primals
y = exp(x)
def vjp_exp(grad):
return (y * grad,)
return (y,), vjp_exp
def sin(self, primals):
(x,) = primals
y = sin(x)
def vjp_sin(grad):
return (cos(x) * grad,)
return (y,), vjp_sin
VJP 盒子,我们把原始值和它的节点装在里面。
class VJPBox(Box):
def __init__(self, interpreter: VJPInterpreter, primal, node: Node) -> None:
super().__init__()
self._interpreter = interpreter
self.primal = primal
self.node = node
def pure(self,value):
return VJPBox(self._interpreter, value, get_leaf_nodes())
def __repr__(self):
return f"VJPBox (primal={self.primal}, node={self.node})"
class VJPInterpreter(Interpreter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.rules = VJPRules()
def pure(self, val):
return VJPBox(self, val, get_leaf_nodes())
def process_primitive(self, prim, boxes, params):
primals_in = [box.primal for box in boxes]
nodes_in = [box.node for box in boxes]
op = self.rules[prim]
primals_out, vjp_out = op(primals_in, **params)
nodes_out = [Node(vjp_out, nodes_in)]
result = []
for p, n in zip(primals_out, nodes_out):
result.append(VJPBox(self, p, n))
return result
正向传播将给出输出和图,这将在反向传播中用于获取梯度。再次,观看 Karpathy 的视频以获得直观理解。
def toposort(end_node):
def _toposort(seen, node):
result = []
if id(node) not in seen:
seen.add(id(node))
for p in node.parents:
result.extend(_toposort(seen, p))
result.append(node)
return result
return reversed([n for n in _toposort(set(), end_node) if n.parents])
def add_grads(grad1, grad2):
if grad1 is None:
return grad2
return grad1 + grad2
def backward_pass(in_nodes, out_node, gradient):
node_map = {id(out_node): gradient}
topo_sorted = toposort(out_node)
for node in topo_sorted:
node_grad = node_map.pop(id(node))
input_grads = node.vjp(node_grad)
for input_grad, parent in zip(input_grads, node.parents):
parent_id = id(parent)
node_map[parent_id] = add_grads(node_map.get(parent_id), input_grad)
return [node_map.get(id(node)) for node in in_nodes]
现在,一个简单的 VJP 函数,它将接收一个函数,为所有输入添加叶节点,并返回输出和反向传播函数。
def vjp_simple(func, *args):
with interpreter_context(VJPInterpreter) as iptr:
box_in = [VJPBox(iptr, x, get_leaf_nodes()) for x in args]
out = func(*box_in)
box_out = full_raise(iptr, out)
in_nodes = [box.node for box in box_in]
out_node = box_out.node
primal_out = box_out.primal
def func_vjp(grad):
return backward_pass(in_nodes, out_node, grad)
return primal_out, func_vjp
grad 是 VJP 的一个小包装器。
def grad(func):
def grad_func(*args):
_, backward = vjp_simple(func, *args)
return backward(1)[0]
return grad_func
def func(x):
# return x*x
return 3 * x * x * x + 2 * x * x + 2 * x
print(grad(func)(2))
print(grad(grad(func))(2))
# 46
# 40
同样,你可以对函数进行任意多次求导,例如
def func(x):
return tanh(x)
def forward(func, vec):
return [func(x) for x in vec]
x = [i / 25 for i in range(-100, 100)]
plt.figure(figsize=(15,8))
primals = forward(func, x)
# first derivative
f_prime = forward(deriv(func), x)
f_prime2 = forward(deriv(deriv(func)), x)
f_prime3 = forward(deriv(deriv(deriv(func))), x)
f_prime4 = forward(deriv(deriv(deriv(deriv(func)))), x)
plt.plot(x, primals, label='f(x) = x*sin(x)')
plt.plot(x, f_prime, label="f'(x)")
plt.plot(x, f_prime2, label="f''(x)")
plt.plot(x, f_prime3, label="f'''(x)")
plt.plot(x, f_prime4, label="f''''(x)")
plt.grid(True, linestyle="--", alpha=0.15)
plt.box(False)
plt.legend(loc='upper right')
plt.show()
函数组合
由于 jvp 和 vjp 都是通过原始函数定义的,因此您可以在正向模式自动微分上执行反向传播。
def func(x):
return sin(x)+ tanh(x)*exp(x)
print("forward on backward")
print(deriv(grad(func))(2))
print("backward on forward")
print(grad(deriv(func))(2))
forward on backward
6.251514736700764
backward on forward
6.251514736700765
现在有一个问题,您只能将原始数据类型传递给函数,例如
def func(inputs):
x,y = inputs
return tanh(x) + y
inputs = (6,9)
grad = deriv(func)
#print(grad(inputs))
如果您运行此代码,您将收到错误,
1 def func(inputs):
----> 2 x,y = inputs
3 return tanh(x) + y
TypeError: cannot unpack non-iterable JVPBox object
我们希望输入列表中包含盒装值,但此处我们的列表本身是盒装的 Box(list([x,y]))
,我们需要类似 [Box(x),Box(y)]
的东西,这就是 pytree 的用武之地。
Pytree
Pytree 是一种表示嵌套数据的数据结构。它解决了值装箱和拆箱的问题。它可以获取任何数据结构并将其转换为扁平化和 pytree。现在这个扁平化只是原始数据类型的列表,所以我们可以遍历它们并进行装箱。树保留了数据的结构,因此在装箱值之后,我们可以重建原始数据结构,其中每个值都已装箱。
from __future__ import annotations
import numpy as np
import numbers
from typing import Any, Hashable, Iterable
PyTreeTypes = list | dict | tuple | Any
每个 pytree 都有一个类型、元数据和子树。
类型是数据结构的类型,元数据是数据本身,子树是数据结构的子树。
class PyNode:
def __init__(
self, node_type: type, metadata: Hashable = None, child_tree: "PyNode" = None
):
self.node_type = node_type
self.metadata = metadata
self.child_tree = child_tree
def __repr__(self):
s = f"({self.node_type.__name__ if self.node_type !='leaf' else 'leaf'}"
if isinstance(self.metadata, np.ndarray) or self.metadata:
s += f":{self.metadata.__class__.__name__}"
if self.child_tree is not None:
s += f",{self.child_tree}"
return s + ")"
@staticmethod
def from_iter(pytree) -> tuple[Hashable, "PyNode"]:
raise NotImplementedError("Not implemented")
@staticmethod
def to_iter() -> PyTreeTypes:
raise NotImplementedError("Not implemented")
def __eq__(self, other: PyNode) -> bool:
if self.node_type != other.node_type:
return False
if self.child_tree != other.child_tree:
return False
return True
class ListNode(PyNode):
@staticmethod
def to_iter(lst):
return None, lst
@staticmethod
def from_iter(_, iterable):
return list(iterable)
class DictNode(PyNode):
@staticmethod
def from_iter(keys, vals):
return dict(zip(keys, vals))
@staticmethod
def to_iter(dct):
keys, values = [], []
for key, value in sorted(dct.items()):
keys.append(key)
values.append(value)
return keys, values
class TupleNode(PyNode):
@staticmethod
def from_iter(_, tup):
return tuple(tup)
@staticmethod
def to_iter(tup):
return None, tup
node_types: dict[Hashable, PyNode | None] = {
list: ListNode,
dict: DictNode,
tuple: TupleNode,
}
现在我们将定义 tree_flatten 和 tree_unflatten。
tree_flatten 将获取任何数据结构并将其转换为扁平化和 pytree。
tree_unflatten 将获取扁平化列表和 pytree,并将其转换回原始数据结构。
def tree_flatten(x: Any) -> tuple[list[Any], PyNode]:
def _flatten(x: Any) -> tuple[Iterable, PyNode]:
data_type = type(x)
node_type = node_types.get(data_type)
if node_type is None:
return [x], PyNode(node_type="leaf", metadata=x, child_tree=None)
node_metadata, children = node_type.to_iter(x)
children_flat, child_trees = [], []
for node in children:
flat, tree = _flatten(node)
children_flat.extend(flat)
child_trees.append(tree)
subtree = PyNode(
data_type, # store the base type instead of the specific node type
node_metadata,
tuple(child_trees),
)
return children_flat, subtree
flatten, pytree = _flatten(x)
return flatten, pytree
def tree_unflatten(flattened_list: list, tree: PyNode) -> Any:
def _unflatten(flattened_list: list, tree: PyNode) -> Any:
if tree.node_type == "leaf":
return next(flattened_list)
children = []
for child_tree in tree.child_tree:
children.append(_unflatten(flattened_list, child_tree))
node_type = node_types[tree.node_type]
return node_type.from_iter(tree.metadata, children)
return _unflatten(iter(flattened_list[:]), tree)
def display_tree(node: PyNode, indent: str = "") -> None:
if node.node_type == "leaf":
print(f"{indent}Leaf: {node.metadata}")
else:
node_type_name = node.node_type.__name__ if node.node_type != "leaf" else "leaf"
print(f"{indent}{node_type_name}: {node.metadata}")
for child in node.child_tree:
display_tree(child, indent + " ")
if __name__ == "__main__":
x = [1, (2, {"a": 3, "b": 4}, 5), [6, 7]]
flattened, tree = tree_flatten(x)
print(x)
print("\nTree structure:")
display_tree(tree)
print("\n")
print("Flattened:", flattened)
print("\n")
reconstructed = tree_unflatten(flattened, tree)
print("\nReconstructed:", reconstructed)
assert x == reconstructed, "Reconstruction failed"
print("Reconstruction successful!")
[1, (2, {'a': 3, 'b': 4}, 5), [6, 7]]
Tree structure:
list: None
Leaf: 1
tuple: None
Leaf: 2
dict: ['a', 'b']
Leaf: 3
Leaf: 4
Leaf: 5
list: None
Leaf: 6
Leaf: 7
Flattened: [1, 2, 3, 4, 5, 6, 7]
Reconstructed: [1, (2, {'a': 3, 'b': 4}, 5), [6, 7]]
Reconstruction successful!
我们有办法扁平化和反扁平化任何数据结构,现在我们需要扁平化和反扁平化函数。我们将创建一个函数,它接受函数和 pytree 并返回新函数和存储。一旦您评估函数,它将存储函数输出的 pytree。
(简化:这将把函数转换为接受扁平化列表并返回扁平化列表的函数)
def flatten_fun(func, in_tree):
store = {}
def flat_fun(*args_flat):
pytree_args = tree_unflatten(args_flat, in_tree)
out = func(*pytree_args)
out_flat, out_tree = tree_flatten(out)
assert len(store) == 0, "Store already has a value!"
store["tree"] = out_tree
return out_flat
return flat_fun, store
一些辅助函数
# These functions create nested structures of ones or zeros that match the input structure
def nested_ones_like(item):
"""Create a nested structure of ones with the same shape as the input."""
if isinstance(item, list):
return [nested_ones_like(x) for x in item]
if isinstance(item, tuple):
return tuple(nested_ones_like(x) for x in item)
if isinstance(item, dict):
return {k: nested_ones_like(v) for k, v in item.items()}
return 1.0 if isinstance(item, numbers.Number) else np.ones_like(item)
def nested_zero_like(item):
"""Create a nested structure of zeros with the same shape as the input."""
if isinstance(item, list):
return [nested_zero_like(x) for x in item]
if isinstance(item, tuple):
return tuple(nested_zero_like(x) for x in item)
if isinstance(item, dict):
return {k: nested_zero_like(v) for k, v in item.items()}
return 0.0 if isinstance(item, numbers.Number) else np.zeros_like(item)
现在我们将使用 pytree 重新实现 jvp 和 vjp。首先,我们将函数作为输入并将其扁平化。当我们需要评估函数时,我们将扁平化输入并将其传递给函数。然后我们将反扁平化输出并将其返回。
### Refinement of JVP
def jvp_flat(func, primals, tangents):
with interpreter_context(JVPInterpreter) as iptr:
tracers_in = [JVPBox(iptr, x, t) for x, t in zip(primals, tangents)]
outs = func(*tracers_in)
tracers_out = [full_raise(iptr, out) for out in outs]
primals_out, tangents_out = [], []
for t in tracers_out:
primals_out.append(t.primal)
tangents_out.append(t.tangent)
return primals_out, tangents_out
def jvp(func, primals, tangents):
# Flatten the primals and tangents into flat lists
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
assert in_tree == in_tree2, "Input trees for primals and tangents must match"
# Flatten the function f according to the input tree structure
func_flat, out_tree = flatten_fun(func, in_tree)
# forward pass
primals_out_flat, tangents_out_flat = jvp_flat(
func_flat, primals_flat, tangents_flat
)
assert len(out_tree) == 1, "out tree dict must have only one item"
out_tree: PyNode = out_tree["tree"]
primals_out = tree_unflatten(primals_out_flat, out_tree)
tangents_out = tree_unflatten(tangents_out_flat, out_tree)
return primals_out, tangents_out
def deriv(func, argnums=0):
if isinstance(argnums, int):
argnums = [argnums]
def jvp_forward(*input_value):
# pass tangent 1 for argnums and 0 for others
tangents = tuple(
nested_ones_like(x) if idx in argnums else nested_zero_like(x)
for idx, x in enumerate(input_value)
)
_, gradient = jvp(func, input_value, tangents)
return gradient
return jvp_forward
def func(x, y):
k = tanh(x) * 2.0 + y * y
z0 = -y + k
z1 = y*k
return {" lets": z0,"f*in":z1, "go!": [x, y]}
print("## pytree.py ##")
x = 3.14
y = 2.71
print(deriv(func, argnums=0)(x, y))
## pytree.py ##
{' lets': 0.01493120808257803, 'f*in': 0.040463573903786465, 'go!': [1.0, 0.0]}
vjp 也一样
### Refinement of VJP
def add_grads(grad1, grad2):
if grad1 is None:
return grad2
return grad1 + grad2
def toposort(end_nodes):
def _toposort(seen, node):
result = []
if node not in seen:
seen.add(node)
for p in node.parents:
result.extend(_toposort(seen, p))
result.append(node)
return result
outs = []
seen = set()
topo_sorted = []
for end_node in end_nodes:
topo_sorted.extend(_toposort(seen, end_node))
for node in topo_sorted:
if node.parents:
outs.append(node)
result = reversed(outs)
return list(result)
def backward_pass(in_nodes, out_nodes, gradient):
node_map = {out_node: g for g, out_node in zip(gradient, out_nodes)}
topo_sorted = toposort(out_nodes)
for node in topo_sorted:
node_grad = node_map.pop(node)
input_grads = node.vjp(node_grad)
for input_grad, parent in zip(input_grads, node.parents):
node_map[parent] = add_grads(node_map.get(parent), input_grad)
return [node_map.get(node) for node in in_nodes]
def vjp_flat(func, args):
with interpreter_context(VJPInterpreter) as iptr:
box_in = [VJPBox(iptr, x, get_leaf_nodes()) for x in args]
outs = func(*box_in)
box_out = [full_raise(iptr, o) for o in outs]
in_nodes = [box.node for box in box_in]
out_nodes = [box.node for box in box_out]
out_primals = [box.primal for box in box_out]
def func_vjp(grad):
return backward_pass(in_nodes, out_nodes, grad)
return out_primals, func_vjp
def vjp(func, primals):
# Flatten the primals and tangents into flat lists
primals_flat, in_tree = tree_flatten(primals)
# Flatten the function f according to the input tree structure
func_flat, out_tree = flatten_fun(func, in_tree)
# forward pass
primals_out_flat, vjp_func = vjp_flat(
func_flat,
primals_flat,
)
assert len(out_tree) == 1, "out tree dict must have only one item"
out_tree: PyNode = out_tree["tree"]
primals_out = tree_unflatten(primals_out_flat, out_tree)
return primals_out, vjp_func
def grad(func, argnums=0):
if isinstance(argnums, int):
argnums = [argnums]
def vjp_func(*input_value):
result, vjp_func = vjp(func, input_value)
ones = nested_ones_like(result)
flat, _ = tree_flatten(ones)
grads = vjp_func(flat)
_, in_tree = tree_flatten(input_value)
grads = tree_unflatten(grads, in_tree)
grads = tuple(g for idx, g in enumerate(grads) if idx in argnums)
return grads[0] if len(argnums) == 1 else grads
return vjp_func
def value_and_grad(func, argnums=0):
if isinstance(argnums, int):
argnums = [argnums]
def vjp_forward(*input_value):
result, vjp_func = vjp(func, input_value)
# <hack>
# jax dont do this nasted ones funnny busniess
# it just requires output to be scalar
# but I you can pass one to all output nodes
# which is effectively like result = sum(result) I dont have redution op
# basically result.sum().backward() in pytorch
ones = nested_ones_like(result)
flat, _ = tree_flatten(ones)
# </hack>
# backward pass
grads = vjp_func(flat)
output, in_tree = tree_flatten(input_value)
grads = tree_unflatten(grads, in_tree)
grads = tuple(g for idx, g in enumerate(grads) if idx in argnums)
return result, grads[0] if len(argnums) == 1 else grads
return vjp_forward
现在您可以做这样的事情,您可以传递状态字典并获取该状态字典的梯度,并构建复杂的、可微分的程序。
def linear(state,inputs):
weight,bias = state["weights"], state["bias"]
total = 0
for w, x in zip(weight, inputs):
prod = w * x
total = total + prod
return total + bias
state = {"weights":[1,2,3], "bias": 1}
inputs = [0.3, 0.5, 0.7]
print(grad(linear)(state,inputs))
{'bias': 1.0, 'weights': [0.3, 0.5, 0.7]}
value,grads = value_and_grad(linear)(state,inputs)
print(value)
print(grads)
4.3999999999999995
{'bias': 1.0, 'weights': [0.3, 0.5, 0.7]}
vmap、pmap 和 jit
我不会详细介绍这些,毕竟这是 microjax。但为了给您一个直观的认识,就像我们为 jvp 添加切线,为 vjp 添加节点一样,对于 vmap,我们封装了形状信息,编写了批处理解释器,并在规则级别执行 lambda x: [f(x[0]) for _ in range(x.shape[0])]
,是的,它只是一个映射。如果您并行执行此映射,您将获得 pmap,
就像我们为 jit 携带了切线信息一样,我们携带了函数的所有历史(图),并进行图优化,然后将其编译为 xla。当您第二次调用 jit 函数时,它会流向优化的图,而不是原始函数。这使得它更快。