在 Hugging Face 上使用 Stable-Baselines3
stable-baselines3
是一套使用 PyTorch 实现的可靠强化学习算法。
在 Hub 中探索 Stable-Baselines3
您可以在 模型页面 左侧的筛选器中找到 Stable-Baselines3 模型。
Hub 上的所有模型都附带了有用的功能
- 自动生成的模型卡片,其中包含描述、训练配置等。
- 有助于发现的元数据标签。
- 评估结果,以便与其他模型进行比较。
- 一个视频小部件,您可以在其中观看您的代理执行操作。
安装库
要安装 stable-baselines3
库,您需要安装两个包
stable-baselines3
:Stable-Baselines3 库。huggingface-sb3
:加载和从 Hub 上传 Stable-baselines3 模型的其他代码。
pip install stable-baselines3
pip install huggingface-sb3
使用现有模型
您可以使用 load_from_hub
函数简单地从 Hub 下载模型
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
您需要定义两个参数
--repo-id
:您要下载的 Hugging Face 仓库的名称。--filename
:您要下载的文件。
分享您的模型
您可以使用两种不同的函数轻松上传您的模型
package_to_hub()
:保存模型、评估模型、生成模型卡片并在将完整仓库推送到 Hub 之前记录代理的回放视频。
package_to_hub(model=model,
model_name="ppo-LunarLander-v2",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="ThomasSimonini/ppo-LunarLander-v2",
commit_message="Test commit")
您需要定义七个参数
--model
:您训练的模型。--model_architecture
:模型架构的名称(DQN、PPO、A2C、SAC 等)。--env_id
:环境的名称。--eval_env
:用于评估代理的环境。--repo-id
:您要创建或更新的 Hugging Face 仓库的名称。它是<您的 huggingface 用户名>/<仓库名称>
。--提交消息
.--filename
:您要推送到 Hub 的文件。
push_to_hub()
:将文件简单推送到 Hub
push_to_hub(
repo_id="ThomasSimonini/ppo-LunarLander-v2",
filename="ppo-LunarLander-v2.zip",
commit_message="Added LunarLander-v2 model trained with PPO",
)
您需要定义三个参数
--repo-id
:您要创建或更新的 Hugging Face 仓库的名称。它是<您的 huggingface 用户名>/<仓库名称>
。--filename
:您要推送到 Hub 的文件。--提交消息
.