使用 Transformers.js 制作 ML 驱动的网页游戏

发布于 2023 年 7 月 5 日
在 GitHub 上更新

在这篇博文中,我将向您展示我是如何制作涂鸦冲刺的,这是一个完全在浏览器中运行的实时 ML 驱动网页游戏(感谢 Transformers.js)。本教程的目标是向您展示制作自己的 ML 驱动网页游戏是多么容易……正好赶上即将到来的开源 AI 游戏大会(2023 年 7 月 7 日至 9 日)。如果您还没有参加,请加入游戏大会!

快速链接

概述

在开始之前,让我们谈谈我们将要创建什么。这款游戏灵感来源于谷歌的 Quick, Draw! 游戏,在该游戏中,您会被给出一个单词,神经网络有 20 秒的时间来猜测您正在画什么(重复 6 次)。事实上,我们将使用他们的训练数据来训练我们自己的草图检测模型!您难道不喜欢开源吗? 😍

在我们的版本中,您将有一分钟的时间尽可能多地绘制物品,一次一个提示。如果模型预测出正确的标签,画布将被清除,您将获得一个新单词。一直这样做,直到计时器用完!由于游戏在您的浏览器中本地运行,我们根本不必担心服务器延迟。该模型能够在您绘制时进行实时预测,每秒超过 60 次预测…… 🤯 太棒了!

本教程分为 3 个部分

  1. 训练神经网络
  2. 使用 Transformers.js 在浏览器中运行
  3. 游戏设计

1. 训练神经网络

训练数据

我们将使用谷歌 Quick, Draw! 数据集的子集来训练我们的模型,该数据集包含 345 个类别的 500 多万张图画。以下是数据集中的一些示例:

Quick, Draw! dataset

模型架构

我们将微调 `apple/mobilevit-small`,这是一个轻量级且适合移动设备的视觉 Transformer,已在 ImageNet-1k 上进行预训练。它只有 5.6M 个参数(文件大小约 20MB),是浏览器内运行的理想选择!有关更多信息,请查看 MobileViT 论文和下面的模型架构。

MobileViT archtecture

微调

Open In Colab

为了使博客文章(相对)简短,我们准备了一个 Colab 笔记本,其中将向您展示我们对 `apple/mobilevit-small` 进行微调的精确步骤。总的来说,这包括

  1. 加载“Quick, Draw!”数据集

  2. 使用`MobileViTImageProcessor`转换数据集。

  3. 定义我们的排序函数评估指标

  4. 使用`MobileViTForImageClassification.from_pretrained`加载预训练的 MobileVIT 模型

  5. 使用`Trainer``TrainingArguments`辅助类训练模型。

  6. 使用🤗 Evaluate评估模型。

注:您可以在 Hugging Face Hub 上此处找到我们微调过的模型。

2. 使用 Transformers.js 在浏览器中运行

什么是 Transformers.js?

Transformers.js 是一个 JavaScript 库,允许您直接在浏览器中运行 🤗 Transformers(无需服务器)!它的设计旨在与 Python 库功能等效,这意味着您可以使用非常相似的 API 运行相同的预训练模型。

在幕后,Transformers.js 使用 ONNX Runtime,因此我们需要将我们微调过的 PyTorch 模型转换为 ONNX。

将我们的模型转换为 ONNX

幸运的是,🤗 Optimum 库使得将您的微调模型转换为 ONNX 变得超级简单!最简单(也是推荐的)方法是

  1. 克隆 Transformers.js 仓库并安装必要的依赖项

    git clone https://github.com/xenova/transformers.js.git
    cd transformers.js
    pip install -r scripts/requirements.txt
    
  2. 运行转换脚本(它在底层使用 Optimum

    python -m scripts.convert --model_id <model_id>
    

    其中 <model_id> 是您要转换的模型名称(例如 Xenova/quickdraw-mobilevit-small)。

设置我们的项目

让我们首先使用 Vite 搭建一个简单的 React 应用程序

npm create vite@latest doodle-dash -- --template react

接下来,进入项目目录并安装必要的依赖项

cd doodle-dash
npm install
npm install @xenova/transformers

然后可以通过运行以下命令启动开发服务器

npm run dev

在浏览器中运行模型

运行机器学习模型计算密集,因此在单独的线程中执行推理非常重要。这样我们就不会阻塞主线程,主线程用于渲染 UI 和响应您的绘图手势😉。 Web Workers API 使这变得超级简单!

在 `src` 目录中创建一个新文件(例如,`worker.js`),并添加以下代码

import { pipeline, RawImage } from "@xenova/transformers";

const classifier = await pipeline("image-classification", 'Xenova/quickdraw-mobilevit-small', { quantized: false });

const image = await RawImage.read('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png');

const output = await classifier(image.grayscale());
console.log(output);

我们现在可以通过将以下代码添加到 `App` 组件中,在 `App.jsx` 文件中使用此 worker

import { useState, useEffect, useRef } from 'react'
// ... rest of the imports

function App() {
    // Create a reference to the worker object.
    const worker = useRef(null);

    // We use the `useEffect` hook to set up the worker as soon as the `App` component is mounted.
    useEffect(() => {
        if (!worker.current) {
            // Create the worker if it does not yet exist.
            worker.current = new Worker(new URL('./worker.js', import.meta.url), {
                type: 'module'
            });
        }

        // Create a callback function for messages from the worker thread.
        const onMessageReceived = (e) => { /* See code */ };

        // Attach the callback function as an event listener.
        worker.current.addEventListener('message', onMessageReceived);

        // Define a cleanup function for when the component is unmounted.
        return () => worker.current.removeEventListener('message', onMessageReceived);
    });

    // ... rest of the component
}

您可以通过运行开发服务器(使用 `npm run dev`),访问本地网站(通常是 https://:5173/),并打开浏览器控制台来测试一切是否正常工作。您应该会看到模型输出被记录到控制台。

[{ label: "skateboard", score: 0.9980043172836304 }]

太棒了!🥳 尽管上述代码只是最终产品的一小部分,但它展示了机器学习方面的简单性!其余的只是使其看起来美观并添加一些游戏逻辑。

3. 游戏设计

在本节中,我将简要讨论游戏设计过程。提醒一下,您可以在 GitHub 上找到该项目的完整源代码,因此我不会详细介绍代码本身。

利用实时性能

在浏览器中进行推理的主要优点之一是我们可以实时进行预测(每秒超过 60 次)。在最初的 Quick, Draw! 游戏中,模型每隔几秒才进行一次新预测。我们可以在我们的游戏中做同样的事情,但那样我们就无法利用其实时性能了!因此,我决定重新设计主游戏循环

  • 我们的版本不再是六个 20 秒的回合(每个回合对应一个新单词),而是让玩家在 60 秒内正确绘制尽可能多的涂鸦(一次一个提示)。
  • 如果您遇到无法绘制的单词,您可以跳过它(但这将花费您剩余时间的 3 秒)。
  • 在原版游戏中,由于模型每隔几秒钟就会进行一次猜测,它会慢慢地从列表中划掉标签,直到最终猜对。在我们的版本中,我们反而会降低模型对前 n 个错误标签的分数,其中 n 会随着用户持续绘图而逐渐增加。

生活质量改进

原始数据集包含 345 个不同的类别,由于我们的模型相对较小(约 20MB),它有时无法正确猜测某些类别。为了解决这个问题,我们删除了一些单词,这些单词要么是

  • 与其他标签过于相似(例如,“谷仓”与“房子”)
  • 太难理解(例如,“动物迁徙”)
  • 太难画出足够的细节(例如,“大脑”)
  • 模棱两可(例如,“蝙蝠”)

筛选后,我们仍然剩下 300 多个不同的类别!

额外内容:构思名称

本着开源的精神,我决定向 Hugging Chat 寻求一些游戏名称的想法……毋庸置疑,它没有让我失望!

Game name suggestions by Hugging Chat

我喜欢“涂鸦冲刺”(建议 #4)的头韵,所以我决定选择这个。感谢 Hugging Chat!🤗


希望您喜欢和我一起制作这款游戏!如果您有任何问题或建议,可以在 TwitterGitHub🤗 Hub 上找到我。此外,如果您想改进游戏(游戏模式?道具?动画?音效?),请随意派生项目并提交拉取请求!我很乐意看到您的作品!

附言:别忘了参加开源 AI 游戏节!希望这篇博文能激发您使用 Transformers.js 制作自己的网页游戏!😉 游戏节见!🚀

社区

有趣的游戏

太棒了

注册登录发表评论