Hub 文档
在 Hugging Face 上使用 Stable-Baselines3
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
在 Hugging Face 上使用 Stable-Baselines3
stable-baselines3
是一组用 PyTorch 实现的可靠的强化学习算法。
在 Hub 中探索 Stable-Baselines3
您可以通过在模型页面的左侧进行筛选来找到 Stable-Baselines3 模型。
Hub 上的所有模型都具有有用的功能
- 自动生成的模型卡片,包含描述、训练配置等信息。
- 有助于发现的元数据标签。
- 用于与其他模型进行比较的评估结果。
- 一个视频小部件,您可以在其中观看您的智能体执行任务。
安装库
要安装 stable-baselines3
库,您需要安装两个软件包
stable-baselines3
:Stable-Baselines3 库。huggingface-sb3
:用于从 Hub 加载和上传 Stable-baselines3 模型的附加代码。
pip install stable-baselines3
pip install huggingface-sb3
使用现有模型
您可以使用 load_from_hub
函数从 Hub 简单地下载模型
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
您需要定义两个参数
--repo-id
:您要下载的 Hugging Face 仓库的名称。--filename
:您要下载的文件。
分享您的模型
您可以使用两个不同的函数轻松上传您的模型
package_to_hub()
:保存模型,评估它,生成模型卡片,并记录您的智能体的重放视频,然后再将完整的仓库推送到 Hub。
package_to_hub(model=model,
model_name="ppo-LunarLander-v2",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="ThomasSimonini/ppo-LunarLander-v2",
commit_message="Test commit")
您需要定义七个参数
--model
:您训练的模型。--model_architecture
:您的模型的架构名称 (DQN、PPO、A2C、SAC...)。--env_id
:环境名称。--eval_env
:用于评估智能体的环境。--repo-id
:您要创建或更新的 Hugging Face 仓库的名称。它是<您的 huggingface 用户名>/<仓库名称>
。--commit-message
.--filename
:您要推送到 Hub 的文件。
push_to_hub()
:简单地将文件推送到 Hub
push_to_hub(
repo_id="ThomasSimonini/ppo-LunarLander-v2",
filename="ppo-LunarLander-v2.zip",
commit_message="Added LunarLander-v2 model trained with PPO",
)
您需要定义三个参数
--repo-id
:您要创建或更新的 Hugging Face 仓库的名称。它是<您的 huggingface 用户名>/<仓库名称>
。--filename
:您要推送到 Hub 的文件。--commit-message
.