欢迎 Stable-baselines3 加入 Hugging Face Hub 🤗

发布于 2022 年 1 月 21 日
在 GitHub 上更新

在 Hugging Face,我们正致力于为深度强化学习研究人员和爱好者们打造一个良好的生态系统。因此,我们很高兴地宣布,我们将 Stable-Baselines3 集成到了 Hugging Face Hub。

Stable-Baselines3 是最受欢迎的 PyTorch 深度强化学习库之一,它能让你在各种环境(Gym、Atari、MuJoco、Procgen 等)中轻松训练和测试你的智能体。通过这次集成,你现在可以托管你保存的模型 💾,并从社区中加载强大的模型。

在本文中,我们将向你展示如何操作。

安装

要将 stable-baselines3 与 Hugging Face Hub 一起使用,你只需安装这两个库即可

pip install huggingface_hub
pip install huggingface_sb3

寻找模型

我们目前正在上传玩《太空侵略者 (Space Invaders)》、《打砖块 (Breakout)》、《月球着陆器 (LunarLander)》等游戏的智能体模型。除此之外,你可以在这里找到社区中所有 stable-baselines-3 模型

当你找到所需的模型时,只需复制仓库 ID 即可。

Image showing how to copy a repository id

从 Hub 下载模型

本次集成最酷的功能是,你现在可以非常轻松地将 Hub 上保存的模型加载到 Stable-baselines3 中。

为此,你只需复制包含已保存模型的仓库 ID (repo-id) 以及仓库中已保存模型的 zip 文件名。

例如:sb3/demo-hf-CartPole-v1

import gym

from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository including the extension .zip
checkpoint = load_from_hub(
    repo_id="sb3/demo-hf-CartPole-v1",
    filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)

# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
    model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

将模型分享到 Hub

只需一分钟,你就可以将保存的模型上传到 Hub。

首先,你需要登录 Hugging Face 才能上传模型。

  • 如果你正在使用 Colab/Jupyter Notebooks
from huggingface_hub import notebook_login
notebook_login()
  • 否则
huggingface-cli login

然后,在这个例子中,我们训练一个 PPO 智能体来玩 CartPole-v1,并将其推送到一个新的仓库 `ThomasSimonini/demo-hf-CartPole-v1`。

from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO

# Define a PPO model with MLP policy network
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)

# Train it for 10000 timesteps
model.learn(total_timesteps=10_000)

# Save the model
model.save("ppo-CartPole-v1")

# Push this saved model to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
    repo_id="ThomasSimonini/demo-hf-CartPole-v1",
    filename="ppo-CartPole-v1.zip",
    commit_message="Added Cartpole-v1 model trained with PPO",
)

快来试试并与社区分享你的模型吧!

下一步?

在接下来的几周和几个月里,我们将通过以下方式扩展生态系统:

  • 集成 RL-baselines3-zoo
  • RL-trained-agents 模型上传到 Hub:这是一个使用 stable-baselines3 预训练的强化学习智能体的庞大集合。
  • 集成其他深度强化学习库
  • 实现 Decision Transformers 🔥
  • 以及更多即将推出的内容 🥳

保持联系的最佳方式是加入我们的 discord 服务器,与我们以及社区进行交流。

如果你想更深入地了解,我们编写了一篇教程,你将学到:

  • 如何训练一个深度强化学习着陆器智能体,使其正确地在月球上着陆 🌕
  • 如何将其上传到 Hub 🚀

gif

  • 如何从 Hub 下载并使用一个玩《太空侵略者》的已保存模型 👾。

gif

👉 教程

结论

我们很高兴看到你使用 Stable-baselines3 进行的工作,并期待在 Hub 中试用你的模型 😍。

我们也很乐意听到你的反馈 💖。 📧 欢迎随时联系我们

最后,我们要感谢 SB3 团队,特别是 Antonin Raffin,感谢他们为该库的集成提供的宝贵帮助 🤗。

你想将你的库集成到 Hub 吗?

这次集成是借助 huggingface_hub 库实现的,该库包含了我们所有的组件以及所有支持的库的 API。如果你想将你的库集成到 Hub,我们为你准备了一份指南

社区

注册登录以发表评论