Hub 文档

在 Hugging Face 使用 Stable-Baselines3

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

在 Hugging Face 上使用 Stable-Baselines3

stable-baselines3 是一套使用 PyTorch 实现的可靠强化学习算法。

在 Hub 中探索 Stable-Baselines3

您可以在 模型页面 左侧的筛选器中找到 Stable-Baselines3 模型。

Hub 上的所有模型都附带了有用的功能

  1. 自动生成的模型卡片,其中包含描述、训练配置等。
  2. 有助于发现的元数据标签。
  3. 评估结果,以便与其他模型进行比较。
  4. 一个视频小部件,您可以在其中观看您的代理执行操作。

安装库

要安装 stable-baselines3 库,您需要安装两个包

  • stable-baselines3:Stable-Baselines3 库。
  • huggingface-sb3:加载和从 Hub 上传 Stable-baselines3 模型的其他代码。
pip install stable-baselines3
pip install huggingface-sb3

使用现有模型

您可以使用 load_from_hub 函数简单地从 Hub 下载模型

checkpoint = load_from_hub(
    repo_id="sb3/demo-hf-CartPole-v1",
    filename="ppo-CartPole-v1.zip",
)

您需要定义两个参数

  • --repo-id:您要下载的 Hugging Face 仓库的名称。
  • --filename:您要下载的文件。

分享您的模型

您可以使用两种不同的函数轻松上传您的模型

  1. package_to_hub():保存模型、评估模型、生成模型卡片并在将完整仓库推送到 Hub 之前记录代理的回放视频。
package_to_hub(model=model, 
               model_name="ppo-LunarLander-v2",
               model_architecture="PPO",
               env_id=env_id,
               eval_env=eval_env,
               repo_id="ThomasSimonini/ppo-LunarLander-v2",
               commit_message="Test commit")

您需要定义七个参数

  • --model:您训练的模型。
  • --model_architecture:模型架构的名称(DQN、PPO、A2C、SAC 等)。
  • --env_id:环境的名称。
  • --eval_env:用于评估代理的环境。
  • --repo-id:您要创建或更新的 Hugging Face 仓库的名称。它是 <您的 huggingface 用户名>/<仓库名称>
  • --提交消息.
  • --filename:您要推送到 Hub 的文件。
  1. push_to_hub():将文件简单推送到 Hub
push_to_hub(
    repo_id="ThomasSimonini/ppo-LunarLander-v2",
    filename="ppo-LunarLander-v2.zip",
    commit_message="Added LunarLander-v2 model trained with PPO",
)

您需要定义三个参数

  • --repo-id:您要创建或更新的 Hugging Face 仓库的名称。它是 <您的 huggingface 用户名>/<仓库名称>
  • --filename:您要推送到 Hub 的文件。
  • --提交消息.

其他资源

  • Hugging Face Stable-Baselines3 文档
  • Stable-Baselines3 文档
< > 在 GitHub 上更新