构建自动求导引擎tinytorch 03

社区文章 发布于 2024年1月30日

image/png


清理与重构

(不要跳过,我们这里也添加了新东西)

让我们变得专业。现在是时候添加类型提示并进行一些清理了。让我们从张量类开始。

from __future__ import annotations
import numpy as np


class Tensor:
    def __init__(self, data):
        self.data:np.ndarray = Tensor._data_to_numpy(data)
        self.grad:Tensor = None
        self._ctx:Function = None
    
    @staticmethod
    def _data_to_numpy(data):
        if isinstance(data,(int,float)):
            return np.array([data])
        if isinstance(data,(list,tuple)):
            return np.array(data)
        if isinstance(data,np.ndarray):
            return data
        if isinstance(data,Tensor):
            return data.data.copy()  
        raise ValueError("Invalid value passed to tensor")
    
    @staticmethod
    def _ensure_tensor(data):
        if isinstance(data,Tensor):
            return data
        return Tensor(data)
        

导入 annotations,这样类型定义就不会与旧的Python版本产生问题。并添加一个新的 _data_to_numpy 方法,这样人们就可以传递任何输入,如 Tensor(2)、Tensor([1,2]) 或 Tensor(np.arange(10)),因为我们将在底层获得一致的数据。同时还添加了新的 ensure_tensor,这是一个简单的方法,可以确保数据是张量。

    def __add__(self, other):
        fn = Function(Add, self, other)
        result = Add.forward(self, other)
        result._ctx = fn
        return result

这看起来很丑,让我们通过修改函数类来使其更整洁。

class Function:
    def __init__(self, op, *args):
        self.op:Function = op
        self.args:list[Tensor] = args
        
    @classmethod
    def apply(cls,*args):
        ctx = Function(cls,*args)
        result = cls.forward(*args)
        result._ctx = ctx
        return result
    
    @staticmethod
    def forward(self,*args):
        raise NotImplementedError
    
    @staticmethod
    def backward(self,*args):
        raise NotImplementedError

Apply 是一个类方法,它本质上做了我们在 __add____mul__ 中所做的事情。其思想是所有函数都将继承 Function 类,并使用静态方法。然后它们将只执行 Add.apply()。此外,forward 和 backward 方法表示你必须实现它们。加上一些类型定义。

让我们修改 add,mul 以获得函数的美好特性。


class Add(Function):
    ...
class Mul(Function):
    ...

是时候改变魔术方法了。

    def __add__(self, other)->Tensor:
        return Add.apply(self,Tensor.ensure_tensor(other))

    def __mul__(self, other)->Tensor:
        return Mul.apply(self,Tensor.ensure_tensor(other))

    def __radd__(self, other)->Tensor:
        return Add.apply(self,Tensor.ensure_tensor(other))

    def __rmul__(self, other)->Tensor:
        return Mul.apply(self,Tensor.ensure_tensor(other))

看看,整洁了。

新的 radd 方法在另一个值不是张量时进行反向加法,Tensor(1)+1 将使用 add,而 1+Tensor(1) 将使用 radd。

现在我们还为 += 做了处理,为此我们需要 iadd,让我们添加这些。


    def __iadd__(self, other)->Tensor:
        return Add.apply(self,Tensor.ensure_tensor(other))

    def __imul__(self, other)->Tensor:
        return Mul.apply(self,Tensor.ensure_tensor(other))

既然我们在这里,也让我们添加 detach(),它将从图中移除节点。我们将如何做到这一点?通过将 ctx 设置为 None。


class Tensor:
    ...

    def detach(self)->Tensor:
        self._ctx = None
        return self

既然我们添加了 detach,就让我们清除在反向传播过程中可能生成的任何不必要的图。


    def backward(self,grad=None):
        if self._ctx is None:
            return 
        
        if grad is None:
            grad = Tensor([1.])
            self.grad = grad
        
        op = self._ctx.op
        child_nodes = self._ctx.args
        
        grads = op.backward(self._ctx,grad)
        if len(self._ctx.args) == 1:
            grads = [grads]
        
        for tensor,grad in zip(child_nodes,grads):
            if grad is None:
                continue
            if tensor.grad is None:
                tensor.grad = Tensor(np.zeros_like(tensor.data)) 
            tensor.grad += grad.detach()
            tensor.backward(grad)

好的,我们在这里做了几件事。最重要的是我们分离了梯度。我们跳过了没有梯度的节点。将来我们将保存单个节点转换,因此如果函数的参数是 1。我们将梯度包装在列表中,以便循环正确工作。

作为奖励,我们将添加克隆功能。

    def clone(self)->Tensor:
        return Tensor(self.data.clone())

让我们也添加一些小的实用函数

张量类上的形状属性

class Tensor:
    ...
    @property
    def shape(self) -> tuple:
        return self.data.shape

以及创建ones和zeros的两个方法

def ones(shape) -> Tensor:
    return Tensor(np.ones(shape))

def zeros(shape) -> Tensor:
    return Tensor(np.zeros(shape))

提交 96bd8a262398ad56699175cf64592b88ebd4be11

最后文件应该看起来像这样

from __future__ import annotations
import numpy as np


class Tensor:
    def __init__(self, data):
        self.data: np.ndarray = Tensor._data_to_numpy(data)
        self.grad: Tensor = None
        self._ctx: Function = None

    @staticmethod
    def _data_to_numpy(data):
        if isinstance(data, (int, float)):
            return np.array([data])
        if isinstance(data, (list, tuple)):
            return np.array(data)
        if isinstance(data, np.ndarray):
            return data
        if isinstance(data, Tensor):
            return data.data.copy()
        raise ValueError("Invalid value passed to tensor")

    @staticmethod
    def ensure_tensor(data):
        if isinstance(data, Tensor):
            return data
        return Tensor(data)

    def __add__(self, other) -> Tensor:
        return Add.apply(self, Tensor.ensure_tensor(other))

    def __mul__(self, other) -> Tensor:
        return Mul.apply(self, Tensor.ensure_tensor(other))

    def __radd__(self, other) -> Tensor:
        return Add.apply(self, Tensor.ensure_tensor(other))

    def __rmul__(self, other) -> Tensor:
        return Mul.apply(self, Tensor.ensure_tensor(other))

    def __iadd__(self, other) -> Tensor:
        return Add.apply(self, Tensor.ensure_tensor(other))

    def __imul__(self, other) -> Tensor:
        return Mul.apply(self, Tensor.ensure_tensor(other))

    def __repr__(self):
        return f"tensor({self.data})"

    @property
    def shape(self) -> tuple:
        return self.data.shape

    def detach(self) -> Tensor:
        self._ctx = None
        return self

    def clone(self) -> Tensor:
        return Tensor(self.data.clone())

    def backward(self, grad=None):
        if self._ctx is None:
            return

        if grad is None:
            grad = Tensor([1.0])
            self.grad = grad

        op = self._ctx.op
        child_nodes = self._ctx.args

        grads = op.backward(self._ctx, grad)
        if len(self._ctx.args) == 1:
            grads = [grads]

        for tensor, grad in zip(child_nodes, grads):
            if grad is None:
                continue
            if tensor.grad is None:
                tensor.grad = Tensor(np.zeros_like(tensor.data))
            tensor.grad += grad.detach()
            tensor.backward(grad)


class Function:
    def __init__(self, op, *args):
        self.op: Function = op
        self.args: list[Tensor] = args

    @classmethod
    def apply(cls, *args):
        ctx = Function(cls, *args)
        result = cls.forward(*args)
        result._ctx = ctx
        return result

    @staticmethod
    def forward(self, *args):
        raise NotImplementedError

    @staticmethod
    def backward(self, *args):
        raise NotImplementedError


class Add(Function):
    @staticmethod
    def forward(x, y):
        return Tensor(x.data + y.data)

    @staticmethod
    def backward(ctx, grad):
        x, y = ctx.args
        return Tensor([1]) * grad, Tensor([1]) * grad


class Mul(Function):
    @staticmethod
    def forward(x, y):
        return Tensor(x.data * y.data)  # z = x*y

    @staticmethod
    def backward(ctx, grad):
        x, y = ctx.args
        return Tensor(y.data) * grad, Tensor(x.data) * grad  #  dz/dx, dz/dy


def ones(shape) -> Tensor:
    return Tensor(np.ones(shape))


def zeros(shape) -> Tensor:
    return Tensor(np.zeros(shape))


if __name__ == "__main__":

    def f(x):
        return x * x * x + x

    x = Tensor([1.2])

    z = f(x)
    z.backward()
    print(f"X: {x} grad: {x.grad}")

社区

注册登录 发表评论