├── .gitignore ├── 1_getting_started.ipynb ├── 2_gym_wrappers_saving_loading.ipynb ├── 3_multiprocessing.ipynb ├── 4_callbacks_hyperparameter_tuning.ipynb ├── 5_custom_gym_env.ipynb ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | *.zip 4 | videos/ -------------------------------------------------------------------------------- /1_getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": { 7 | "colab_type": "text", 8 | "id": "view-in-github" 9 | }, 10 | "source": [ 11 | "\"Open" 12 | ] 13 | }, 14 | { 15 | "attachments": {}, 16 | "cell_type": "markdown", 17 | "metadata": { 18 | "colab_type": "text", 19 | "id": "hyyN-2qyK_T2" 20 | }, 21 | "source": [ 22 | "# Stable Baselines3 Tutorial - Getting Started\n", 23 | "\n", 24 | "Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/\n", 25 | "\n", 26 | "Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3\n", 27 | "\n", 28 | "Documentation: https://stable-baselines3.readthedocs.io/en/master/\n", 29 | "\n", 30 | "SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib\n", 31 | "\n", 32 | "RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo\n", 33 | "\n", 34 | "[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL), using Stable Baselines3.\n", 35 | "\n", 36 | "It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.\n", 37 | "\n", 38 | "\n", 39 | "## Introduction\n", 40 | "\n", 41 | "In this notebook, you will learn the basics for using stable baselines library: how to create a RL model, train it and evaluate it. Because all algorithms share the same interface, we will see how simple it is to switch from one algorithm to another.\n", 42 | "\n", 43 | "\n", 44 | "## Install Dependencies and Stable Baselines3 Using Pip\n", 45 | "\n", 46 | "List of full dependencies can be found in the [README](https://github.com/DLR-RM/stable-baselines3).\n", 47 | "\n", 48 | "\n", 49 | "```\n", 50 | "pip install stable-baselines3[extra]\n", 51 | "```" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# for autoformatting\n", 61 | "# %load_ext jupyter_black" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "colab": { 69 | "base_uri": "https://localhost:8080/", 70 | "height": 784 71 | }, 72 | "colab_type": "code", 73 | "id": "gWskDE2c9WoN", 74 | "outputId": "03477445-4249-49c3-ddba-4e12df09e98e" 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "!apt-get install ffmpeg freeglut3-dev xvfb # For visualization\n", 79 | "!pip install \"stable-baselines3[extra]>=2.0.0a4\"" 80 | ] 81 | }, 82 | { 83 | "attachments": {}, 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "colab_type": "text", 87 | "id": "FtY8FhliLsGm" 88 | }, 89 | "source": [ 90 | "## Imports" 91 | ] 92 | }, 93 | { 94 | "attachments": {}, 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "colab_type": "text", 98 | "id": "gcX8hEcaUpR0" 99 | }, 100 | "source": [ 101 | "Stable-Baselines3 works on environments that follow the [gym interface](https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html).\n", 102 | "You can find a list of available environment [here](https://gymnasium.farama.org/environments/classic_control/).\n", 103 | "\n", 104 | "Not all algorithms can work with all action spaces, you can find more in this [recap table](https://stable-baselines3.readthedocs.io/en/master/guide/algos.html)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "colab": {}, 112 | "colab_type": "code", 113 | "id": "BIedd7Pz9sOs" 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "import gymnasium as gym\n", 118 | "import numpy as np" 119 | ] 120 | }, 121 | { 122 | "attachments": {}, 123 | "cell_type": "markdown", 124 | "metadata": { 125 | "colab_type": "text", 126 | "id": "Ae32CtgzTG3R" 127 | }, 128 | "source": [ 129 | "The first thing you need to import is the RL model, check the documentation to know what you can use on which problem" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": { 136 | "colab": {}, 137 | "colab_type": "code", 138 | "id": "R7tKaBFrTR0a" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "from stable_baselines3 import PPO" 143 | ] 144 | }, 145 | { 146 | "attachments": {}, 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "colab_type": "text", 150 | "id": "-0_8OQbOTTNT" 151 | }, 152 | "source": [ 153 | "The next thing you need to import is the policy class that will be used to create the networks (for the policy/value functions).\n", 154 | "This step is optional as you can directly use strings in the constructor: \n", 155 | "\n", 156 | "```PPO('MlpPolicy', env)``` instead of ```PPO(MlpPolicy, env)```\n", 157 | "\n", 158 | "Note that some algorithms like `SAC` have their own `MlpPolicy`, that's why using string for the policy is the recommended option." 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "colab": {}, 166 | "colab_type": "code", 167 | "id": "ROUJr675TT01" 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "from stable_baselines3.ppo.policies import MlpPolicy" 172 | ] 173 | }, 174 | { 175 | "attachments": {}, 176 | "cell_type": "markdown", 177 | "metadata": { 178 | "colab_type": "text", 179 | "id": "RapkYvTXL7Cd" 180 | }, 181 | "source": [ 182 | "## Create the Gym env and instantiate the agent\n", 183 | "\n", 184 | "For this example, we will use CartPole environment, a classic control problem.\n", 185 | "\n", 186 | "\"A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. \"\n", 187 | "\n", 188 | "Cartpole environment: [https://gymnasium.farama.org/environments/classic_control/cart_pole/](https://gymnasium.farama.org/environments/classic_control/cart_pole/)\n", 189 | "\n", 190 | "![Cartpole](https://cdn-images-1.medium.com/max/1143/1*h4WTQNVIsvMXJTCpXm_TAw.gif)\n", 191 | "\n", 192 | "\n", 193 | "We chose the MlpPolicy because the observation of the CartPole task is a feature vector, not images.\n", 194 | "\n", 195 | "The type of action to use (discrete/continuous) will be automatically deduced from the environment action space\n", 196 | "\n", 197 | "Here we are using the [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) algorithm, which is an Actor-Critic method: it uses a value function to improve the policy gradient descent (by reducing the variance).\n", 198 | "\n", 199 | "It combines ideas from [A2C](https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html) (having multiple workers and using an entropy bonus for exploration) and [TRPO](https://stable-baselines.readthedocs.io/en/master/modules/trpo.html) (it uses a trust region to improve stability and avoid catastrophic drops in performance).\n", 200 | "\n", 201 | "PPO is an on-policy algorithm, which means that the trajectories used to update the networks must be collected using the latest policy.\n", 202 | "It is usually less sample efficient than off-policy alorithms like [DQN](https://stable-baselines.readthedocs.io/en/master/modules/dqn.html), [SAC](https://stable-baselines3.readthedocs.io/en/master/modules/sac.html) or [TD3](https://stable-baselines3.readthedocs.io/en/master/modules/td3.html), but is much faster regarding wall-clock time.\n" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "colab": {}, 210 | "colab_type": "code", 211 | "id": "pUWGZp3i9wyf" 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "env = gym.make(\"CartPole-v1\")\n", 216 | "\n", 217 | "model = PPO(MlpPolicy, env, verbose=0)" 218 | ] 219 | }, 220 | { 221 | "attachments": {}, 222 | "cell_type": "markdown", 223 | "metadata": { 224 | "colab_type": "text", 225 | "id": "4efFdrQ7MBvl" 226 | }, 227 | "source": [ 228 | "We create a helper function to evaluate the agent:" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "colab": {}, 236 | "colab_type": "code", 237 | "id": "63M8mSKR-6Zt" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "from stable_baselines3.common.base_class import BaseAlgorithm\n", 242 | "\n", 243 | "\n", 244 | "def evaluate(\n", 245 | " model: BaseAlgorithm,\n", 246 | " num_episodes: int = 100,\n", 247 | " deterministic: bool = True,\n", 248 | ") -> float:\n", 249 | " \"\"\"\n", 250 | " Evaluate an RL agent for `num_episodes`.\n", 251 | "\n", 252 | " :param model: the RL Agent\n", 253 | " :param env: the gym Environment\n", 254 | " :param num_episodes: number of episodes to evaluate it\n", 255 | " :param deterministic: Whether to use deterministic or stochastic actions\n", 256 | " :return: Mean reward for the last `num_episodes`\n", 257 | " \"\"\"\n", 258 | " # This function will only work for a single environment\n", 259 | " vec_env = model.get_env()\n", 260 | " obs = vec_env.reset()\n", 261 | " all_episode_rewards = []\n", 262 | " for _ in range(num_episodes):\n", 263 | " episode_rewards = []\n", 264 | " done = False\n", 265 | " # Note: SB3 VecEnv resets automatically:\n", 266 | " # https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api\n", 267 | " # obs = vec_env.reset()\n", 268 | " while not done:\n", 269 | " # _states are only useful when using LSTM policies\n", 270 | " # `deterministic` is to use deterministic actions\n", 271 | " action, _states = model.predict(obs, deterministic=deterministic)\n", 272 | " # here, action, rewards and dones are arrays\n", 273 | " # because we are using vectorized env\n", 274 | " obs, reward, done, _info = vec_env.step(action)\n", 275 | " episode_rewards.append(reward)\n", 276 | "\n", 277 | " all_episode_rewards.append(sum(episode_rewards))\n", 278 | "\n", 279 | " mean_episode_reward = np.mean(all_episode_rewards)\n", 280 | " print(f\"Mean reward: {mean_episode_reward:.2f} - Num episodes: {num_episodes}\")\n", 281 | "\n", 282 | " return mean_episode_reward" 283 | ] 284 | }, 285 | { 286 | "attachments": {}, 287 | "cell_type": "markdown", 288 | "metadata": { 289 | "colab_type": "text", 290 | "id": "zjEVOIY8NVeK" 291 | }, 292 | "source": [ 293 | "Let's evaluate the un-trained agent, this should be a random agent." 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": { 300 | "colab": { 301 | "base_uri": "https://localhost:8080/", 302 | "height": 35 303 | }, 304 | "colab_type": "code", 305 | "id": "xDHLMA6NFk95", 306 | "outputId": "231b2170-a607-48ed-e9d9-daef596f6384" 307 | }, 308 | "outputs": [], 309 | "source": [ 310 | "# Random Agent, before training\n", 311 | "mean_reward_before_train = evaluate(model, num_episodes=100, deterministic=True)" 312 | ] 313 | }, 314 | { 315 | "attachments": {}, 316 | "cell_type": "markdown", 317 | "metadata": { 318 | "colab_type": "text", 319 | "id": "QjjPxrwkYJ2i" 320 | }, 321 | "source": [ 322 | "Stable-Baselines already provides you with that helper:" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "colab": {}, 330 | "colab_type": "code", 331 | "id": "8z6K9YImYJEx" 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "from stable_baselines3.common.evaluation import evaluate_policy" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": { 342 | "colab": {}, 343 | "colab_type": "code", 344 | "id": "4oPTHjxyZSOL" 345 | }, 346 | "outputs": [], 347 | "source": [ 348 | "mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n", 349 | "\n", 350 | "print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")" 351 | ] 352 | }, 353 | { 354 | "attachments": {}, 355 | "cell_type": "markdown", 356 | "metadata": { 357 | "colab_type": "text", 358 | "id": "r5UoXTZPNdFE" 359 | }, 360 | "source": [ 361 | "## Train the agent and evaluate it" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "colab": {}, 369 | "colab_type": "code", 370 | "id": "e4cfSXIB-pTF" 371 | }, 372 | "outputs": [], 373 | "source": [ 374 | "# Train the agent for 10000 steps\n", 375 | "model.learn(total_timesteps=10_000)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": { 382 | "colab": {}, 383 | "colab_type": "code", 384 | "id": "ygl_gVmV_QP7" 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "# Evaluate the trained agent\n", 389 | "mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)\n", 390 | "\n", 391 | "print(f\"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}\")" 392 | ] 393 | }, 394 | { 395 | "attachments": {}, 396 | "cell_type": "markdown", 397 | "metadata": { 398 | "colab_type": "text", 399 | "id": "A00W6yY3NkHG" 400 | }, 401 | "source": [ 402 | "Apparently the training went well, the mean reward increased a lot ! " 403 | ] 404 | }, 405 | { 406 | "attachments": {}, 407 | "cell_type": "markdown", 408 | "metadata": { 409 | "colab_type": "text", 410 | "id": "xVm9QPNVwKXN" 411 | }, 412 | "source": [ 413 | "### Prepare video recording" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": { 420 | "colab": {}, 421 | "colab_type": "code", 422 | "id": "MPyfQxD5z26J" 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "# Set up fake display; otherwise rendering will fail\n", 427 | "import os\n", 428 | "os.system(\"Xvfb :1 -screen 0 1024x768x24 &\")\n", 429 | "os.environ['DISPLAY'] = ':1'" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "colab": {}, 437 | "colab_type": "code", 438 | "id": "SLzXxO8VMD6N" 439 | }, 440 | "outputs": [], 441 | "source": [ 442 | "import base64\n", 443 | "from pathlib import Path\n", 444 | "\n", 445 | "from IPython import display as ipythondisplay\n", 446 | "\n", 447 | "\n", 448 | "def show_videos(video_path=\"\", prefix=\"\"):\n", 449 | " \"\"\"\n", 450 | " Taken from https://github.com/eleurent/highway-env\n", 451 | "\n", 452 | " :param video_path: (str) Path to the folder containing videos\n", 453 | " :param prefix: (str) Filter the video, showing only the only starting with this prefix\n", 454 | " \"\"\"\n", 455 | " html = []\n", 456 | " for mp4 in Path(video_path).glob(\"{}*.mp4\".format(prefix)):\n", 457 | " video_b64 = base64.b64encode(mp4.read_bytes())\n", 458 | " html.append(\n", 459 | " \"\"\"\"\"\".format(\n", 463 | " mp4, video_b64.decode(\"ascii\")\n", 464 | " )\n", 465 | " )\n", 466 | " ipythondisplay.display(ipythondisplay.HTML(data=\"
\".join(html)))" 467 | ] 468 | }, 469 | { 470 | "attachments": {}, 471 | "cell_type": "markdown", 472 | "metadata": { 473 | "colab_type": "text", 474 | "id": "LTRNUfulOGaF" 475 | }, 476 | "source": [ 477 | "We will record a video using the [VecVideoRecorder](https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecvideorecorder) wrapper, you will learn about those wrapper in the next notebook." 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": { 484 | "colab": {}, 485 | "colab_type": "code", 486 | "id": "Trag9dQpOIhx" 487 | }, 488 | "outputs": [], 489 | "source": [ 490 | "from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv\n", 491 | "\n", 492 | "\n", 493 | "def record_video(env_id, model, video_length=500, prefix=\"\", video_folder=\"videos/\"):\n", 494 | " \"\"\"\n", 495 | " :param env_id: (str)\n", 496 | " :param model: (RL model)\n", 497 | " :param video_length: (int)\n", 498 | " :param prefix: (str)\n", 499 | " :param video_folder: (str)\n", 500 | " \"\"\"\n", 501 | " eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode=\"rgb_array\")])\n", 502 | " # Start the video at step=0 and record 500 steps\n", 503 | " eval_env = VecVideoRecorder(\n", 504 | " eval_env,\n", 505 | " video_folder=video_folder,\n", 506 | " record_video_trigger=lambda step: step == 0,\n", 507 | " video_length=video_length,\n", 508 | " name_prefix=prefix,\n", 509 | " )\n", 510 | "\n", 511 | " obs = eval_env.reset()\n", 512 | " for _ in range(video_length):\n", 513 | " action, _ = model.predict(obs)\n", 514 | " obs, _, _, _ = eval_env.step(action)\n", 515 | "\n", 516 | " # Close the video recorder\n", 517 | " eval_env.close()" 518 | ] 519 | }, 520 | { 521 | "attachments": {}, 522 | "cell_type": "markdown", 523 | "metadata": { 524 | "colab_type": "text", 525 | "id": "KOObbeu5MMlR" 526 | }, 527 | "source": [ 528 | "### Visualize trained agent\n", 529 | "\n" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": { 536 | "colab": { 537 | "base_uri": "https://localhost:8080/", 538 | "height": 35 539 | }, 540 | "colab_type": "code", 541 | "id": "iATu7AiyMQW2", 542 | "outputId": "68acb027-6c94-4389-8456-2cfb11494814" 543 | }, 544 | "outputs": [], 545 | "source": [ 546 | "record_video(\"CartPole-v1\", model, video_length=500, prefix=\"ppo-cartpole\")" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "metadata": { 553 | "colab": {}, 554 | "colab_type": "code", 555 | "id": "-n4i-fW3NojZ" 556 | }, 557 | "outputs": [], 558 | "source": [ 559 | "show_videos(\"videos\", prefix=\"ppo\")" 560 | ] 561 | }, 562 | { 563 | "attachments": {}, 564 | "cell_type": "markdown", 565 | "metadata": { 566 | "colab_type": "text", 567 | "id": "9Y8zg4V566qD" 568 | }, 569 | "source": [ 570 | "## Bonus: Train a RL Model in One Line\n", 571 | "\n", 572 | "The policy class to use will be inferred and the environment will be automatically created. This works because both are [registered](https://stable-baselines3.readthedocs.io/en/master/guide/quickstart.html)." 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": null, 578 | "metadata": { 579 | "colab": {}, 580 | "colab_type": "code", 581 | "id": "iaOPfOrwWEP4" 582 | }, 583 | "outputs": [], 584 | "source": [ 585 | "model = PPO('MlpPolicy', \"CartPole-v1\", verbose=1).learn(1000)" 586 | ] 587 | }, 588 | { 589 | "attachments": {}, 590 | "cell_type": "markdown", 591 | "metadata": { 592 | "colab_type": "text", 593 | "id": "FrI6f5fWnzp-" 594 | }, 595 | "source": [ 596 | "## Conclusion\n", 597 | "\n", 598 | "In this notebook we have seen:\n", 599 | "- how to define and train a RL model using stable baselines3, it takes only one line of code ;)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": { 606 | "colab": {}, 607 | "colab_type": "code", 608 | "id": "73ji3gbNDkf7" 609 | }, 610 | "outputs": [], 611 | "source": [] 612 | } 613 | ], 614 | "metadata": { 615 | "accelerator": "GPU", 616 | "colab": { 617 | "collapsed_sections": [], 618 | "include_colab_link": true, 619 | "name": "1_getting_started.ipynb", 620 | "provenance": [] 621 | }, 622 | "kernelspec": { 623 | "display_name": "Python 3 (ipykernel)", 624 | "language": "python", 625 | "name": "python3" 626 | }, 627 | "language_info": { 628 | "codemirror_mode": { 629 | "name": "ipython", 630 | "version": 3 631 | }, 632 | "file_extension": ".py", 633 | "mimetype": "text/x-python", 634 | "name": "python", 635 | "nbconvert_exporter": "python", 636 | "pygments_lexer": "ipython3", 637 | "version": "3.10.9" 638 | } 639 | }, 640 | "nbformat": 4, 641 | "nbformat_minor": 4 642 | } 643 | -------------------------------------------------------------------------------- /2_gym_wrappers_saving_loading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "3ezJ3Y7XRUnj" 18 | }, 19 | "source": [ 20 | "# Stable Baselines3 Tutorial - Gym wrappers, saving and loading models\n", 21 | "\n", 22 | "Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/\n", 23 | "\n", 24 | "Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3\n", 25 | "\n", 26 | "Documentation: https://stable-baselines3.readthedocs.io/en/master/\n", 27 | "\n", 28 | "SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib\n", 29 | "\n", 30 | "RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo\n", 31 | "\n", 32 | "\n", 33 | "## Introduction\n", 34 | "\n", 35 | "In this notebook, you will learn how to use *Gym Wrappers* which allow to do monitoring, normalization, limit the number of steps, feature augmentation, ...\n", 36 | "\n", 37 | "\n", 38 | "You will also see the *loading* and *saving* functions, and how to read the outputted files for possible exporting.\n", 39 | "\n", 40 | "## Install Dependencies and Stable Baselines3 Using Pip" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# for autoformatting\n", 50 | "# %load_ext jupyter_black" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "colab": {}, 58 | "colab_type": "code", 59 | "id": "YFdlFByORUnl" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "!pip install swig\n", 64 | "!pip install \"stable-baselines3[extra]>=2.0.0a4\"" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "colab": {}, 72 | "colab_type": "code", 73 | "id": "grXe85G9RUnp" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "import gymnasium as gym\n", 78 | "from stable_baselines3 import A2C, SAC, PPO, TD3" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "colab_type": "text", 85 | "id": "hMPAn1SRd32f" 86 | }, 87 | "source": [ 88 | "# Saving and loading\n", 89 | "\n", 90 | "Saving and loading stable-baselines models is straightforward: you can directly call `.save()` and `.load()` on the models." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "colab": {}, 98 | "colab_type": "code", 99 | "id": "vBNFnN4Gd32g" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "import os\n", 104 | "\n", 105 | "# Create save dir\n", 106 | "save_dir = \"/tmp/gym/\"\n", 107 | "os.makedirs(save_dir, exist_ok=True)\n", 108 | "\n", 109 | "model = PPO(\"MlpPolicy\", \"Pendulum-v1\", verbose=0).learn(8_000)\n", 110 | "# The model will be saved under PPO_tutorial.zip\n", 111 | "model.save(f\"{save_dir}/PPO_tutorial\")\n", 112 | "\n", 113 | "# sample an observation from the environment\n", 114 | "obs = model.env.observation_space.sample()\n", 115 | "\n", 116 | "# Check prediction before saving\n", 117 | "print(\"pre saved\", model.predict(obs, deterministic=True))\n", 118 | "\n", 119 | "del model # delete trained model to demonstrate loading\n", 120 | "\n", 121 | "loaded_model = PPO.load(f\"{save_dir}/PPO_tutorial\")\n", 122 | "# Check that the prediction is the same after loading (for the same observation)\n", 123 | "print(\"loaded\", loaded_model.predict(obs, deterministic=True))" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "colab_type": "text", 130 | "id": "gXWPrVqId32o" 131 | }, 132 | "source": [ 133 | "Saving in stable-baselines is quite powerful, as you save the training hyperparameters, with the current weights. This means in practice, you can simply load a custom model, without redefining the parameters, and continue learning.\n", 134 | "\n", 135 | "The loading function can also update the model's class variables when loading." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "colab": {}, 143 | "colab_type": "code", 144 | "id": "LCtxrAbXd32q" 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "import os\n", 149 | "from stable_baselines3.common.vec_env import DummyVecEnv\n", 150 | "\n", 151 | "# Create save dir\n", 152 | "save_dir = \"/tmp/gym/\"\n", 153 | "os.makedirs(save_dir, exist_ok=True)\n", 154 | "\n", 155 | "model = A2C(\"MlpPolicy\", \"Pendulum-v1\", verbose=0, gamma=0.9, n_steps=20).learn(8000)\n", 156 | "# The model will be saved under A2C_tutorial.zip\n", 157 | "model.save(f\"{save_dir}/A2C_tutorial\")\n", 158 | "\n", 159 | "del model # delete trained model to demonstrate loading\n", 160 | "\n", 161 | "# load the model, and when loading set verbose to 1\n", 162 | "loaded_model = A2C.load(f\"{save_dir}/A2C_tutorial\", verbose=1)\n", 163 | "\n", 164 | "# show the save hyperparameters\n", 165 | "print(f\"loaded: gamma={loaded_model.gamma}, n_steps={loaded_model.n_steps}\")\n", 166 | "\n", 167 | "# as the environment is not serializable, we need to set a new instance of the environment\n", 168 | "loaded_model.set_env(DummyVecEnv([lambda: gym.make(\"Pendulum-v1\")]))\n", 169 | "# and continue training\n", 170 | "loaded_model.learn(8_000)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "colab_type": "text", 177 | "id": "hKwupU-Jgxjm" 178 | }, 179 | "source": [ 180 | "# Gym and VecEnv wrappers" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "colab_type": "text", 187 | "id": "ds4AAfmISQIA" 188 | }, 189 | "source": [ 190 | "## Anatomy of a gym wrapper" 191 | ] 192 | }, 193 | { 194 | "attachments": {}, 195 | "cell_type": "markdown", 196 | "metadata": { 197 | "colab_type": "text", 198 | "id": "gnTS9e9hTzZZ" 199 | }, 200 | "source": [ 201 | "A gym wrapper follows the [gym](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html) interface: it has a `reset()` and `step()` method.\n", 202 | "\n", 203 | "Because a wrapper is *around* an environment, we can access it with `self.env`, this allow to easily interact with it without modifying the original env.\n", 204 | "There are many wrappers that have been predefined, for a complete list refer to [gym documentation](https://gymnasium.farama.org/api/wrappers/)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "colab": {}, 212 | "colab_type": "code", 213 | "id": "hYo0C0TQSL3c" 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "class CustomWrapper(gym.Wrapper):\n", 218 | " \"\"\"\n", 219 | " :param env: (gym.Env) Gym environment that will be wrapped\n", 220 | " \"\"\"\n", 221 | "\n", 222 | " def __init__(self, env):\n", 223 | " # Call the parent constructor, so we can access self.env later\n", 224 | " super().__init__(env)\n", 225 | "\n", 226 | " def reset(self, **kwargs):\n", 227 | " \"\"\"\n", 228 | " Reset the environment\n", 229 | " \"\"\"\n", 230 | " obs, info = self.env.reset(**kwargs)\n", 231 | "\n", 232 | " return obs, info\n", 233 | "\n", 234 | " def step(self, action):\n", 235 | " \"\"\"\n", 236 | " :param action: ([float] or int) Action taken by the agent\n", 237 | " :return: (np.ndarray, float, bool, bool, dict) observation, reward, is this a final state (episode finished),\n", 238 | " is the max number of steps reached (episode finished artificially), additional informations\n", 239 | " \"\"\"\n", 240 | " obs, reward, terminated, truncated, info = self.env.step(action)\n", 241 | " return obs, reward, terminated, truncated, info" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": { 247 | "colab_type": "text", 248 | "id": "4zeGuyICUN26" 249 | }, 250 | "source": [ 251 | "## First example: limit the episode length\n", 252 | "\n", 253 | "One practical use case of a wrapper is when you want to limit the number of steps by episode, for that you will need to overwrite the `done` signal when the limit is reached. It is also a good practice to pass that information in the `info` dictionary." 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": { 260 | "colab": {}, 261 | "colab_type": "code", 262 | "id": "Eb2U4_K6SNUx" 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "class TimeLimitWrapper(gym.Wrapper):\n", 267 | " \"\"\"\n", 268 | " :param env: (gym.Env) Gym environment that will be wrapped\n", 269 | " :param max_steps: (int) Max number of steps per episode\n", 270 | " \"\"\"\n", 271 | "\n", 272 | " def __init__(self, env, max_steps=100):\n", 273 | " # Call the parent constructor, so we can access self.env later\n", 274 | " super(TimeLimitWrapper, self).__init__(env)\n", 275 | " self.max_steps = max_steps\n", 276 | " # Counter of steps per episode\n", 277 | " self.current_step = 0\n", 278 | "\n", 279 | " def reset(self, **kwargs):\n", 280 | " \"\"\"\n", 281 | " Reset the environment\n", 282 | " \"\"\"\n", 283 | " # Reset the counter\n", 284 | " self.current_step = 0\n", 285 | " return self.env.reset(**kwargs)\n", 286 | "\n", 287 | " def step(self, action):\n", 288 | " \"\"\"\n", 289 | " :param action: ([float] or int) Action taken by the agent\n", 290 | " :return: (np.ndarray, float, bool, bool, dict) observation, reward, is the episode over?, additional informations\n", 291 | " \"\"\"\n", 292 | " self.current_step += 1\n", 293 | " obs, reward, terminated, truncated, info = self.env.step(action)\n", 294 | " # Overwrite the truncation signal when when the number of steps reaches the maximum\n", 295 | " if self.current_step >= self.max_steps:\n", 296 | " truncated = True\n", 297 | " return obs, reward, terminated, truncated, info" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": { 303 | "colab_type": "text", 304 | "id": "oZufaUJwVM9w" 305 | }, 306 | "source": [ 307 | "#### Test the wrapper" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": { 314 | "colab": {}, 315 | "colab_type": "code", 316 | "id": "szZ43D5PVB07" 317 | }, 318 | "outputs": [], 319 | "source": [ 320 | "from gymnasium.envs.classic_control.pendulum import PendulumEnv\n", 321 | "\n", 322 | "# Here we create the environment directly because gym.make() already wrap the environment in a TimeLimit wrapper otherwise\n", 323 | "env = PendulumEnv()\n", 324 | "# Wrap the environment\n", 325 | "env = TimeLimitWrapper(env, max_steps=100)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": { 332 | "colab": {}, 333 | "colab_type": "code", 334 | "id": "cencka9iVg9V" 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "obs, _ = env.reset()\n", 339 | "done = False\n", 340 | "n_steps = 0\n", 341 | "while not done:\n", 342 | " # Take random actions\n", 343 | " random_action = env.action_space.sample()\n", 344 | " obs, reward, terminated, truncated, info = env.step(random_action)\n", 345 | " done = terminated or truncated\n", 346 | " n_steps += 1\n", 347 | "\n", 348 | "print(n_steps, info)" 349 | ] 350 | }, 351 | { 352 | "attachments": {}, 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "colab_type": "text", 356 | "id": "jkMYA63sV9aA" 357 | }, 358 | "source": [ 359 | "In practice, `gym` already have a wrapper for that named `TimeLimit` (`gym.wrappers.TimeLimit`) that is used by most environments." 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": { 365 | "colab_type": "text", 366 | "id": "VIIJbSyQW9R-" 367 | }, 368 | "source": [ 369 | "## Second example: normalize actions\n", 370 | "\n", 371 | "It is usually a good idea to normalize observations and actions before giving it to the agent, this prevents this [hard to debug issue](https://github.com/hill-a/stable-baselines/issues/473).\n", 372 | "\n", 373 | "In this example, we are going to normalize the action space of *Pendulum-v1* so it lies in [-1, 1] instead of [-2, 2].\n", 374 | "\n", 375 | "Note: here we are dealing with continuous actions, hence the `gym.Box` space" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": { 382 | "colab": {}, 383 | "colab_type": "code", 384 | "id": "F5E6kZfzW8vy" 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "import numpy as np\n", 389 | "\n", 390 | "\n", 391 | "class NormalizeActionWrapper(gym.Wrapper):\n", 392 | " \"\"\"\n", 393 | " :param env: (gym.Env) Gym environment that will be wrapped\n", 394 | " \"\"\"\n", 395 | "\n", 396 | " def __init__(self, env):\n", 397 | " # Retrieve the action space\n", 398 | " action_space = env.action_space\n", 399 | " assert isinstance(\n", 400 | " action_space, gym.spaces.Box\n", 401 | " ), \"This wrapper only works with continuous action space (spaces.Box)\"\n", 402 | " # Retrieve the max/min values\n", 403 | " self.low, self.high = action_space.low, action_space.high\n", 404 | "\n", 405 | " # We modify the action space, so all actions will lie in [-1, 1]\n", 406 | " env.action_space = gym.spaces.Box(\n", 407 | " low=-1, high=1, shape=action_space.shape, dtype=np.float32\n", 408 | " )\n", 409 | "\n", 410 | " # Call the parent constructor, so we can access self.env later\n", 411 | " super(NormalizeActionWrapper, self).__init__(env)\n", 412 | "\n", 413 | " def rescale_action(self, scaled_action):\n", 414 | " \"\"\"\n", 415 | " Rescale the action from [-1, 1] to [low, high]\n", 416 | " (no need for symmetric action space)\n", 417 | " :param scaled_action: (np.ndarray)\n", 418 | " :return: (np.ndarray)\n", 419 | " \"\"\"\n", 420 | " return self.low + (0.5 * (scaled_action + 1.0) * (self.high - self.low))\n", 421 | "\n", 422 | " def reset(self, **kwargs):\n", 423 | " \"\"\"\n", 424 | " Reset the environment\n", 425 | " \"\"\"\n", 426 | " return self.env.reset(**kwargs)\n", 427 | "\n", 428 | " def step(self, action):\n", 429 | " \"\"\"\n", 430 | " :param action: ([float] or int) Action taken by the agent\n", 431 | " :return: (np.ndarray, float,bool, bool, dict) observation, reward, final state? truncated?, additional informations\n", 432 | " \"\"\"\n", 433 | " # Rescale action from [-1, 1] to original [low, high] interval\n", 434 | " rescaled_action = self.rescale_action(action)\n", 435 | " obs, reward, terminated, truncated, info = self.env.step(rescaled_action)\n", 436 | " return obs, reward, terminated, truncated, info" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "colab_type": "text", 443 | "id": "TmJ0eahNaR6K" 444 | }, 445 | "source": [ 446 | "#### Test before rescaling actions" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "metadata": { 453 | "colab": {}, 454 | "colab_type": "code", 455 | "id": "UEnjBwisaQIx" 456 | }, 457 | "outputs": [], 458 | "source": [ 459 | "original_env = gym.make(\"Pendulum-v1\")\n", 460 | "\n", 461 | "print(original_env.action_space.low)\n", 462 | "for _ in range(10):\n", 463 | " print(original_env.action_space.sample())" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": { 469 | "colab_type": "text", 470 | "id": "jvcll2L3afVd" 471 | }, 472 | "source": [ 473 | "#### Test the NormalizeAction wrapper" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": { 480 | "colab": {}, 481 | "colab_type": "code", 482 | "id": "WsCM9AUGaeBN" 483 | }, 484 | "outputs": [], 485 | "source": [ 486 | "env = NormalizeActionWrapper(gym.make(\"Pendulum-v1\"))\n", 487 | "\n", 488 | "print(env.action_space.low)\n", 489 | "\n", 490 | "for _ in range(10):\n", 491 | " print(env.action_space.sample())" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": { 497 | "colab_type": "text", 498 | "id": "V5h5kk2mbGNs" 499 | }, 500 | "source": [ 501 | "#### Test with a RL algorithm\n", 502 | "\n", 503 | "We are going to use the Monitor wrapper of stable baselines, which allow to monitor training stats (mean episode reward, mean episode length)" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": { 510 | "colab": {}, 511 | "colab_type": "code", 512 | "id": "R9FNCN8ybOVU" 513 | }, 514 | "outputs": [], 515 | "source": [ 516 | "from stable_baselines3.common.monitor import Monitor\n", 517 | "from stable_baselines3.common.vec_env import DummyVecEnv" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": { 524 | "colab": { 525 | "base_uri": "https://localhost:8080/", 526 | "height": 53 527 | }, 528 | "colab_type": "code", 529 | "id": "wutM3c1GbfGP", 530 | "outputId": "eda3b489-ab0f-45cd-8acd-c36835f063df" 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "env = Monitor(gym.make(\"Pendulum-v1\"))\n", 535 | "env = DummyVecEnv([lambda: env])" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": { 542 | "colab": {}, 543 | "colab_type": "code", 544 | "id": "8cxnE5bdaQ_3" 545 | }, 546 | "outputs": [], 547 | "source": [ 548 | "model = A2C(\"MlpPolicy\", env, verbose=1).learn(int(1000))" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": { 554 | "colab_type": "text", 555 | "id": "EJFSM-Drb3Wc" 556 | }, 557 | "source": [ 558 | "With the action wrapper" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": null, 564 | "metadata": { 565 | "colab": {}, 566 | "colab_type": "code", 567 | "id": "GszFZthob2wM" 568 | }, 569 | "outputs": [], 570 | "source": [ 571 | "normalized_env = Monitor(gym.make(\"Pendulum-v1\"))\n", 572 | "# Note that we can use multiple wrappers\n", 573 | "normalized_env = NormalizeActionWrapper(normalized_env)\n", 574 | "normalized_env = DummyVecEnv([lambda: normalized_env])" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "metadata": { 581 | "colab": {}, 582 | "colab_type": "code", 583 | "id": "wrKJEO4NcIMd" 584 | }, 585 | "outputs": [], 586 | "source": [ 587 | "model_2 = A2C(\"MlpPolicy\", normalized_env, verbose=1).learn(int(1000))" 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": { 593 | "colab_type": "text", 594 | "id": "5BxqXd_6dpJx" 595 | }, 596 | "source": [ 597 | "## Additional wrappers: VecEnvWrappers\n", 598 | "\n", 599 | "In the same vein as gym wrappers, stable baselines provide wrappers for `VecEnv`. Among the different wrappers that exist (and you can create your own), you should know: \n", 600 | "\n", 601 | "- VecNormalize: it computes a running mean and standard deviation to normalize observation and returns\n", 602 | "- VecFrameStack: it stacks several consecutive observations (useful to integrate time in the observation, e.g. successive frame of an atari game)\n", 603 | "\n", 604 | "More info in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#wrappers)\n", 605 | "\n", 606 | "Note: when using `VecNormalize` wrapper, you must save the running mean and std along with the model, otherwise you will not get proper results when loading the agent again. If you use the [rl zoo](https://github.com/DLR-RM/rl-baselines3-zoo), this is done automatically" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": null, 612 | "metadata": { 613 | "colab": {}, 614 | "colab_type": "code", 615 | "id": "zuIcbfv3g9dd" 616 | }, 617 | "outputs": [], 618 | "source": [ 619 | "from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack\n", 620 | "\n", 621 | "env = DummyVecEnv([lambda: gym.make(\"Pendulum-v1\")])\n", 622 | "normalized_vec_env = VecNormalize(env)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": null, 628 | "metadata": { 629 | "colab": {}, 630 | "colab_type": "code", 631 | "id": "-PAbu21pg90A" 632 | }, 633 | "outputs": [], 634 | "source": [ 635 | "obs = normalized_vec_env.reset()\n", 636 | "for _ in range(10):\n", 637 | " action = [normalized_vec_env.action_space.sample()]\n", 638 | " obs, reward, _, _ = normalized_vec_env.step(action)\n", 639 | " print(obs, reward)" 640 | ] 641 | }, 642 | { 643 | "cell_type": "markdown", 644 | "metadata": { 645 | "colab_type": "text", 646 | "id": "UEpTys28Wz05" 647 | }, 648 | "source": [ 649 | "## Exercise: code you own monitor wrapper\n", 650 | "\n", 651 | "Now that you know how does a wrapper work and what you can do with it, it's time to experiment.\n", 652 | "\n", 653 | "The goal here is to create a wrapper that will monitor the training progress, storing both the episode reward (sum of reward for one episode) and episode length (number of steps in for the last episode).\n", 654 | "\n", 655 | "You will return those values using the `info` dict after each end of episode." 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": { 662 | "colab": {}, 663 | "colab_type": "code", 664 | "id": "8FWeDRd5W7hO" 665 | }, 666 | "outputs": [], 667 | "source": [ 668 | "class MyMonitorWrapper(gym.Wrapper):\n", 669 | " \"\"\"\n", 670 | " :param env: (gym.Env) Gym environment that will be wrapped\n", 671 | " \"\"\"\n", 672 | "\n", 673 | " def __init__(self, env):\n", 674 | " # Call the parent constructor, so we can access self.env later\n", 675 | " super().__init__(env)\n", 676 | " # === YOUR CODE HERE ===#\n", 677 | " # Initialize the variables that will be used\n", 678 | " # to store the episode length and episode reward\n", 679 | "\n", 680 | " # ====================== #\n", 681 | "\n", 682 | " def reset(self, **kwargs):\n", 683 | " \"\"\"\n", 684 | " Reset the environment\n", 685 | " \"\"\"\n", 686 | " obs = self.env.reset(**kwargs)\n", 687 | " # === YOUR CODE HERE ===#\n", 688 | " # Reset the variables\n", 689 | "\n", 690 | " # ====================== #\n", 691 | " return obs\n", 692 | "\n", 693 | " def step(self, action):\n", 694 | " \"\"\"\n", 695 | " :param action: ([float] or int) Action taken by the agent\n", 696 | " :return: (np.ndarray, float, bool, bool, dict)\n", 697 | " observation, reward, is the episode over?, is the episode truncated?, additional information\n", 698 | " \"\"\"\n", 699 | " obs, reward, terminated, truncated, info = self.env.step(action)\n", 700 | " # === YOUR CODE HERE ===#\n", 701 | " # Update the current episode reward and episode length\n", 702 | "\n", 703 | " # ====================== #\n", 704 | "\n", 705 | " if terminated or truncated:\n", 706 | " # === YOUR CODE HERE ===#\n", 707 | " # Store the episode length and episode reward in the info dict\n", 708 | " pass\n", 709 | "\n", 710 | " # ====================== #\n", 711 | " return obs, reward, terminated, truncated, info" 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": { 717 | "colab_type": "text", 718 | "id": "d4fY4QwWXNFK" 719 | }, 720 | "source": [ 721 | "#### Test your wrapper" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": null, 727 | "metadata": { 728 | "colab": {}, 729 | "colab_type": "code", 730 | "id": "bJbUG-A_liYt" 731 | }, 732 | "outputs": [], 733 | "source": [ 734 | "# To use LunarLander, you need to install box2d box2d-kengz (pip) and swig (apt-get)\n", 735 | "!pip install box2d-py" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": null, 741 | "metadata": { 742 | "colab": {}, 743 | "colab_type": "code", 744 | "id": "oWZp1olSXMUg" 745 | }, 746 | "outputs": [], 747 | "source": [ 748 | "env = gym.make(\"LunarLander-v2\")\n", 749 | "# === YOUR CODE HERE ===#\n", 750 | "# Wrap the environment\n", 751 | "\n", 752 | "# Reset the environment\n", 753 | "\n", 754 | "# Take random actions in the environment and check\n", 755 | "# that it returns the correct values after the end of each episode\n", 756 | "\n", 757 | "# ====================== #" 758 | ] 759 | }, 760 | { 761 | "cell_type": "markdown", 762 | "metadata": { 763 | "colab_type": "text", 764 | "id": "dJ2IqSM2eOt8" 765 | }, 766 | "source": [ 767 | " # Conclusion\n", 768 | " \n", 769 | " In this notebook, we have seen:\n", 770 | " - how to easily save and load a model\n", 771 | " - what is wrapper and what we can do with it\n", 772 | " - how to create your own wrapper" 773 | ] 774 | }, 775 | { 776 | "cell_type": "markdown", 777 | "metadata": { 778 | "colab_type": "text", 779 | "id": "qhWB_bHpSkas" 780 | }, 781 | "source": [ 782 | "## Wrapper Bonus: changing the observation space: a wrapper for episode of fixed length" 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "execution_count": null, 788 | "metadata": { 789 | "colab": {}, 790 | "colab_type": "code", 791 | "id": "bBlS9YxYSpJn" 792 | }, 793 | "outputs": [], 794 | "source": [ 795 | "from gym.wrappers import TimeLimit\n", 796 | "\n", 797 | "\n", 798 | "class TimeFeatureWrapper(gym.Wrapper):\n", 799 | " \"\"\"\n", 800 | " Add remaining time to observation space for fixed length episodes.\n", 801 | " See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.\n", 802 | "\n", 803 | " :param env: (gym.Env)\n", 804 | " :param max_steps: (int) Max number of steps of an episode\n", 805 | " if it is not wrapped in a TimeLimit object.\n", 806 | " :param test_mode: (bool) In test mode, the time feature is constant,\n", 807 | " equal to zero. This allow to check that the agent did not overfit this feature,\n", 808 | " learning a deterministic pre-defined sequence of actions.\n", 809 | " \"\"\"\n", 810 | "\n", 811 | " def __init__(self, env, max_steps=1000, test_mode=False):\n", 812 | " assert isinstance(env.observation_space, gym.spaces.Box)\n", 813 | " # Add a time feature to the observation\n", 814 | " low, high = env.observation_space.low, env.observation_space.high\n", 815 | " low, high = np.concatenate((low, [0])), np.concatenate((high, [1.0]))\n", 816 | " env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)\n", 817 | "\n", 818 | " super().__init__(env)\n", 819 | "\n", 820 | " if isinstance(env, TimeLimit):\n", 821 | " self._max_steps = env._max_episode_steps\n", 822 | " else:\n", 823 | " self._max_steps = max_steps\n", 824 | " self._current_step = 0\n", 825 | " self._test_mode = test_mode\n", 826 | "\n", 827 | " def reset(self, **kwargs):\n", 828 | " self._current_step = 0\n", 829 | " obs, info = self.env.reset(**kwargs)\n", 830 | " return self._get_obs(obs), info\n", 831 | "\n", 832 | " def step(self, action):\n", 833 | " self._current_step += 1\n", 834 | " obs, reward, terminated, truncated, info = self.env.step(action)\n", 835 | " return self._get_obs(obs), reward, terminated, truncated, info\n", 836 | "\n", 837 | " def _get_obs(self, obs):\n", 838 | " \"\"\"\n", 839 | " Concatenate the time feature to the current observation.\n", 840 | "\n", 841 | " :param obs: (np.ndarray)\n", 842 | " :return: (np.ndarray)\n", 843 | " \"\"\"\n", 844 | " # Remaining time is more general\n", 845 | " time_feature = 1 - (self._current_step / self._max_steps)\n", 846 | " if self._test_mode:\n", 847 | " time_feature = 1.0\n", 848 | " # Optionally: concatenate [time_feature, time_feature ** 2]\n", 849 | " return np.concatenate((obs, [time_feature]))" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": null, 855 | "metadata": { 856 | "colab": {}, 857 | "colab_type": "code", 858 | "id": "z-vWgkZzd4F1" 859 | }, 860 | "outputs": [], 861 | "source": [] 862 | }, 863 | { 864 | "cell_type": "markdown", 865 | "metadata": { 866 | "colab_type": "text", 867 | "id": "Ojn4nvNNRUoT" 868 | }, 869 | "source": [ 870 | "## Going further - Saving format \n", 871 | "\n", 872 | "The format for saving and loading models is a zip-archived JSON dump and NumPy zip archive of the arrays:\n", 873 | "```\n", 874 | "saved_model.zip/\n", 875 | "├── data JSON file of class-parameters (dictionary)\n", 876 | "├── parameter_list JSON file of model parameters and their ordering (list)\n", 877 | "├── parameters Bytes from numpy.savez (a zip file of the numpy arrays). ...\n", 878 | " ├── ... Being a zip-archive itself, this object can also be opened ...\n", 879 | " ├── ... as a zip-archive and browsed.\n", 880 | "```" 881 | ] 882 | }, 883 | { 884 | "cell_type": "markdown", 885 | "metadata": { 886 | "colab_type": "text", 887 | "id": "QWAcc8RFRUoU" 888 | }, 889 | "source": [ 890 | "## Save and find " 891 | ] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "execution_count": null, 896 | "metadata": { 897 | "colab": {}, 898 | "colab_type": "code", 899 | "id": "4tcQxzSCRUoV" 900 | }, 901 | "outputs": [], 902 | "source": [ 903 | "# Create save dir\n", 904 | "save_dir = \"/tmp/gym/\"\n", 905 | "os.makedirs(save_dir, exist_ok=True)\n", 906 | "\n", 907 | "model = PPO(\"MlpPolicy\", \"Pendulum-v1\", verbose=0).learn(8000)\n", 908 | "model.save(save_dir + \"/PPO_tutorial\")" 909 | ] 910 | }, 911 | { 912 | "cell_type": "code", 913 | "execution_count": null, 914 | "metadata": { 915 | "colab": {}, 916 | "colab_type": "code", 917 | "id": "rGaMNz4HRUoX" 918 | }, 919 | "outputs": [], 920 | "source": [ 921 | "!ls /tmp/gym/PPO_tutorial*" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": null, 927 | "metadata": { 928 | "colab": {}, 929 | "colab_type": "code", 930 | "id": "gYY3nQyyRUoa" 931 | }, 932 | "outputs": [], 933 | "source": [ 934 | "import zipfile\n", 935 | "\n", 936 | "archive = zipfile.ZipFile(\"/tmp/gym/PPO_tutorial.zip\", \"r\")\n", 937 | "for f in archive.filelist:\n", 938 | " print(f.filename)" 939 | ] 940 | }, 941 | { 942 | "attachments": {}, 943 | "cell_type": "markdown", 944 | "metadata": { 945 | "colab_type": "text", 946 | "id": "cPKkkTvjRUo2" 947 | }, 948 | "source": [ 949 | "## Exporting saved models\n", 950 | "\n", 951 | "And finally some further reading for those who want to export to tensorflowJS or Java.\n", 952 | "\n", 953 | "https://stable-baselines3.readthedocs.io/en/master/guide/export.html" 954 | ] 955 | }, 956 | { 957 | "cell_type": "markdown", 958 | "metadata": {}, 959 | "source": [] 960 | } 961 | ], 962 | "metadata": { 963 | "colab": { 964 | "collapsed_sections": [], 965 | "include_colab_link": true, 966 | "name": "2_gym_wrappers_saving_loading.ipynb", 967 | "provenance": [] 968 | }, 969 | "kernelspec": { 970 | "display_name": "Python 3 (ipykernel)", 971 | "language": "python", 972 | "name": "python3" 973 | }, 974 | "language_info": { 975 | "codemirror_mode": { 976 | "name": "ipython", 977 | "version": 3 978 | }, 979 | "file_extension": ".py", 980 | "mimetype": "text/x-python", 981 | "name": "python", 982 | "nbconvert_exporter": "python", 983 | "pygments_lexer": "ipython3", 984 | "version": "3.10.9" 985 | }, 986 | "vscode": { 987 | "interpreter": { 988 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 989 | } 990 | } 991 | }, 992 | "nbformat": 4, 993 | "nbformat_minor": 4 994 | } 995 | -------------------------------------------------------------------------------- /3_multiprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "KnPeMWYi0vAx" 18 | }, 19 | "source": [ 20 | "# Stable Baselines3 Tutorial - Multiprocessing of environments\n", 21 | "\n", 22 | "Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/\n", 23 | "\n", 24 | "Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3\n", 25 | "\n", 26 | "Documentation: https://stable-baselines3.readthedocs.io/en/master/\n", 27 | "\n", 28 | "SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib\n", 29 | "\n", 30 | "RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo\n", 31 | "\n", 32 | "\n", 33 | "## Introduction\n", 34 | "\n", 35 | "In this notebook, you will learn how to use *Vectorized Environments* (aka multiprocessing) to make training faster. You will also see that this speed up comes at a cost of sample efficiency.\n", 36 | "\n", 37 | "## Install Dependencies and Stable Baselines3 Using Pip" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# for autoformatting\n", 47 | "# %load_ext jupyter_black" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "colab": {}, 55 | "colab_type": "code", 56 | "id": "ClRYNMkVvpUX" 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "!apt install swig\n", 61 | "!pip install \"stable-baselines3[extra]>=2.0.0a4\"" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "colab_type": "text", 68 | "id": "OQunADhw1EXX" 69 | }, 70 | "source": [ 71 | "## Vectorized Environments and Imports\n", 72 | "\n", 73 | "[Vectorized Environments](https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html) are a method for stacking multiple independent environments into a single environment. Instead of training an RL agent on 1 environment per step, it allows us to train it on n environments per step. This provides two benefits:\n", 74 | "* Agent experience can be collected more quickly\n", 75 | "* The experience will contain a more diverse range of states, it usually improves exploration\n", 76 | "\n", 77 | "Stable-Baselines provides two types of Vectorized Environment:\n", 78 | "- SubprocVecEnv which run each environment in a separate process\n", 79 | "- DummyVecEnv which run all environment on the same process\n", 80 | "\n", 81 | "In practice, DummyVecEnv is usually faster than SubprocVecEnv because of communication delays that subprocesses have." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "colab": {}, 89 | "colab_type": "code", 90 | "id": "AvO5BGrVv2Rk" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "import time\n", 95 | "import numpy as np\n", 96 | "import matplotlib.pyplot as plt\n", 97 | "\n", 98 | "%matplotlib inline\n", 99 | "\n", 100 | "import gymnasium as gym\n", 101 | "\n", 102 | "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n", 103 | "from stable_baselines3.common.utils import set_random_seed\n", 104 | "from stable_baselines3 import PPO, A2C" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "colab_type": "text", 111 | "id": "JcdG_UZS1-yO" 112 | }, 113 | "source": [ 114 | "Import evaluate function" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "colab": {}, 122 | "colab_type": "code", 123 | "id": "NHslfVkuwALj" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "from stable_baselines3.common.evaluation import evaluate_policy" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": { 133 | "colab_type": "text", 134 | "id": "WWsIT2vP2FzB" 135 | }, 136 | "source": [ 137 | "## Define an environment function\n", 138 | "\n", 139 | "The multiprocessing implementation requires a function that can be called inside the process to instantiate a gym env" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "colab": {}, 147 | "colab_type": "code", 148 | "id": "6S95WiPGwF6z" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "def make_env(env_id, rank, seed=0):\n", 153 | " \"\"\"\n", 154 | " Utility function for multiprocessed env.\n", 155 | "\n", 156 | " :param env_id: (str) the environment ID\n", 157 | " :param seed: (int) the inital seed for RNG\n", 158 | " :param rank: (int) index of the subprocess\n", 159 | " \"\"\"\n", 160 | "\n", 161 | " def _init():\n", 162 | " env = gym.make(env_id)\n", 163 | " # use a seed for reproducibility\n", 164 | " # Important: use a different seed for each environment\n", 165 | " # otherwise they would generate the same experiences\n", 166 | " env.reset(seed=seed + rank)\n", 167 | " return env\n", 168 | "\n", 169 | " set_random_seed(seed)\n", 170 | " return _init" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "colab_type": "text", 177 | "id": "9-QID4O2bd7c" 178 | }, 179 | "source": [ 180 | "Stable-Baselines also provides directly an helper to create vectorized environment:" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "colab": {}, 188 | "colab_type": "code", 189 | "id": "Gk7Ukbqlbl-i" 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "from stable_baselines3.common.env_util import make_vec_env" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": { 199 | "colab_type": "text", 200 | "id": "DJUP0PQi2WEE" 201 | }, 202 | "source": [ 203 | "## Define a few constants (feel free to try out other environments and algorithms)\n", 204 | "We will be using the Cartpole environment: [https://gym.openai.com/envs/CartPole-v1/](https://gym.openai.com/envs/CartPole-v1/)\n", 205 | "\n", 206 | "![Cartpole](https://cdn-images-1.medium.com/max/1143/1*h4WTQNVIsvMXJTCpXm_TAw.gif)\n", 207 | "\n" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "colab": {}, 215 | "colab_type": "code", 216 | "id": "bmdNV8UVwTht" 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "env_id = \"CartPole-v1\"\n", 221 | "# The different number of processes that will be used\n", 222 | "PROCESSES_TO_TEST = [1, 2, 4, 8, 16]\n", 223 | "NUM_EXPERIMENTS = 3 # RL algorithms can often be unstable, so we run several experiments (see https://arxiv.org/abs/1709.06560)\n", 224 | "TRAIN_STEPS = 5000\n", 225 | "# Number of episodes for evaluation\n", 226 | "EVAL_EPS = 20\n", 227 | "ALGO = A2C\n", 228 | "\n", 229 | "# We will create one environment to evaluate the agent on\n", 230 | "eval_env = gym.make(env_id)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "colab_type": "text", 237 | "id": "y08bJGxj2ezh" 238 | }, 239 | "source": [ 240 | "## Iterate through the different numbers of processes\n", 241 | "\n", 242 | "For each processes, several experiments are run per process\n", 243 | "This may take a couple of minutes." 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "colab": { 251 | "base_uri": "https://localhost:8080/", 252 | "height": 106 253 | }, 254 | "colab_type": "code", 255 | "id": "kcYpsA8ExB9T", 256 | "outputId": "11e28f5c-c3d3-4669-ab4b-acff3e710ac1" 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "reward_averages = []\n", 261 | "reward_std = []\n", 262 | "training_times = []\n", 263 | "total_procs = 0\n", 264 | "for n_procs in PROCESSES_TO_TEST:\n", 265 | " total_procs += n_procs\n", 266 | " print(f\"Running for n_procs = {n_procs}\")\n", 267 | " if n_procs == 1:\n", 268 | " # if there is only one process, there is no need to use multiprocessing\n", 269 | " train_env = DummyVecEnv([lambda: gym.make(env_id)])\n", 270 | " else:\n", 271 | " # Here we use the \"fork\" method for launching the processes, more information is available in the doc\n", 272 | " # This is equivalent to make_vec_env(env_id, n_envs=n_procs, vec_env_cls=SubprocVecEnv, vec_env_kwargs=dict(start_method='fork'))\n", 273 | " train_env = SubprocVecEnv(\n", 274 | " [make_env(env_id, i + total_procs) for i in range(n_procs)],\n", 275 | " start_method=\"fork\",\n", 276 | " )\n", 277 | "\n", 278 | " rewards = []\n", 279 | " times = []\n", 280 | "\n", 281 | " for experiment in range(NUM_EXPERIMENTS):\n", 282 | " # it is recommended to run several experiments due to variability in results\n", 283 | " train_env.reset()\n", 284 | " model = ALGO(\"MlpPolicy\", train_env, verbose=0)\n", 285 | " start = time.time()\n", 286 | " model.learn(total_timesteps=TRAIN_STEPS)\n", 287 | " times.append(time.time() - start)\n", 288 | " mean_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes=EVAL_EPS)\n", 289 | " rewards.append(mean_reward)\n", 290 | " # Important: when using subprocesses, don't forget to close them\n", 291 | " # otherwise, you may have memory issues when running a lot of experiments\n", 292 | " train_env.close()\n", 293 | " reward_averages.append(np.mean(rewards))\n", 294 | " reward_std.append(np.std(rewards))\n", 295 | " training_times.append(np.mean(times))" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": { 301 | "colab_type": "text", 302 | "id": "2z5paN1q3AaC" 303 | }, 304 | "source": [ 305 | "## Plot the results" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": { 312 | "id": "CGnZ8SccKG4D" 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "def plot_training_results(training_steps_per_second, reward_averages, reward_std):\n", 317 | " \"\"\"\n", 318 | " Utility function for plotting the results of training\n", 319 | "\n", 320 | " :param training_steps_per_second: List[double]\n", 321 | " :param reward_averages: List[double]\n", 322 | " :param reward_std: List[double]\n", 323 | " \"\"\"\n", 324 | " plt.figure(figsize=(9, 4))\n", 325 | " plt.subplots_adjust(wspace=0.5)\n", 326 | " plt.subplot(1, 2, 1)\n", 327 | " plt.errorbar(\n", 328 | " PROCESSES_TO_TEST,\n", 329 | " reward_averages,\n", 330 | " yerr=reward_std,\n", 331 | " capsize=2,\n", 332 | " c=\"k\",\n", 333 | " marker=\"o\",\n", 334 | " )\n", 335 | " plt.xlabel(\"Processes\")\n", 336 | " plt.ylabel(\"Average return\")\n", 337 | " plt.subplot(1, 2, 2)\n", 338 | " plt.bar(range(len(PROCESSES_TO_TEST)), training_steps_per_second)\n", 339 | " plt.xticks(range(len(PROCESSES_TO_TEST)), PROCESSES_TO_TEST)\n", 340 | " plt.xlabel(\"Processes\")\n", 341 | " plt.ylabel(\"Training steps per second\")" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": { 348 | "colab": { 349 | "base_uri": "https://localhost:8080/", 350 | "height": 279 351 | }, 352 | "colab_type": "code", 353 | "id": "fPWfc96JxT-k", 354 | "outputId": "df2b74d5-61ea-487b-9364-8ec33b4e0624" 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "training_steps_per_second = [TRAIN_STEPS / t for t in training_times]\n", 359 | "\n", 360 | "plot_training_results(training_steps_per_second, reward_averages, reward_std)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "colab_type": "text", 367 | "id": "R5xE8EX63PO9" 368 | }, 369 | "source": [ 370 | "## Sample efficiency vs wall clock time trade-off\n", 371 | "There is clearly a trade-off between sample efficiency, diverse experience and wall clock time. Let's try getting the best performance in a fixed amount of time, say 10 seconds per experiment" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "colab": { 379 | "base_uri": "https://localhost:8080/", 380 | "height": 106 381 | }, 382 | "outputId": "2cdda2c8-e2f0-401b-a6ea-99c80d91fe8e" 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "SECONDS_PER_EXPERIMENT = 10\n", 387 | "steps_per_experiment = [int(SECONDS_PER_EXPERIMENT * fps) for fps in training_steps_per_second]\n", 388 | "reward_averages = []\n", 389 | "reward_std = []\n", 390 | "training_times = []\n", 391 | "\n", 392 | "for n_procs, train_steps in zip(PROCESSES_TO_TEST, steps_per_experiment):\n", 393 | " total_procs += n_procs\n", 394 | " print(f\"Running for n_procs = {n_procs} for steps = {train_steps}\")\n", 395 | " if n_procs == 1:\n", 396 | " # if there is only one process, there is no need to use multiprocessing\n", 397 | " train_env = DummyVecEnv([lambda: gym.make(env_id)])\n", 398 | " else:\n", 399 | " train_env = SubprocVecEnv([make_env(env_id, i+total_procs) for i in range(n_procs)], start_method=\"spawn\")\n", 400 | " # Alternatively, you can use a DummyVecEnv if the communication delays is the bottleneck\n", 401 | " # train_env = DummyVecEnv([make_env(env_id, i+total_procs) for i in range(n_procs)])\n", 402 | "\n", 403 | " rewards = []\n", 404 | " times = []\n", 405 | "\n", 406 | " for experiment in range(NUM_EXPERIMENTS):\n", 407 | " # it is recommended to run several experiments due to variability in results\n", 408 | " train_env.reset()\n", 409 | " model = ALGO(\"MlpPolicy\", train_env, verbose=0)\n", 410 | " start = time.time()\n", 411 | " model.learn(total_timesteps=train_steps)\n", 412 | " times.append(time.time() - start)\n", 413 | " mean_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes=EVAL_EPS)\n", 414 | " rewards.append(mean_reward)\n", 415 | "\n", 416 | " train_env.close()\n", 417 | " reward_averages.append(np.mean(rewards))\n", 418 | " reward_std.append(np.std(rewards))\n", 419 | " training_times.append(np.mean(times))\n" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": { 425 | "colab_type": "text", 426 | "id": "G7a7ZiVw5A11" 427 | }, 428 | "source": [ 429 | "## Plot the results" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "colab": { 437 | "base_uri": "https://localhost:8080/", 438 | "height": 297 439 | }, 440 | "colab_type": "code", 441 | "id": "EQXJ1hI46DVB", 442 | "outputId": "d5b47716-3551-47b1-f690-16d726e89a05" 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "training_steps_per_second = [s / t for s,t in zip(steps_per_experiment, training_times)]\n", 447 | "\n", 448 | "plot_training_results(training_steps_per_second, reward_averages, reward_std)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": { 454 | "colab_type": "text", 455 | "id": "0FcOcVf5rY3C" 456 | }, 457 | "source": [ 458 | "## DummyVecEnv vs SubprocVecEnv" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "colab": { 466 | "base_uri": "https://localhost:8080/", 467 | "height": 106 468 | }, 469 | "colab_type": "code", 470 | "id": "MebaTHQvqhoH", 471 | "outputId": "637e9934-e6b1-4ce3-a401-c20f23437e67" 472 | }, 473 | "outputs": [], 474 | "source": [ 475 | "reward_averages = []\n", 476 | "reward_std = []\n", 477 | "training_times = []\n", 478 | "total_procs = 0\n", 479 | "for n_procs in PROCESSES_TO_TEST:\n", 480 | " total_procs += n_procs\n", 481 | " print(f'Running for n_procs = {n_procs}'))\n", 482 | " # Here we are using only one process even for n_env > 1\n", 483 | " # this is equivalent to DummyVecEnv([make_env(env_id, i + total_procs) for i in range(n_procs)])\n", 484 | " train_env = make_vec_env(env_id, n_envs=n_procs)\n", 485 | "\n", 486 | " rewards = []\n", 487 | " times = []\n", 488 | "\n", 489 | " for experiment in range(NUM_EXPERIMENTS):\n", 490 | " # it is recommended to run several experiments due to variability in results\n", 491 | " train_env.reset()\n", 492 | " model = ALGO(\"MlpPolicy\", train_env, verbose=0)\n", 493 | " start = time.time()\n", 494 | " model.learn(total_timesteps=TRAIN_STEPS)\n", 495 | " times.append(time.time() - start)\n", 496 | " mean_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes=EVAL_EPS)\n", 497 | " rewards.append(mean_reward)\n", 498 | "\n", 499 | " train_env.close()\n", 500 | " reward_averages.append(np.mean(rewards))\n", 501 | " reward_std.append(np.std(rewards))\n", 502 | " training_times.append(np.mean(times))" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": null, 508 | "metadata": { 509 | "colab": { 510 | "base_uri": "https://localhost:8080/", 511 | "height": 297 512 | }, 513 | "colab_type": "code", 514 | "id": "kmMr_c1hqmoi", 515 | "outputId": "cc174025-ed75-4897-f745-c08944493366" 516 | }, 517 | "outputs": [], 518 | "source": [ 519 | "training_steps_per_second = [TRAIN_STEPS / t for t in training_times]\n", 520 | "\n", 521 | "plot_training_results(training_steps_per_second, reward_averages, reward_std)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": { 527 | "colab_type": "text", 528 | "id": "e9PNbT35spZW" 529 | }, 530 | "source": [ 531 | "### What's happening?\n", 532 | "\n", 533 | "It seems that having only one process for n environments is faster in our case.\n", 534 | "In practice, the bottleneck does not come from the environment computation, but from synchronisation and communication between processes. To learn more about that problem, you can start [here](https://github.com/hill-a/stable-baselines/issues/322#issuecomment-492202915)" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "colab_type": "text", 541 | "id": "GlcJPYN-6ebp" 542 | }, 543 | "source": [ 544 | "## Conclusions\n", 545 | "This notebook has highlighted some of the pros and cons of multiprocessing. It is worth mentioning that colab notebooks only provide two CPU cores per process, so we do not see a linear scaling of the FPS of the environments. State of the art Deep RL research has scaled parallel processing to tens of thousands of CPU cores, [OpenAI RAPID](https://openai.com/blog/how-to-train-your-openai-five/) [IMPALA](https://arxiv.org/abs/1802.01561).\n", 546 | "\n", 547 | "Do you think this direction of research is transferable to real world robots / intelligent agents?\n", 548 | "\n", 549 | "Things to try:\n", 550 | "* Another algorithm / environment.\n", 551 | "* Increase the number of experiments.\n", 552 | "* Train for more iterations.\n" 553 | ] 554 | } 555 | ], 556 | "metadata": { 557 | "colab": { 558 | "collapsed_sections": [], 559 | "include_colab_link": true, 560 | "name": "3_multiprocessing.ipynb", 561 | "provenance": [] 562 | }, 563 | "kernelspec": { 564 | "display_name": "Python 3 (ipykernel)", 565 | "language": "python", 566 | "name": "python3" 567 | }, 568 | "language_info": { 569 | "codemirror_mode": { 570 | "name": "ipython", 571 | "version": 3 572 | }, 573 | "file_extension": ".py", 574 | "mimetype": "text/x-python", 575 | "name": "python", 576 | "nbconvert_exporter": "python", 577 | "pygments_lexer": "ipython3", 578 | "version": "3.10.9" 579 | }, 580 | "vscode": { 581 | "interpreter": { 582 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 583 | } 584 | } 585 | }, 586 | "nbformat": 4, 587 | "nbformat_minor": 4 588 | } 589 | -------------------------------------------------------------------------------- /4_callbacks_hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "i8lIXBiHRYb6" 18 | }, 19 | "source": [ 20 | "# Stable Baselines3 Tutorial - Callbacks and hyperparameter tuning\n", 21 | "\n", 22 | "Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/\n", 23 | "\n", 24 | "Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3\n", 25 | "\n", 26 | "Documentation: https://stable-baselines3.readthedocs.io/en/master/\n", 27 | "\n", 28 | "SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib\n", 29 | "\n", 30 | "RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo\n", 31 | "\n", 32 | "\n", 33 | "## Introduction\n", 34 | "\n", 35 | "In this notebook, you will learn how to use [Callbacks](https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html) which allow to do monitoring, auto saving, model manipulation, progress bars, ...\n", 36 | "\n", 37 | "\n", 38 | "You will also see that finding good hyperparameters is key to success in RL.\n", 39 | "\n", 40 | "## Install Dependencies and Stable Baselines3 Using Pip" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# for autoformatting\n", 50 | "# %load_ext jupyter_black" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "colab": {}, 58 | "colab_type": "code", 59 | "id": "owKXXp8rRZI7" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "!apt install swig\n", 64 | "!pip install \"stable-baselines3[extra]>=2.0.0a4\"" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "colab": {}, 72 | "colab_type": "code", 73 | "id": "18ivrnsaSWUn" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "import gymnasium as gym\n", 78 | "from stable_baselines3 import A2C, SAC, PPO, TD3" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "colab_type": "text", 85 | "id": "PytOtL9GdmrE" 86 | }, 87 | "source": [ 88 | "# The importance of hyperparameter tuning\n", 89 | "\n", 90 | "When compared with Supervised Learning, Deep Reinforcement Learning is far more sensitive to the choice of hyper-parameters such as learning rate, number of neurons, number of layers, optimizer ... etc. \n", 91 | "Poor choice of hyper-parameters can lead to poor/unstable convergence. This challenge is compounded by the variability in performance across random seeds (used to initialize the network weights and the environment).\n", 92 | "\n", 93 | "Here we demonstrate on a toy example the [Soft Actor Critic](https://arxiv.org/abs/1801.01290) algorithm applied in the Pendulum environment. Note the change in performance between the default and \"tuned\" parameters. " 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "colab": {}, 101 | "colab_type": "code", 102 | "id": "w5oVvYHwdnYv" 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "import numpy as np\n", 107 | "\n", 108 | "from stable_baselines3.common.evaluation import evaluate_policy" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "colab": { 116 | "base_uri": "https://localhost:8080/", 117 | "height": 53 118 | }, 119 | "colab_type": "code", 120 | "id": "-a0v3fgwe54j", 121 | "outputId": "52f15317-d898-4aae-cd53-893e928909b3" 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "eval_env = gym.make(\"Pendulum-v1\")" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "colab": {}, 133 | "colab_type": "code", 134 | "id": "5WRR7kmIeqEB" 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "default_model = SAC(\n", 139 | " \"MlpPolicy\",\n", 140 | " \"Pendulum-v1\",\n", 141 | " verbose=1,\n", 142 | " seed=0,\n", 143 | " batch_size=64,\n", 144 | " policy_kwargs=dict(net_arch=[64, 64]),\n", 145 | ").learn(8000)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "colab": { 153 | "base_uri": "https://localhost:8080/", 154 | "height": 35 155 | }, 156 | "colab_type": "code", 157 | "id": "jQbDcbEheqWj", 158 | "outputId": "4f664eeb-0374-4db0-c29e-1b8cd131b22b" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "mean_reward, std_reward = evaluate_policy(default_model, eval_env, n_eval_episodes=100)\n", 163 | "print(f\"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}\")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": { 170 | "colab": {}, 171 | "colab_type": "code", 172 | "id": "smMdkZnvfL1g" 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "tuned_model = SAC(\n", 177 | " \"MlpPolicy\",\n", 178 | " \"Pendulum-v1\",\n", 179 | " batch_size=256,\n", 180 | " verbose=1,\n", 181 | " policy_kwargs=dict(net_arch=[256, 256]),\n", 182 | " seed=0,\n", 183 | ").learn(8000)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "colab": { 191 | "base_uri": "https://localhost:8080/", 192 | "height": 35 193 | }, 194 | "colab_type": "code", 195 | "id": "DN05_Io8fMAr", 196 | "outputId": "a009b1ea-17f7-4f6f-b021-35cf7356b2ce" 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "mean_reward, std_reward = evaluate_policy(tuned_model, eval_env, n_eval_episodes=100)\n", 201 | "print(f\"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}\")" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": { 207 | "colab_type": "text", 208 | "id": "pi9IwxBYVMl8" 209 | }, 210 | "source": [ 211 | "Exploring hyperparameter tuning is out of the scope (and time schedule) of this tutorial. However, you need to know that we provide tuned hyperparameter in the [rl zoo](https://github.com/DLR-RM/rl-baselines3-zoo) as well as automatic hyperparameter optimization using [Optuna](https://github.com/pfnet/optuna).\n" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": { 217 | "colab_type": "text", 218 | "id": "irHk8FXdRUnw" 219 | }, 220 | "source": [ 221 | "# Callbacks\n", 222 | "\n", 223 | "\n", 224 | "Please read the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html). Although Stable-Baselines3 provides you with a callback collection (e.g. for creating checkpoints or for evaluation), we are going to re-implement some so you can get a good understanding of how they work.\n", 225 | "\n", 226 | "To build a custom callback, you need to create a class that derives from `BaseCallback`. This will give you access to events (`_on_training_start`, `_on_step()`) and useful variables (like `self.model` for the RL model).\n", 227 | "\n", 228 | "`_on_step` returns a boolean value for whether or not the training should continue.\n", 229 | "\n", 230 | "Thanks to the access to the models variables, in particular `self.model`, we are able to even change the parameters of the model without halting the training, or changing the model's code." 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "colab": {}, 238 | "colab_type": "code", 239 | "id": "uE30k2i7kohh" 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "from stable_baselines3.common.callbacks import BaseCallback" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "colab": {}, 251 | "colab_type": "code", 252 | "id": "wjRvJ8zBftL3" 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "class CustomCallback(BaseCallback):\n", 257 | " \"\"\"\n", 258 | " A custom callback that derives from ``BaseCallback``.\n", 259 | "\n", 260 | " :param verbose: (int) Verbosity level 0: not output 1: info 2: debug\n", 261 | " \"\"\"\n", 262 | "\n", 263 | " def __init__(self, verbose=0):\n", 264 | " super().__init__(verbose)\n", 265 | " # Those variables will be accessible in the callback\n", 266 | " # (they are defined in the base class)\n", 267 | " # The RL model\n", 268 | " # self.model = None # type: BaseRLModel\n", 269 | " # An alias for self.model.get_env(), the environment used for training\n", 270 | " # self.training_env = None # type: Union[gym.Env, VecEnv, None]\n", 271 | " # Number of time the callback was called\n", 272 | " # self.n_calls = 0 # type: int\n", 273 | " # self.num_timesteps = 0 # type: int\n", 274 | " # local and global variables\n", 275 | " # self.locals = None # type: Dict[str, Any]\n", 276 | " # self.globals = None # type: Dict[str, Any]\n", 277 | " # The logger object, used to report things in the terminal\n", 278 | " # self.logger = None # type: logger.Logger\n", 279 | " # # Sometimes, for event callback, it is useful\n", 280 | " # # to have access to the parent object\n", 281 | " # self.parent = None # type: Optional[BaseCallback]\n", 282 | "\n", 283 | " def _on_training_start(self) -> None:\n", 284 | " \"\"\"\n", 285 | " This method is called before the first rollout starts.\n", 286 | " \"\"\"\n", 287 | " pass\n", 288 | "\n", 289 | " def _on_rollout_start(self) -> None:\n", 290 | " \"\"\"\n", 291 | " A rollout is the collection of environment interaction\n", 292 | " using the current policy.\n", 293 | " This event is triggered before collecting new samples.\n", 294 | " \"\"\"\n", 295 | " pass\n", 296 | "\n", 297 | " def _on_step(self) -> bool:\n", 298 | " \"\"\"\n", 299 | " This method will be called by the model after each call to `env.step()`.\n", 300 | "\n", 301 | " For child callback (of an `EventCallback`), this will be called\n", 302 | " when the event is triggered.\n", 303 | "\n", 304 | " :return: (bool) If the callback returns False, training is aborted early.\n", 305 | " \"\"\"\n", 306 | " return True\n", 307 | "\n", 308 | " def _on_rollout_end(self) -> None:\n", 309 | " \"\"\"\n", 310 | " This event is triggered before updating the policy.\n", 311 | " \"\"\"\n", 312 | " pass\n", 313 | "\n", 314 | " def _on_training_end(self) -> None:\n", 315 | " \"\"\"\n", 316 | " This event is triggered before exiting the `learn()` method.\n", 317 | " \"\"\"\n", 318 | " pass" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": { 324 | "colab_type": "text", 325 | "id": "OqpPtxaCfynB" 326 | }, 327 | "source": [ 328 | "Here we have a simple callback that can only be called twice:" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": { 335 | "colab": {}, 336 | "colab_type": "code", 337 | "id": "7ILY0AkFfzPJ" 338 | }, 339 | "outputs": [], 340 | "source": [ 341 | "class SimpleCallback(BaseCallback):\n", 342 | " \"\"\"\n", 343 | " a simple callback that can only be called twice\n", 344 | "\n", 345 | " :param verbose: (int) Verbosity level 0: not output 1: info 2: debug\n", 346 | " \"\"\"\n", 347 | "\n", 348 | " def __init__(self, verbose=0):\n", 349 | " super(SimpleCallback, self).__init__(verbose)\n", 350 | " self._called = False\n", 351 | "\n", 352 | " def _on_step(self):\n", 353 | " if not self._called:\n", 354 | " print(\"callback - first call\")\n", 355 | " self._called = True\n", 356 | " return True # returns True, training continues.\n", 357 | " print(\"callback - second call\")\n", 358 | " return False # returns False, training stops." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": { 365 | "colab": {}, 366 | "colab_type": "code", 367 | "id": "5gTXaNLARUnw" 368 | }, 369 | "outputs": [], 370 | "source": [ 371 | "model = SAC(\"MlpPolicy\", \"Pendulum-v1\", verbose=1)\n", 372 | "model.learn(8000, callback=SimpleCallback())" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": { 378 | "colab_type": "text", 379 | "id": "adsKMvDkRUn0" 380 | }, 381 | "source": [ 382 | "## First example: Auto saving best model\n", 383 | "In RL, it is quite useful to keep a clean version of a model as you are training, as we can end up with burn-in of a bad policy. This is a typical use case for callback, as they can call the save function of the model, and observe the training over time.\n", 384 | "\n", 385 | "Using the monitoring wrapper, we can save statistics of the environment, and use them to determine the mean training reward.\n", 386 | "This allows us to save the best model while training.\n", 387 | "\n", 388 | "Note that this is not the proper way of evaluating an RL agent, you should create an test environment and evaluate the agent performance in the callback (cf `EvalCallback`). For simplicity, we will be using the training reward as a proxy." 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "colab": {}, 396 | "colab_type": "code", 397 | "id": "IDI3lKTiiKP9" 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "import os\n", 402 | "\n", 403 | "import numpy as np\n", 404 | "\n", 405 | "from stable_baselines3.common.monitor import Monitor\n", 406 | "from stable_baselines3.common.vec_env import DummyVecEnv\n", 407 | "from stable_baselines3.common.env_util import make_vec_env\n", 408 | "from stable_baselines3.common.results_plotter import load_results, ts2xy" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": { 415 | "colab": {}, 416 | "colab_type": "code", 417 | "id": "nzMHj7r3h78m" 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "class SaveOnBestTrainingRewardCallback(BaseCallback):\n", 422 | " \"\"\"\n", 423 | " Callback for saving a model (the check is done every ``check_freq`` steps)\n", 424 | " based on the training reward (in practice, we recommend using ``EvalCallback``).\n", 425 | "\n", 426 | " :param check_freq: (int)\n", 427 | " :param log_dir: (str) Path to the folder where the model will be saved.\n", 428 | " It must contains the file created by the ``Monitor`` wrapper.\n", 429 | " :param verbose: (int)\n", 430 | " \"\"\"\n", 431 | "\n", 432 | " def __init__(self, check_freq, log_dir, verbose=1):\n", 433 | " super().__init__(verbose)\n", 434 | " self.check_freq = check_freq\n", 435 | " self.log_dir = log_dir\n", 436 | " self.save_path = os.path.join(log_dir, \"best_model\")\n", 437 | " self.best_mean_reward = -np.inf\n", 438 | "\n", 439 | " def _init_callback(self) -> None:\n", 440 | " # Create folder if needed\n", 441 | " if self.save_path is not None:\n", 442 | " os.makedirs(self.save_path, exist_ok=True)\n", 443 | "\n", 444 | " def _on_step(self) -> bool:\n", 445 | " if self.n_calls % self.check_freq == 0:\n", 446 | "\n", 447 | " # Retrieve training reward\n", 448 | " x, y = ts2xy(load_results(self.log_dir), \"timesteps\")\n", 449 | " if len(x) > 0:\n", 450 | " # Mean training reward over the last 100 episodes\n", 451 | " mean_reward = np.mean(y[-100:])\n", 452 | " if self.verbose > 0:\n", 453 | " print(\"Num timesteps: {}\".format(self.num_timesteps))\n", 454 | " print(\n", 455 | " \"Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}\".format(\n", 456 | " self.best_mean_reward, mean_reward\n", 457 | " )\n", 458 | " )\n", 459 | "\n", 460 | " # New best model, you could save the agent here\n", 461 | " if mean_reward > self.best_mean_reward:\n", 462 | " self.best_mean_reward = mean_reward\n", 463 | " # Example for saving best model\n", 464 | " if self.verbose > 0:\n", 465 | " print(\"Saving new best model at {} timesteps\".format(x[-1]))\n", 466 | " print(\"Saving new best model to {}.zip\".format(self.save_path))\n", 467 | " self.model.save(self.save_path)\n", 468 | "\n", 469 | " return True" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "colab": {}, 477 | "colab_type": "code", 478 | "id": "1TuYLBEaRUn0" 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "# Create log dir\n", 483 | "log_dir = \"/tmp/gym/\"\n", 484 | "os.makedirs(log_dir, exist_ok=True)\n", 485 | "\n", 486 | "# Create and wrap the environment\n", 487 | "env = make_vec_env(\"CartPole-v1\", n_envs=1, monitor_dir=log_dir)\n", 488 | "# it is equivalent to:\n", 489 | "# env = gym.make('CartPole-v1')\n", 490 | "# env = Monitor(env, log_dir)\n", 491 | "# env = DummyVecEnv([lambda: env])\n", 492 | "\n", 493 | "# Create Callback\n", 494 | "callback = SaveOnBestTrainingRewardCallback(check_freq=20, log_dir=log_dir, verbose=1)\n", 495 | "\n", 496 | "model = A2C(\"MlpPolicy\", env, verbose=0)\n", 497 | "model.learn(total_timesteps=5000, callback=callback)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": { 503 | "colab_type": "text", 504 | "id": "Mx18FkEORUn3" 505 | }, 506 | "source": [ 507 | "## Second example: Realtime plotting of performance\n", 508 | "While training, it is sometimes useful to how the training progresses over time, relative to the episodic reward.\n", 509 | "For this, Stable-Baselines has [Tensorboard support](https://stable-baselines.readthedocs.io/en/master/guide/tensorboard.html), however this can be very combersome, especially in disk space usage. \n", 510 | "\n", 511 | "**NOTE: Unfortunately live plotting does not work out of the box on google colab**\n", 512 | "\n", 513 | "Here, we can use callback again, to plot the episodic reward in realtime, using the monitoring wrapper:" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": { 520 | "colab": {}, 521 | "colab_type": "code", 522 | "id": "c0Bu1HWKRUn4" 523 | }, 524 | "outputs": [], 525 | "source": [ 526 | "import matplotlib.pyplot as plt\n", 527 | "import numpy as np\n", 528 | "%matplotlib notebook\n", 529 | "\n", 530 | "\n", 531 | "class PlottingCallback(BaseCallback):\n", 532 | " \"\"\"\n", 533 | " Callback for plotting the performance in realtime.\n", 534 | "\n", 535 | " :param verbose: (int)\n", 536 | " \"\"\"\n", 537 | " def __init__(self, verbose=1):\n", 538 | " super().__init__(verbose)\n", 539 | " self._plot = None\n", 540 | "\n", 541 | " def _on_step(self) -> bool:\n", 542 | " # get the monitor's data\n", 543 | " x, y = ts2xy(load_results(log_dir), 'timesteps')\n", 544 | " if self._plot is None: # make the plot\n", 545 | " plt.ion()\n", 546 | " fig = plt.figure(figsize=(6,3))\n", 547 | " ax = fig.add_subplot(111)\n", 548 | " line, = ax.plot(x, y)\n", 549 | " self._plot = (line, ax, fig)\n", 550 | " plt.show()\n", 551 | " else: # update and rescale the plot\n", 552 | " self._plot[0].set_data(x, y)\n", 553 | " self._plot[-2].relim()\n", 554 | " self._plot[-2].set_xlim([self.locals[\"total_timesteps\"] * -0.02, \n", 555 | " self.locals[\"total_timesteps\"] * 1.02])\n", 556 | " self._plot[-2].autoscale_view(True,True,True)\n", 557 | " self._plot[-1].canvas.draw()\n", 558 | " \n", 559 | "# Create log dir\n", 560 | "log_dir = \"/tmp/gym/\"\n", 561 | "os.makedirs(log_dir, exist_ok=True)\n", 562 | "\n", 563 | "# Create and wrap the environment\n", 564 | "env = make_vec_env('MountainCarContinuous-v0', n_envs=1, monitor_dir=log_dir)\n", 565 | "\n", 566 | "plotting_callback = PlottingCallback()\n", 567 | " \n", 568 | "model = PPO('MlpPolicy', env, verbose=0)\n", 569 | "model.learn(10000, callback=plotting_callback)" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": { 575 | "colab_type": "text", 576 | "id": "49RVX7ieRUn7" 577 | }, 578 | "source": [ 579 | "## Third example: Progress bar\n", 580 | "Quality of life improvement are always welcome when developping and using RL. Here, we used [tqdm](https://tqdm.github.io/) to show a progress bar of the training, along with number of timesteps per second and the estimated time remaining to the end of the training:\n", 581 | "\n", 582 | "Please note that this callback is already included in SB3 and can be used by passing `progress_bar=True` to the `learn()` method." 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": null, 588 | "metadata": { 589 | "colab": {}, 590 | "colab_type": "code", 591 | "id": "pXa8f6FsRUn8" 592 | }, 593 | "outputs": [], 594 | "source": [ 595 | "from tqdm.auto import tqdm\n", 596 | "\n", 597 | "\n", 598 | "class ProgressBarCallback(BaseCallback):\n", 599 | " \"\"\"\n", 600 | " :param pbar: (tqdm.pbar) Progress bar object\n", 601 | " \"\"\"\n", 602 | "\n", 603 | " def __init__(self, pbar):\n", 604 | " super().__init__()\n", 605 | " self._pbar = pbar\n", 606 | "\n", 607 | " def _on_step(self):\n", 608 | " # Update the progress bar:\n", 609 | " self._pbar.n = self.num_timesteps\n", 610 | " self._pbar.update(0)\n", 611 | "\n", 612 | "\n", 613 | "# this callback uses the 'with' block, allowing for correct initialisation and destruction\n", 614 | "class ProgressBarManager(object):\n", 615 | " def __init__(self, total_timesteps): # init object with total timesteps\n", 616 | " self.pbar = None\n", 617 | " self.total_timesteps = total_timesteps\n", 618 | "\n", 619 | " def __enter__(self): # create the progress bar and callback, return the callback\n", 620 | " self.pbar = tqdm(total=self.total_timesteps)\n", 621 | "\n", 622 | " return ProgressBarCallback(self.pbar)\n", 623 | "\n", 624 | " def __exit__(self, exc_type, exc_val, exc_tb): # close the callback\n", 625 | " self.pbar.n = self.total_timesteps\n", 626 | " self.pbar.update(0)\n", 627 | " self.pbar.close()\n", 628 | "\n", 629 | "\n", 630 | "model = TD3(\"MlpPolicy\", \"Pendulum-v1\", verbose=0)\n", 631 | "# Using a context manager garanties that the tqdm progress bar closes correctly\n", 632 | "with ProgressBarManager(2000) as callback:\n", 633 | " model.learn(2000, callback=callback)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "metadata": { 639 | "colab_type": "text", 640 | "id": "lBF4ij46RUoC" 641 | }, 642 | "source": [ 643 | "## Forth example: Composition\n", 644 | "Thanks to the functional nature of callbacks, it is possible to do a composition of callbacks, into a single callback. This means we can auto save our best model, show the progress bar and episodic reward of the training.\n", 645 | "\n", 646 | "The callbacks are automatically composed when you pass a list to the `learn()` method. Under the hood, a `CallbackList` is created." 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": { 653 | "colab": {}, 654 | "colab_type": "code", 655 | "id": "5hU3T9tkRUoD" 656 | }, 657 | "outputs": [], 658 | "source": [ 659 | "from stable_baselines3.common.callbacks import CallbackList\n", 660 | "\n", 661 | "# Create log dir\n", 662 | "log_dir = \"/tmp/gym/\"\n", 663 | "os.makedirs(log_dir, exist_ok=True)\n", 664 | "\n", 665 | "# Create and wrap the environment\n", 666 | "env = make_vec_env('CartPole-v1', n_envs=1, monitor_dir=log_dir)\n", 667 | "\n", 668 | "# Create callbacks\n", 669 | "auto_save_callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)\n", 670 | "\n", 671 | "model = PPO('MlpPolicy', env, verbose=0)\n", 672 | "with ProgressBarManager(1000) as progress_callback:\n", 673 | " # This is equivalent to callback=CallbackList([progress_callback, auto_save_callback])\n", 674 | " model.learn(1000, callback=[progress_callback, auto_save_callback])" 675 | ] 676 | }, 677 | { 678 | "cell_type": "markdown", 679 | "metadata": { 680 | "colab_type": "text", 681 | "id": "SRB4-qIxg_c9" 682 | }, 683 | "source": [ 684 | "## Exercise: Code your own callback\n", 685 | "\n", 686 | "\n", 687 | "The previous examples showed the basics of what is a callback and what you do with it.\n", 688 | "\n", 689 | "The goal of this exercise is to create a callback that will evaluate the model using a test environment and save it if this is the best known model.\n", 690 | "\n", 691 | "To make things easier, we are going to use a class instead of a function with the magic method `__call__`." 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": null, 697 | "metadata": { 698 | "colab": {}, 699 | "colab_type": "code", 700 | "id": "MOn0Sr3OhC2U" 701 | }, 702 | "outputs": [], 703 | "source": [ 704 | "class EvalCallback(BaseCallback):\n", 705 | " \"\"\"\n", 706 | " Callback for evaluating an agent.\n", 707 | "\n", 708 | " :param eval_env: (gym.Env) The environment used for initialization\n", 709 | " :param n_eval_episodes: (int) The number of episodes to test the agent\n", 710 | " :param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.\n", 711 | " \"\"\"\n", 712 | "\n", 713 | " def __init__(self, eval_env, n_eval_episodes=5, eval_freq=20):\n", 714 | " super().__init__()\n", 715 | " self.eval_env = eval_env\n", 716 | " self.n_eval_episodes = n_eval_episodes\n", 717 | " self.eval_freq = eval_freq\n", 718 | " self.best_mean_reward = -np.inf\n", 719 | "\n", 720 | " def _on_step(self):\n", 721 | " \"\"\"\n", 722 | " This method will be called by the model.\n", 723 | "\n", 724 | " :return: (bool)\n", 725 | " \"\"\"\n", 726 | "\n", 727 | " # self.n_calls is automatically updated because\n", 728 | " # we derive from BaseCallback\n", 729 | " if self.n_calls % self.eval_freq == 0:\n", 730 | " # === YOUR CODE HERE ===#\n", 731 | " # Evaluate the agent:\n", 732 | " # you need to do self.n_eval_episodes loop using self.eval_env\n", 733 | " # hint: you can use self.model.predict(obs, deterministic=True)\n", 734 | "\n", 735 | " # Save the agent if needed\n", 736 | " # and update self.best_mean_reward\n", 737 | "\n", 738 | " print(\"Best mean reward: {:.2f}\".format(self.best_mean_reward))\n", 739 | "\n", 740 | " # ====================== #\n", 741 | " return True" 742 | ] 743 | }, 744 | { 745 | "cell_type": "markdown", 746 | "metadata": { 747 | "colab_type": "text", 748 | "id": "IO0I81jAkQ0z" 749 | }, 750 | "source": [ 751 | "### Test your callback" 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": null, 757 | "metadata": { 758 | "colab": {}, 759 | "colab_type": "code", 760 | "id": "_OMop3TlkTbx" 761 | }, 762 | "outputs": [], 763 | "source": [ 764 | "# Env used for training\n", 765 | "env = gym.make(\"CartPole-v1\")\n", 766 | "# Env for evaluating the agent\n", 767 | "eval_env = gym.make(\"CartPole-v1\")\n", 768 | "\n", 769 | "# === YOUR CODE HERE ===#\n", 770 | "# Create the callback object\n", 771 | "callback = None\n", 772 | "\n", 773 | "# Create the RL model\n", 774 | "model = None\n", 775 | "\n", 776 | "# ====================== #\n", 777 | "\n", 778 | "# Train the RL model\n", 779 | "model.learn(int(100000), callback=callback)\n" 780 | ] 781 | }, 782 | { 783 | "cell_type": "markdown", 784 | "metadata": { 785 | "colab_type": "text", 786 | "id": "5wS20a_NfMAh" 787 | }, 788 | "source": [ 789 | "# Conclusion\n", 790 | "\n", 791 | "\n", 792 | "In this notebook we have seen:\n", 793 | "- that good hyperparameters are key to the success of RL, you should not except the default ones to work on every problems\n", 794 | "- what is a callback and what you can do with it\n", 795 | "- how to create your own callback\n" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": null, 801 | "metadata": { 802 | "colab": {}, 803 | "colab_type": "code", 804 | "id": "uA4gCDtogIaD" 805 | }, 806 | "outputs": [], 807 | "source": [] 808 | } 809 | ], 810 | "metadata": { 811 | "accelerator": "GPU", 812 | "colab": { 813 | "collapsed_sections": [], 814 | "include_colab_link": true, 815 | "name": "4_callbacks_hyperparameter_tuning.ipynb", 816 | "provenance": [] 817 | }, 818 | "kernelspec": { 819 | "display_name": "Python 3 (ipykernel)", 820 | "language": "python", 821 | "name": "python3" 822 | }, 823 | "language_info": { 824 | "codemirror_mode": { 825 | "name": "ipython", 826 | "version": 3 827 | }, 828 | "file_extension": ".py", 829 | "mimetype": "text/x-python", 830 | "name": "python", 831 | "nbconvert_exporter": "python", 832 | "pygments_lexer": "ipython3", 833 | "version": "3.10.9" 834 | }, 835 | "vscode": { 836 | "interpreter": { 837 | "hash": "3201c96db5836b171d01fee72ea1be894646622d4b41771abf25c98b548a611d" 838 | } 839 | } 840 | }, 841 | "nbformat": 4, 842 | "nbformat_minor": 4 843 | } 844 | -------------------------------------------------------------------------------- /5_custom_gym_env.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "AoxOjIlOImwx" 18 | }, 19 | "source": [ 20 | "# Stable Baselines3 Tutorial - Creating a custom Gym environment\n", 21 | "\n", 22 | "Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/\n", 23 | "\n", 24 | "Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3\n", 25 | "\n", 26 | "Documentation: https://stable-baselines3.readthedocs.io/en/master/\n", 27 | "\n", 28 | "SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib\n", 29 | "\n", 30 | "RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo\n", 31 | "\n", 32 | "\n", 33 | "## Introduction\n", 34 | "\n", 35 | "In this notebook, you will learn how to use your own environment following the OpenAI Gym interface.\n", 36 | "Once it is done, you can easily use any compatible (depending on the action space) RL algorithm from Stable Baselines on that environment.\n", 37 | "\n", 38 | "## Install Dependencies and Stable Baselines3 Using Pip\n", 39 | "\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# for autoformatting\n", 49 | "# %load_ext jupyter_black" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": { 56 | "colab": {}, 57 | "colab_type": "code", 58 | "id": "Sp8rSS4DIhEV" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "!pip install \"stable-baselines3[extra]>=2.0.0a4\"" 63 | ] 64 | }, 65 | { 66 | "attachments": {}, 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "colab_type": "text", 70 | "id": "rzevZcgmJmhi" 71 | }, 72 | "source": [ 73 | "## First steps with the gym interface\n", 74 | "\n", 75 | "As you have noticed in the previous notebooks, an environment that follows the gym interface is quite simple to use.\n", 76 | "It provides to this user mainly three methods, which have the following signature (for gym versions > 0.26)\n", 77 | "- `reset()` called at the beginning of an episode, it returns an observation and a dictionary with additional info (defaults to an empty dict)\n", 78 | "- `step(action)` called to take an action with the environment, it returns the next observation, the immediate reward, whether new state is a terminal state (episode is finished), whether the max number of timesteps is reached (episode is artificially finished), and additional information\n", 79 | "- (Optional) `render()` which allow to visualize the agent in action. Note that graphical interface does not work on google colab, so we cannot use it directly (we have to rely on `render_mode='rbg_array'` to retrieve an image of the scene).\n", 80 | "\n", 81 | "Under the hood, it also contains two useful properties:\n", 82 | "- `observation_space` which one of the gym spaces (`Discrete`, `Box`, ...) and describe the type and shape of the observation\n", 83 | "- `action_space` which is also a gym space object that describes the action space, so the type of action that can be taken\n", 84 | "\n", 85 | "The best way to learn about [gym spaces](https://gymnasium.farama.org/api/spaces/) is to look at the [source code](https://github.com/Farama-Foundation/Gymnasium/tree/main/gymnasium/spaces), but you need to know at least the main ones:\n", 86 | "- `gym.spaces.Box`: A (possibly unbounded) box in $R^n$. Specifically, a Box represents the Cartesian product of n closed intervals. Each interval has the form of one of [a, b], (-oo, b], [a, oo), or (-oo, oo). Example: A 1D-Vector or an image observation can be described with the Box space.\n", 87 | "```python\n", 88 | "# Example for using image as input:\n", 89 | "observation_space = spaces.Box(low=0, high=255, shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)\n", 90 | "``` \n", 91 | "\n", 92 | "- `gym.spaces.Discrete`: A discrete space in $\\{ 0, 1, \\dots, n-1 \\}$\n", 93 | " Example: if you have two actions (\"left\" and \"right\") you can represent your action space using `Discrete(2)`, the first action will be 0 and the second 1.\n", 94 | "\n", 95 | "\n", 96 | "[Documentation on custom env](https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html)\n", 97 | "\n", 98 | "Also keep in mind that Stabe-baselines internally uses the previous gym API (<0.26), so every VecEnv returns only the observation after resetting and returns a 4-tuple instead of a 5-tuple (terminated & truncated are already combined to done)." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "colab": {}, 106 | "colab_type": "code", 107 | "id": "I98IKKyNJl6K" 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "import gymnasium as gym\n", 112 | "\n", 113 | "env = gym.make(\"CartPole-v1\")\n", 114 | "\n", 115 | "# Box(4,) means that it is a Vector with 4 components\n", 116 | "print(\"Observation space:\", env.observation_space)\n", 117 | "print(\"Shape:\", env.observation_space.shape)\n", 118 | "# Discrete(2) means that there is two discrete actions\n", 119 | "print(\"Action space:\", env.action_space)\n", 120 | "\n", 121 | "# The reset method is called at the beginning of an episode\n", 122 | "obs, info = env.reset()\n", 123 | "# Sample a random action\n", 124 | "action = env.action_space.sample()\n", 125 | "print(\"Sampled action:\", action)\n", 126 | "obs, reward, terminated, truncated, info = env.step(action)\n", 127 | "# Note the obs is a numpy array\n", 128 | "# info is an empty dict for now but can contain any debugging info\n", 129 | "# reward is a scalar\n", 130 | "print(obs.shape, reward, terminated, truncated, info)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "colab_type": "text", 137 | "id": "RqxatIwPOXe_" 138 | }, 139 | "source": [ 140 | "## Gym env skeleton\n", 141 | "\n", 142 | "In practice this is how a gym environment looks like.\n", 143 | "Here, we have implemented a simple grid world were the agent must learn to go always left." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "colab": {}, 151 | "colab_type": "code", 152 | "id": "rYzDXA9vJfz1" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "import numpy as np\n", 157 | "import gymnasium as gym\n", 158 | "from gymnasium import spaces\n", 159 | "\n", 160 | "\n", 161 | "class GoLeftEnv(gym.Env):\n", 162 | " \"\"\"\n", 163 | " Custom Environment that follows gym interface.\n", 164 | " This is a simple env where the agent must learn to go always left.\n", 165 | " \"\"\"\n", 166 | "\n", 167 | " # Because of google colab, we cannot implement the GUI ('human' render mode)\n", 168 | " metadata = {\"render_modes\": [\"console\"]}\n", 169 | "\n", 170 | " # Define constants for clearer code\n", 171 | " LEFT = 0\n", 172 | " RIGHT = 1\n", 173 | "\n", 174 | " def __init__(self, grid_size=10, render_mode=\"console\"):\n", 175 | " super(GoLeftEnv, self).__init__()\n", 176 | " self.render_mode = render_mode\n", 177 | "\n", 178 | " # Size of the 1D-grid\n", 179 | " self.grid_size = grid_size\n", 180 | " # Initialize the agent at the right of the grid\n", 181 | " self.agent_pos = grid_size - 1\n", 182 | "\n", 183 | " # Define action and observation space\n", 184 | " # They must be gym.spaces objects\n", 185 | " # Example when using discrete actions, we have two: left and right\n", 186 | " n_actions = 2\n", 187 | " self.action_space = spaces.Discrete(n_actions)\n", 188 | " # The observation will be the coordinate of the agent\n", 189 | " # this can be described both by Discrete and Box space\n", 190 | " self.observation_space = spaces.Box(\n", 191 | " low=0, high=self.grid_size, shape=(1,), dtype=np.float32\n", 192 | " )\n", 193 | "\n", 194 | " def reset(self, seed=None, options=None):\n", 195 | " \"\"\"\n", 196 | " Important: the observation must be a numpy array\n", 197 | " :return: (np.array)\n", 198 | " \"\"\"\n", 199 | " super().reset(seed=seed, options=options)\n", 200 | " # Initialize the agent at the right of the grid\n", 201 | " self.agent_pos = self.grid_size - 1\n", 202 | " # here we convert to float32 to make it more general (in case we want to use continuous actions)\n", 203 | " return np.array([self.agent_pos]).astype(np.float32), {} # empty info dict\n", 204 | "\n", 205 | " def step(self, action):\n", 206 | " if action == self.LEFT:\n", 207 | " self.agent_pos -= 1\n", 208 | " elif action == self.RIGHT:\n", 209 | " self.agent_pos += 1\n", 210 | " else:\n", 211 | " raise ValueError(\n", 212 | " f\"Received invalid action={action} which is not part of the action space\"\n", 213 | " )\n", 214 | "\n", 215 | " # Account for the boundaries of the grid\n", 216 | " self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size)\n", 217 | "\n", 218 | " # Are we at the left of the grid?\n", 219 | " terminated = bool(self.agent_pos == 0)\n", 220 | " truncated = False # we do not limit the number of steps here\n", 221 | "\n", 222 | " # Null reward everywhere except when reaching the goal (left of the grid)\n", 223 | " reward = 1 if self.agent_pos == 0 else 0\n", 224 | "\n", 225 | " # Optionally we can pass additional info, we are not using that for now\n", 226 | " info = {}\n", 227 | "\n", 228 | " return (\n", 229 | " np.array([self.agent_pos]).astype(np.float32),\n", 230 | " reward,\n", 231 | " terminated,\n", 232 | " truncated,\n", 233 | " info,\n", 234 | " )\n", 235 | "\n", 236 | " def render(self):\n", 237 | " # agent is represented as a cross, rest as a dot\n", 238 | " if self.render_mode == \"console\":\n", 239 | " print(\".\" * self.agent_pos, end=\"\")\n", 240 | " print(\"x\", end=\"\")\n", 241 | " print(\".\" * (self.grid_size - self.agent_pos))\n", 242 | "\n", 243 | " def close(self):\n", 244 | " pass" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "colab_type": "text", 251 | "id": "Zy5mlho1-Ine" 252 | }, 253 | "source": [ 254 | "### Validate the environment\n", 255 | "\n", 256 | "Stable Baselines3 provides a [helper](https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html) to check that your environment follows the Gym interface. It also optionally checks that the environment is compatible with Stable-Baselines (and emits warning if necessary)." 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": { 263 | "colab": {}, 264 | "colab_type": "code", 265 | "id": "9DOpP_B0-LXm" 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "from stable_baselines3.common.env_checker import check_env" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": { 276 | "colab": {}, 277 | "colab_type": "code", 278 | "id": "1CcUVatq-P0l" 279 | }, 280 | "outputs": [], 281 | "source": [ 282 | "env = GoLeftEnv()\n", 283 | "# If the environment don't follow the interface, an error will be thrown\n", 284 | "check_env(env, warn=True)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": { 290 | "colab_type": "text", 291 | "id": "eJ3khFtkSE0g" 292 | }, 293 | "source": [ 294 | "### Testing the environment" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": { 301 | "colab": {}, 302 | "colab_type": "code", 303 | "id": "i62yf2LvSAYY" 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "env = GoLeftEnv(grid_size=10)\n", 308 | "\n", 309 | "obs, _ = env.reset()\n", 310 | "env.render()\n", 311 | "\n", 312 | "print(env.observation_space)\n", 313 | "print(env.action_space)\n", 314 | "print(env.action_space.sample())\n", 315 | "\n", 316 | "GO_LEFT = 0\n", 317 | "# Hardcoded best agent: always go left!\n", 318 | "n_steps = 20\n", 319 | "for step in range(n_steps):\n", 320 | " print(f\"Step {step + 1}\")\n", 321 | " obs, reward, terminated, truncated, info = env.step(GO_LEFT)\n", 322 | " done = terminated or truncated\n", 323 | " print(\"obs=\", obs, \"reward=\", reward, \"done=\", done)\n", 324 | " env.render()\n", 325 | " if done:\n", 326 | " print(\"Goal reached!\", \"reward=\", reward)\n", 327 | " break" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": { 333 | "colab_type": "text", 334 | "id": "Pv1e1qJETfHU" 335 | }, 336 | "source": [ 337 | "### Try it with Stable-Baselines\n", 338 | "\n", 339 | "Once your environment follow the gym interface, it is quite easy to plug in any algorithm from stable-baselines" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": { 346 | "colab": {}, 347 | "colab_type": "code", 348 | "id": "PQfLBE28SNDr" 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "from stable_baselines3 import PPO, A2C, DQN\n", 353 | "from stable_baselines3.common.env_util import make_vec_env\n", 354 | "\n", 355 | "# Instantiate the env\n", 356 | "vec_env = make_vec_env(GoLeftEnv, n_envs=1, env_kwargs=dict(grid_size=10))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": { 363 | "colab": {}, 364 | "colab_type": "code", 365 | "id": "zRV4Q7FVUKB6" 366 | }, 367 | "outputs": [], 368 | "source": [ 369 | "# Train the agent\n", 370 | "model = A2C(\"MlpPolicy\", env, verbose=1).learn(5000)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": { 377 | "colab": {}, 378 | "colab_type": "code", 379 | "id": "BJbeiF0RUN-p" 380 | }, 381 | "outputs": [], 382 | "source": [ 383 | "# Test the trained agent\n", 384 | "# using the vecenv\n", 385 | "obs = vec_env.reset()\n", 386 | "n_steps = 20\n", 387 | "for step in range(n_steps):\n", 388 | " action, _ = model.predict(obs, deterministic=True)\n", 389 | " print(f\"Step {step + 1}\")\n", 390 | " print(\"Action: \", action)\n", 391 | " obs, reward, done, info = vec_env.step(action)\n", 392 | " print(\"obs=\", obs, \"reward=\", reward, \"done=\", done)\n", 393 | " vec_env.render()\n", 394 | " if done:\n", 395 | " # Note that the VecEnv resets automatically\n", 396 | " # when a done signal is encountered\n", 397 | " print(\"Goal reached!\", \"reward=\", reward)\n", 398 | " break" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": { 404 | "colab_type": "text", 405 | "id": "jOggIa9sU--b" 406 | }, 407 | "source": [ 408 | "## It is your turn now, be creative!\n", 409 | "\n", 410 | "As an exercise, that's now your turn to build a custom gym environment.\n", 411 | "There is no constrain about what to do, be creative! (but not too creative, there is not enough time for that)\n", 412 | "\n", 413 | "If you don't have any idea, here is is a list of the environment you can implement:\n", 414 | "- Transform the discrete grid world to a continuous one, you will need to change a bit the logic and the action space\n", 415 | "- Create a 2D grid world and add walls\n", 416 | "- Create a tic-tac-toe game\n" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": { 423 | "colab": {}, 424 | "colab_type": "code", 425 | "id": "lBDp4Pm-Uh4D" 426 | }, 427 | "outputs": [], 428 | "source": [] 429 | } 430 | ], 431 | "metadata": { 432 | "accelerator": "GPU", 433 | "colab": { 434 | "collapsed_sections": [], 435 | "include_colab_link": true, 436 | "name": "5.custom_gym_env.ipynb", 437 | "provenance": [] 438 | }, 439 | "kernelspec": { 440 | "display_name": "Python 3 (ipykernel)", 441 | "language": "python", 442 | "name": "python3" 443 | }, 444 | "language_info": { 445 | "codemirror_mode": { 446 | "name": "ipython", 447 | "version": 3 448 | }, 449 | "file_extension": ".py", 450 | "mimetype": "text/x-python", 451 | "name": "python", 452 | "nbconvert_exporter": "python", 453 | "pygments_lexer": "ipython3", 454 | "version": "3.10.9" 455 | }, 456 | "vscode": { 457 | "interpreter": { 458 | "hash": "3201c96db5836b171d01fee72ea1be894646622d4b41771abf25c98b548a611d" 459 | } 460 | } 461 | }, 462 | "nbformat": 4, 463 | "nbformat_minor": 4 464 | } 465 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Antonin RAFFIN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Baselines3 RL tutorial 2 | 3 | Stable-Baselines reinforcement learning tutorial for Journées Nationales de la Recherche en Robotique 2019. 4 | 5 | Website: https://jnrr2019.loria.fr/ 6 | 7 | Slides: https://araffin.github.io/slides/rl-tuto-jnrr19/#/ 8 | 9 | Stable-Baselines3 repo: https://github.com/DLR-RM/stable-baselines3 10 | 11 | Documentation: https://stable-baselines3.readthedocs.io/en/master/ 12 | 13 | RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo 14 | 15 | This tutorial was created by [Edward Beeching](https://github.com/edbeeching), [Ashley Hill](https://github.com/hill-a) and [Antonin Raffin](https://araffin.github.io/) 16 | 17 | ## Content 18 | 19 | 1. Getting Started [Colab Notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/1_getting_started.ipynb) 20 | 2. Gym Wrappers, saving and loading models [Colab Notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/2_gym_wrappers_saving_loading.ipynb) 21 | 3. Multiprocessing [Colab Notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/3_multiprocessing.ipynb) 22 | 4. Callbacks and hyperparameter tuning [Colab Notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/4_callbacks_hyperparameter_tuning.ipynb) 23 | 5. Creating a custom gym environment [Colab Notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb) 24 | 25 | ## Bonus 26 | 27 | RL baselines zoo: [Colab Notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb) 28 | 29 | ## Contributors 30 | 31 | We would like to thanks our contributors: [@rbahumi](https://github.com/rbahumi) [@stefanbschneider](https://github.com/stefanbschneider) 32 | --------------------------------------------------------------------------------