使用 AutoTrain 训练目标检测模型
目标检测是计算机视觉中的一项基本任务,它使模型能够识别和分类图像中的物体。AutoTrain 通过让用户轻松训练最先进的目标检测模型,简化了这一过程。在这篇博文中,我们将引导您完成准备数据、配置训练参数以及使用命令行界面 (CLI) 和用户界面 (UI) 在本地和 Hugging Face 云上高效训练目标检测模型的步骤。
准备数据
在训练模型之前,您需要整理图像并创建一个元数据文件。请遵循以下指南
为 UI 准备数据
创建 Zip 压缩包: 将您的图像和
metadata.jsonl
文件收集到一个 zip 文件中。您的文件结构应如下所示Archive.zip ├── 0001.png ├── 0002.png ├── 0003.png ├── ... └── metadata.jsonl
准备元数据:
metadata.jsonl
文件包含每张图像的信息,包括物体的边界框和类别。这是一个例子{"file_name": "0001.png", "objects": {"bbox": [[302.0, 109.0, 73.0, 52.0]], "category": [0]}} {"file_name": "0002.png", "objects": {"bbox": [[810.0, 100.0, 57.0, 28.0]], "category": [1]}} {"file_name": "0003.png", "objects": {"bbox": [[160.0, 31.0, 248.0, 616.0], [741.0, 68.0, 202.0, 401.0]], "category": [2, 2]}}
确保边界框采用 COCO 格式
[x, y, width, height]
。
为 CLI 准备数据
或者,如果您不使用 UI,也可以将数据整理到文件夹中
创建训练和验证文件夹: 将您的图像和
metadata.jsonl
文件分别整理到用于训练和验证的文件夹中。training/ ├── 0001.png ├── 0002.png ├── 0003.png ├── ... └── metadata.jsonl validation/ ├── 0004.png ├── 0005.png ├── ... └── metadata.jsonl
准备元数据: 与 UI 方法类似,
metadata.jsonl
文件应包含边界框和类别信息。
图片要求
- 格式: 所有图片必须为 JPEG、JPG 或 PNG 格式。
- 数量: 至少包含 5 张图片,以提供足够的学习样本。
- 专有性: zip 文件应仅包含图片和
metadata.jsonl
文件。不应包含其他文件或嵌套文件夹。
当 train.zip
解压时,它不应创建任何文件夹,只应包含图片和 metadata.jsonl
文件。
注意:您也可以使用来自 Hugging Face Hub 的数据集。这将在本博文中进一步讨论。
配置训练参数
AutoTrain 提供多种参数来自定义您的训练过程。以下是您可以配置的关键参数
基本参数
- --image-square-size: 将输入图像调整为指定大小的正方形 (默认为 600)。
- --batch-size: 设置训练批次大小。
- --seed: 用于可复现性的随机种子。
- --epochs: 训练轮数。
- --gradient_accumulation: 梯度累积步数。
- --disable_gradient_checkpointing: 禁用梯度检查点。
- --lr: 学习率。
- --log: 实验跟踪选项 (
none
、wandb
、tensorboard
)。
高级参数
- --image-column: 指定要使用的图像列。
- --target-column: 指定要使用的目标列。
- --warmup-ratio: 用于线性预热的训练比例 (默认为 0.1)。
- --optimizer: 选择优化器算法 (默认为
adamw_torch
)。 - --scheduler: 选择学习率调度器 (默认为
linear
,cosine
是另一个选项)。 - --weight-decay: 设置权重衰减率 (默认为 0.0)。
- --max-grad-norm: 用于梯度裁剪的梯度最大范数 (默认为 1.0)。
- --logging-steps: 确定记录训练进度的频率 (默认为 -1,表示自动确定)。
- --evaluation-strategy: 指定评估频率 (
no
、steps
、epoch
)。 - --save-total-limit: 限制要保存的模型检查点数量。
- --auto-find-batch-size: 根据硬件能力自动确定批次大小。
- --mixed-precision: 选择精度模式 (
fp16
、bf16
或 None)。
使用 CLI 进行训练
要使用 CLI 训练目标检测模型,您可以创建一个配置文件并运行 autotrain
命令。以下是在 Hugging Face Hub 上的 CPPE-5 数据集上进行训练的示例配置文件。
示例配置文件
task: object_detection
base_model: facebook/detr-resnet-50
project_name: autotrain-obj-det-cppe5-2
log: tensorboard
backend: local
data:
path: cppe-5
train_split: train
valid_split: test
column_mapping:
image_column: image
objects_column: objects
params:
image_square_size: 600
epochs: 100
batch_size: 8
lr: 5e-5
weight_decay: 1e-4
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 1
mixed_precision: fp16
early_stopping_patience: 50
early_stopping_threshold: 0.001
hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
运行训练
要开始训练,请使用以下命令
$ export HF_USERNAME=your_hugging_face_username
$ export HF_TOKEN=your_hugging_face_write_token
$ autotrain --config configfile.yml
此命令将使用 configfile.yml
中指定的配置来训练您的目标检测模型。
注意:仅当您将 push_to_hub
设置为 true
时,才需要导出您的用户名和令牌。
在某些 Hugging Face 数据集的情况下,数据集可能包含配置,此时您可以对 train_split 和 valid_split 使用 dataset_config:split_name
。例如,这个数据集有配置:full
和 mini
为此,配置文件的更改将是
data:
path: keremberke/license-plate-object-detection
train_split: full:train
valid_split: full:validation
column_mapping:
image_column: image
objects_column: objects
如果您的数据集存储在本地,则需要在配置 yaml 中更新以下内容
data:
path: /path/to/data/folder/
train_split: train # this folder contains images and metadata.jsonl
valid_split: val # this folder contains images and metadata.jsonl, optional, can be set to null
column_mapping:
image_column: image
objects_column: objects
使用 UI 进行训练
在本地,您可以通过运行以下命令启动 AutoTrain UI
$ pip install -U autotrain-advanced
$ autotrain app --host 127.0.0.1 --port 8000
应用程序将在 http://127.0.0.1:8000 启动
在 UI 中上传的数据格式与上面描述的 zip 文件相同。
如果您没有合适的硬件,也可以通过点击这里在 Hugging Face Spaces 上启动 UI。更多信息请阅读文档。
当使用来自 Hub 的数据集时,您必须正确映射列。使用本地数据集(文件夹或 zip)时,列映射应保持原样。
总结
AutoTrain 简化了训练目标检测模型这一复杂任务,使您能够专注于微调模型以获得最佳性能。通过遵循这些指南并利用可用的参数,您可以创建满足您特定需求的有效目标检测模型。无论使用 UI 还是 CLI,AutoTrain 都为构建强大的目标检测模型提供了简化的流程。
附言:所有使用 AutoTrain 训练的模型都可以通过 API 推理和推理端点进行部署。
如有任何问题或功能请求,请查看 GitHub 仓库。
训练愉快! :)