使用图像回归进行销售预测

社区文章 发布于 2024年5月24日

image/png

概述

在本文中,我们将训练自己的图像回归模型,然后用它根据产品图像预测销售额。图像回归是一种机器学习技术,它从图像中预测一个连续的数值。我们将使用我的简单 图像回归模型训练器 工具进行训练、上传和推理。

这个项目的主要动机之一是,在撰写本文时,我在 🤗 上找不到任何关于图像回归的资源。图像回归模型训练器 是基于 🤗 Transformers 和 PyTorch 构建的,旨在集成到 🤗 生态系统中。

数据集

模型训练器将 🤗 数据集 ID 作为输入,因此您的数据集必须上传到 🤗。它应该包含一列图像和一列值(浮点数或整数)。如果您需要帮助创建 🤗 图像数据集,请查看 🤗 创建图像数据集。您需要将图像格式化到一个文件夹中,并附带一个 metadata.csv 文件,如下所示:

folder/metadata.csv
folder/0001.png
folder/0002.png
folder/0003.png

您的 metadata.csv 文件将如下所示:

file_name,sales
0001.png,1000
0002.png,20000
0003.png,100000

上传到 🤗 Hub

from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="/path/to/folder")
dataset.push_to_hub("tonyassi/clothing-sales-ds")

您的数据集应该类似于 tonyassi/clothing-sales-ds(值列的名称可以随意命名)。

image/png

使用 PyTorch 和 🤗 Transformers 进行图像回归

我们的图像回归模型将是 Google 的 Vision Transformer (ViT) 的微调版本。Google ViT 通过将 224x224 像素的图像分成 16x16 像素的补丁来处理图像分类等任务。您可以在这篇论文中了解更多信息。

image/png

我们需要自定义模型以输出**连续数值**而不是**图像分类标签**。让我们深入了解 图像回归模型训练器 的内部,看看模型是如何定义的。

class ViTRegressionModel(nn.Module):
    def __init__(self):
        super(ViTRegressionModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.classifier = nn.Linear(self.vit.config.hidden_size, 1)

让我们来分解一下。这行代码从 Hugging Face 模型中心加载了一个预训练的 Vision Transformer 模型(ViT)google/vit-base-patch16-224

self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')

此行定义了一个线性层(全连接层),该层将 ViT 模型的隐藏层大小作为输入,并输出单个值。此层将用于输出预测的回归值。

self.classifier = nn.Linear(self.vit.config.hidden_size, 1)

图像回归模型训练器 的另一个重要组成部分是 🤗 Transformers 训练器。训练器是一个完整的 PyTorch 模型训练和评估循环,因此您只需向其传递训练所需的必要部分(模型、数据集、训练超参数等),训练器类就会处理其余部分。

以下是训练参数的样子

training_args = TrainingArguments(
        output_dir='./results',
        evaluation_strategy="epoch",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=10,
        learning_rate=1e-4,
        save_steps=10,
        save_total_limit=2,
        logging_steps=10,
        remove_unused_columns=False,
        resume_from_checkpoint=True,
)

我们只需要将模型、参数和数据集传递给训练器即可

model = ViTRegressionModel()
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        data_collator=collate_fn,
)
trainer.train()

图像回归模型训练器 将这些细节抽象化,因此我们无需深入研究 PyTorch/Transformers 代码。

下载

从 GitHub 下载 图像回归模型训练器

git clone https://github.com/TonyAssi/ImageRegression.git
cd ImageRegression

安装

安装所需的库

pip install -r requirements.txt

训练

最后,让我们训练我们的模型!图像回归模型训练器 使其变得非常简单。如果您使用自己的数据集,请确保其已正确上传到 🤗 Hub。value_column_name 变量将是您值列的名称。请随意尝试 test_splitnum_train_epochslearning_rate(以下值是很好的起点)。

  • dataset_id 🤗 数据集 ID
  • value_column_name 数据集中预测值的列名
  • test_split 训练/测试分割的测试比例
  • output_dir 保存检查点的目录
  • num_train_epochs 训练轮次
  • learning_rate 学习率
train_model(dataset_id='tonyassi/clothing-sales-ds',
            value_column_name='sales',
            test_split=0.2,
            output_dir='./results',
            num_train_epochs=10,
            learning_rate=1e-4)

训练器会将检查点保存在 output_dir 位置。model.safetensors 是您将用于推理(预测)的训练权重。

当模型进行训练时,您应该会看到一些信息被打印出来。均方误差 (MSE) 是回归任务中常用的损失函数,用于衡量预测值与实际值之间的差异。您应该会看到该值在每个 epoch 后都在下降。

上传模型

建议将您的模型上传到 🤗 Hub,因为这将大大简化推理过程,并且会自动生成模型卡。您需要为模型选择一个唯一的名称 model_id,生成一个 token,并定义检查点文件夹。转到训练的 output_dir 位置,您应该会看到检查点文件夹——选择最新的检查点。

  • model_id 模型 ID 的名称
  • token 前往此处创建新的 🤗 token
  • checkpoint_dir 将被上传的检查点文件夹
upload_model(model_id='sales-prediction',
             token='YOUR_HF_TOKEN',
             checkpoint_dir='./results/checkpoint-940')

前往您的 🤗 个人资料,您会找到已上传的模型,它应该类似于 tonyassi/sales-prediction。上传功能会自动为您生成模型卡,其中列出了数据集信息、训练参数和使用说明。

推理

现在我们可以使用自定义训练的模型,根据图像预测浮点值。在我们的示例中,数据集是产品图片和销售额,因此我们可以使用此模型来预测新产品的销售额。您将需要上一步中的模型 repo ID 和图片路径来预测其值。

  • repo_id 模型的 🤗 仓库 ID
  • image_path 图片路径
predict(repo_id='tonyassi/sales-prediction',
        image_path='image.jpg')

首次调用此函数时,它将下载 safetensor 模型。后续函数调用将运行得更快。

其他应用

这种图像回归方法除了销售预测外,还可以用于许多不同的应用。

  • 根据图像预测一个人的年龄
  • 根据美学对图像进行评分
  • 预测医学图像中肿瘤的大小,例如核磁共振或CT扫描
  • 根据航空或卫星图像估算作物产量
  • 通过分析天空或水体的图像评估空气或水质
  • 从道路图像预测交通密度或车辆数量
  • 任何需要根据图像预测数字的机器学习问题

关于我

大家好,我叫 Tony Assi。我是洛杉矶的一名设计师。我拥有软件、时尚和营销背景。目前我在一家电子商务时尚品牌工作。查看我的 🤗 个人资料,了解更多应用、模型和数据集。

如有任何问题、意见、业务咨询或工作机会,请随时发送电子邮件至 tony.assi.media@gmail.com

社区

注册登录以发表评论