使用 JAX 和 Flax NNX 进行分布式训练:分片实用指南

社区文章 发布于 2025 年 3 月 26 日

训练大型机器学习模型通常会超出单个 GPU 或 TPU 的限制。为了高效扩展,需要将计算和内存分布到多个设备上。JAX 凭借其强大的 jit 编译器和显式分片功能,为此提供了出色的工具包。

最近,Flax 团队推出了 NNX,这是一个旨在更显式地管理神经网络状态的新 API。本文通过一个实际示例(基于一个官方 Flax NNX 示例)演示了如何将 JAX 的分片功能与 Flax NNX 结合,以实现像全分片数据并行 (FSDP) 这样的分布式训练策略。

我们的目标是逐步分解代码,使 JAX 分片的概念及其与新 Flax NNX API 的集成更易于理解。

注意:此处呈现的代码源自官方 Flax 示例仓库。您可以在以下链接找到原始来源:https://github.com/google/flax/blob/f7d3873b203ac0f3c6859738b1d48c2385359ca0/examples/nnx_toy_examples/10_fsdp_and_optimizer.py。本博客旨在提供详细解释以帮助理解。

让我们深入了解!

设置:导入和模拟设备

首先,我们需要必要的导入。我们将使用 jaxflax.nnxnumpy 和 JAX 的分片实用程序。开发的关键技巧是使用 XLA_FLAGS 环境变量模拟多个设备。这使我们即使在单个 CPU 或 GPU 机器上也能测试分片逻辑,然后再部署到更大的集群。

import dataclasses
import os
# Forces JAX to behave as if 8 devices (e.g., CPU cores) are available,
# even if running on a machine with fewer physical accelerators.
# Useful for testing sharding logic without multi-GPU/TPU hardware.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

from matplotlib import pyplot as plt # For plotting results
# Utilities for creating device meshes easily
from jax.experimental import mesh_utils
# Core JAX sharding components: Mesh defines the device grid,
# PartitionSpec defines how tensor axes map to mesh axes,
# NamedSharding links PartitionSpec to a Mesh with named axes.
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax
import jax.numpy as jnp # JAX's accelerated NumPy
import numpy as np # Standard NumPy for data generation
# Import the Flax NNX API components
from flax import nnx
import typing as tp # For type hints
  • 导入和环境设置:各种功能的标准导入。
  • os.environ['XLA_FLAGS'] = ...: 这告诉 JAX 的底层 XLA 编译器模拟 8 个设备,从而使多设备开发变得容易。

定义并行策略:设备网格

JAX 中显式分片的核心概念是 Mesh。它表示设备的逻辑网格(真实或模拟)。我们将名称与此网格的轴关联起来,以定义我们的并行策略。在这里,我们创建一个 2x4 的网格(8 个设备),其轴命名为 'data''model'

# Create a 2D mesh (grid) of devices with shape (2, 4), meaning 8 devices total.
# Assign logical names 'data' and 'model' to the axes of this grid.
# The first dimension (size 2) is named 'data'.
# The second dimension (size 4) is named 'model'.
mesh = jax.sharding.Mesh(
  mesh_utils.create_device_mesh((2, 4)),
  ('data', 'model'),
)
  • 网格定义: mesh_utils.create_device_mesh((2, 4)) 排列设备。jax.sharding.Mesh(...) 使用命名轴创建逻辑网格。这使我们能够声明“沿 'data' 轴拆分数据”和“沿 'model' 轴拆分参数”。

分片辅助工具:named_shardingMeshRules

为了使分片规范的定义更简单、更有组织,本示例使用了一个辅助函数和一个数据类。

# A helper function to quickly create a NamedSharding object
# using the globally defined 'mesh'.
def named_sharding(*names: str | None) -> NamedSharding:
  # P(*names) creates a PartitionSpec, e.g., P('data', None)
  # NamedSharding binds this PartitionSpec to the 'mesh'.
  return NamedSharding(mesh, P(*names))
  • named_sharding 辅助函数:简化了 NamedSharding 对象的创建,该对象将 PartitionSpec(维度如何映射到网格轴,例如 P('data', None) 将维度 0 沿 'data' 轴分片,维度 1 复制)链接到我们特定的 mesh
# A dataclass to hold sharding rules for different parts of the model/data.
# Makes it easy to manage and change sharding strategies.
@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
  embed: str | None = None # Sharding rule for embedding-like dimensions
  mlp: str | None = None   # Sharding rule for MLP layers dimensions
  data: str | None = None  # Sharding rule for the data batch dimension

  # Allows calling the instance like `mesh_rules('embed', 'mlp')`
  # to get a tuple of the corresponding sharding rules.
  def __call__(self, *keys: str) -> tuple[str, ...]:
    return tuple(getattr(self, key) for key in keys)

# Create an instance of MeshRules defining the specific strategy:
# - 'embed' dimensions will be replicated (None).
# - 'mlp' dimensions will be sharded along the 'model' mesh axis.
# - 'data' dimensions will be sharded along the 'data' mesh axis.
mesh_rules = MeshRules(
  embed=None,
  mlp='model',
  data='data',
)
  • MeshRules 数据类:提供了一种结构化的方式来定义和检索模型逻辑部分(embedmlp)和数据所需的分片轴名称('data''model'None 表示复制)。

使用 Flax NNX 构建分片模型

现在,让我们使用 Flax NNX API 定义我们的 MLP。NNX 的一个关键特性是其显式状态管理。请注意,在使用 nnx.Param 创建参数时,如何直接指定分片意图。

# Define the MLP using Flax NNX API.
class MLP(nnx.Module):
  # Constructor takes input/hidden/output dimensions and an NNX Rngs object.
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    # Define the first weight matrix as an nnx.Param.
    self.w1 = nnx.Param(
      # Initialize with lecun_normal initializer using a key from rngs.
      nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
      # CRITICAL: Specify the desired sharding using MeshRules.
      # ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
      sharding=mesh_rules('embed', 'mlp'),
    )
    # Define the first bias vector as an nnx.Param.
    self.b1 = nnx.Param(
      jnp.zeros((dmid,)), # Initialize with zeros.
      # Sharding: ('mlp',) -> ('model',) -> Shard dim 0 along 'model' axis.
      sharding=mesh_rules('mlp'),
    )
    # Define the second weight matrix as an nnx.Param.
    self.w2 = nnx.Param(
      nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
       # Sharding: ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
      sharding=mesh_rules('embed', 'mlp'),
    )
    # Note: No second bias b2 is defined in this simple example.

  # The forward pass of the MLP.
  def __call__(self, x: jax.Array):
    # Standard MLP calculation: (x @ W1 + b1) -> ReLU -> @ W2
    # NNX automatically accesses the .value attribute of nnx.Param objects.
    return nnx.relu(x @ self.w1 + self.b1) @ self.w2
  • MLP NNX 模块
    • nnx.Param:定义可训练参数,封装 JAX 数组。
    • sharding=mesh_rules(...):这是 NNX 的关键集成点。我们将元数据(例如 (None, 'model'))直接附加到参数,指示它如何根据我们的 MeshRulesmesh 轴上分片。

处理分片优化器状态

参数不是我们唯一需要管理的状态;优化器状态(如动量)也必须与其相应的参数保持一致地分片。NNX 的显式变量系统优雅地处理了这个问题。

# Define a custom type for SGD momentum state, inheriting from nnx.Variable.
# This allows it to be tracked as part of the NNX state tree.
class SGDState(nnx.Variable):
  pass

# Define the SGD optimizer using NNX API.
class SGD(nnx.Object):
  # Constructor takes the model parameters (as nnx.State), learning rate, and decay.
  def __init__(self, params: nnx.State, lr, decay=0.9):
    # Helper function to initialize momentum buffer for a given parameter.
    def init_optimizer_state(variable: nnx.Variable):
      # Create momentum state with zeros, same shape and metadata (incl. sharding)
      # as the parameter it corresponds to.
      return SGDState(
        jnp.zeros_like(variable.value), **variable.get_metadata()
      )

    self.lr = lr
    # Store a reference to the parameter State tree.
    self.params = params
    # Create the momentum state tree, mirroring the structure of 'params',
    # using the helper function. Momentum will have the same sharding as params.
    self.momentum = jax.tree.map(init_optimizer_state, self.params)
    self.decay = decay

  # Method to update parameters based on gradients.
  def update(self, grads: nnx.State):
    # Define the update logic for a single parameter/momentum/gradient triple.
    def update_fn(
      params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
    ):
      # Standard SGD with momentum update rule.
      # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
      momentum.value = self.decay * momentum.value + (1 - self.decay) * grad.value
      # θ_{t+1} = θ_t - α * v_t
      params.value -= self.lr * momentum.value # NOTE: Direct mutation of param value!

    # Apply the update function across the parameter, momentum, and gradient trees.
    # This performs the update in-place on the parameter values referenced by self.params.
    jax.tree.map(update_fn, self.params, self.momentum, grads)
  • SGD NNX 优化器
    • SGDState(nnx.Variable):动量的自定义类型,使其成为 NNX 状态系统的一部分。
    • init_optimizer_state:创建与形状匹配的动量缓冲区,并且最重要的是,通过 **variable.get_metadata() 从相应参数继承元数据(包括 sharding 元组)。
    • update:就地将梯度应用于分片参数(通过 self.params 引用)和动量缓冲区。

应用和强制分片:create_model 函数

我们已经通过元数据定义了如何分片。现在,我们需要告诉 JAX 在计算过程中实际强制执行此分片布局。这发生在 create_model 函数中,使用 jax.lax.with_sharding_constraintnnx.update

# JIT-compile the model and optimizer creation function.
@nnx.jit
def create_model():
  # Instantiate the MLP model. rngs=nnx.Rngs(0) provides PRNG keys.
  model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
  # Create the optimizer. nnx.variables(model, nnx.Param) extracts
  # only the nnx.Param state variables from the model object.
  optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)

  # === Explicit Sharding Application ===
  # 1. Extract ALL state (model params + optimizer momentum) into a flat State pytree.
  state = nnx.state(optimizer)

  # 2. Define the target sharding for the state pytree.
  # This function maps state paths to NamedSharding objects based on stored metadata.
  def get_named_shardings(path: tuple, value: nnx.VariableState):
    # Assumes params and momentum use the sharding defined in their metadata.
    if path[0] in ('params', 'momentum'):
      # value.sharding contains the tuple like ('model',) or (None, 'model')
      # stored during Param/SGDState creation.
      return value.replace(NamedSharding(mesh, P(*value.sharding)))
    else:
      # Handle other state if necessary (e.g., learning rate if it were a Variable)
      raise ValueError(f'Unknown path: {path}')
  # Create the pytree of NamedSharding objects.
  named_shardings = state.map(get_named_shardings)

  # 3. Apply sharding constraint. This tells JAX how the 'state' pytree
  # SHOULD be sharded when computations involving it are run under jit/pjit.
  # It doesn't immediately move data but sets up the constraint for the compiler.
  sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)

  # 4. Update the original objects (model params, optimizer momentum)
  # with the constrained state values. This step makes the sharding
  # "stick" to the objects themselves for subsequent use outside this function.
  nnx.update(optimizer, sharded_state)

  # Return the model and optimizer objects, now containing sharded state variables.
  return model, optimizer

# Call the function to create the sharded model and optimizer.
model, optimizer = create_model()

# Visualize the sharding of the first weight's parameter tensor.
jax.debug.visualize_array_sharding(model.w1.value)
# Visualize the sharding of the first weight's momentum tensor.
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)
  • create_model 函数
    1. nnx.state(optimizer):将所有 nnx.Variable 实例(参数和动量)收集到一个 nnx.State pytree 中。
    2. state.map(get_named_shardings):创建一个 NamedSharding 规范的并行 pytree,读取附加到每个变量的 sharding 元数据。
    3. jax.lax.with_sharding_constraint:核心 JAX 原语。它告诉 JIT 编译器 state pytree 必须符合 named_shardings 布局。
    4. nnx.update(optimizer, sharded_state):此 NNX 函数将受约束的状态推回原始 modeloptimizer 对象,使分片对未来使用有效。
  • 可视化: jax.debug.visualize_array_sharding 确认参数和动量张量按预期分布在设备网格中。

分布式训练步骤

模型和优化器状态正确分片后,定义 JIT 编译的训练步骤就很简单了。JAX 根据我们建立的分片约束自动处理必要的通信(如梯度聚合)。

# JIT-compile the training step function.
@nnx.jit
def train_step(model: MLP, optimizer: SGD, x, y):
  # Define the loss function (Mean Squared Error).
  # Takes the model object as input, consistent with nnx.value_and_grad.
  def loss_fn(model):
    y_pred = model(x) # Forward pass
    loss = jnp.mean((y - y_pred) ** 2)
    return loss

  # Calculate loss and gradients w.r.t the model's state (its nnx.Param variables).
  # 'grad' will be an nnx.State object mirroring model's Param structure.
  loss, grad = nnx.value_and_grad(loss_fn)(model)

  # Call the optimizer's update method to apply gradients.
  # This updates the model parameters in-place.
  optimizer.update(grad)

  # Return the calculated loss.
  return loss
  • train_step 函数
    • @nnx.jit:编译函数。JAX 推断分布式执行计划。
    • nnx.value_and_grad:计算 model 中分片变量的损失和梯度。
    • optimizer.update(grad):将(隐式分片的)梯度应用于分片状态。

数据加载和训练循环

最后,训练循环生成数据并将其输入 train_step。这里的关键部分是在每个步骤之前使用 jax.device_put 沿 'data' 轴对输入数据批次进行分片。

# Generate synthetic dataset: y = 0.8*x^2 + 0.1 + noise
X = np.linspace(-2, 2, 100)[:, None] # Input features
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) # Target values

# A generator function to yield batches of data for training.
def dataset(batch_size, num_steps):
  for _ in range(num_steps):
    # Randomly sample indices for the batch.
    idx = np.random.choice(len(X), size=batch_size)
    # Yield the corresponding input and target pairs.
    yield X[idx], Y[idx]

# --- Training Loop ---
losses = [] # To store loss values for plotting
# Iterate through the dataset generator for 10,000 steps.
for step, (x_batch, y_batch) in enumerate(
  dataset(batch_size=32, num_steps=10_000)
):
  # CRITICAL: Place the NumPy data onto JAX devices AND apply sharding.
  # named_sharding('data') -> Shard along the 'data' mesh axis (first dim, size 2).
  # Each device along the 'data' axis gets a slice of the batch.
  x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data'))

  # Execute the JIT-compiled training step with the sharded model, optimizer, and data.
  loss = train_step(model, optimizer, x_batch, y_batch)

  # Record the loss (move scalar loss back to host CPU).
  losses.append(float(loss))
  # Log progress periodically.
  if step % 1000 == 0:
    print(f'Step {step}: Loss = {loss}')

# --- Plotting Results ---
plt.figure()
plt.title("Training Loss")
plt.plot(losses[20:]) # Plot loss, skipping initial noisy steps
plt.xlabel("Step")
plt.ylabel("MSE Loss")

# Get model predictions on the full dataset (X is on host CPU).
# Model applies function executes potentially on device, result brought back implicitly.
y_pred = model(X)
plt.figure()
plt.title("Model Fit")
plt.scatter(X, Y, color='blue', label='Data') # Original data
plt.plot(X, y_pred, color='black', label='Prediction') # Model's predictions
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.show() # Display the plots
  • 数据生成和数据集:标准数据设置。
  • 训练循环
    • jax.device_put(..., named_sharding('data')):这实现了数据并行。它将 NumPy 批次发送到设备并沿 'data' 网格轴拆分。
    • loss = train_step(...):执行分布式训练步骤。
  • 绘图:可视化训练进度和最终模型拟合。

总结

此示例演示了一种用于可伸缩机器学习训练的强大模式

  1. 定义 Mesh将逻辑名称('data''model')映射到设备轴。
  2. 使用 Flax NNX:显式定义参数 (nnx.Param) 和其他状态 (nnx.Variable)。
  3. 附加分片元数据:在创建 NNX 变量时直接指定所需的分片元组。
  4. 强制分片:在 JIT 上下文中使用 nnx.statejax.lax.with_sharding_constraintnnx.update 来应用约束。
  5. 分片数据:使用 jax.device_putNamedSharding 分布输入批次。
  6. JIT 训练步骤:让 JAX 编译分布式执行计划。

JAX 的显式分片控制和 Flax NNX 的显式状态管理的结合,为实现 FSDP 等复杂并行策略提供了一种清晰灵活的方式,从而能够训练更大、更强大的模型。

我们希望对官方 Flax 示例的详细演练有助于阐明这些强大工具如何协同工作!

社区

很棒的教程!

注册登录 发表评论