构建自动求导引擎tinytorch 03

清理与重构
(不要跳过,我们这里也添加了新东西)
让我们变得专业。现在是时候添加类型提示并进行一些清理了。让我们从张量类开始。
from __future__ import annotations
import numpy as np
class Tensor:
def __init__(self, data):
self.data:np.ndarray = Tensor._data_to_numpy(data)
self.grad:Tensor = None
self._ctx:Function = None
@staticmethod
def _data_to_numpy(data):
if isinstance(data,(int,float)):
return np.array([data])
if isinstance(data,(list,tuple)):
return np.array(data)
if isinstance(data,np.ndarray):
return data
if isinstance(data,Tensor):
return data.data.copy()
raise ValueError("Invalid value passed to tensor")
@staticmethod
def _ensure_tensor(data):
if isinstance(data,Tensor):
return data
return Tensor(data)
导入 annotations
,这样类型定义就不会与旧的Python版本产生问题。并添加一个新的 _data_to_numpy 方法,这样人们就可以传递任何输入,如 Tensor(2)、Tensor([1,2]) 或 Tensor(np.arange(10)),因为我们将在底层获得一致的数据。同时还添加了新的 ensure_tensor,这是一个简单的方法,可以确保数据是张量。
def __add__(self, other):
fn = Function(Add, self, other)
result = Add.forward(self, other)
result._ctx = fn
return result
这看起来很丑,让我们通过修改函数类来使其更整洁。
class Function:
def __init__(self, op, *args):
self.op:Function = op
self.args:list[Tensor] = args
@classmethod
def apply(cls,*args):
ctx = Function(cls,*args)
result = cls.forward(*args)
result._ctx = ctx
return result
@staticmethod
def forward(self,*args):
raise NotImplementedError
@staticmethod
def backward(self,*args):
raise NotImplementedError
Apply 是一个类方法,它本质上做了我们在 __add__
和 __mul__
中所做的事情。其思想是所有函数都将继承 Function 类,并使用静态方法。然后它们将只执行 Add.apply()。此外,forward 和 backward 方法表示你必须实现它们。加上一些类型定义。
让我们修改 add,mul 以获得函数的美好特性。
class Add(Function):
...
class Mul(Function):
...
是时候改变魔术方法了。
def __add__(self, other)->Tensor:
return Add.apply(self,Tensor.ensure_tensor(other))
def __mul__(self, other)->Tensor:
return Mul.apply(self,Tensor.ensure_tensor(other))
def __radd__(self, other)->Tensor:
return Add.apply(self,Tensor.ensure_tensor(other))
def __rmul__(self, other)->Tensor:
return Mul.apply(self,Tensor.ensure_tensor(other))
看看,整洁了。
新的 radd 方法在另一个值不是张量时进行反向加法,Tensor(1)+1 将使用 add,而 1+Tensor(1) 将使用 radd。
现在我们还为 += 做了处理,为此我们需要 iadd,让我们添加这些。
def __iadd__(self, other)->Tensor:
return Add.apply(self,Tensor.ensure_tensor(other))
def __imul__(self, other)->Tensor:
return Mul.apply(self,Tensor.ensure_tensor(other))
既然我们在这里,也让我们添加 detach()
,它将从图中移除节点。我们将如何做到这一点?通过将 ctx 设置为 None。
class Tensor:
...
def detach(self)->Tensor:
self._ctx = None
return self
既然我们添加了 detach,就让我们清除在反向传播过程中可能生成的任何不必要的图。
def backward(self,grad=None):
if self._ctx is None:
return
if grad is None:
grad = Tensor([1.])
self.grad = grad
op = self._ctx.op
child_nodes = self._ctx.args
grads = op.backward(self._ctx,grad)
if len(self._ctx.args) == 1:
grads = [grads]
for tensor,grad in zip(child_nodes,grads):
if grad is None:
continue
if tensor.grad is None:
tensor.grad = Tensor(np.zeros_like(tensor.data))
tensor.grad += grad.detach()
tensor.backward(grad)
好的,我们在这里做了几件事。最重要的是我们分离了梯度。我们跳过了没有梯度的节点。将来我们将保存单个节点转换,因此如果函数的参数是 1。我们将梯度包装在列表中,以便循环正确工作。
作为奖励,我们将添加克隆功能。
def clone(self)->Tensor:
return Tensor(self.data.clone())
让我们也添加一些小的实用函数
张量类上的形状属性
class Tensor:
...
@property
def shape(self) -> tuple:
return self.data.shape
以及创建ones和zeros的两个方法
def ones(shape) -> Tensor:
return Tensor(np.ones(shape))
def zeros(shape) -> Tensor:
return Tensor(np.zeros(shape))
提交 96bd8a262398ad56699175cf64592b88ebd4be11
最后文件应该看起来像这样
from __future__ import annotations
import numpy as np
class Tensor:
def __init__(self, data):
self.data: np.ndarray = Tensor._data_to_numpy(data)
self.grad: Tensor = None
self._ctx: Function = None
@staticmethod
def _data_to_numpy(data):
if isinstance(data, (int, float)):
return np.array([data])
if isinstance(data, (list, tuple)):
return np.array(data)
if isinstance(data, np.ndarray):
return data
if isinstance(data, Tensor):
return data.data.copy()
raise ValueError("Invalid value passed to tensor")
@staticmethod
def ensure_tensor(data):
if isinstance(data, Tensor):
return data
return Tensor(data)
def __add__(self, other) -> Tensor:
return Add.apply(self, Tensor.ensure_tensor(other))
def __mul__(self, other) -> Tensor:
return Mul.apply(self, Tensor.ensure_tensor(other))
def __radd__(self, other) -> Tensor:
return Add.apply(self, Tensor.ensure_tensor(other))
def __rmul__(self, other) -> Tensor:
return Mul.apply(self, Tensor.ensure_tensor(other))
def __iadd__(self, other) -> Tensor:
return Add.apply(self, Tensor.ensure_tensor(other))
def __imul__(self, other) -> Tensor:
return Mul.apply(self, Tensor.ensure_tensor(other))
def __repr__(self):
return f"tensor({self.data})"
@property
def shape(self) -> tuple:
return self.data.shape
def detach(self) -> Tensor:
self._ctx = None
return self
def clone(self) -> Tensor:
return Tensor(self.data.clone())
def backward(self, grad=None):
if self._ctx is None:
return
if grad is None:
grad = Tensor([1.0])
self.grad = grad
op = self._ctx.op
child_nodes = self._ctx.args
grads = op.backward(self._ctx, grad)
if len(self._ctx.args) == 1:
grads = [grads]
for tensor, grad in zip(child_nodes, grads):
if grad is None:
continue
if tensor.grad is None:
tensor.grad = Tensor(np.zeros_like(tensor.data))
tensor.grad += grad.detach()
tensor.backward(grad)
class Function:
def __init__(self, op, *args):
self.op: Function = op
self.args: list[Tensor] = args
@classmethod
def apply(cls, *args):
ctx = Function(cls, *args)
result = cls.forward(*args)
result._ctx = ctx
return result
@staticmethod
def forward(self, *args):
raise NotImplementedError
@staticmethod
def backward(self, *args):
raise NotImplementedError
class Add(Function):
@staticmethod
def forward(x, y):
return Tensor(x.data + y.data)
@staticmethod
def backward(ctx, grad):
x, y = ctx.args
return Tensor([1]) * grad, Tensor([1]) * grad
class Mul(Function):
@staticmethod
def forward(x, y):
return Tensor(x.data * y.data) # z = x*y
@staticmethod
def backward(ctx, grad):
x, y = ctx.args
return Tensor(y.data) * grad, Tensor(x.data) * grad # dz/dx, dz/dy
def ones(shape) -> Tensor:
return Tensor(np.ones(shape))
def zeros(shape) -> Tensor:
return Tensor(np.zeros(shape))
if __name__ == "__main__":
def f(x):
return x * x * x + x
x = Tensor([1.2])
z = f(x)
z.backward()
print(f"X: {x} grad: {x.grad}")