ggml 简介

发布于 2024 年 8 月 13 日
在 GitHub 上更新

ggml 是一个用 C 和 C++ 编写的机器学习 (ML) 库,专注于 Transformer 推理。该项目是开源的,并由一个不断壮大的社区积极开发。ggml 类似于 PyTorch 和 TensorFlow 等 ML 库,但它仍处于开发的早期阶段,其一些基础部分仍在快速变化。

随着时间的推移,ggml 与 llama.cppwhisper.cpp 等项目一起广受欢迎。许多其他项目也在底层使用 ggml 来实现在设备上运行 LLM,包括 ollamajanLM StudioGPT4All

人们选择使用 ggml 而非其他库的主要原因是:

  1. 极简主义:核心库自成一体,文件数量少于 5 个。虽然你可能希望包含额外的文件以支持 GPU,但这是可选的。
  2. 易于编译:你不需要复杂的构建工具。在没有 GPU 支持的情况下,你只需要 GCC 或 Clang!
  3. 轻量级:编译后的二进制文件大小小于 1MB,与通常占用数百 MB 的 PyTorch 相比非常小。
  4. 良好的兼容性:支持多种硬件,包括 x86_64、ARM、Apple Silicon、CUDA 等。
  5. 支持量化张量:张量可以被量化以节省内存(类似于 JPEG 压缩),并在某些情况下提高性能。
  6. 极高的内存效率:存储张量和执行计算的开销极小。

然而,ggml 也有一些缺点,在使用时需要注意(此列表可能会在 ggml 的未来版本中发生变化):

  • 并非所有张量操作都支持所有后端。例如,某些操作可能在 CPU 上有效,但在 CUDA 上无效。
  • 使用 ggml 进行开发可能不那么直接,可能需要深入的底层编程知识。
  • 该项目正处于活跃开发中,因此预计会有重大变更。

在本文中,我们将重点介绍 ggml 的基础知识,以帮助希望开始使用该库的开发者。我们不涉及更高级的任务,例如使用基于 ggml 构建的 llama.cpp 进行 LLM 推理。相反,我们将探讨 ggml 的核心概念和基本用法,为进一步学习和开发打下坚实的基础。

开始使用

太棒了,那么你该如何开始呢?

为简单起见,本指南将向你展示如何在 Ubuntu 上编译 ggml。实际上,你几乎可以在任何平台(包括 Windows、macOS 和 BSD)上编译 ggml。

# Start by installing build dependencies
# "gdb" is optional, but is recommended
sudo apt install build-essential cmake git gdb

# Then, clone the repository
git clone https://github.com/ggerganov/ggml.git
cd ggml

# Try compiling one of the examples
cmake -B build
cmake --build build --config Release --target simple-ctx

# Run the example
./build/bin/simple-ctx

预期输出

mul mat (4 x 3) (transposed result):
[ 60.00 55.00 50.00 110.00
 90.00 54.00 54.00 126.00
 42.00 29.00 28.00 64.00 ]

如果你看到了预期的结果,那就意味着我们可以继续了!

术语与概念

在深入研究 ggml 之前,我们应该了解一些关键概念。如果你来自 PyTorch 或 TensorFlow 等高级库,这些概念可能看起来很难理解。但是,请记住 ggml 是一个底层库。理解这些术语可以让你更好地控制性能。

  • ggml_context:一个“容器”,用于存放张量、图和可选的数据等对象。
  • ggml_cgraph:表示一个计算图。可以把它想象成将要传输到后端的“计算顺序”。
  • ggml_backend:表示一个用于执行计算图的接口。有多种类型的后端:CPU(默认)、CUDA、Metal(Apple Silicon)、Vulkan、RPC 等。
  • ggml_backend_buffer_type:表示一个缓冲区类型。可以把它想象成一个连接到每个 ggml_backend 的“内存分配器”。例如,如果要在 GPU 上执行计算,你需要通过 buffer_type(通常缩写为 buft)在 GPU 上分配内存。
  • ggml_backend_buffer:表示由 buffer_type 分配的缓冲区。记住:一个缓冲区可以容纳多个张量的数据。
  • ggml_gallocr:表示图内存分配器,用于高效地分配计算图中使用的张量。
  • ggml_backend_sched:一个可以并发使用多个后端的调度器。在处理大型模型或多个 GPU 时,它可以将计算分布在不同的硬件(例如 GPU 和 CPU)上。调度器还可以自动将 GPU 不支持的操作分配给 CPU,以确保最佳的资源利用率和兼容性。

简单示例

在此示例中,我们将逐步重现我们在开始使用部分运行的代码。我们需要创建两个矩阵,将它们相乘并得到结果。使用 PyTorch,代码如下:

import torch

# Create two matrices
matrix1 = torch.tensor([
  [2, 8],
  [5, 1],
  [4, 2],
  [8, 6],
])
matrix2 = torch.tensor([
  [10, 5],
  [9, 9],
  [5, 4],
])

# Perform matrix multiplication
result = torch.matmul(matrix1, matrix2.T)
print(result.T)

使用 ggml,必须执行以下步骤才能实现相同的结果:

  1. 分配 ggml_context 以存储张量数据
  2. 创建张量并设置数据
  3. 为 mul_mat 操作创建 ggml_cgraph
  4. 运行计算
  5. 检索结果(输出张量)
  6. 释放内存并退出

注意:在此示例中,为简单起见,我们将张量数据分配在 ggml_context 内部。在实践中,内存应作为设备缓冲区进行分配,我们将在下一节中看到。

要开始,让我们创建一个新目录 examples/demo

cd ggml # make sure you're in the project root

# create C source and CMakeLists file
touch examples/demo/demo.c
touch examples/demo/CMakeLists.txt

此示例的代码基于 simple-ctx.cpp

用以下内容编辑 examples/demo/demo.c

#include "ggml.h"
#include "ggml-cpu.h"
#include <string.h>
#include <stdio.h>

int main(void) {
    // initialize data of matrices to perform matrix multiplication
    const int rows_A = 4, cols_A = 2;
    float matrix_A[rows_A * cols_A] = {
        2, 8,
        5, 1,
        4, 2,
        8, 6
    };
    const int rows_B = 3, cols_B = 2;
    float matrix_B[rows_B * cols_B] = {
        10, 5,
        9, 9,
        5, 4
    };

    // 1. Allocate `ggml_context` to store tensor data
    // Calculate the size needed to allocate
    size_t ctx_size = 0;
    ctx_size += rows_A * cols_A * ggml_type_size(GGML_TYPE_F32); // tensor a
    ctx_size += rows_B * cols_B * ggml_type_size(GGML_TYPE_F32); // tensor b
    ctx_size += rows_A * rows_B * ggml_type_size(GGML_TYPE_F32); // result
    ctx_size += 3 * ggml_tensor_overhead(); // metadata for 3 tensors
    ctx_size += ggml_graph_overhead(); // compute graph
    ctx_size += 1024; // some overhead (exact calculation omitted for simplicity)

    // Allocate `ggml_context` to store tensor data
    struct ggml_init_params params = {
        /*.mem_size   =*/ ctx_size,
        /*.mem_buffer =*/ NULL,
        /*.no_alloc   =*/ false,
    };
    struct ggml_context * ctx = ggml_init(params);

    // 2. Create tensors and set data
    struct ggml_tensor * tensor_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols_A, rows_A);
    struct ggml_tensor * tensor_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols_B, rows_B);
    memcpy(tensor_a->data, matrix_A, ggml_nbytes(tensor_a));
    memcpy(tensor_b->data, matrix_B, ggml_nbytes(tensor_b));


    // 3. Create a `ggml_cgraph` for mul_mat operation
    struct ggml_cgraph * gf = ggml_new_graph(ctx);

    // result = a*b^T
    // Pay attention: ggml_mul_mat(A, B) ==> B will be transposed internally
    // the result is transposed
    struct ggml_tensor * result = ggml_mul_mat(ctx, tensor_a, tensor_b);

    // Mark the "result" tensor to be computed
    ggml_build_forward_expand(gf, result);

    // 4. Run the computation
    int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
    ggml_graph_compute_with_ctx(ctx, gf, n_threads);

    // 5. Retrieve results (output tensors)
    float * result_data = (float *) result->data;
    printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
    for (int j = 0; j < result->ne[1] /* rows */; j++) {
        if (j > 0) {
            printf("\n");
        }

        for (int i = 0; i < result->ne[0] /* cols */; i++) {
            printf(" %.2f", result_data[j * result->ne[0] + i]);
        }
    }
    printf(" ]\n");

    // 6. Free memory and exit
    ggml_free(ctx);
    return 0;
}

在你创建的 examples/demo/CMakeLists.txt 文件中写入以下行

set(TEST_TARGET demo)
add_executable(${TEST_TARGET} demo)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)

编辑 examples/CMakeLists.txt,在末尾添加此行

add_subdirectory(demo)

编译并运行它

cmake -B build
cmake --build build --config Release --target demo

# Run it
./build/bin/demo

预期结果

mul mat (4 x 3) (transposed result):
[ 60.00 55.00 50.00 110.00
 90.00 54.00 54.00 126.00
 42.00 29.00 28.00 64.00 ]

使用后端的示例

ggml 中的“后端”指的是可以处理张量操作的接口。后端可以是 CPU、CUDA、Vulkan 等。

后端抽象了计算图的执行。一旦定义,就可以使用相应的后端实现,利用可用的硬件来计算图。请注意,ggml 将自动为计算所需的任何中间张量保留内存,并根据这些中间结果的生命周期优化内存使用。

使用后端进行计算或推理时,需要执行的常见步骤是

  1. 初始化 ggml_backend
  2. 分配 ggml_context 以存储张量元数据(我们不需要立即分配张量数据)
  3. 创建张量元数据(仅它们的形状和数据类型)
  4. 分配一个 ggml_backend_buffer 来存储所有张量
  5. 将张量数据从主内存(RAM)复制到后端缓冲区
  6. 为 mul_mat 操作创建 ggml_cgraph
  7. 为 cgraph 分配创建一个 ggml_gallocr
  8. 可选:使用 ggml_backend_sched 调度 cgraph
  9. 运行计算
  10. 检索结果(输出张量)
  11. 释放内存并退出

此示例的代码基于 simple-backend.cpp

#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#include <stdlib.h>
#include <string.h>
#include <stdio.h>

int main(void) {
    // initialize data of matrices to perform matrix multiplication
    const int rows_A = 4, cols_A = 2;
    float matrix_A[rows_A * cols_A] = {
        2, 8,
        5, 1,
        4, 2,
        8, 6
    };
    const int rows_B = 3, cols_B = 2;
    float matrix_B[rows_B * cols_B] = {
        10, 5,
        9, 9,
        5, 4
    };

    // 1. Initialize backend
    ggml_backend_t backend = NULL;
#ifdef GGML_USE_CUDA
    fprintf(stderr, "%s: using CUDA backend\n", __func__);
    backend = ggml_backend_cuda_init(0); // init device 0
    if (!backend) {
        fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
    }
#endif
    // if there aren't GPU Backends fallback to CPU backend
    if (!backend) {
        backend = ggml_backend_cpu_init();
    }

    // Calculate the size needed to allocate
    size_t ctx_size = 0;
    ctx_size += 2 * ggml_tensor_overhead(); // tensors
    // no need to allocate anything else!

    // 2. Allocate `ggml_context` to store tensor data
    struct ggml_init_params params = {
        /*.mem_size   =*/ ctx_size,
        /*.mem_buffer =*/ NULL,
        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
    };
    struct ggml_context * ctx = ggml_init(params);

    // Create tensors metadata (only there shapes and data type)
    struct ggml_tensor * tensor_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols_A, rows_A);
    struct ggml_tensor * tensor_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols_B, rows_B);

    // 4. Allocate a `ggml_backend_buffer` to store all tensors
    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);

    // 5. Copy tensor data from main memory (RAM) to backend buffer
    ggml_backend_tensor_set(tensor_a, matrix_A, 0, ggml_nbytes(tensor_a));
    ggml_backend_tensor_set(tensor_b, matrix_B, 0, ggml_nbytes(tensor_b));

    // 6. Create a `ggml_cgraph` for mul_mat operation
    struct ggml_cgraph * gf = NULL;
    struct ggml_context * ctx_cgraph = NULL;
    {
        // create a temporally context to build the graph
        struct ggml_init_params params0 = {
            /*.mem_size   =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
            /*.mem_buffer =*/ NULL,
            /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
        };
        ctx_cgraph = ggml_init(params0);
        gf = ggml_new_graph(ctx_cgraph);

        // result = a*b^T
        // Pay attention: ggml_mul_mat(A, B) ==> B will be transposed internally
        // the result is transposed
        struct ggml_tensor * result0 = ggml_mul_mat(ctx_cgraph, tensor_a, tensor_b);

        // Add "result" tensor and all of its dependencies to the cgraph
        ggml_build_forward_expand(gf, result0);
    }

    // 7. Create a `ggml_gallocr` for cgraph computation
    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
    ggml_gallocr_alloc_graph(allocr, gf);

    // (we skip step 8. Optionally: schedule the cgraph using `ggml_backend_sched`)

    // 9. Run the computation
    int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
    if (ggml_backend_is_cpu(backend)) {
        ggml_backend_cpu_set_n_threads(backend, n_threads);
    }
    ggml_backend_graph_compute(backend, gf);

    // 10. Retrieve results (output tensors)
    // in this example, output tensor is always the last tensor in the graph
    struct ggml_tensor * result = ggml_graph_node(gf, -1);
    float * result_data = malloc(ggml_nbytes(result));
    // because the tensor data is stored in device buffer, we need to copy it back to RAM
    ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
    printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
    for (int j = 0; j < result->ne[1] /* rows */; j++) {
        if (j > 0) {
            printf("\n");
        }

        for (int i = 0; i < result->ne[0] /* cols */; i++) {
            printf(" %.2f", result_data[j * result->ne[0] + i]);
        }
    }
    printf(" ]\n");
    free(result_data);

    // 11. Free memory and exit
    ggml_free(ctx_cgraph);
    ggml_gallocr_free(allocr);
    ggml_free(ctx);
    ggml_backend_buffer_free(buffer);
    ggml_backend_free(backend);
    return 0;
}

编译并运行它,你应该会得到与上一个示例相同的结果。

cmake -B build
cmake --build build --config Release --target demo

# Run it
./build/bin/demo

预期结果

mul mat (4 x 3) (transposed result):
[ 60.00 55.00 50.00 110.00
 90.00 54.00 54.00 126.00
 42.00 29.00 28.00 64.00 ]

打印计算图

ggml_cgraph 表示计算图,它定义了将由后端执行的操作顺序。打印图可能是一个有用的调试工具,尤其是在处理更复杂的模型和计算时。

你可以添加 ggml_graph_print 来打印 cgraph

...

// Mark the "result" tensor to be computed
ggml_build_forward_expand(gf, result0);

// Print the cgraph
ggml_graph_print(gf);

运行它

=== GRAPH ===
n_nodes = 1
 -   0: [     4,     3,     1]          MUL_MAT  
n_leafs = 2
 -   0: [     2,     4]     NONE           leaf_0
 -   1: [     2,     3]     NONE           leaf_1
========================================

此外,你可以将 cgraph 绘制为 graphviz dot 格式

ggml_graph_dump_dot(gf, NULL, "debug.dot");

你可以使用 dot 命令或这个在线网站debug.dot 渲染成最终图像。

ggml-debug

结论

本文对 ggml 进行了介绍性概述,涵盖了关键概念、一个简单的使用示例以及一个使用后端的示例。虽然我们已经涵盖了基础知识,但 ggml 还有更多值得探索的内容。

在接下来的文章中,我们将更深入地探讨其他与 ggml 相关的主题,例如 GGUF 格式、量化以及不同后端的组织和使用方式。此外,你可以访问 ggml 示例目录 查看更高级的用例和示例代码。敬请期待未来更多 ggml 相关内容!

社区

simple-backend.cpp 第 10 步的第一行应该是

struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf) - 1);

而不是

struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];

因为 gf 被用作不透明指针。

·
文章作者

这篇博文是在 ggml_graph_n_nodes 引入之前写的,所以内容不是最新的。如果你愿意,可以随时提交一个 PR 来修正。谢谢。

在 ggml 项目、llama.cpp、像 ollama 这样的外部项目之间...更新流程是怎样的?

对 ggml 的更改应该提交到 ggml-org\ggml,然后拉取到 llama.cpp 和其他项目中吗?为了避免碎片化,ggml 似乎应该单独维护,并作为 llama.cpp 的子模块(但我远非 git 专家)。

我想我需要在第二个 demo.c 中加上 #include “ggml-cpu.h” 这一行。

注册登录以发表评论