如何利用对抗性数据动态训练模型

发布于 2022 年 7 月 16 日
在 GitHub 上更新
您将在这里学到什么
  • 💡动态对抗性数据收集的基本概念及其重要性。
  • ⚒ 如何动态收集对抗性数据并在其上训练您的模型——以 MNIST 手写数字识别任务为例。

动态对抗性数据收集 (DADC)

静态基准虽然是评估模型性能的广泛使用方法,但存在许多问题:它们饱和、存在偏差或漏洞,并且经常导致研究人员追求指标的增加,而不是构建可供人类使用的值得信赖的模型1

动态对抗性数据收集(DADC)作为一种缓解静态基准部分问题的方法,前景广阔。在 DADC 中,人类创建示例来**欺骗**最先进(SOTA)模型。这个过程提供两个好处:

  1. 它允许用户评估其模型的真实鲁棒性;
  2. 它产生的数据可以用于进一步训练更强大的模型。

这种在对抗性收集的数据上欺骗和训练模型的过程会重复多轮,从而产生一个更鲁棒且与人类对齐的模型1

利用对抗性数据动态训练您的模型

在这里,我将向您展示如何动态地从用户那里收集对抗性数据并根据它们训练您的模型——以 MNIST 手写数字识别任务为例。

在 MNIST 手写数字识别任务中,模型经过训练,可以根据手写数字的 28x28 灰度图像输入(参见下图中的示例)预测数字。数字范围从 0 到 9。

图片来源:mnist | Tensorflow Datasets

这项任务被广泛认为是计算机视觉的“hello world”,并且很容易训练模型在标准(和静态)基准测试集上实现高精度。然而,已经表明这些 SOTA 模型在人类书写数字(并将其作为输入提供给模型)时仍然难以预测正确的数字:研究人员认为这很大程度上是因为静态测试集不能充分代表人类书写方式的非常多样性。因此,需要人类参与循环,为模型提供**对抗性**样本,这将有助于它们更好地泛化。

本教程将分为以下几个部分:

  1. 配置您的模型
  2. 与您的模型交互
  3. 标记您的模型
  4. 整合所有部分

配置您的模型

首先,您需要定义您的模型架构。我简单的模型架构由两个卷积网络组成,它们连接到一个 50 维的全连接层和一个用于 10 个类别的最终层。最后,我们使用 softmax 激活函数将模型的输出转换为类别上的概率分布。

# Adapted from: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
    def __init__(self):
        super(MNIST_Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

现在您已经定义了模型的结构,您需要将其在标准 MNIST 训练/开发数据集上进行训练。

与您的模型交互

至此,我们假设您已经训练好了模型。尽管该模型已经训练,但我们旨在通过人机协作对抗性数据使其变得健壮。为此,您需要一种用户与模型交互的方式:具体来说,您希望用户能够在画布上书写/绘制数字 0-9,并让模型尝试对其进行分类。您可以使用 🤗 Spaces 完成所有这些操作,它允许您快速轻松地为您的机器学习模型构建演示。在此处了解有关 Spaces 以及如何构建它们的更多信息here

下面是一个简单的 Space,用于与我训练了 20 个 epoch 的 `MNIST_Model` 交互(在测试集上获得了 89% 的准确率)。您在白色画布上绘制一个数字,模型会从您的图像中预测该数字。完整的 Space 可以在此处访问。尝试欺骗这个模型😁。使用您最有趣的笔迹;在画布的侧面书写;尽情发挥吧!

标记您的模型

您能骗过上面的模型吗?😀 如果能,那么是时候**标记**您的对抗性示例了。标记包括:

  1. 将对抗性示例保存到数据集
  2. 在收集到一定数量的样本后,对对抗性示例进行模型训练。
  3. 重复步骤 1-2 若干次。

我编写了一个自定义的 `flag` 函数来完成所有这些操作。有关更多详细信息,请随时在此处查阅完整代码。

注意:Gradio 有一个内置的标记回调,可以让您轻松标记模型的对抗性样本。在此处阅读更多相关信息:here

将所有内容整合

最后一步是将所有三个组件(配置模型、与模型交互和标记模型)整合到一个演示空间中!为此,我创建了 MNIST 对抗性空间,用于 MNIST 手写识别任务的动态对抗性数据收集。请随意在下面进行测试。

结论

动态对抗性数据收集 (DADC) 在机器学习社区中作为一种收集多样化、非饱和、与人类对齐的数据集、改进模型评估和任务性能的方式,正获得越来越多的关注。通过动态收集带有模型循环的人工生成的对抗性数据,我们可以提高模型的泛化潜力。

这种在对抗性收集的数据上欺骗和训练模型的过程应该重复多轮1Eric Wallace 等人在自然语言推理任务的实验中表明,尽管短期内标准非对抗性数据收集表现更好,但从长远来看,动态对抗性数据收集带来了显著更高的准确率。

使用 🤗 Spaces,构建一个平台来动态收集模型的对抗性数据并进行训练变得相对容易。

社区

注册登录 评论