在 Rust 中从零开始构建张量(第 1.1 部分):核心结构和索引

社区文章 发布于 2025 年 6 月 12 日

在本系列中,我们将从零开始用 Rust 构建一个张量库,类似于 PyTorch 或 NumPy,重点是理解张量操作背后的基本概念。在第一部分中,我们将创建基本的张量结构并实现索引操作。

张量结构

张量看起来像一个简单的多维数组,但我们在这里做出的设计选择将影响我们实现的每一个操作。问题是:我们如何有效地存储多维数据,同时保持切片、重塑和设备传输等操作的速度?

为了使我们的实现更简洁,我们将把张量结构分成两部分:形状组件和存储组件。

我们这样做有两个原因:

首先,我们希望能够灵活地存储数据,无论是在 RAM、VRAM 还是网络驱动器上的某个地方。

其次,操作分为两类:

  • 视图操作:在不复制数据的情况下重新解释或重塑数据(如切片或重塑)
  • 数据操作:实际修改或转换数据(如加法或乘法)

这意味着视图操作不需要访问实际数据,从而避免了不必要的数据检索。

所以我们想把张量分成两部分

  • TensorShape,存储形状。
  • TensorStorage,存储类型为 T 的原始值

让我们将它们组合到 Tensor 类型中。

#[derive(Debug, Clone, PartialEq)]
struct TensorShape {
    shape: Vec<usize>,
}

impl TensorShape {
    fn size(&self) -> usize {
        self.shape.iter().product()
    }
}

#[derive(Debug, Clone, PartialEq)]
struct TensorStorage<T> {
    data: Vec<T>,
}

#[derive(Debug, Clone, PartialEq)]
struct Tensor<T> {
    shape: TensorShape,
    storage: TensorStorage<T>,
}

您可能想知道我们为什么选择以这种方式将数据存储在 Vec 中,而不是存储在列表的列表或其他数据结构中

  • 在设备之间传输数据并将数据保存到磁盘更简单
  • 我们可以保证原始数据是连续的,以获得更好的性能

行主序(又称“C 风格”)是最流行的排序方式。在行主序中,在最后一个维度中移动一步对应于数据缓冲区中向前移动一步。

在行主序中,在最后一个维度中移动一步对应于数据缓冲区中向前移动一步。

因此,对于形状为 [2,2,2] 的张量:0 ->[0,0,0] 1 ->[0,0,1] 2 ->[0,1,0] 14 ->[1,1,0] 15 ->[1,1,1]

相反,列主序(在 Fortran 和 MATLAB 中使用)通过首先沿第一个维度前进存储元素。

0 ->[0,0,0] 1 ->[1,0,0] 2 ->[0,1,0] 14 ->[0,1,1] 15 ->[1,1,1]

生产环境:Candle

在继续之前,我认为快速了解一下 Candle(一个用 Rust 编写的现有张量框架)如何处理其张量会很有帮助。

首先,让我们看一下 Tensor_ 结构体

pub struct Tensor_ {
  storage: Arc<RwLock<Storage>>,
  layout: Layout,
  // Other fields omitted for brevity
}

您可以看到,它有单独的形状和存储结构,与我们自己的 Tensor 类型的结构类似。

Layout 结构体 包含有关如何索引给定张量的字段

pub struct Layout {
    shape: Shape,
    // Other fields omitted for brevity
}

包括一个 Shape 结构体,它与我们上面的 TensorShape 相同。

pub struct Shape(Vec<usize>);

Storage 枚举 看起来像

pub enum Storage {
    Cpu(CpuStorage),
    Cuda(CudaStorage),
    Metal(MetalStorage),
}

它有许多用于不同设备上数据存储的条目,但如果我们查看 CpuStorage 枚举

pub enum CpuStorage {
    U8(Vec<u8>),
    U32(Vec<u32>),
    I64(Vec<i64>),
    BF16(Vec<bf16>),
    F16(Vec<f16>),
    F32(Vec<f32>),
    F64(Vec<f64>),
}

我们可以看到它在概念上与我们的 TensorStorage 相似。

张量初始化

首先,我们希望能够创建张量。您可以创建的最简单的张量是零张量。

在大多数框架中,都有一个名为 zeros 的函数。让我们在 Tensor 上创建一个关联函数:Tensor::zeros

这只适用于张量的内部类型可以初始化为零的情况。我们可以为所有类型创建一个像 Zeroable 这样的 trait,它会提供如何将变量归零的说明,例如

trait Zeroable {
    fn zero() -> Self
}

impl Zeroable for f32 {
    fn zero() -> f32 {
        0.0
    }
}

或者我们可以引入 num-traits crate,它已经包含了 Zero trait 和许多其他有用的数值 trait。

这大大简化了事情。这使我们不必为每种支持的类型手动实现零初始化。

use num_traits::Zero;

impl<T: Zero + Clone> Tensor<T> {
    fn zeros(shape: Vec<usize>) -> Self {
        let shape = TensorShape { shape };
        let storage = TensorStorage::<T>::zeros(shape.size());
        Tensor { shape, storage }
    }
}

impl<T: Zero + Clone> TensorStorage<T> {
    fn zeros(size: usize) -> Self {
        TensorStorage {
            data: vec![T::zero(); size],
        }
    }
}

索引

我们真正需要做的第一件事是能够索引到我们的向量中并查看或更改单个元素。

为了索引到存储中,我们将实现扁平化和反扁平化。

扁平化

想象一下将一个二维电子表格扁平化成一个单列。您逐行迭代,将每个元素按顺序放入一个扁平缓冲区。这基本上就是任何多维张量存储在计算机线性内存中时发生的情况。它需要被“扁平化”成一个一维表示。

对于具有 shape 的张量的多维 index,这种映射的规范方式是

linear_index = index[-1] + index[-2]*shape[-1] + index[-3]*shape[-1]*shape[-2] + ....

您也可以将其视为索引向量和“步幅”向量的点积,其中

strides = [..., shape[-2]*shape[-1], shape[-1], 1]

这些步幅可以在设置形状时预先计算一次,以节省以后的一些时间。

对于形状为 [2, 3, 4] 的张量,步幅将为 [12, 4, 1]

  • 在维度 0 中移动一步跳跃 12 个位置,每个矩阵
  • 在维度 1 中移动一步跳跃 4 个位置,每个向量
  • 在维度 2 中移动一步跳跃 1 个位置,每个元素

例如,我们可以将灰度图像视为形状为 [h, w] 的二维张量。您可以想象这个线性索引按行向下遍历图像。

其中 i 索引高度 (h),j 索引宽度 (w)

linear_index = j + i*w = i*w + j

因此,每个 i 跳一行,每个 j 在行内选择。

impl TensorShape 中,我们可以添加 ravel_index 函数

fn ravel_index(&self, indices: &[usize]) -> usize {
    if indices.len() != self.shape.len() {
        panic!("Indices length does not match tensor shape dimensions.");
    }
    
    indices.iter().zip(self.shape.iter())
        .rev()
        .scan(1, |stride, (&idx, &dim_size)| {
            let result = idx * *stride;
            *stride *= dim_size;
            Some(result)
        })
        .sum()
}

Scan 是 Rust 中“有状态映射”的等价物。它类似于 map,但有一个累加器(在本例中为 stride)和一个初始状态(在本例中为 1)。

反扁平化

如果我们想从 linear_index 中获取 index,假设我们有原始 shape,这会有点棘手。如果您好奇,值得自己先解决这个问题。

让我们用灰度图像的例子来澄清,其中

linear_index = i*w + j

如果我们想从中取回 j,我们应该记住如何使用模数 (%) 获取除以 w 的余数。因为 j[0, w) 之间

j = linear_index % w

现在我们只需要 i,它是

i = linear_index - j = linear_index - linear_index % w = linear_index // w

其中 // 表示向下取整的除法,即向下舍入到最接近的整数的除法。

对于更高维的索引,模式如下

index[-1] = linear_index % shape[-1]

这确保了在我们的 linear_index 中任何乘以 shape[-1] 的值都变为 0

所以

  • = linear_index % shape[-1]
  • = (index[-1] + index[-2]*shape[-1] + index[-3]*shape[-1]*shape[-2] + ...) % shape[-1]
  • = index[-1] % shape[-1] + index[-2]*shape[-1] % shape[-1] + index[-3]*shape[-1]*shape[-2] % shape[-1] + ...
  • = index[-1] % shape[-1] + 0 + 0 + ...
  • = index[-1]

然后对于下一个,我们使用前面提到的向下取整除法 // 减去之前找到的索引

  • linear_index // shape[-1]
  • = (index[-1] + index[-2]*shape[-1] + index[-3]*shape[-1]*shape[-2] + ...) // shape[-1]
  • = index[-1] // shape[-1] + index[-2]*shape[-1] // shape[-1] + index[-3]*shape[-1]*shape[-2] // shape[-1] + ...
  • = 0 + index[-2] + index[-3]*shape[-2] + ...
  • = index[-2] + index[-3]*shape[-2] + ...

但由于我们有更多的维度,我们需要为下一个维度再次取模

  • index[-2] = (linear_index // shape[-1]) % shape[-2]
  • index[-3] = (linear_index // shape[-1]*shape[-2]) % shape[-3]
  • index[-4] = (linear_index // shape[-1]*shape[-2]*shape[-3]) % shape[-4]

让我们回到灰度图像的例子。

如果 4x5 图像的 linear_index = 15,我们得到 j = 15 % 5 = 0i = 15 // 5 = 3,所以我们位于 [3, 0] 位置,即第四行的第一个像素。

impl TensorShape 中,我们可以添加 unravel_index 函数

fn unravel_index(&self, index: usize) -> Vec<usize> {
    if self.shape.is_empty() {
        return vec![];
    }

    let mut indices = vec![0; self.shape.len()];
    let mut remaining_index = index;

    for (i, &dim_size) in self.shape.iter().enumerate().rev() {
        indices[i] = remaining_index % dim_size;
        remaining_index /= dim_size;
    }

    indices
}

索引最终位

对于我们的 TensorStorage,让我们实现 IndexIndexMut trait,它们只对我们的存储进行线性索引。

use std::operations::{Index, IndexMut};

impl<T> Index<usize> for TensorStorage<T> {
    type Output = T;

    fn index(&self, index: usize) -> &Self::Output {
        &self.data[index]
    }
}

impl<T> IndexMut<usize> for TensorStorage<T> {
    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
        &mut self.data[index]
    }
}

对于我们的 Tensor,让我们实现 IndexIndexMut trait,它们将多维索引扁平化为线性索引并获取值。

impl<T> Index<&[usize]> for Tensor<T> {
    type Output = T;

    fn index(&self, indices: &[usize]) -> &Self::Output {
        &self.storage[self.shape.ravel_index(indices)]
    }
}

impl<T> IndexMut<&[usize]> for Tensor<T> {
    fn index_mut(&mut self, indices: &[usize]) -> &mut Self::Output {
        &mut self.storage[self.shape.ravel_index(indices)]
    }
}

生产环境:Candle

现在我们了解了扁平化和反扁平化的基本原理,让我们看看 Candle 在实践中如何优化索引操作。我们的 ravel_index 函数每次都从头开始计算线性索引。

Candle 采用不同的方法,预先计算步幅——每个维度的乘数。Layout 结构体 包含形状和步幅信息

pub struct Layout {
    shape: Shape,
    stride: Vec<usize>,  // Pre-computed strides
    start_offset: usize, // Pre-computed start_offset
}

让我们看看 Candle 如何实现 AvgPool2D

以下是总结

let (b_sz, c, h, w) = layout.shape().dims4()?;
let mut src_index = layout.start_offset();
for b_idx in 0..b_sz {
    src_index += b_idx * stride[0];   // Add batch offset
    for c_idx in 0..c {
        src_index += c_idx * stride[1];   // Add channel offset
        for m in 0..kernel_h {
            for n in 0..kernel_w {
                let final_index = src_index + m * stride[2] + n * stride[3];
            }
        }
    }
}

这正在执行扁平化操作。

与我们目前的方法相比

let (b_sz, c, h, w) = layout.shape().dims4()?;
for b_idx in 0..b_sz {
    for c_idx in 0..c {
        for m in 0..kernel_h {
            for n in 0..kernel_w {
                let final_index = layout.start_offset() + layout.ravel_index(&[b_idx, c_idx, m, n]); 
            }
        }
    }
}

Candle 的步幅方法更有效率,因为它具有可预测的内存访问模式,编译器可以针对 CPU 缓存进行优化。

代码

要运行博客文章中的代码

首先,克隆 仓库

git clone git@github.com:greenrazer/easytensor.git

然后切换到此部分的标记提交:

git checkout part-1

然后运行 cargo test

cargo test

查看 src/ 中的代码

后续步骤

我们已经通过核心索引操作构建了张量库的基础,这些操作反映了像 Candle 这样的生产框架的内部工作方式。在下一部分中,我们将探讨如何使视图操作高效。

尝试使用代码,创建一些张量,索引它们,并查看扁平化/反扁平化如何与不同的形状协同工作。理解这些基础将使我们将来要介绍的更高级操作变得更加清晰。

社区

注册登录 发表评论