Hub 文档

在 Hugging Face 使用 mlx-image

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

在 Hugging Face 上使用 mlx-image

mlx-image 是由 Riccardo Musmeci 开发的图像模型库,基于 Apple MLX 构建。它试图复制出色的 timm,但适用于 MLX 模型。

在 Hub 上探索 mlx-image

您可以使用 mlx-image 库名称进行过滤来查找 mlx-image 模型,例如 此查询。还有一个开放的 mlx-vision 社区,供贡献者转换和发布 MLX 格式的权重。

安装

pip install mlx-image

模型

模型权重可在 HuggingFace 上的 mlx-vision 社区中找到。

要加载具有预训练权重的模型

from mlxim.model import create_model

# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)
model = create_model("resnet18") # pretrained weights loaded from HF

# loading weights from local file
model = create_model("resnet18", weights="path/to/resnet18/model.safetensors")

列出所有可用模型

from mlxim.model import list_models
list_models()

截至今天 (2024-03-15),mlx 不支持 nn.Conv2dgroup 参数。因此,诸如 resnextregnetefficientnet 等架构在 mlx-image 中尚不受支持。

ImageNet-1K 结果

请访问 results-imagenet-1k.csv 以检查每个转换为 mlx-image 的模型及其在不同设置下在 ImageNet-1K 上的性能。

简而言之,性能与 PyTorch 实现中的原始模型相当。

与 PyTorch 及其他常用工具的相似性

mlx-image 试图尽可能接近 PyTorch

  • DataLoader -> 你可以定义自己的 collate_fn,也可以使用 num_workers 加速数据加载

  • Dataset -> mlx-image 已经支持 LabelFolderDataset(传统的 PyTorch ImageFolder)和 FolderDataset(包含图像的通用文件夹)

  • ModelCheckpoint -> 跟踪最佳模型并将其保存到磁盘(类似于 PyTorchLightning)。它还建议提前停止

训练

训练类似于 PyTorch。以下是如何训练模型的示例

import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader

train_dataset = LabelFolderDataset(
    root_dir="path/to/train",
    class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)

def train_step(model, inputs, targets):
    logits = model(inputs)
    loss = mx.mean(nn.losses.cross_entropy(logits, target))
    return loss

model.train()
for epoch in range(10):
    for batch in train_loader:
        x, target = batch
        train_step_fn = nn.value_and_grad(model, train_step)
        loss, grads = train_step_fn(x, target)
        optimizer.update(model, grads)
        mx.eval(model.state, optimizer.state)

附加资源

联系方式

如有任何问题,请发送邮件至 [email protected]

< > 在 GitHub 上更新