JAX 中注意力机制的简单实现

社区文章 发布于 2025 年 3 月 3 日

我最近开始学习 JAX,尽管我已经作为机器学习工程师工作了一段时间,但我的主要经验是使用 PyTorch 和 TensorFlow。在我的空闲时间,我决定深入研究 JAX,而理解它的最佳方式莫过于实现我已熟知的概念。在本教程中,我将引导您通过注意力机制的简单实现——一个我多年来学习和使用的概念。在这里,您将找到使用 JAX 和 Flax 进行单头和多头注意力的详细解释,以及使用 JIT 编译进行性能基准测试。可能存在错误,并且可能进行改进,但这可以被视为我与大家一起学习的过程。


1. 单头注意力

在本节中,我们创建一个单头注意力模块。该模块首先使用全连接(线性)层将输入编码转换为三种表示:查询(queries)、键(keys)和值(values)。然后,它计算查询和键之间的点积(在适当对齐轴后)以衡量相似性。这些相似性会进行缩放以避免训练期间出现问题,并且可以应用可选的掩码以忽略某些位置。最后,Softmax 将这些分数转换为概率,这些概率用于组合值向量。

import jax
import jax.numpy as jnp
from flax import linen as nn

class Attention(nn.Module):
    d_model: int = 2
    row_dim: int = 0
    col_dim: int = 1

    @nn.compact
    def __call__(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        # Create dense layers (without bias) to generate queries, keys, and values.
        W_q = nn.Dense(features=self.d_model, use_bias=False, name="W_q")
        W_k = nn.Dense(features=self.d_model, use_bias=False, name="W_k")
        W_v = nn.Dense(features=self.d_model, use_bias=False, name="W_v")
        
        # Project the input encodings into queries, keys, and values.
        q = W_q(encodings_for_q)
        k = W_k(encodings_for_k)
        v = W_v(encodings_for_v)
        
        # Swap axes of the key tensor to align dimensions for matrix multiplication.
        k_t = jnp.swapaxes(k, self.row_dim, self.col_dim)
        
        # Compute dot products between queries and the transposed keys.
        sims = jnp.matmul(q, k_t)
        
        # Scale the similarity scores by the square root of the key dimension.
        scale = jnp.sqrt(k.shape[self.col_dim])
        scaled_sims = sims / scale
        
        # If a mask is provided, apply it to ignore specified positions.
        if mask is not None:
            scaled_sims = jnp.where(mask, -1e9, scaled_sims)
        
        # Apply softmax to convert similarity scores into attention probabilities.
        attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)
        
        # Use the attention weights to compute a weighted sum of the value vectors.
        attention_scores = jnp.matmul(attention_percents, v)
        return attention_scores
  • 用于投影的全连接层
    使用三个全连接层从输入数据中计算查询(q)、键(k)和值(v)。这些投影至关重要,因为它们将输入转换为不同的子空间。

  • 键的轴交换
    通过交换键张量的轴,我们确保随后的矩阵乘法对齐正确的维度。这对于计算每个查询与每个键之间的点积至关重要。

  • 扩展
    将点积除以键维度的平方根有助于在训练期间稳定梯度,尤其是在维度较大时。

  • 掩码
    如果需要忽略某些位置(例如,填充标记),可选的掩码会将这些位置替换为非常低的值,以确保它们的影响可以忽略不计。

  • Softmax 和加权和
    Softmax 函数将缩放后的点积转换为概率分布。然后,这些概率用于加权值向量,从而产生一个突出输入最相关部分的聚焦输出。


2. 多头注意力

多头注意力通过并行运行多个注意力头来增强模型捕获不同模式的能力。每个注意力头独立运行并处理输入数据,然后将其输出沿特征维度进行拼接。这种技术允许模型共同关注来自不同表示子空间的信息。

class MultiHeadAttention(nn.Module):
    d_model: int = 2
    row_dim: int = 0
    col_dim: int = 1
    num_heads: int = 1

    def setup(self):
        # Initialize a list of attention heads.
        self.heads = [Attention(d_model=self.d_model, 
                                row_dim=self.row_dim, 
                                col_dim=self.col_dim)
                      for _ in range(self.num_heads)]

    def __call__(self, encodings_for_q, encodings_for_k, encodings_for_v):
        # Run each attention head independently and collect their outputs.
        head_outputs = [head(encodings_for_q, encodings_for_k, encodings_for_v)
                        for head in self.heads]
        # Concatenate the outputs along the specified dimension.
        return jnp.concatenate(head_outputs, axis=self.col_dim)
  • 使用 setup 初始化
    setup 方法用于创建单头注意力模块的多个实例。这确保每个注意力头都有自己的一组参数。

  • 独立处理
    每个注意力头独立处理输入编码,这意味着输入数据的不同方面可以由不同的注意力头捕获。

  • 拼接
    一旦所有注意力头都生成了它们的输出,这些输出将沿着选定的维度进行拼接。这种组合的输出比任何单个注意力头单独提供的信息更丰富。


3. 测试注意力模块

本节演示如何测试单头和多头注意力模块。我们定义作为 JAX 数组的样本令牌编码,使用随机键(对于 Flax 参数初始化至关重要)初始化模块,并将模块应用于计算注意力输出。

# Sample token encodings (3 tokens, each with 2 features)
encodings_for_q = jnp.array([[1.16, 0.23],
                             [0.57, 1.36],
                             [4.41, -2.16]])
encodings_for_k = jnp.array([[1.16, 0.23],
                             [0.57, 1.36],
                             [4.41, -2.16]])
encodings_for_v = jnp.array([[1.16, 0.23],
                             [0.57, 1.36],
                             [4.41, -2.16]])

# Create a random key for parameter initialization.
key = jax.random.PRNGKey(42)

# --- Single-Head Attention Test ---
attention_module = Attention(d_model=2, row_dim=0, col_dim=1)
params = attention_module.init(key, encodings_for_q, encodings_for_k, encodings_for_v)
single_head_output = attention_module.apply(params, encodings_for_q, encodings_for_k, encodings_for_v)
print("Single-head attention output:")
print(single_head_output)

# --- Multi-Head Attention Test (1 head) ---
multi_head_module_1 = MultiHeadAttention(d_model=2, row_dim=0, col_dim=1, num_heads=1)
params_multi1 = multi_head_module_1.init(key, encodings_for_q, encodings_for_k, encodings_for_v)
multi_head_output_1 = multi_head_module_1.apply(params_multi1, encodings_for_q, encodings_for_k, encodings_for_v)
print("Multi-head attention (1 head) output:")
print(multi_head_output_1)

# --- Multi-Head Attention Test (2 heads) ---
multi_head_module_2 = MultiHeadAttention(d_model=2, row_dim=0, col_dim=1, num_heads=2)
params_multi2 = multi_head_module_2.init(key, encodings_for_q, encodings_for_k, encodings_for_v)
multi_head_output_2 = multi_head_module_2.apply(params_multi2, encodings_for_q, encodings_for_k, encodings_for_v)
print("Multi-head attention (2 heads) output:")
print(multi_head_output_2)
Single-head attention output:
[[1.668201   2.6169908 ]
 [2.433429   3.3817132 ]
 [0.51508707 1.4933776 ]]
Multi-head attention (1 head) output:
[[-0.7741511  -0.24243875]
 [-1.3947037   0.28557885]
 [-0.08808593 -0.9197984 ]]
Multi-head attention (2 heads) output:
[[-0.7741511  -0.24243875  2.0704143  -2.0301726 ]
 [-1.3947037   0.28557885  0.04033631 -0.86105233]
 [-0.08808593 -0.9197984   3.9204044  -3.142049  ]]
  • 样本数据
    这些数组表示三个令牌,每个令牌包含两个特征。这是一个简化的设置,用于验证我们的注意力模块是否按预期工作。

  • 用于初始化的随机键
    在 Flax 中,随机键是初始化模型参数所必需的。使用固定键可确保结果可重现。

  • 模块应用
    模块首先用给定的输入进行初始化,然后应用以生成注意力输出。这有助于确认实现是正确的并且正在运行。


4. 使用 JIT 进行基准测试和加速

JAX 的主要优势之一是它能够使用即时 (JIT) 编译。JIT 编译将您的 Python 函数转换为高度优化的机器代码。本节对多头注意力模块进行基准测试,比较有 JIT 编译和没有 JIT 编译的执行时间。

import time

# Define a function to repeatedly run multi-head attention for benchmarking purposes.
def run_multi_head(params, module, iterations=1000):
    for _ in range(iterations):
        _ = module.apply(params, encodings_for_q, encodings_for_k, encodings_for_v)

# Create a JIT-compiled version of the multi-head attention call for 2 heads.
jit_multi_head = jax.jit(lambda params, q, k, v: multi_head_module_2.apply(params, q, k, v))

# Warm-up call to trigger the JIT compilation process.
_ = jit_multi_head(params_multi2, encodings_for_q, encodings_for_k, encodings_for_v)

# Benchmark the non-JIT version.
start = time.perf_counter()
run_multi_head(params_multi2, multi_head_module_2, iterations=1000)
end = time.perf_counter()
print("Non-JIT execution time: {:.6f} seconds".format(end - start))

# Benchmark the JIT version (after warm-up).
start = time.perf_counter()
for _ in range(1000):
    _ = jit_multi_head(params_multi2, encodings_for_q, encodings_for_k, encodings_for_v)
end = time.perf_counter()
print("JIT execution time: {:.6f} seconds".format(end - start))
Non-JIT execution time: 25.080998 seconds
JIT execution time: 0.020293 seconds
  • 预热阶段
    对 JIT 编译函数的第一次调用包含编译开销。预热调用可确保计时测量仅反映执行时间,而不反映编译时间。

  • 重复执行
    多头注意力模块重复运行(1,000 次迭代),以获得可靠的执行时间测量结果。

  • 计时测量
    使用 Python 的 time.perf_counter(),我们测量并比较了非 JIT 和 JIT 编译执行所需的时间。通常,JIT 版本会显示出显著的速度提升。


结论

在本教程中,我们探讨了使用 JAX 和 Flax 实现注意力机制。从单头注意力的基础到更高级的多头注意力,我们详细讨论了每个组件,并演示了如何使用 JIT 编译优化性能。作为一名主要使用 PyTorch 和 TensorFlow 的人,这次练习对我来说是一个宝贵的实践学习 JAX 的机会。

我希望本教程有所帮助且清晰明了,即使可能存在需要改进的地方。请记住,学习是一个持续的过程,我很期待继续完善这些概念和实现。感谢您与我一起踏上这段学习之旅——让我们继续一起探索和构建!

社区

注册登录 发表评论