MicroJAX

社区文章 发布于2024年8月25日

image/gif

这是一个关于如何构建一个微型 Jax/mlx 风格的转换引擎的微型开发博客,因为我发现网上没有简单易懂的解释函数转换引擎的资料。就像 Karpathy-senpai 的 micrograd 是 PyTorch 的一个简单版本一样,microjax 是 Jax 的一个更简单版本。

Github: microjax - 如果你觉得有用,请点个⭐。

Notebook: 分步学习

如果你有任何问题/更正,请在 Twitter 上私信我 @shxf0072

🤗 博客是 pythonstuff 的镜像

警告:此项目旨在通过递增的复杂性来帮助理解,而非追求绝对正确性。

现在有 Autodidax,但它相当难,使用了大量术语和 lambda 函数。image/png 我因此有了灵感 :3

很久以前,谷歌沉迷于 TensorFlow。谷歌有定制芯片 TPU 和 XLA,一个非常智能的线性代数编译器,让它飞速运行,但 TensorFlow 用起来很糟糕。其他新兴框架如 PyTorch 具有动态特性,但表现不佳。Jax 是谷歌的一个实验项目,它将 Python 跟踪转换为 XLA,然后可以用 MLIR 编译并在 GPU 和 TPU 等加速器上运行。我们可以将 Jax 分为两部分:函数转换和 XLA 编译器。函数转换使我们能够计算梯度、Hessian 并定义类似 vmap 的转换,而 XLA IR 转换则带来了速度。在这里,我们只介绍一个简单的函数转换引擎,因此 Jax 的 X 部分不在讨论范围内,但嘿,MLX 有 X,所以我不在乎,我称它为 microjax。尽管 Jax 最近变得更主流了,但 Jax 的故事可以追溯到 autograd。

Autograd 有一个更具创新性的目标。你需要为机器学习或科学计算计算梯度,而 Python 是科学计算的语言。那么,如何使 Python 这样的高级解释型语言可微分呢?

当然是编写一个解释器!

基础解释器

首先,我们将从原语开始。这些是 Jax 中的基本操作,所有其他操作都将基于它们。顺便说一句,这里只涉及标量值,因此更容易理解(虽然技术上支持 np 数组,但为了简化,我不会定义数组操作)。

from __future__ import annotations

import math
from contextlib import contextmanager
from typing import Any, Callable
class OPS:
    ADD = "add"
    MUL = "mul"
    NEG = "neg"
    RECIP = "recip"
    EXP = "exp"
    SIN = "sin"
# ik i can make this enum, i choose not too :P

对于大多数情况,你只需要这些操作,如果你喜欢,也可以定义自己的操作。

让我们从基础解释器开始,这就像一个抽象类。

class Interpreter:
    def __init__(self, level: int = 0, *args, **kwargs):
        self.level = level

    def process_primitive(self, prim, boxes, params):
        "in this function, either you process primitives or you unbox and send to lower level interpreter"
        raise NotImplementedError

image/png

在 JAX 中,这称为 Trace。Trace 跟踪 tracers,它们只是封装的值。我们把值和一些额外的信息(例如是否需要计算梯度,以及形状和要使用的解释器)一起放在盒子里。对于每种独特的转换类型,我们都会有一个该类型的盒子和一个该类型的解释器。

现在我们将有多个解释器。为了知道封装的值应该在哪个上下文中评估,我们需要跟踪解释器。我们将使用最常见的数据结构:栈来实现这一点。

STACK: list[Interpreter] = []


def push_interpreter(interpreter: Interpreter):
    STACK.append(interpreter)
    return STACK


def pop_interpreter() -> Interpreter:
    return STACK.pop()


@contextmanager
def interpreter_context(interpreter_type: Interpreter):
    stack_item = interpreter_type(level=len(STACK))
    push_interpreter(stack_item)
    try:
        yield stack_item
    finally:
        pop_interpreter()

现在我们来定义 Box,Box(tracer)是实际流经你定义的函数的值。我们需要重写一些双下划线方法,使其与 Python 兼容。

class Box:
    _interpreter: Interpreter

    def aval(self):
        raise NotImplementedError

    def __add__(self, other):
        return add(self, other)

    def __radd__(self, other):
        return add(other, self)

    def __mul__(self, other):
        return mul(self, other)

    def __rmul__(self, other):
        return mul(other, self)

    def __neg__(self):
        return neg(self)

    def __sub__(self, other):
        return add(self, neg(other))

    def __rsub__(self, other):
        return add(other, neg(self))

    def __truediv__(self, other):
        return mul(self, recip(other))

    def __rtruediv__(self, other):
        return mul(other, recip(self))

    def __iadd__(self, other):
        return add(self, other)

    def __imul__(self, other):
        return mul(self, other)

    def __isub__(self, other):
        return add(self, neg(other))

    def __itruediv__(self, other):
        return mul(self, recip(other))

# dont worry about this undefined functions, we add them later

我们抽象类快完成了,只剩几个辅助函数

当函数接收到多个封装值时,我们需要为它们找到最高级别的解释器。

def find_top_interpreter(args):
    """
    find the top level interpreter for the given arguments
    """
    interpreters = []
    for item in args:
        if isinstance(item, Box):
            interpreters.append(item._interpreter)

    if interpreters:
        return max(interpreters, key=lambda x: x.level)

    # if no interpreters are found, return the default EvalInterpreter
    return STACK[0]

如果一个值在第 2 层,另一个在第 3 层,我们需要用这个函数将它们提升到第 3 层。

def full_raise(interpreter: Interpreter | Any, out) -> Box | JVPBox:
    """
    if interpreter need values boxed
    if out is not boxed, box it (using interpreter.pure)
    ie. raise out to the box level
    """
    if not isinstance(out, Box):
        return interpreter.pure(out)
    return out

每个封装值都将分配一个解释器。每个解释器都有一个表示其在栈中位置的级别。`find_top_interpreter` 函数将找到所有解释器中级别最高的那个。`full_raise` 将把一个值提升到栈中当前解释器的级别。`bind_single` 只是一个小的包装器,用于处理 `bind` 返回的元组。

def bind(prim, *args, **params):
    interpreter = find_top_interpreter(args)
    
    # this will raise the boxes to the top level
    # eg converts primitive values to Boxes if interpreter is not the top level
    boxes = [full_raise(interpreter, arg) for arg in args]
    
    outs = interpreter.process_primitive(prim, boxes, params)
    
    return [out for out in outs]


def bind_single(prim, *args, **params):
    (out,) = bind(prim, *args, **params)
    return out

“bind”是重要的函数,它将调用解释器。

原语

这些是构建块,所有其他函数都将在此基础上构建。我喜欢称它们为模拟函数,因为它们不真正计算任何东西;它们更像是箱子到解释器的路由器。

def add(*args):
    return bind_single(OPS.ADD, *args)

def mul(*args):
    return bind_single(OPS.MUL, *args)

def neg(x):
    return bind_single(OPS.NEG, x)

def recip(x):
    return bind_single(OPS.RECIP, x)

def exp(x):
    return bind_single(OPS.EXP, x)

def sin(x):
    return bind_single(OPS.SIN, x)

原语就像模拟函数。当你调用 `mul(Box1(3), Box1(2))` 时,它会为 `Box1(3)` 和 `Box1(2)` 找到解释器,然后找到其中级别最高的解释器。它会解封装这些值,并告诉该解释器处理这些原语。对于每种操作类型,都有一个原语操作函数。

复合函数建立在原语之上。只要你可以用原语表达你的函数,你就可以使用任意复杂的函数。

def cos(x):
    return sin(x + math.pi / 2)

def sigmoid(x):
    return 1 / (1 + exp(-x))

def tanh(x): 
    return 2 * sigmoid(2 * x) - 1

def silu(x):
    return x * sigmoid(x)

Eval 解释器

即使我们有很好的抽象层,最终总需要有人来运行 `add` 或 `mul` 函数。这将由 eval 解释器完成。我们首先定义评估规则,然后定义 eval 解释器。

class EvalRules:
    def __init__(self):
        self.rules = {
            OPS.ADD: self.add,
            OPS.MUL: self.mul,
            OPS.NEG: self.neg,
            OPS.RECIP: self.recip,
            OPS.EXP: self.exp,
            OPS.SIN: self.sin,
        }

    def __getitem__(self, op):
        return self.rules[op]

    def add(self, primals, *args):
        x, y = primals
        return [x + y]

    def mul(self, primals, *args):
        x, y = primals
        return [x * y]

    def neg(self, primals, *args):
        (x,) = primals
        return [-x]

    def recip(self, primals, *args):
        (x,) = primals
        return [1 / x]

    def exp(self, primals, *args):
        (x,) = primals
        return [math.exp(x)]

    def sin(self, primals, *args):
        (x,) = primals
        return [math.sin(x)]

我们不期望任何封装值传递给 eval 解释器,因此我们可以直接在值上调用函数。这很简单:接收参数并返回结果。

class EvalInterpreter(Interpreter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rules = EvalRules()

    def pure(self, val):
        return val

    def process_primitive(self, prim, values, params):
        return self.rules[prim](values, *params)

基本解释器是 `EvalInterpreter`,现在我们运行基本程序,由于评估是原语的末尾,我们将在栈的底部推入 eval 解释器。

push_interpreter(EvalInterpreter())

现在我们可以启动基本程序了。

def func(x):
    return sin(x)*exp(x) + tanh(x)

x = 3.14

func(x)
1.033056645880499

image/png

你可能会觉得这样 ^,但这些抽象将用于构建更复杂的 AD。

自动微分

假设我们有函数 y=D(C(B(A(x)))) y = D(C(B(A(x))))

如果 x 是一个向量,那么它的梯度可以通过雅可比矩阵计算。

yx=Jv \frac{\partial y}{\partial x} = J \cdot v

如果你不知道什么是雅可比矩阵,它只是偏导数组成的矩阵,如果需要复习,请观看此视频 链接

现在我们可以用原语来定义我们的函数了。

y=D(c),c=C(b),b=B(a),a=A(x) y = D(c), \quad c = C(b), \quad b = B(a), \quad a = A(x)

同样地,我们可以将雅可比矩阵定义为每个函数导数的乘积,或者数学上我们可以通过链式法则定义大雅可比矩阵。F(x)=yx F'(x) = \frac{\partial y}{\partial x} yx=yccbbaax \frac{\partial y}{\partial x}= \frac{\partial y}{\partial c} \cdot \frac{\partial c}{\partial b} \cdot \frac{\partial b}{\partial a} \cdot \frac{\partial a}{\partial x}

yc=D(c)cb=C(b)ba=B(a)ax=A(x) \frac{\partial y}{\partial c} = D'(c) \quad \frac{\partial c}{\partial b} = C'(b) \quad \frac{\partial b}{\partial a} = B'(a) \quad \frac{\partial a}{\partial x} = A'(x)

因此,如果你将雅可比矩阵乘以向量,你将得到梯度。

F(x)=yx=[yx1yxn] F'(x) = \frac{\partial y}{\partial x} = \left[\frac{\partial y}{\partial x_1} \cdots \frac{\partial y}{\partial x_n}\right]

你可能从线性代数课上知道,如果乘法是结合律的,那么你可以从任何一边乘以雅可比链。

yc(cb(ba(ax)))=(((yc)cb)ba)ax \frac{\partial y}{\partial c} \cdot \left(\frac{\partial c}{\partial b} \cdot \left(\frac{\partial b}{\partial a} \cdot \left(\frac{\partial a}{\partial x}\right)\right)\right) = \left(\left( \left( \frac{\partial y}{\partial c} \right) \cdot \frac{\partial c}{\partial b}\right) \cdot \frac{\partial b}{\partial a}\right) \cdot \frac{\partial a}{\partial x}

这个等式成立,因为矩阵乘法是结合律的。

雅可比-向量积 (JVP)

雅可比-向量积,或前向自动微分,计算雅可比矩阵和向量的乘积。

JVP=Jv \text{JVP} = J \cdot v

其中 $J$ 是雅可比矩阵, $v$ 是一个向量。

在我们的链式法则示例中

JVP=yc(cb(ba(axv))) JVP = \frac{\partial y}{\partial c} \cdot \left(\frac{\partial c}{\partial b} \cdot \left(\frac{\partial b}{\partial a} \cdot \left(\frac{\partial a}{\partial x} \cdot v \right)\right)\right)

JVP 将导数从输入到输出向前传播通过计算图。

我们可以按照函数评估的相同方向计算梯度 A->B->C->D,dA->dB->dC->dD。

向量-雅可比积 (VJP)

向量-雅可比积,或反向自动微分,计算向量和雅可比矩阵的乘积。

VJP=vTJ \text{VJP} = v^T \cdot J

其中 $v^T$ 是向量 $v$ 的转置, $J$ 是雅可比矩阵。

在我们的链式法则示例中

VJP=(((vTyc)cb)ba)ax VJP = \left(\left(\left(v^T \cdot \frac{\partial y}{\partial c}\right) \cdot \frac{\partial c}{\partial b}\right) \cdot \frac{\partial b}{\partial a}\right) \cdot \frac{\partial a}{\partial x}

VJP 将导数从输出到输入反向传播通过计算图。这是训练神经网络中常用的反向传播算法的基础。

正向自动微分 (JVP)

前向自动微分真的很容易。我们将把值及其导数(初始化为 1)装箱。当我们向前计算函数时,我们也会计算其导数。

class JVPBox(Box):
    def __init__(self, interpretor: Interpreter, primal, tangent) -> None:
        super().__init__()
        self._interpreter = interpretor
        self.primal = primal
        self.tangent = tangent
    
    def __repr__(self):
        return f"JVPBox (primal={self.primal}, tangent={self.tangent})"

值称为原值,其导数称为切线。

当您有一个函数 f(x) = sin(x) 时,它的导数表示该点处切线的斜率。由于每个点的斜率都在变化,这些变化由梯度函数定义。因此,尽管我们使用 cos(x) 得到 sin(x) 在原点 x 处的导数,但我们称之为切线。它应该表示该点处切线的斜率。

image/png

我们将为每个原始操作定义规则。

注意,现在要定义这个规则,你只能使用你定义的原始函数,所以整个操作需要是封闭的,即只能使用上面定义的 cos,不能使用 math.cos(x)。

class JVPRules:
    def __init__(self):
        self.rules = {
            OPS.ADD: self.add,
            OPS.MUL: self.mul,
            OPS.NEG: self.neg,
            OPS.RECIP: self.recip,
            OPS.EXP: self.exp,
            OPS.SIN: self.sin,
        }

    # dont forget to return tuple(primals),tuple(tangents)
    def __getitem__(self, op):
        return self.rules[op]

    @staticmethod
    def add(primals, tangents):
        (x, y), (x_dot, y_dot) = primals, tangents
        return (x + y,), (x_dot + y_dot,)

    @staticmethod
    def mul(primals, tangents):
        (x, y), (x_dot, y_dot) = primals, tangents
        return (x * y,), (x_dot * y + x * y_dot,)

    @staticmethod
    def neg(primals, tangents):
        (x,), (x_dot,) = primals, tangents
        return (-x,), (-x_dot,)

    @staticmethod
    def recip(primals, tangents):
        (x,), (x_dot,) = primals, tangents
        y = 1 / x
        return (y,), (-y * y * x_dot,)

    @staticmethod
    def exp(primals, tangents):
        (x,), (x_dot,) = primals, tangents
        y = exp(x)
        return (y,), (y * x_dot,)

    @staticmethod
    def sin(primals, tangents):
        (x,), (x_dot,) = primals, tangents
        return (sin(x),), (cos(x) * x_dot,)

JVP 解释器

现在我们来构建第一个真正的解释器。我们将值及其切线装箱。首先,我们解开值,并处理这些原值和切线。然后,我们将结果装箱。

image/png

class JVPInterpreter(Interpreter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rules = JVPRules()

    def pure(self, val):
        return JVPBox(self, val, 0.0)

    def process_primitive(self, prim, boxes, params):
        # unbox values
        primals = [box.primal for box in boxes]
        tangents = [box.tangent for box in boxes]
        
        # process primitive
        op = self.rules[prim]
        primals_out, tangents_out = op(primals, tangents, **params)

        # box values
        result = []
        for p, t in zip(primals_out, tangents_out):
            result.append(JVPBox(self, p, t))
            
        return result

JVP 简单来说就是一个函数,它接受一个函数、其输入和其切线,并返回该函数的输出和其切线。

def jvp_simple(func, primals, tangents):
    with interpreter_context(JVPInterpreter) as iptr:
        box_in = [JVPBox(iptr, x, t) for x, t in zip(primals, tangents)]
        out = func(*box_in)
        box_out = full_raise(iptr, out)
        primal_out, tangent_out = box_out.primal, box_out.tangent
    return primal_out, tangent_out


def func(x):
    return sin(x) + exp(x)

x,xdot = 3.14,1

y,y_dot = jvp_simple(func, (x,), (xdot,))

print(y,y_dot)
23.10545951163867 22.103868126994644

现在我们来定义一个包装函数,它将接收一个函数,修改输入以执行 JVP 并返回梯度。


def deriv(function):
    def jvp_forward(input_value):
        _, gradient = jvp_simple(function, (input_value,), (1,))
        return gradient

    return jvp_forward

def func(x):
    return sin(x) 
x = 3.14

print(func(x))
deriv_func = deriv(func)
print(deriv_func(x))
0.0015926529164868282
-0.9999987317275395

但是如果我们对导数函数求导会发生什么呢?

deriv_func = deriv(func)
deriv_func2 = deriv(deriv_func)
print(deriv_func2(x))
-0.0015926529164865067

我们得到二阶导数或 Hessian。

image/png

这里发生的是我们正在双重封装值,`BOX2( BOX1(原始值, 导数1), 导数2 )`,但是由于函数在组合下是封闭的,我们可以任意多次组合导数,

例如,我们来看一个函数

f(x)=xsin(x) f(x) = x \cdot sin(x)

import matplotlib.pyplot as plt
import random
plt.style.use("dark_background")

plt.figure(figsize=(10,5))

def forward(func, vec):
    return [func(x) for x in vec]


x = [i / 10 for i in range(-50, 50)]

def func(x):
    return x*sin(x)

primals = forward(func, x)

# first derivative
f_prime = forward(deriv(func), x)
f_prime2 = forward(deriv(deriv(func)), x)
f_prime3 = forward(deriv(deriv(deriv(func))), x)
f_prime4 = forward(deriv(deriv(deriv(deriv(func)))), x)

plt.plot(x, primals, label='f(x) = x*sin(x)')
plt.plot(x, f_prime, label="f'(x)")
plt.plot(x, f_prime2, label="f''(x)")
plt.plot(x, f_prime3, label="f'''(x)")
plt.plot(x, f_prime4, label="f''''(x)")

plt.grid(True, linestyle="--", alpha=0.15)
plt.box(False)
plt.legend(loc='upper right')
plt.show()

image/png

反向自动微分 (VJP)

就像我们在 JVP 中通过用切线封装值来添加额外信息一样,在 VJP 中,我们将添加额外的信息节点。这将创建一个可以反向遍历的图。因此称为反向自动微分。

网上关于反向自动微分的解释很多,推特上充斥着“我在 X 中实现了自动微分”的帖子。所以我不会详细解释。我认为最好的解释是 Karpathy(senpai)的视频 链接。看完这个你就会明白是怎么回事了。

定义节点,它将保存反向传播函数及其父节点。

class Node:
    def __init__(self, vjp: Callable, parents: list[Node]) -> None:
        self.vjp = vjp
        self.parents = parents

    @property
    def is_leaf(self):
        return len(self.parents) == 0


def get_leaf_nodes() -> Node:
    return Node(None, [])

反向传播规则

需要注意的是,即使在反向传播中,我们也只调用我们定义的原始函数。与 PyTorch 不同,在 PyTorch 中你可以在反向传播中做任何你想要的事情。在 Jax 中,你的反向传播需要是封闭的,即只能通过组合现有操作来实现。

因此,Torch 相对容易扩展(FAFO),而 Jax 则不然。



class VJPRules:
    def __init__(self):
        self.rules = {
            OPS.ADD: self.add,
            OPS.MUL: self.mul,
            OPS.NEG: self.neg,
            OPS.RECIP: self.recip,
            OPS.EXP: self.exp,
            OPS.SIN: self.sin,
        }
        """
        Jax define one of vjp or jvp rules
        it derives one from the other 
        but this is much more simple to understand
        """

    def __getitem__(self, op):
        return self.rules[op]

    def add(self, primals):
        x, y = primals

        def vjp_add(grad):
            return grad, grad

        return (x + y,), vjp_add

    def mul(self, primals):
        x, y = primals

        def vjp_mul(grad):
            return grad * y, grad * x

        return (x * y,), vjp_mul

    def tanh(self, primals):
        (x,) = primals
        y = tanh(x)

        def vjp_tanh(grad):
            return ((1 - y * y) * grad,)

        return (y,), vjp_tanh

    def neg(self, primals):
        (x,) = primals

        def vjp_neg(grad):
            return (-grad,)

        return (-x,), vjp_neg

    def recip(self, primals):
        (x,) = primals
        y = 1 / x

        def vjp_recip(grad):
            return (-y * y * grad,)

        return (y,), vjp_recip

    def exp(self, primals):
        (x,) = primals
        y = exp(x)

        def vjp_exp(grad):
            return (y * grad,)

        return (y,), vjp_exp

    def sin(self, primals):
        (x,) = primals
        y = sin(x)

        def vjp_sin(grad):
            return (cos(x) * grad,)

        return (y,), vjp_sin

VJP 盒子,我们把原始值和它的节点装在里面。

class VJPBox(Box):
    def __init__(self, interpreter: VJPInterpreter, primal, node: Node) -> None:
        super().__init__()
        self._interpreter = interpreter
        self.primal = primal
        self.node = node
        
    def pure(self,value):
        return VJPBox(self._interpreter, value, get_leaf_nodes())

    def __repr__(self):
        return f"VJPBox (primal={self.primal}, node={self.node})"


class VJPInterpreter(Interpreter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rules = VJPRules()

    def pure(self, val):
        return VJPBox(self, val, get_leaf_nodes())

    def process_primitive(self, prim, boxes, params):
        primals_in = [box.primal for box in boxes]
        nodes_in = [box.node for box in boxes]
        op = self.rules[prim]
        primals_out, vjp_out = op(primals_in, **params)
        nodes_out = [Node(vjp_out, nodes_in)]
        result = []
        for p, n in zip(primals_out, nodes_out):
            result.append(VJPBox(self, p, n))
        return result

正向传播将给出输出和图,这将在反向传播中用于获取梯度。再次,观看 Karpathy 的视频以获得直观理解。


def toposort(end_node):
    def _toposort(seen, node):
        result = []
        if id(node) not in seen:
            seen.add(id(node))
            for p in node.parents:
                result.extend(_toposort(seen, p))
            result.append(node)
        return result

    return reversed([n for n in _toposort(set(), end_node) if n.parents])


def add_grads(grad1, grad2):
    if grad1 is None:
        return grad2
    return grad1 + grad2


def backward_pass(in_nodes, out_node, gradient):
    node_map = {id(out_node): gradient}

    topo_sorted = toposort(out_node)
    for node in topo_sorted:
        node_grad = node_map.pop(id(node))

        input_grads = node.vjp(node_grad)

        for input_grad, parent in zip(input_grads, node.parents):
            parent_id = id(parent)
            node_map[parent_id] = add_grads(node_map.get(parent_id), input_grad)

    return [node_map.get(id(node)) for node in in_nodes]

现在,一个简单的 VJP 函数,它将接收一个函数,为所有输入添加叶节点,并返回输出和反向传播函数。


def vjp_simple(func, *args):
    with interpreter_context(VJPInterpreter) as iptr:
        box_in = [VJPBox(iptr, x, get_leaf_nodes()) for x in args]
        out = func(*box_in)
        box_out = full_raise(iptr, out)
        in_nodes = [box.node for box in box_in]
        out_node = box_out.node
        primal_out = box_out.primal

        def func_vjp(grad):
            return backward_pass(in_nodes, out_node, grad)

    return primal_out, func_vjp

grad 是 VJP 的一个小包装器。

def grad(func):
    def grad_func(*args):
        _, backward = vjp_simple(func, *args)
        return backward(1)[0]

    return grad_func


def func(x):
    # return x*x
    return 3 * x * x * x + 2 * x * x + 2 * x

print(grad(func)(2))
print(grad(grad(func))(2))

#    46
#    40

同样,你可以对函数进行任意多次求导,例如

def func(x):
    return tanh(x)

def forward(func, vec):
    return [func(x) for x in vec]

x = [i / 25 for i in range(-100, 100)]

plt.figure(figsize=(15,8))
primals = forward(func, x)

# first derivative
f_prime = forward(deriv(func), x)
f_prime2 = forward(deriv(deriv(func)), x)
f_prime3 = forward(deriv(deriv(deriv(func))), x)
f_prime4 = forward(deriv(deriv(deriv(deriv(func)))), x)

plt.plot(x, primals, label='f(x) = x*sin(x)')
plt.plot(x, f_prime, label="f'(x)")
plt.plot(x, f_prime2, label="f''(x)")
plt.plot(x, f_prime3, label="f'''(x)")
plt.plot(x, f_prime4, label="f''''(x)")

plt.grid(True, linestyle="--", alpha=0.15)
plt.box(False)
plt.legend(loc='upper right')
plt.show()

image/png

函数组合

由于 jvp 和 vjp 都是通过原始函数定义的,因此您可以在正向模式自动微分上执行反向传播。

def func(x): 
    return sin(x)+ tanh(x)*exp(x)

print("forward on backward")
print(deriv(grad(func))(2))

print("backward on forward")
print(grad(deriv(func))(2))
forward on backward
6.251514736700764
backward on forward
6.251514736700765

现在有一个问题,您只能将原始数据类型传递给函数,例如

def func(inputs):
    x,y = inputs
    return tanh(x) + y 

inputs = (6,9)

grad = deriv(func)

#print(grad(inputs))

如果您运行此代码,您将收到错误,

      1 def func(inputs):
----> 2     x,y = inputs
      3     return tanh(x) + y

TypeError: cannot unpack non-iterable JVPBox object

我们希望输入列表中包含盒装值,但此处我们的列表本身是盒装的 Box(list([x,y])) ,我们需要类似 [Box(x),Box(y)] 的东西,这就是 pytree 的用武之地。

Pytree

image/png

Pytree 是一种表示嵌套数据的数据结构。它解决了值装箱和拆箱的问题。它可以获取任何数据结构并将其转换为扁平化和 pytree。现在这个扁平化只是原始数据类型的列表,所以我们可以遍历它们并进行装箱。树保留了数据的结构,因此在装箱值之后,我们可以重建原始数据结构,其中每个值都已装箱。

from __future__ import annotations

import numpy as np
import numbers

from typing import Any, Hashable, Iterable

PyTreeTypes = list | dict | tuple | Any

每个 pytree 都有一个类型、元数据和子树。

类型是数据结构的类型,元数据是数据本身,子树是数据结构的子树。


class PyNode:
    def __init__(
        self, node_type: type, metadata: Hashable = None, child_tree: "PyNode" = None
    ):
        self.node_type = node_type
        self.metadata = metadata
        self.child_tree = child_tree

    def __repr__(self):
        s = f"({self.node_type.__name__ if self.node_type !='leaf' else 'leaf'}"
        if isinstance(self.metadata, np.ndarray) or self.metadata:
            s += f":{self.metadata.__class__.__name__}"
        if self.child_tree is not None:
            s += f",{self.child_tree}"
        return s + ")"

    @staticmethod
    def from_iter(pytree) -> tuple[Hashable, "PyNode"]:
        raise NotImplementedError("Not implemented")

    @staticmethod
    def to_iter() -> PyTreeTypes:
        raise NotImplementedError("Not implemented")

    def __eq__(self, other: PyNode) -> bool:
        if self.node_type != other.node_type:
            return False
        if self.child_tree != other.child_tree:
            return False
        return True


class ListNode(PyNode):
    @staticmethod
    def to_iter(lst):
        return None, lst

    @staticmethod
    def from_iter(_, iterable):
        return list(iterable)


class DictNode(PyNode):
    @staticmethod
    def from_iter(keys, vals):
        return dict(zip(keys, vals))

    @staticmethod
    def to_iter(dct):
        keys, values = [], []
        for key, value in sorted(dct.items()):
            keys.append(key)
            values.append(value)
        return keys, values


class TupleNode(PyNode):
    @staticmethod
    def from_iter(_, tup):
        return tuple(tup)

    @staticmethod
    def to_iter(tup):
        return None, tup


node_types: dict[Hashable, PyNode | None] = {
    list: ListNode,
    dict: DictNode,
    tuple: TupleNode,
}

现在我们将定义 tree_flatten 和 tree_unflatten。

tree_flatten 将获取任何数据结构并将其转换为扁平化和 pytree。

tree_unflatten 将获取扁平化列表和 pytree,并将其转换回原始数据结构。


def tree_flatten(x: Any) -> tuple[list[Any], PyNode]:
    def _flatten(x: Any) -> tuple[Iterable, PyNode]:
        data_type = type(x)
        node_type = node_types.get(data_type)
        if node_type is None:
            return [x], PyNode(node_type="leaf", metadata=x, child_tree=None)

        node_metadata, children = node_type.to_iter(x)

        children_flat, child_trees = [], []
        for node in children:
            flat, tree = _flatten(node)
            children_flat.extend(flat)
            child_trees.append(tree)

        subtree = PyNode(
            data_type,  # store the base type instead of the specific node type
            node_metadata,
            tuple(child_trees),
        )
        return children_flat, subtree

    flatten, pytree = _flatten(x)
    return flatten, pytree


def tree_unflatten(flattened_list: list, tree: PyNode) -> Any:
    def _unflatten(flattened_list: list, tree: PyNode) -> Any:
        if tree.node_type == "leaf":
            return next(flattened_list)

        children = []
        for child_tree in tree.child_tree:
            children.append(_unflatten(flattened_list, child_tree))

        node_type = node_types[tree.node_type]
        return node_type.from_iter(tree.metadata, children)

    return _unflatten(iter(flattened_list[:]), tree)




def display_tree(node: PyNode, indent: str = "") -> None:
    if node.node_type == "leaf":
        print(f"{indent}Leaf: {node.metadata}")
    else:
        node_type_name = node.node_type.__name__ if node.node_type != "leaf" else "leaf"
        print(f"{indent}{node_type_name}: {node.metadata}")
        for child in node.child_tree:
            display_tree(child, indent + "    ")


if __name__ == "__main__":
    x = [1, (2, {"a": 3, "b": 4}, 5), [6, 7]]
    flattened, tree = tree_flatten(x)
    print(x)
    print("\nTree structure:")
    display_tree(tree)
    print("\n")
    print("Flattened:", flattened)
    print("\n")

    reconstructed = tree_unflatten(flattened, tree)
    print("\nReconstructed:", reconstructed)
    assert x == reconstructed, "Reconstruction failed"
    print("Reconstruction successful!")

    [1, (2, {'a': 3, 'b': 4}, 5), [6, 7]]
    
    Tree structure:
    list: None
        Leaf: 1
        tuple: None
            Leaf: 2
            dict: ['a', 'b']
                Leaf: 3
                Leaf: 4
            Leaf: 5
        list: None
            Leaf: 6
            Leaf: 7
    
    
    Flattened: [1, 2, 3, 4, 5, 6, 7]
    
    
    
    Reconstructed: [1, (2, {'a': 3, 'b': 4}, 5), [6, 7]]
    Reconstruction successful!

我们有办法扁平化和反扁平化任何数据结构,现在我们需要扁平化和反扁平化函数。我们将创建一个函数,它接受函数和 pytree 并返回新函数和存储。一旦您评估函数,它将存储函数输出的 pytree。
(简化:这将把函数转换为接受扁平化列表并返回扁平化列表的函数)

def flatten_fun(func, in_tree):
    store = {}

    def flat_fun(*args_flat):
        pytree_args = tree_unflatten(args_flat, in_tree)
        out = func(*pytree_args)
        out_flat, out_tree = tree_flatten(out)
        assert len(store) == 0, "Store already has a value!"
        store["tree"] = out_tree
        return out_flat

    return flat_fun, store

一些辅助函数


# These functions create nested structures of ones or zeros that match the input structure

def nested_ones_like(item):
    """Create a nested structure of ones with the same shape as the input."""
    if isinstance(item, list):
        return [nested_ones_like(x) for x in item]
    if isinstance(item, tuple):
        return tuple(nested_ones_like(x) for x in item)
    if isinstance(item, dict):
        return {k: nested_ones_like(v) for k, v in item.items()}
    return 1.0 if isinstance(item, numbers.Number) else np.ones_like(item)


def nested_zero_like(item):
    """Create a nested structure of zeros with the same shape as the input."""
    if isinstance(item, list):
        return [nested_zero_like(x) for x in item]
    if isinstance(item, tuple):
        return tuple(nested_zero_like(x) for x in item)
    if isinstance(item, dict):
        return {k: nested_zero_like(v) for k, v in item.items()}
    return 0.0 if isinstance(item, numbers.Number) else np.zeros_like(item)

现在我们将使用 pytree 重新实现 jvp 和 vjp。首先,我们将函数作为输入并将其扁平化。当我们需要评估函数时,我们将扁平化输入并将其传递给函数。然后我们将反扁平化输出并将其返回。


### Refinement of JVP
def jvp_flat(func, primals, tangents):
    with interpreter_context(JVPInterpreter) as iptr:
        tracers_in = [JVPBox(iptr, x, t) for x, t in zip(primals, tangents)]

        outs = func(*tracers_in)

        tracers_out = [full_raise(iptr, out) for out in outs]

        primals_out, tangents_out = [], []
        for t in tracers_out:
            primals_out.append(t.primal)
            tangents_out.append(t.tangent)

    return primals_out, tangents_out


def jvp(func, primals, tangents):
    # Flatten the primals and tangents into flat lists
    primals_flat, in_tree = tree_flatten(primals)
    tangents_flat, in_tree2 = tree_flatten(tangents)
    assert in_tree == in_tree2, "Input trees for primals and tangents must match"

    # Flatten the function f according to the input tree structure
    func_flat, out_tree = flatten_fun(func, in_tree)

    # forward pass
    primals_out_flat, tangents_out_flat = jvp_flat(
        func_flat, primals_flat, tangents_flat
    )

    assert len(out_tree) == 1, "out tree dict must have only one item"
    out_tree: PyNode = out_tree["tree"]

    primals_out = tree_unflatten(primals_out_flat, out_tree)
    tangents_out = tree_unflatten(tangents_out_flat, out_tree)

    return primals_out, tangents_out


def deriv(func, argnums=0):
    if isinstance(argnums, int):
        argnums = [argnums]

    def jvp_forward(*input_value):
        # pass tangent 1 for argnums and 0 for others
        tangents = tuple(
            nested_ones_like(x) if idx in argnums else nested_zero_like(x)
            for idx, x in enumerate(input_value)
        )

        _, gradient = jvp(func, input_value, tangents)

        return gradient

    return jvp_forward

def func(x, y):
    k = tanh(x) * 2.0 + y * y
    z0 = -y + k 
    z1 = y*k
    return {" lets": z0,"f*in":z1, "go!": [x, y]}



print("## pytree.py ##")
x = 3.14
y = 2.71
print(deriv(func, argnums=0)(x, y))

    ## pytree.py ##
    {' lets': 0.01493120808257803, 'f*in': 0.040463573903786465, 'go!': [1.0, 0.0]}

vjp 也一样


### Refinement of VJP


def add_grads(grad1, grad2):
    if grad1 is None:
        return grad2
    return grad1 + grad2


def toposort(end_nodes):
    def _toposort(seen, node):
        result = []
        if node not in seen:
            seen.add(node)
            for p in node.parents:
                result.extend(_toposort(seen, p))
            result.append(node)
        return result

    outs = []
    seen = set()
    topo_sorted = []
    for end_node in end_nodes:
        topo_sorted.extend(_toposort(seen, end_node))

    for node in topo_sorted:
        if node.parents:
            outs.append(node)
    result = reversed(outs)
    return list(result)


def backward_pass(in_nodes, out_nodes, gradient):
    node_map = {out_node: g for g, out_node in zip(gradient, out_nodes)}

    topo_sorted = toposort(out_nodes)
    for node in topo_sorted:
        node_grad = node_map.pop(node)

        input_grads = node.vjp(node_grad)

        for input_grad, parent in zip(input_grads, node.parents):
            node_map[parent] = add_grads(node_map.get(parent), input_grad)

    return [node_map.get(node) for node in in_nodes]


def vjp_flat(func, args):
    with interpreter_context(VJPInterpreter) as iptr:
        box_in = [VJPBox(iptr, x, get_leaf_nodes()) for x in args]
        outs = func(*box_in)
        box_out = [full_raise(iptr, o) for o in outs]
        in_nodes = [box.node for box in box_in]
        out_nodes = [box.node for box in box_out]
        out_primals = [box.primal for box in box_out]

        def func_vjp(grad):
            return backward_pass(in_nodes, out_nodes, grad)

    return out_primals, func_vjp


def vjp(func, primals):
    # Flatten the primals and tangents into flat lists
    primals_flat, in_tree = tree_flatten(primals)

    # Flatten the function f according to the input tree structure
    func_flat, out_tree = flatten_fun(func, in_tree)

    # forward pass
    primals_out_flat, vjp_func = vjp_flat(
        func_flat,
        primals_flat,
    )

    assert len(out_tree) == 1, "out tree dict must have only one item"
    out_tree: PyNode = out_tree["tree"]

    primals_out = tree_unflatten(primals_out_flat, out_tree)

    return primals_out, vjp_func


def grad(func, argnums=0):
    if isinstance(argnums, int):
        argnums = [argnums]

    def vjp_func(*input_value):
        result, vjp_func = vjp(func, input_value)

        ones = nested_ones_like(result)
        flat, _ = tree_flatten(ones)
        grads = vjp_func(flat)
        _, in_tree = tree_flatten(input_value)
        grads = tree_unflatten(grads, in_tree)
        grads = tuple(g for idx, g in enumerate(grads) if idx in argnums)
        return grads[0] if len(argnums) == 1 else grads

    return vjp_func


def value_and_grad(func, argnums=0):
    if isinstance(argnums, int):
        argnums = [argnums]

    def vjp_forward(*input_value):
        result, vjp_func = vjp(func, input_value)

        # <hack>
        # jax dont do this nasted ones funnny busniess
        # it just requires output to be scalar
        # but I you can pass one to all output nodes
        # which is effectively like result = sum(result) I dont have redution op
        # basically result.sum().backward() in pytorch
        ones = nested_ones_like(result)
        flat, _ = tree_flatten(ones)
        # </hack>

        # backward pass
        grads = vjp_func(flat)

        output, in_tree = tree_flatten(input_value)
        grads = tree_unflatten(grads, in_tree)

        grads = tuple(g for idx, g in enumerate(grads) if idx in argnums)

        return result, grads[0] if len(argnums) == 1 else grads

    return vjp_forward

现在您可以做这样的事情,您可以传递状态字典并获取该状态字典的梯度,并构建复杂的、可微分的程序。


def linear(state,inputs):
    weight,bias = state["weights"], state["bias"]
    total = 0
    for w, x in zip(weight, inputs):
        prod = w * x 
        total = total + prod 
    return total + bias

state = {"weights":[1,2,3], "bias": 1}
inputs  = [0.3, 0.5, 0.7]

print(grad(linear)(state,inputs))
    {'bias': 1.0, 'weights': [0.3, 0.5, 0.7]}
value,grads = value_and_grad(linear)(state,inputs)
print(value)
print(grads)
    4.3999999999999995
    {'bias': 1.0, 'weights': [0.3, 0.5, 0.7]}

vmap、pmap 和 jit

我不会详细介绍这些,毕竟这是 microjax。但为了给您一个直观的认识,就像我们为 jvp 添加切线,为 vjp 添加节点一样,对于 vmap,我们封装了形状信息,编写了批处理解释器,并在规则级别执行 lambda x: [f(x[0]) for _ in range(x.shape[0])] ,是的,它只是一个映射。如果您并行执行此映射,您将获得 pmap,

就像我们为 jit 携带了切线信息一样,我们携带了函数的所有历史(图),并进行图优化,然后将其编译为 xla。当您第二次调用 jit 函数时,它会流向优化的图,而不是原始函数。这使得它更快。


ko-fi

社区

注册登录 发表评论