什么是自动微分?

社区文章 发布于2024年3月19日

澄清


在我们开始之前,本文涉及的技术和数学内容较多,涵盖微积分、线性代数和机器学习 (ML) 相关主题。例如,我们将讨论偏导数、梯度、多元函数和损失函数。为节省时间,我假设您对这些主题有一些基本知识或熟悉程度。

TL;DR


本文的主要目的是强调自动微分(AD)及其在训练神经网络中基于梯度的优化中的应用。我将介绍:

  • 数值微分:一种近似函数导数的过程,因截断误差和舍入误差等问题,在神经网络优化中被避免使用。

  • 符号微分:一种通过变换推导表达式导数的系统方法。因代码约束、表达式膨胀和重复计算等问题,在基于梯度的优化中被避免使用。

  • 自动微分(AD):一种通过将导数与函数的基本操作交织在一起,来增强算术计算的方法。我还将描述求值跟踪和计算图——这在正向和反向模式AD中非常有用。

  • 正向模式AD:一种应用自动微分的方法,通过在函数正向传播过程中计算变量的偏导数。它被认为有用,但在某些条件下速度较慢,因此不适用于基于梯度的优化。

  • 反向模式AD:另一种模式,在函数求值后,在反向传播过程中计算变量的偏导数。由于反向传播和向量-雅可比积等特性,该模式在优化神经网络的背景下比正向模式更快。

最后,我将使用运算符重载在Python中实现正向和反向模式AD。通过深入了解AD在代码中的构建方式,我将把该实现与熟悉的机器学习框架进行对比。

引言


在机器学习中,神经网络实现了广泛的数学甚至科学技术,才使得我们今天拥有的AI工具(例如LLM)成为可能。该领域已经从使用简单的神经网络对手写数字进行分类,发展到有时被认为是“有意识的” Transformer 架构。其核心仍然是解决优化问题——即,我们如何才能教会这些模型学习?

可能有一些更好的方法来解决这个问题,某位天才有朝一日会发现,但今天的共识是通过梯度下降。梯度下降是一种优化算法,旨在通过迭代增量步骤来提高模型的性能。以下是其分解:

梯度下降

  1. 使用目标[损失]函数,计算模型在给定输入集上的预测与这些输入的真实值之间的损失(误差)。

  2. 通过计算损失函数对模型每个参数的偏导数(即梯度),找出模型对损失的影响。

  3. 通过将每个参数减去其各自的梯度乘以一个称为学习率的超参数,来朝着最小化损失的方向调整模型参数。

  4. 清除所有梯度,并重复该过程,直到模型收敛;换句话说,直到模型不再改进并达到最佳性能。

这个过程显然构建了强大的深度神经网络,但是,它需要一个困难的过程,那就是找到更新模型参数的梯度。从数学意义上讲,我们如何才能计算损失函数对模型参数的导数?为此,我将介绍自动微分

在本博客中,我将解释什么是自动微分,我们还将深入探讨其实现。但在进一步深入之前,我们需要了解两种替代方法,以及为什么它们无法满足神经网络基于梯度优化的需求:数值微分符号微分

数值微分


数值微分是一种我们可以用来计算梯度以优化神经网络的方法,它使用导数的极限定义。

dfdx=limh0f(x+h)f(x)h(1) \tag{1}\frac{df}{dx} = \lim_{ h\to 0}\frac{f(x + h) - f(x)}{h}

为了计算函数 f:RR f: \R \to \R 在输入 x x 处的导数,我们找到了函数在 x x 处的切线斜率。切线斜率可以分解为函数在 x x 处的上升和运行的比值。为了计算上升,我们在两个输入处评估函数,一个使用原始输入,另一个将输入微调一个小常数 h h ;这在等式1的分子中可以看到。接下来,我们除以运行;由于我们将输入移动了 h h ,因此运行是 h h 。当 h h 趋近于零时,通过取极限,函数 f(x) f(x) x x 处的导数近似值变得更加准确。

在实践中,神经网络对多维数组(通常称为张量)执行算术运算,并且程序化地取极限是没有意义的,因此我们放弃极限并重写表达式以对网络参数进行操作。

fθif(θ+hei)f(θ)h+O(h)(2) \tag{2}\frac{\partial f}{\partial \bold{\theta}_i} \approx \frac{f(\bold{\theta} + h \cdot \bold{e}_i) - f(\bold{\theta})}{h} + O(h)

上面是计算多元函数偏导数的前向差分
f:RnR f: \R^{n} \to \R 关于单个参数 θi \theta{_i} ,该参数来自参数向量 θ \bold{\theta} 。符号 ei \bold{e}_i 表示单位向量,其中第 i-th i \text{-th} 个元素为1,而所有其他元素为零。现在计算 θ \bold{\theta} 中第 i-th i \text{-{th}} 个参数的偏导数就像用 h0 h \approx 0 评估等式2一样简单。在神经网络领域中,f f 代表目标损失函数,而 θ \bold{\theta} 则是模型的参数化。通过对所有模型参数评估等式2,我们将获得梯度下降一步所需的梯度。

我尚未提到的是等式2末尾添加的大O项。这个符号——在算法的时间和空间复杂度领域之外——是关于 h h 的函数,表示截断误差。截断误差定义为由于截断 fθi \frac{\partial f}{\partial \bold{\theta}_i} 泰勒级数中的某些值而产生的误差。更具体地说,因为我们使用 h h 来近似 θi \theta{_i} 的偏导数,所以我们用一些依赖于 h h 的误差来错误地近似它。此外,截断误差与 h h 成正比,这意味着无论 h h 乘以什么因子,截断误差都会按相同因子缩放。

现在,有一些方法可以最小化这种误差。首先,我们可以将我们的近似方法改为如下所示的中心差分方法。

fθif(θ+hei)f(θhei)2h+O(h2)(3) \tag{3}\frac{\partial f}{\partial \bold{\theta}_i} \approx \frac{f(\bold{\theta} + h \cdot \bold{e}_i) - f(\bold{\theta} - h \cdot \bold{e}_i)}{2h} + O(h^2)

中心差分是等式2中的前向差分和后向差分的组合。通过从前向差分中减去后向差分并简化,O(h) O(h) 中的一阶误差项将相互抵消,留下二阶误差项作为主导。误差现在与 h h 的平方成正比,这意味着如果 h h 减小一个数量级,误差将减小两个数量级。

导数近似

derivatives
导数近似:图示为使用数值微分方法(所有计算均使用32位浮点数)对 cos(x) cos(x) x=2 x = 2 处,以及 h=0.5 h = 0.5 时,导数的近似值。该图突出显示了中心差分优于前向和后向差分的近似效果。此外,还可以看出实际导数 sin(2) - \sin(2) 与因截断误差导致的近似导数之间的差异。

作为稳定这种方法的另一种方式,我们可以简单地减小 h h ,因为当 h0 h \approx 0 时,截断误差将不复存在。理论上,这应该消除我们在数值微分中遇到的误差。然而,这样做会带来副作用,导致我们进入下一节。

数值微分的问题


了解数值微分后,我们可以探讨为什么在神经网络优化实现中避免使用它。关于减小 h h 以减轻截断误差的解决方案,我们还引入了另一种误差,称为舍入误差

舍入误差是由于数字在计算机中表示的不准确性而引起的误差。诸如IEEE 754之类的标准已经推广了使用单精度浮点数(float32)来表示程序中的实数。神经网络依赖于这些表示,但是它们是有限的。浮点数被分配了固定数量的空间(在大多数情况下为32位),这限制了任意大或小值的精度。将其与数值微分联系起来,如果数字变得太小,它们将下溢为零,并在过程中丢失数值信息。

这很重要,因为当我们尝试减小 h h 以减轻截断误差时,我们增加了舍入误差。事实上,舍入误差与 h h 的比例成反比。例如,如果我们将 h h 减半,舍入误差就会翻倍。截断误差和舍入误差之间的这种平衡在选择一个可行的 h h 来计算精确梯度时引入了一个权衡。

截断误差与舍入误差

errors
截断误差与舍入误差:上图显示了函数 f(x)=(x10)2(3x23x+1) f(x) = (x - 10)^2(3x^2 -3x + 1) 的前向差分法(等式2)和中心差分法(等式3)计算中出现的误差样本。对于 h h ,使用了范围 [107,1][10^{-7}, 1] 的单精度浮点值。可以看出,当 h h 减小,截断误差减小,同时引入舍入误差;反之,当 h h 增大时,情况则相反。

有人可能会建议使用更高精度的数据类型(例如float64),但这会增加硬件限制,因为需要更多的内存和额外的计算——这是另一个完全不必要的权衡。转而考虑,数值微分的运行时复杂度也带来了另一个问题。

为了实际计算梯度,我们必须评估感兴趣的函数。在找到具有 n n 个输入和标量输出的函数的梯度时,我们需要 O(n) O(n) 次操作。这甚至还没有考虑到神经网络中出现的向量值函数。例如,如果我们有函数 f:RnRm f: \R^{n} \to \R^{m} ,我们大约需要 O(mn) O(mn) 次操作来计算梯度,这使得对于大的 m m n n 值,总计算效率低下。

鉴于梯度下降是一个对数百万甚至数十亿参数进行的迭代过程,我们可以看到数值微分对于神经网络优化来说扩展性不够。了解它的不足之处后,我们可以转向另一种方法:符号微分。

符号微分


符号微分是我们接下来要探讨的梯度计算方法。它是一个系统过程,将由算术运算和符号组成的表达式转换为表示其导数的表达式。这是通过将微积分的导数规则(例如,求和规则)应用于闭合形式表达式来实现的。

实际上,符号微分是计算机推导表达式导数的方式。例如,对于下面两个函数 f f g g ,我们可以使用微积分推导出其导数的表达式。

g(x)=cos(x)+2xex g(x)=cos(x)+ 2x - e^x

f(g)=4g2 f(g)=4g^2

f(g(x))=4(cos(x)+2xex)2(4) \tag{4}f(g(x)) = 4(cos(x)+ 2x - e^x)^2

dfdx=dfdgdgdx=8(cos(x)+2xex)(sin(x)+2ex)(5) \tag{5} \frac{df}{dx} = \frac{df}{dg} \cdot \frac{dg}{dx} = 8(cos(x)+2x - e^x) \cdot (-sin(x) + 2-e^x)

为了找到输入 f(g(x)) f(g(x)) 的导数,我们只需将其代入上述变换后的表达式并进行求值。在实践中,我们将以程序方式实现此过程,并且表示的变量将不仅仅是标量(例如,向量、矩阵或张量)。下面是使用Python中的SymPy* 如何符号微分等式4以得到等式5的示例。

from sympy import symbols, cos, exp, diff

x = symbols("x")
fog = 4 * (cos(x) + 2 * x - exp(x)) ** 2
dfdx = diff(fog, x)
print(dfdx)
4*(2*x - exp(x) + cos(x))*(-2*exp(x) - 2*sin(x) + 4)
* 导数表达式可能看起来与等式5不同,但它们实际上是相同的。这些项被略微重新排序,并且2 2 dfdg \frac{df}{dg} 被分配到dgdx \frac{dg}{dx}

这解决了数值微分中出现的数值不准确和不稳定问题(参见“导数近似”和“截断误差与舍入误差”图示),因为我们有一个可以直接计算函数梯度的表达式。然而,我们仍然面临限制其在优化神经网络方面可行性的问题,我们将在下一节中解决这些问题。

符号微分的问题


符号微分最主要的问题是“表达式膨胀”。表达式膨胀会导致导数表达式通过转换呈指数级增长,这是系统地将导数规则应用于原始表达式的代价。例如,下面的乘法规则。

ddxf(x)g(x)=f(x)g(x)+g(x)f(x) \frac{d}{dx}f(x)g(x)=f'(x)g(x) + g'(x)f(x)

导数表达式不仅在项数上增加,而且在计算量上也增加。这还没有考虑到f f g g 本身可能是复杂的函数——这可能会导致更多的表达式膨胀。

当推导dfdx \frac{df}{dx} 时,我们看到了一些表达式膨胀,而这还是一个相对简单的函数。现在想象一下,对于许多复合函数,它们可能一次又一次地应用导数规则。这样做,考虑到神经网络代表着许多复杂的复合函数,是极其不切实际的。

表达式膨胀

f(x)=ewx+b+e(wx+b)ewx+be(wx+b) f(x) = \frac{e^{wx+b} + e^{-(wx+b)}}{e^{wx+b} - e^{-(wx+b)}}

fw=(xebwxxeb+wx)(ebwx+eb+wx)(ebwx+eb+wx)2+xebwx+xeb+wxebwx+eb+wx \frac{\partial f}{\partial w} = \frac{(- x e^{- b - w x} - x e^{b + w x}) (e^{- b - w x} + e^{b + w x})}{(- e^{- b - w x} + e^{b + w x})^{2}} + \frac{- x e^{- b - w x} + x e^{b + w x}}{- e^{- b - w x} + e^{b + w x}}

表达式膨胀:图中展示了神经网络中的线性投影,后接非线性激活函数tanh \text{tanh} 。在不进行简化和优化的情况下,计算梯度以更新权重w w 可能导致大量的表达式膨胀和重复计算。

符号微分面临的另一个缺点是,它仅限于封闭形式的表达式。编程之所以有用,在于能够使用控制流根据程序状态改变其行为,同样的原则也常应用于神经网络。如果我们要根据特定输入改变操作的执行方式,或者希望模型根据其模式表现不同,那该怎么办?这种功能无法进行符号微分,因此,我们将失去实现各种模型架构所需的任何动态特性。

无控制流

from sympy import symbols, diff

def f(x):
    if x > 2:
        return x * 2 + 5
    return x / 2 + 5

x = symbols("x")
dfdx = diff(f(x))
print(dfdx)
TypeError: cannot determine truth value of Relational

最后一个缺点,在《表达式膨胀》示例中有所提及,就是我们可能会遇到重复计算。对于等式4和5,我们计算ex e^x 三次:一次在计算
等式4中,两次在等式5中。对于更复杂的函数,这可能会以更大的规模进行,为符号微分带来更多不切实际的问题。我们可以通过缓存结果来减少这个问题,但这不一定能解决表达式膨胀的问题。

总而言之,是表达式膨胀、表达式必须是封闭形式的要求以及重复计算,限制了符号微分在神经网络优化方面的应用。但是,应用导数规则和缓存(作为重复计算的解决方案)的直觉,构成了自动微分的基础。

自动微分


自动微分,简称 AD,将复合函数分解为构成它们的基本变量和基本运算*。所有的数值计算都围绕着这些运算展开,由于我们知道它们的导数,我们可以将它们链接起来,从而得到整个函数的导数。简而言之,AD 是数值计算的一种增强版本,它不仅评估数学函数,还同时计算它们的导数。

* 基本运算是指具有明确导数的原子数学运算:加、减、乘、除。超越函数(例如自然对数和余弦)在技术上不被视为基本运算,但在自动微分的上下文中,它们通常被视为基本运算,因为它们的导数是明确的。

为了实现这一点,我们可以利用“求值追踪”。求值追踪是一种特殊的表格,它记录中间变量以及创建它们的运算。每一行对应一个中间变量和导致该变量的基本运算。这些变量被称为“基元”,通常用vi v_i 表示,用于函数f:RnRm f:\R^n \to \R^m 并遵循这些规则

  • 输入变量:vin=xi,i=1,...,n v_{i-n}=x_i, i=1,...,n
  • 中间变量:vi,i=1,...,l v_i, i=1,...,l
  • 输出变量:ymi=vli,i=m1,...,0 y_{m-i}=v_{l-i},i=m-1,...,0

下面,我留下一个示例,仅展示接受两个输入x1 x_1 x2 x_2 的函数的原始计算的求值追踪。

y=f(x1,x2)=x1x2+x2ln(x1)x1=2,x2=4(6) \tag{6}y=f(x_1, x_2)=x_1x_2 + x_2 - \ln(x_1) \\ x_1=2, x_2=4

前向原始追踪(等式6)

前向原始追踪 输出
v₋₁ = x₁ 2
v₀ = x₂ 4
v₁ = v₋₁v₀ 2(4) = 8
v₂ = ln(v₋₁) ln(2) = 0.693
v₃ = v₁ + v₀ 8 + 4 = 12
v₄ = v₃ − v₂ 12 - 0.693 = 11.307
y = v₄ 11.307

在求值追踪的基础上,我们可以使用有向无环图(DAG)作为数据结构来算法表示求值追踪。DAG中的节点表示输入变量、中间变量和输出变量,而边则描述了从输入到输出转换的计算层次结构。最后,为了确保正确的计算流,图必须是有向且无环的。总而言之,这种类型的DAG通常被称为“计算图”。

计算图(等式6)

computational graph

这些工具(求值追踪和计算图)的引入对于理解和实现 AD 至关重要,特别是它的两种模式:正向模式反向模式

正向模式 AD


正向模式 AD 采用了我们之前讨论过的求值追踪原则,但引入了与基元vi v_i 对应的“切线”,表示为vi˙ \dot{v_i} 。这些切线携带着基元关于特定感兴趣输入变量的偏导数信息。

回顾等式6,如果我们对求解yx2 \frac{\partial{y}}{\partial{x_2}} 感兴趣,则会有以下切线定义。

vi˙=vix2 \dot{v_i} = \frac{\partial{v_i}}{\partial{x_2}}

根据这个定义,我们可以构建“正向原始”和“正向切线追踪”来计算当x1=3 x_1 = 3 x2=4 x_2 = -4 x˙1=x1x2=0 \dot{x}_1 = \frac{\partial x_1}{\partial x_2} = 0 x2˙=x2x2=1 \dot{x_2} = \frac{\partial x_2}{\partial x_2} = 1 yx2 \frac{\partial y}{\partial x_2}

正向模式追踪(等式6)

前向原始追踪 输出 正向切线追踪 输出
v₋₁ = x₁ 3 v̇₋₁ = ẋ₋₁ 0
v₀ = x₂ -4 v̇₀ = ẋ₂ 1
v₁ = v₋₁v₀ 3 ⋅ -4 = -12 v̇₁ = v̇₋₁v₀ + v̇₀v₋₁ 0 ⋅ -4 + 1 ⋅ 3 = 3
v₂ = ln(v₋₁) ln(3) = 1.10 v̇₂ = v̇₋₁ ⋅ (1 / v₋₁) 0 ⋅ (1 / 3) = 0
v₃ = v₁ + v₀ -12 + -4 = -16 v̇₃ = v̇₁ + v̇₀ 3 + 1 = 4
v₄ = v₃ − v₂ -16 - 1.10 = -17.10 v̇₄ = v̇₃ − v̇₂ 4 - 0 = 4
y = v₄ -17.10 ẏ = v̇₄ 4

这个过程是前向模式AD(自动微分)的精髓。对于给定函数的每个基本操作,通过应用基本算术操作计算中间变量(原始值),并同步使用我们从微积分中得知的知识计算它们的导数(切线)。

通过这种方法,我们不仅可以计算导数,还可以计算雅可比矩阵(Jacobians)。对于向量值函数f:RnRm f: \R^n \to \R^m ,我们选择一组输入aRn \bold{a} \in \R^n (其中x=a \bold{x} = \bold{a} )和切线x˙=ei \bold{\dot{x}} = \bold{e}_i ,其中i=1,...,n i=1,...,n 。现在,将这些输入应用于我们的前向模式函数,将生成所有输出变量yj y_j (对于j=1,...,m j=1,...,m )相对于单个输入变量xi x_i 的偏导数。实质上,前向模式AD中的每次前向传播都会生成雅可比矩阵的一列——对应于所有输出相对于单个输入的偏导数。

雅可比矩阵

J=[y1x1y1xnymx1ymxn] \large \bold{J} = \LARGE \begin{bmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_n} \\ \end{bmatrix}

因为函数f:RnRm f: \R^n \to \R^mn n 个输入,并且前向模式中的每次前向传播都会生成雅可比矩阵的一列,因此计算完整的m×n m \times n 雅可比矩阵需要O(n) O(n) 次评估。如果您不记得线性代数,完整的雅可比矩阵表示所有输出相对于所有输入的偏导数;对于我们的目的而言,这是我们试图为优化推导的梯度。

此特性可推广到雅可比向量积(Jacobian-vector product,JVP)。JVP是函数JRm×n \bold{J} \in \R^{m \times n} 的雅可比矩阵与列向量rRn \bold{r} \in \R^n 的点积。点积的结果返回一个m m 维列向量,编码了当输入受到扰动时输出的变化。更准确地说,它描述了当输入被r \bold{r} 方向性微调时输出的变化。

特别是在前向模式AD中,其特殊之处在于我们不需要计算完整的雅可比矩阵。通过选择一组输入并设置扰动向量r \bold{r} ,函数在前向模式中的一次评估即可输出JVP,而无需计算整个雅可比矩阵。

雅可比向量积

Jr=[y1x1y1xnymx1ymxn][r1rn] \large \bold{J} \cdot \bold{r} = \LARGE \begin{bmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_n} \\ \end{bmatrix} \cdot \begin{bmatrix} r_1 \\ \vdots \\ r_n \end{bmatrix}

总而言之,这使得前向模式AD在某些情况下非常实用。具体来说,当前向模式AD评估一个函数f:RnRm f: \R^{n} \to \R^m 并且nm n \ll m 时非常有效。例如,一个具有一个输入和m m 个输出的函数,在这种模式下只需进行一次前向传播即可计算其雅可比矩阵。相反,一个具有n n 个输入和一个输出(f:RnR f: \R^n \to \R )的函数则需要nn 次前向传播才能获得其雅可比矩阵。

这种情况很重要,因为通常情况下,神经网络的参数表示Rn \R^n ,而由模型参数引起的标量损失表示R \R 。因此,如果我们将前向模式AD用于基于梯度的优化,那将是次优的。

总而言之,前向模式AD优于数值微分和符号微分,因为它没有数值不稳定性或表达式膨胀等问题(参见“截断误差与舍入误差”图示和“表达式膨胀”示例)。但是,由于它缺乏神经网络优化所需的扩展性,我们可以转向AD的第二种模式,即反向模式。

反向模式AD


至此,我们有了反向模式AD——与前向模式相似,但在方法上有所不同。我们首先定义伴随vˉi\bar{v}_i,它表示函数f:RnRm f: \R^n \to \R^m 的输出yj y_j 相对于中间变量vi v_i 的偏导数(其中i=1,...,n i = 1,...,n j=1,...,mj = 1,..., m )。我们可以将伴随形式化定义为

vˉi=yjvi \bar{v}_i = \frac{\partial y_j}{\partial v_i}

在反向模式AD中,我们通过应用基本操作进行前向传播以计算中间变量,但在该阶段,伴随不会像前向模式AD中的切线那样与其原始对应物一起计算。相反,计算vˉi \bar{v}_i 所需的所有依赖项都存储在计算图中。

接下来,我们利用对基本操作导数、链式法则和缓存依赖项(来自前向传播)的熟悉程度来计算伴随。伴随的计算顺序是从输出变量开始,到导致输出变量的所有输入变量结束。这个阶段通常被称为反向传播(reverse pass)。如果您还没有看出,正是“反向”传播赋予了这种AD模式其名称——导数以反向方式计算。

有了对反向模式AD的直观理解,我们来看看使用与“前向模式跟踪”中相同的输入变量值,对公式6进行反向模式评估的跟踪。
公式6的反向模式跟踪

反向模式跟踪(公式6)

前向原始追踪 输出 反向伴随跟踪 输出
v₋₁ = x₁ 3 v̅₋₁ = x̅₁ = v̅₂ ⋅ (1 / v₋₁) + v̅₁ ⋅ v₀ -1 ⋅ (1 / 3) + 1 ⋅ -4 = -4.33
v₀ = x₂ -4 v̅₀ = x̅₂ = v̅₃ ⋅ 1 + v̅₁ ⋅ v₋₁ 1 ⋅ 1 + 1 ⋅ 3 = 4
v₁ = v₋₁v₀ 3 ⋅ -4 = -12 v̅₁ = v̅₃ ⋅ 1 1 ⋅ 1 = 1
v₂ = ln(v₋₁) ln(3) = 1.10 v̅₂ = v̅₄ ⋅ −1 1 ⋅ -1 = -1
v₃ = v₁ + v₀ -12 + -4 = -16 v̅₃ = v̅₄ ⋅ 1 1 ⋅ 1 = 1
v₄ = v₃ − v₂ -16 - 1.10 = -17.10 v̅₄ = y̅ 1
y = v₄ -17.10 1

在此特定的跟踪中,我们从伴随yˉ=yy=1 \bar{y} = \frac{\partial y}{\partial y} = 1 开始,并通过应用求导规则将其发送到其所有依赖项(导致它的变量)。最终,任何对输出y y 有贡献的输入变量x x 都会填充其伴随。

您可能会对vˉ1 \bar{v}_{-1} vˉ0 \bar{v}_0 的计算感到困惑。在我看来,这有点反直觉,但由于它们的原始值通过多条路径(在v2 v_2 v1 v_1 的计算中可以看到)对输出y y 有贡献,它们都将有两个传入导数。我们不会丢弃任何导数信息而偏向其中一个,因为那样会失去x1 x_1 x2 x_2 y y 的影响。相反,我们累积它们各自的导数。这样,x1 x_1 x2 x_2 的总贡献都包含在它们的伴随x1ˉ \bar{x_1} x2ˉ \bar{x_2} 中。

* 请记住,v1 v_{-1} v0 v_0 只是x1 x_1 x2 x_2 的别名;它们的伴随也是如此。

如前向模式中所见,雅可比矩阵也可以为向量值函数计算
f:RnRm f: \R^n \to \R^m 。通过选择输入aRn\bold{a} \in \R^n,将x=a \bold{x = a} 赋值,并设置yˉ=ej \bold{\bar{y}} = \bold{e}_j,对于
j=1,...,m j = 1,...,m ——每次反向传播都会生成第j-th j\text{-th} 输出关于所有输入变量xi x_i 的偏导数,其中i=1,...,n i = 1,...,n 。由于有m m 行,并且每次反向传播计算雅可比矩阵的一行,因此在反向模式自动微分(AD)中,需要m m 次评估才能获得函数f f 的完整雅可比矩阵。

在此基础上,我们可以计算**向量-雅可比乘积**(*VJP*)。VJP是转置行向量rTR1×m \bold{r}^T \in \R^{1 \times m} (通常称为余切向量)与函数雅可比矩阵JRm×n \bold{J} \in \R^{m \times n} 的左乘。VJP的计算会生成一个n n 维行向量,其中包含一个输出在受到rT \bold{r}^T 扰动时,该输出相对于其所有输入的偏导数。

向量-雅可比乘积

rTJ=[r1rm]T[y1x1y1xnymx1ymxn] \large \bold{r}^T \cdot \bold{J} = \LARGE \begin{bmatrix} r_1 \dots r_m \end{bmatrix}^T \cdot \begin{bmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_n} \\ \end{bmatrix}

向量-雅可比乘积 (另一种形式)

JTr=[y1x1ymx1y1xnymxn]T[r1rm] \large \bold{J}^T \cdot \bold{r} = \LARGE \begin{bmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_1} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_1}{\partial x_n} & \dots & \frac{\partial y_m}{\partial x_n} \\ \end{bmatrix}^T \cdot \begin{bmatrix} r_1 \\ \vdots \\ r_m \end{bmatrix}

VJP 直接与优化神经网络相关,因为我们可以将J \bold{J} 表示为模型输出关于其输入的偏导数,而将rT \bold{r}^T 表示为目标损失函数输出关于模型输出的偏导数。在此背景下应用 VJP 会产生优化所需的梯度。此外,与 JVP 类似,VJP 不需要函数的完整雅可比矩阵,并且可以在单次反向传播中计算。

总结一下我们对反向模式自动微分的讨论,它只需要一次反向传播即可计算输出关于所有输入的梯度,并在计算m m 个输出的梯度时需要m m 次反向传播。由于这些特性,当nm n \gg m 时,反向模式自动微分能发挥最佳作用。事实上,这使得反向模式成为优化神经网络的最佳选择。对于一个生成标量损失函数的函数,它只需要一次反向传播即可计算其关于模型n n 个参数的梯度;回想一下f:RnR f: \R^n \to \R 的情况。
个参数的梯度。

注意:由于导数计算的依赖关系必须存储在计算图中,因此内存复杂度与函数的操作数量成正比。这是一个缺点,但并不会削弱其在优化神经网络方面的实用性。

总而言之,反向模式自动微分显然是基于梯度优化的最佳选择。我们只需一次反向传播即可完成梯度下降的一个步骤,同时会增加内存——考虑到我们更注重时间而非空间,这是一个可以接受的权衡。

实现


了解了前向模式和反向模式自动微分后,我们可以深入探讨两者的代码实现。实现此目的的几种方法包括使用[*专用编译器*](https://en.wikipedia.org/wiki/Source-to-source_compiler)或[*源代码转换*](https://en.wikipedia.org/wiki/Program_transformation)。这两种实现都可行,但比基本演示所需的复杂得多。相反,我们将选择**运算符重载**方法。

运算符重载——在自动微分的语境中——涉及重写自定义类型中运算符的方法,以便在其中集成自动微分的功能。您可以将这种类型视为用户定义的类、结构体或对象(取决于语言),它具有启用自动微分的属性。通过正确实现运算符重载,对启用自动微分的类型执行的任何算术运算都将轻松实现求导。

Python 是一种相对简单的语言,并且支持运算符重载,这就是我们将其用于前向和反向模式自动微分实现的原因。

前向模式自动微分实现

class Variable:

    def __init__(self, primal, tangent):
        self.primal = primal
        self.tangent = tangent

    def __add__(self, other):
        primal = self.primal + other.primal
        tangent = self.tangent + other.tangent
        return Variable(primal, tangent)

    def __sub__(self, other):
        primal = self.primal - other.primal
        tangent = self.tangent - other.tangent
        return Variable(primal, tangent)

    def __mul__(self, other):
        primal = self.primal * other.primal
        tangent = self.tangent * other.primal + other.tangent * self.primal
        return Variable(primal, tangent)

    def __truediv__(self, other):
        primal = self.primal / other.primal
        tangent = (self.tangent / other.primal) + (
            -self.primal / other.primal**2
        ) * other.tangent
        return Variable(primal, tangent)

    def __repr__(self):
        return f"primal: {self.primal}, tangent: {self.tangent}"

从 `Variable` 类型(我们的 AD 类型)开始,我们将接收 `primal` 和 `tangent` 两个参数,并将它们初始化为属性以供后续使用。显然,`primal` 代表算术运算前向传播过程中使用的原值。同样,`tangent` 是算术运算前向传播过程中用于导数计算的切线。为简单起见,这两个属性都将是标量,但可以使用 [NumPy](https://numpy.com.cn/) 扩展功能以在多维数组上操作。

接下来,我们开始重载 Python 内置的算术运算符。特别是,我们只重载* `+`、`-`、`*` 和 `/`——分别对应 `__add__`、`__sub__`、`__mul__` 和 `__truediv__`。简单地说,重载这些运算符定义了当遇到 `a + b` 时(以 `__add__` 为例)的行为——其中 `a` (`self` 参数) 是 `Variable` 类型,而 `b` (`other` 参数) 是某种其他类型。为简单起见,`b` 将始终是 `Variable` 类型。如前所述,我们可以通过重载更多运算符(例如 `__pow__` 用于 `a ** b`)来添加更多功能,但我正尝试保持简单。

* `__repr__` 也被重载,它规定了当在 `Variable` 对象上调用 `repr()`、`print()` 或 `str()` 时的行为。这样做只是为了我们打印 `Variable` 时能够表示它。

对于每个重载的算术运算符,我们实现以下过程。

前向模式自动微分过程:

  1. 使用其操作数(`self` 和 `other`)评估运算符。

  2. 应用微积分的导数规则,计算输出对每个输入的偏导数。

  3. 将导数相加得到`tangent`——即输出对**两个**输入的导数。

  4. 创建一个新的 `Variable` 对象,其中包含前向计算的结果和导出的切线,并将其返回。

让我们以 `__mul__`(两个数字的乘法)为例,通过将其分解为每个组件来帮助我们理解此过程。

乘法过程:

  1. 我们通过计算 `self.primal * other.primal` 来评估运算符及其操作数,然后将结果存储在另一个变量 `primal` 中。

  2. 我们通过计算 `self.tangent * other.primal` 和 `other.tangent * self.primal` 来找到输出相对于每个输入的偏导数。

  3. 接下来,我们将第 2 步的值求和并存储在 `tangent` 中。这是输出相对于两个输入的导数。

  4. 最后,我们返回一个携带算术运算输出和相关 `tangent` 的新变量,即 `return Variable(primal, tangent)`。

如果运算符重载在具有明确导数的基本算术运算上正确实现,则可以将运算组合起来形成可微分的复合函数。下面,我留下了一些基本函数,用于测试 `Variable` 协助计算表达式及其导数的能力。

前向模式自动微分计算

def mul_add(a, b, c):
    return a * b + c * a

def div_sub(a, b, c):
    return a / b - c

a, b, c = Variable(25.0, 1.0), Variable(4.0, 0.0), Variable(-5.0, 0.0)
print(f"{a = }, {b = }, {c = }")
print(f"{mul_add(a, b, c) = }")
a.tangent, b.tangent, c.tangent = 0.0, 1.0, 0.0
print(f"{div_sub(a, b, c) = }")
a = primal: 25.0, tangent: 1.0, b = primal: 4.0, tangent: 0.0, c = primal: -5.0, tangent: 0.0
mul_add(a, b, c) = primal: -25.0, tangent: -1.0
div_sub(a, b, c) = primal: 11.25, tangent: -1.5625
前向模式下的 AD 计算:在第一个函数中,我们计算 y=a2(b+c)y = a^2 \cdot (b + c) 并求导 ya \frac{\partial y}{\partial a} 。在第二个函数中,我们计算 y=abcy = \frac{a}{b} - c 并求导 yb \frac{\partial y}{\partial b}

反向模式自动微分实现

class Variable:

    def __init__(self, primal, adjoint=0.0):
        self.primal = primal
        self.adjoint = adjoint

    def backward(self, adjoint):
        self.adjoint += adjoint

    def __add__(self, other):
        variable = Variable(self.primal + other.primal)

        def backward(adjoint):
            variable.adjoint += adjoint
            self_adjoint = adjoint * 1.0
            other_adjoint = adjoint * 1.0
            self.backward(self_adjoint)
            other.backward(other_adjoint)

        variable.backward = backward
        return variable

    def __sub__(self, other):
        variable = Variable(self.primal - other.primal)

        def backward(adjoint):
            variable.adjoint += adjoint
            self_adjoint = adjoint * 1.0
            other_adjoint = adjoint * -1.0
            self.backward(self_adjoint)
            other.backward(other_adjoint)

        variable.backward = backward
        return variable

    def __mul__(self, other):
        variable = Variable(self.primal * other.primal)

        def backward(adjoint):
            variable.adjoint += adjoint
            self_adjoint = adjoint * other.primal
            other_adjoint = adjoint * self.primal
            self.backward(self_adjoint)
            other.backward(other_adjoint)

        variable.backward = backward
        return variable

    def __truediv__(self, other):
        variable = Variable(self.primal / other.primal)

        def backward(adjoint):
            variable.adjoint += adjoint
            self_adjoint = adjoint * (1.0 / other.primal)
            other_adjoint = adjoint * (-1.0 * self.primal / other.primal**2)
            self.backward(self_adjoint)
            other.backward(other_adjoint)

        variable.backward = backward
        return variable

    def __repr__(self) -> str:
        return f"primal: {self.primal}, adjoint: {self.adjoint}"

反向模式自动微分的实现与前向模式实现非常相似——前提是 `adjoint` 和 `tangent` 具有相同的用途。两者的不同之处在于,我们将 `adjoint` 默认为 `0`。这是因为在反向模式中,我们从输出到输入传播导数然后累积它们;在创建 `Variable` 时,我们不知道要为 `Variable` 累积哪些导数,因此暂时将其设置为零。

我们再来谈谈 `Variable` 类型的默认 `backward` 方法。此方法所做的只是接受 `adjoint` 参数,并将其累积到调用它的 `Variable` 对象的 `adjoint` 属性中。本质上,它的目的是为那些没有自定义 `backward` 方法的**叶子** `Variable` 累积导数。这目前可能没有意义,但随着我们探索反向模式自动微分的实现,它的目的将更加清晰。

* *将叶子 `Variable` 视为我们在自动微分部分讨论的独立输入变量 xi x_i 。*

有了足够的背景知识,让我们看看为 `Variable` 类型启用反向模式自动微分的过程。

反向模式自动微分过程:

  1. 创建一个 `Variable` 对象,其中包含运算符及其操作数的结果。

  2. 定义一个[*闭包函数*](https://en.wikipedia.org/wiki/Closure_(computer_programming))`backward`,其作用如下:

    • 接受一个 `adjoint` 作为参数,并将其累积到步骤 1 中 `Variable` 对象的 `adjoint` 中。
    • 使用运算符的导数和 `adjoint`(以链式传递传入的导数)来计算输出相对于**每个**输入的偏导数。
    • 使用各自的导数(第二个项目符号)对每个输入调用 `backward()`,以继续反向传播。
  3. 返回步骤 1 中得到的 `Variable` 对象,其 `backward` 方法被步骤 2 中定义的闭包函数覆盖。

为了进一步理解,让我们来看一下在 `__truediv__` 中实现的这个过程——即两个数字之间的浮点除法。

除法过程:

  1. 我们使用运算符及其操作数(`variable = Variable(self.primal / other.primal)`)的结果创建一个新的 `Variable`。

  2. 接下来,我们创建闭包函数 `backward(adjoint)`,其中我们

    • 通过执行 `variable.adjoint += adjoint` 将 `adjoint` 参数累积到 `variable` 中。

    • 通过定义 `self_adjoint = adjoint * (1.0 / other.primal)` 和 `other_adjoint = adjoint * (-1.0 * self.primal / other.primal**2)`,使用商规则和 `adjoint`(用于链式导数)计算每个输入的偏导数。

    • 通过调用 `self.backward(self_adjoint)` 和 `other.backward(other_adjoint)` 继续对两个输入进行反向传播。

  3. 最后,我们将闭包函数绑定并返回经过修改的 `Variable` 对象,该对象已为反向模式求导准备就绪,即 `variable.backward = backward` 和 `return variable`。

回顾一下,正是因为这个实现,我们才需要默认的 `backward` 方法。最终,导数将传播到叶子 `Variable`,由于它们不需要自己传播导数,所以当闭包函数调用它们时,我们只需累积从 `backward` 传递过来的导数。

和之前一样,在具有明确导数的基本算术运算上正确实现运算符重载可以实现可微分复合函数的自动微分。下面是与前向模式示例相同的测试代码,但现在使用的是我们的反向模式实现。

反向模式自动微分计算

def mul_add(a, b, c):
    return a * b + c * a

def div_sub(a, b, c):
    return a / b - c

a, b, c = Variable(25.0, 1.0), Variable(4.0, 0.0), Variable(-5.0, 0.0)

print(f"{a = }, {b = }, {c = }")
d = mul_add(a, b, c)
d.backward(1.0)
print(f"{d = }")
print(f"{a.adjoint = }, {b.adjoint = }, {c.adjoint = }")

a.adjoint, b.adjoint, c.adjoint = 0.0, 0.0, 0.0
e = div_sub(a, b, c)
e.backward(1.0)
print(f"{e = }")
print(f"{a.adjoint = }, {b.adjoint = }, {c.adjoint = }")
a = primal: 25.0, adjoint: 0.0, b = primal: 4.0, adjoint: 0.0, c = primal: -5.0, adjoint: 0.0
d = primal: -25.0, adjoint: 1.0
a.adjoint = -1.0, b.adjoint = 25.0, c.adjoint = 25.0
e = primal: 11.25, adjoint: 1.0
a.adjoint = 0.25, b.adjoint = -1.5625, c.adjoint = -1.0
反向模式下的 AD 计算:代码沿用了前向模式实现的相同函数(y=a2(b+c) y = a^2 \cdot (b + c) y=abc y = \frac{a}{b} - c ),但现在我们已经计算了所有输入的偏导数,而不仅仅是前向模式中的一个。此外,请注意,在调用 `div_sub` 之前,我们将 `adjoint` 归零。如果不这样做,我们会将它计算出的偏导数与 `mul_add` 计算出的偏导数累加起来。

自动梯度


稍微提示一下,这个实现借鉴了 PyTorch 的 autograd API。如果你曾使用他们的框架训练模型,你可能遇到过 loss.backward()。这个方法(至少对我来说)看起来像某种魔法,但实际上它会使用与我们上面类似的方法自动微分损失函数对模型参数的导数。唯一的区别是 PyTorch 的实现更高级,并且将其功能扩展到基本的算术运算符之外,使其成为一个可行的机器学习研究框架……不像我们的。

我对 PyTorch 框架感到惊叹,于是决定在 nura 中开发自己的框架。它远未完成,但这是一个有趣的项目,展示了如何仅使用 numpy 构建一个自动微分引擎和机器学习框架。它的主要功能是提供反向和正向自动微分功能,但也包括创建神经网络的能力,类似于 PyTorch 中的 torch.nn 接口。为了让你有一个更直观的了解,下面是一个代码片段,展示了如何使用正向自动微分来评估函数并计算其雅可比矩阵。

import nura
from nura.autograd.functional import jacfwd

def fn(a, b, c):
    return a * b + c

a = nura.tensor([1.0, 2.0, 3.0, 4.0])
b = nura.tensor([5.0, 6.0, 7.0, 8.0])
c = nura.tensor(1.0)
r = nura.ones(4).double()

output, jacobian = jacfwd((a, b, c), fn, pos=1)
print(f"output:\n{output}\n\njacobian:\n{jacobian}")
output:
tensor([ 6. 13. 22. 33.]) dtype=double)

jacobian:
tensor([[1. 0. 0. 0.]
       [0. 2. 0. 0.]
       [0. 0. 3. 0.]
       [0. 0. 0. 4.]]) dtype=double)

结论


在这篇博客中,我们探讨了有效计算梯度以优化神经网络的挑战。我们发现数值微分和符号微分是潜在的解决方案,但它们的问题将我们引向了自动微分。在 AD 中,我们学习了如何利用评估跟踪和计算图来以正向模式计算偏导数。然而,我们注意到在神经网络和梯度下降方面,反向模式的特性更有效地处理了这项任务。最后,我们通过在 Python 中使用我们的 Variable 类型实现和测试两种模式,加强了我们对 AD 的理解。

总而言之,我希望这篇博客不仅强调了 AD 在通过梯度下降优化神经网络方面的实用性,还强调了我们如何利用数学和系统设计思维过程来解决机器学习领域中的挑战性问题。

链接


参考资料

个人

社区

注册登录 以评论