使用 JAX 和 Flax NNX 进行分布式训练:分片实用指南
训练大型机器学习模型通常会超出单个 GPU 或 TPU 的限制。为了高效扩展,需要将计算和内存分布到多个设备上。JAX 凭借其强大的
jit
编译器和显式分片功能,为此提供了出色的工具包。
最近,Flax 团队推出了 NNX,这是一个旨在更显式地管理神经网络状态的新 API。本文通过一个实际示例(基于一个官方 Flax NNX 示例)演示了如何将 JAX 的分片功能与 Flax NNX 结合,以实现像全分片数据并行 (FSDP) 这样的分布式训练策略。
我们的目标是逐步分解代码,使 JAX 分片的概念及其与新 Flax NNX API 的集成更易于理解。
注意:此处呈现的代码源自官方 Flax 示例仓库。您可以在以下链接找到原始来源:https://github.com/google/flax/blob/f7d3873b203ac0f3c6859738b1d48c2385359ca0/examples/nnx_toy_examples/10_fsdp_and_optimizer.py。本博客旨在提供详细解释以帮助理解。
让我们深入了解!
设置:导入和模拟设备
首先,我们需要必要的导入。我们将使用 jax
、flax.nnx
、numpy
和 JAX 的分片实用程序。开发的关键技巧是使用 XLA_FLAGS
环境变量模拟多个设备。这使我们即使在单个 CPU 或 GPU 机器上也能测试分片逻辑,然后再部署到更大的集群。
import dataclasses
import os
# Forces JAX to behave as if 8 devices (e.g., CPU cores) are available,
# even if running on a machine with fewer physical accelerators.
# Useful for testing sharding logic without multi-GPU/TPU hardware.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
from matplotlib import pyplot as plt # For plotting results
# Utilities for creating device meshes easily
from jax.experimental import mesh_utils
# Core JAX sharding components: Mesh defines the device grid,
# PartitionSpec defines how tensor axes map to mesh axes,
# NamedSharding links PartitionSpec to a Mesh with named axes.
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax
import jax.numpy as jnp # JAX's accelerated NumPy
import numpy as np # Standard NumPy for data generation
# Import the Flax NNX API components
from flax import nnx
import typing as tp # For type hints
- 导入和环境设置:各种功能的标准导入。
os.environ['XLA_FLAGS'] = ...
: 这告诉 JAX 的底层 XLA 编译器模拟 8 个设备,从而使多设备开发变得容易。
定义并行策略:设备网格
JAX 中显式分片的核心概念是 Mesh
。它表示设备的逻辑网格(真实或模拟)。我们将名称与此网格的轴关联起来,以定义我们的并行策略。在这里,我们创建一个 2x4 的网格(8 个设备),其轴命名为 'data'
和 'model'
。
# Create a 2D mesh (grid) of devices with shape (2, 4), meaning 8 devices total.
# Assign logical names 'data' and 'model' to the axes of this grid.
# The first dimension (size 2) is named 'data'.
# The second dimension (size 4) is named 'model'.
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((2, 4)),
('data', 'model'),
)
- 网格定义:
mesh_utils.create_device_mesh((2, 4))
排列设备。jax.sharding.Mesh(...)
使用命名轴创建逻辑网格。这使我们能够声明“沿 'data' 轴拆分数据”和“沿 'model' 轴拆分参数”。
分片辅助工具:named_sharding
和 MeshRules
为了使分片规范的定义更简单、更有组织,本示例使用了一个辅助函数和一个数据类。
# A helper function to quickly create a NamedSharding object
# using the globally defined 'mesh'.
def named_sharding(*names: str | None) -> NamedSharding:
# P(*names) creates a PartitionSpec, e.g., P('data', None)
# NamedSharding binds this PartitionSpec to the 'mesh'.
return NamedSharding(mesh, P(*names))
named_sharding
辅助函数:简化了NamedSharding
对象的创建,该对象将PartitionSpec
(维度如何映射到网格轴,例如P('data', None)
将维度 0 沿 'data' 轴分片,维度 1 复制)链接到我们特定的mesh
。
# A dataclass to hold sharding rules for different parts of the model/data.
# Makes it easy to manage and change sharding strategies.
@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
embed: str | None = None # Sharding rule for embedding-like dimensions
mlp: str | None = None # Sharding rule for MLP layers dimensions
data: str | None = None # Sharding rule for the data batch dimension
# Allows calling the instance like `mesh_rules('embed', 'mlp')`
# to get a tuple of the corresponding sharding rules.
def __call__(self, *keys: str) -> tuple[str, ...]:
return tuple(getattr(self, key) for key in keys)
# Create an instance of MeshRules defining the specific strategy:
# - 'embed' dimensions will be replicated (None).
# - 'mlp' dimensions will be sharded along the 'model' mesh axis.
# - 'data' dimensions will be sharded along the 'data' mesh axis.
mesh_rules = MeshRules(
embed=None,
mlp='model',
data='data',
)
MeshRules
数据类:提供了一种结构化的方式来定义和检索模型逻辑部分(embed
、mlp
)和数据所需的分片轴名称('data'
、'model'
或None
表示复制)。
使用 Flax NNX 构建分片模型
现在,让我们使用 Flax NNX API 定义我们的 MLP。NNX 的一个关键特性是其显式状态管理。请注意,在使用 nnx.Param
创建参数时,如何直接指定分片意图。
# Define the MLP using Flax NNX API.
class MLP(nnx.Module):
# Constructor takes input/hidden/output dimensions and an NNX Rngs object.
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
# Define the first weight matrix as an nnx.Param.
self.w1 = nnx.Param(
# Initialize with lecun_normal initializer using a key from rngs.
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
# CRITICAL: Specify the desired sharding using MeshRules.
# ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
sharding=mesh_rules('embed', 'mlp'),
)
# Define the first bias vector as an nnx.Param.
self.b1 = nnx.Param(
jnp.zeros((dmid,)), # Initialize with zeros.
# Sharding: ('mlp',) -> ('model',) -> Shard dim 0 along 'model' axis.
sharding=mesh_rules('mlp'),
)
# Define the second weight matrix as an nnx.Param.
self.w2 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
# Sharding: ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
sharding=mesh_rules('embed', 'mlp'),
)
# Note: No second bias b2 is defined in this simple example.
# The forward pass of the MLP.
def __call__(self, x: jax.Array):
# Standard MLP calculation: (x @ W1 + b1) -> ReLU -> @ W2
# NNX automatically accesses the .value attribute of nnx.Param objects.
return nnx.relu(x @ self.w1 + self.b1) @ self.w2
MLP
NNX 模块nnx.Param
:定义可训练参数,封装 JAX 数组。sharding=mesh_rules(...)
:这是 NNX 的关键集成点。我们将元数据(例如(None, 'model')
)直接附加到参数,指示它如何根据我们的MeshRules
在mesh
轴上分片。
处理分片优化器状态
参数不是我们唯一需要管理的状态;优化器状态(如动量)也必须与其相应的参数保持一致地分片。NNX 的显式变量系统优雅地处理了这个问题。
# Define a custom type for SGD momentum state, inheriting from nnx.Variable.
# This allows it to be tracked as part of the NNX state tree.
class SGDState(nnx.Variable):
pass
# Define the SGD optimizer using NNX API.
class SGD(nnx.Object):
# Constructor takes the model parameters (as nnx.State), learning rate, and decay.
def __init__(self, params: nnx.State, lr, decay=0.9):
# Helper function to initialize momentum buffer for a given parameter.
def init_optimizer_state(variable: nnx.Variable):
# Create momentum state with zeros, same shape and metadata (incl. sharding)
# as the parameter it corresponds to.
return SGDState(
jnp.zeros_like(variable.value), **variable.get_metadata()
)
self.lr = lr
# Store a reference to the parameter State tree.
self.params = params
# Create the momentum state tree, mirroring the structure of 'params',
# using the helper function. Momentum will have the same sharding as params.
self.momentum = jax.tree.map(init_optimizer_state, self.params)
self.decay = decay
# Method to update parameters based on gradients.
def update(self, grads: nnx.State):
# Define the update logic for a single parameter/momentum/gradient triple.
def update_fn(
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
):
# Standard SGD with momentum update rule.
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
momentum.value = self.decay * momentum.value + (1 - self.decay) * grad.value
# θ_{t+1} = θ_t - α * v_t
params.value -= self.lr * momentum.value # NOTE: Direct mutation of param value!
# Apply the update function across the parameter, momentum, and gradient trees.
# This performs the update in-place on the parameter values referenced by self.params.
jax.tree.map(update_fn, self.params, self.momentum, grads)
SGD
NNX 优化器SGDState(nnx.Variable)
:动量的自定义类型,使其成为 NNX 状态系统的一部分。init_optimizer_state
:创建与形状匹配的动量缓冲区,并且最重要的是,通过**variable.get_metadata()
从相应参数继承元数据(包括sharding
元组)。update
:就地将梯度应用于分片参数(通过self.params
引用)和动量缓冲区。
应用和强制分片:create_model
函数
我们已经通过元数据定义了如何分片。现在,我们需要告诉 JAX 在计算过程中实际强制执行此分片布局。这发生在 create_model
函数中,使用 jax.lax.with_sharding_constraint
和 nnx.update
。
# JIT-compile the model and optimizer creation function.
@nnx.jit
def create_model():
# Instantiate the MLP model. rngs=nnx.Rngs(0) provides PRNG keys.
model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
# Create the optimizer. nnx.variables(model, nnx.Param) extracts
# only the nnx.Param state variables from the model object.
optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)
# === Explicit Sharding Application ===
# 1. Extract ALL state (model params + optimizer momentum) into a flat State pytree.
state = nnx.state(optimizer)
# 2. Define the target sharding for the state pytree.
# This function maps state paths to NamedSharding objects based on stored metadata.
def get_named_shardings(path: tuple, value: nnx.VariableState):
# Assumes params and momentum use the sharding defined in their metadata.
if path[0] in ('params', 'momentum'):
# value.sharding contains the tuple like ('model',) or (None, 'model')
# stored during Param/SGDState creation.
return value.replace(NamedSharding(mesh, P(*value.sharding)))
else:
# Handle other state if necessary (e.g., learning rate if it were a Variable)
raise ValueError(f'Unknown path: {path}')
# Create the pytree of NamedSharding objects.
named_shardings = state.map(get_named_shardings)
# 3. Apply sharding constraint. This tells JAX how the 'state' pytree
# SHOULD be sharded when computations involving it are run under jit/pjit.
# It doesn't immediately move data but sets up the constraint for the compiler.
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
# 4. Update the original objects (model params, optimizer momentum)
# with the constrained state values. This step makes the sharding
# "stick" to the objects themselves for subsequent use outside this function.
nnx.update(optimizer, sharded_state)
# Return the model and optimizer objects, now containing sharded state variables.
return model, optimizer
# Call the function to create the sharded model and optimizer.
model, optimizer = create_model()
# Visualize the sharding of the first weight's parameter tensor.
jax.debug.visualize_array_sharding(model.w1.value)
# Visualize the sharding of the first weight's momentum tensor.
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)
create_model
函数nnx.state(optimizer)
:将所有nnx.Variable
实例(参数和动量)收集到一个nnx.State
pytree 中。state.map(get_named_shardings)
:创建一个NamedSharding
规范的并行 pytree,读取附加到每个变量的sharding
元数据。jax.lax.with_sharding_constraint
:核心 JAX 原语。它告诉 JIT 编译器state
pytree 必须符合named_shardings
布局。nnx.update(optimizer, sharded_state)
:此 NNX 函数将受约束的状态推回原始model
和optimizer
对象,使分片对未来使用有效。
- 可视化:
jax.debug.visualize_array_sharding
确认参数和动量张量按预期分布在设备网格中。
分布式训练步骤
模型和优化器状态正确分片后,定义 JIT 编译的训练步骤就很简单了。JAX 根据我们建立的分片约束自动处理必要的通信(如梯度聚合)。
# JIT-compile the training step function.
@nnx.jit
def train_step(model: MLP, optimizer: SGD, x, y):
# Define the loss function (Mean Squared Error).
# Takes the model object as input, consistent with nnx.value_and_grad.
def loss_fn(model):
y_pred = model(x) # Forward pass
loss = jnp.mean((y - y_pred) ** 2)
return loss
# Calculate loss and gradients w.r.t the model's state (its nnx.Param variables).
# 'grad' will be an nnx.State object mirroring model's Param structure.
loss, grad = nnx.value_and_grad(loss_fn)(model)
# Call the optimizer's update method to apply gradients.
# This updates the model parameters in-place.
optimizer.update(grad)
# Return the calculated loss.
return loss
train_step
函数@nnx.jit
:编译函数。JAX 推断分布式执行计划。nnx.value_and_grad
:计算model
中分片变量的损失和梯度。optimizer.update(grad)
:将(隐式分片的)梯度应用于分片状态。
数据加载和训练循环
最后,训练循环生成数据并将其输入 train_step
。这里的关键部分是在每个步骤之前使用 jax.device_put
沿 'data'
轴对输入数据批次进行分片。
# Generate synthetic dataset: y = 0.8*x^2 + 0.1 + noise
X = np.linspace(-2, 2, 100)[:, None] # Input features
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) # Target values
# A generator function to yield batches of data for training.
def dataset(batch_size, num_steps):
for _ in range(num_steps):
# Randomly sample indices for the batch.
idx = np.random.choice(len(X), size=batch_size)
# Yield the corresponding input and target pairs.
yield X[idx], Y[idx]
# --- Training Loop ---
losses = [] # To store loss values for plotting
# Iterate through the dataset generator for 10,000 steps.
for step, (x_batch, y_batch) in enumerate(
dataset(batch_size=32, num_steps=10_000)
):
# CRITICAL: Place the NumPy data onto JAX devices AND apply sharding.
# named_sharding('data') -> Shard along the 'data' mesh axis (first dim, size 2).
# Each device along the 'data' axis gets a slice of the batch.
x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data'))
# Execute the JIT-compiled training step with the sharded model, optimizer, and data.
loss = train_step(model, optimizer, x_batch, y_batch)
# Record the loss (move scalar loss back to host CPU).
losses.append(float(loss))
# Log progress periodically.
if step % 1000 == 0:
print(f'Step {step}: Loss = {loss}')
# --- Plotting Results ---
plt.figure()
plt.title("Training Loss")
plt.plot(losses[20:]) # Plot loss, skipping initial noisy steps
plt.xlabel("Step")
plt.ylabel("MSE Loss")
# Get model predictions on the full dataset (X is on host CPU).
# Model applies function executes potentially on device, result brought back implicitly.
y_pred = model(X)
plt.figure()
plt.title("Model Fit")
plt.scatter(X, Y, color='blue', label='Data') # Original data
plt.plot(X, y_pred, color='black', label='Prediction') # Model's predictions
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.show() # Display the plots
- 数据生成和数据集:标准数据设置。
- 训练循环
jax.device_put(..., named_sharding('data'))
:这实现了数据并行。它将 NumPy 批次发送到设备并沿'data'
网格轴拆分。loss = train_step(...)
:执行分布式训练步骤。
- 绘图:可视化训练进度和最终模型拟合。
总结
此示例演示了一种用于可伸缩机器学习训练的强大模式
- 定义
Mesh
:将逻辑名称('data'
、'model'
)映射到设备轴。 - 使用 Flax NNX:显式定义参数 (
nnx.Param
) 和其他状态 (nnx.Variable
)。 - 附加分片元数据:在创建 NNX 变量时直接指定所需的分片元组。
- 强制分片:在 JIT 上下文中使用
nnx.state
、jax.lax.with_sharding_constraint
和nnx.update
来应用约束。 - 分片数据:使用
jax.device_put
和NamedSharding
分布输入批次。 - JIT 训练步骤:让 JAX 编译分布式执行计划。
JAX 的显式分片控制和 Flax NNX 的显式状态管理的结合,为实现 FSDP 等复杂并行策略提供了一种清晰灵活的方式,从而能够训练更大、更强大的模型。
我们希望对官方 Flax 示例的详细演练有助于阐明这些强大工具如何协同工作!