├── .github ├── ISSUE_TEMPLATE │ └── issue-template.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── conftest.py ├── docs ├── _static │ ├── css │ │ └── baselines_theme.css │ └── img │ │ └── logo.png └── misc │ └── changelog.rst ├── setup.cfg ├── setup.py ├── stable_baselines ├── __init__.py ├── common │ ├── __init__.py │ ├── base_class.py │ ├── buffers.py │ ├── distributions.py │ ├── evaluation.py │ ├── logger.py │ ├── monitor.py │ ├── noise.py │ ├── policies.py │ ├── running_mean_std.py │ ├── save_util.py │ ├── tile_images.py │ ├── utils.py │ └── vec_env │ │ ├── __init__.py │ │ ├── base_vec_env.py │ │ ├── dummy_vec_env.py │ │ ├── subproc_vec_env.py │ │ ├── util.py │ │ ├── vec_check_nan.py │ │ ├── vec_frame_stack.py │ │ ├── vec_normalize.py │ │ └── vec_video_recorder.py ├── ppo │ ├── __init__.py │ ├── policies.py │ └── ppo.py ├── py.typed └── td3 │ ├── __init__.py │ ├── policies.py │ └── td3.py └── tests ├── __init__.py └── test_run.py /.github/ISSUE_TEMPLATE/issue-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Issue Template 3 | about: How to create an issue for this repository 4 | 5 | --- 6 | 7 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. 8 | 9 | If you have any questions, feel free to create an issue with the tag [question]. 10 | If you wish to suggest an enhancement or feature request, add the tag [feature request]. 11 | If you are submitting a bug report, please fill in the following details. 12 | 13 | If your issue is related to a custom gym environment, please check it first using: 14 | 15 | ```python 16 | from stable_baselines.common.env_checker import check_env 17 | 18 | env = CustomEnv(arg1, ...) 19 | # It will check your custom environment and output additional warnings if needed 20 | check_env(env) 21 | ``` 22 | 23 | **Describe the bug** 24 | A clear and concise description of what the bug is. 25 | 26 | **Code example** 27 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. 28 | 29 | Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) 30 | for both code and stack traces. 31 | 32 | ```python 33 | from stable_baselines import ... 34 | 35 | ``` 36 | 37 | ```bash 38 | Traceback (most recent call last): File ... 39 | 40 | ``` 41 | 42 | **System Info** 43 | Describe the characteristic of your environment: 44 | * Describe how the library was installed (pip, docker, source, ...) 45 | * GPU models and configuration 46 | * Python version 47 | * Tensorflow version 48 | * Versions of any other relevant libraries 49 | 50 | **Additional context** 51 | Add any other context about the problem here. 52 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | ## Motivation and Context 7 | 8 | 9 | 10 | - [ ] I have raised an issue to propose this change ([required](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) for new features and bug fixes) 11 | 12 | ## Types of changes 13 | 14 | - [ ] Bug fix (non-breaking change which fixes an issue) 15 | - [ ] New feature (non-breaking change which adds functionality) 16 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 17 | - [ ] Documentation (update in the documentation) 18 | 19 | ## Checklist: 20 | 21 | 22 | - [ ] I've read the [CONTRIBUTION](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) guide (**required**) 23 | - [ ] I have updated the [changelog](https://github.com/hill-a/stable-baselines/blob/master/docs/misc/changelog.rst) accordingly (**required**). 24 | - [ ] My change requires a change to the documentation. 25 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). 26 | - [ ] I have updated the documentation accordingly. 27 | - [ ] I have ensured `pytest` and `pytype` both pass. 28 | 29 | 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.pkl 4 | *.py~ 5 | *.bak 6 | .pytest_cache 7 | .pytype 8 | .DS_Store 9 | .idea 10 | .coverage 11 | .coverage.* 12 | __pycache__/ 13 | _build/ 14 | *.npz 15 | *.zip 16 | *.lprof 17 | 18 | # Setuptools distribution and build folders. 19 | /dist/ 20 | /build 21 | keys/ 22 | 23 | # Virtualenv 24 | /env 25 | /venv 26 | 27 | *.sublime-project 28 | *.sublime-workspace 29 | 30 | logs/ 31 | 32 | .ipynb_checkpoints 33 | ghostdriver.log 34 | 35 | htmlcov 36 | 37 | junk 38 | src 39 | 40 | *.egg-info 41 | .cache 42 | 43 | MUJOCO_LOG.TXT 44 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to Stable-Baselines 2 | 3 | If you are interested in contributing to Stable-Baselines, your contributions will fall 4 | into two categories: 5 | 1. You want to propose a new Feature and implement it 6 | - Create an issue about your intended feature, and we shall discuss the design and 7 | implementation. Once we agree that the plan looks good, go ahead and implement it. 8 | 2. You want to implement a feature or bug-fix for an outstanding issue 9 | - Look at the outstanding issues here: https://github.com/hill-a/stable-baselines/issues 10 | - Look at the roadmap here: https://github.com/hill-a/stable-baselines/projects/1 11 | - Pick an issue or feature and comment on the task that you want to work on this feature. 12 | - If you need more context on a particular issue, please ask and we shall provide. 13 | 14 | Once you finish implementing a feature or bug-fix, please send a Pull Request to 15 | https://github.com/hill-a/stable-baselines/ 16 | 17 | 18 | If you are not familiar with creating a Pull Request, here are some guides: 19 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request 20 | - https://help.github.com/articles/creating-a-pull-request/ 21 | 22 | 23 | ## Developing Stable-Baselines 24 | 25 | To develop Stable-Baselines on your machine, here are some tips: 26 | 27 | 1. Clone a copy of Stable-Baselines from source: 28 | 29 | ```bash 30 | git clone https://github.com/hill-a/stable-baselines/ 31 | cd stable-baselines 32 | ``` 33 | 34 | 2. Install Stable-Baselines in develop mode, with support for building the docs and running tests: 35 | 36 | ```bash 37 | pip install -e .[docs,tests] 38 | ``` 39 | 40 | ## Codestyle 41 | 42 | We follow the [PEP8 codestyle](https://www.python.org/dev/peps/pep-0008/). Please order the imports as follows: 43 | 44 | 1. built-in 45 | 2. packages 46 | 3. current module 47 | 48 | with one space between each, that gives for instance: 49 | ```python 50 | import os 51 | import warnings 52 | 53 | import numpy as np 54 | 55 | from stable_baselines import PPO2 56 | ``` 57 | 58 | In general, we recommend using pycharm to format everything in an efficient way. 59 | 60 | Please documentation each function/method using the following template: 61 | 62 | ```python 63 | 64 | def my_function(arg1, arg2): 65 | """ 66 | Short description of the function. 67 | 68 | :param arg1: (arg1 type) describe what is arg1 69 | :param arg2: (arg2 type) describe what is arg2 70 | :return: (return type) describe what is returned 71 | """ 72 | ... 73 | return my_variable 74 | ``` 75 | 76 | ## Pull Request (PR) 77 | 78 | Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process. 79 | 80 | Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a , @araffin or @erniejunior ). 81 | A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch. 82 | 83 | Note: in rare cases, we can create exception for codacy failure. 84 | 85 | ## Test 86 | 87 | All new features must add tests in the `tests/` folder ensuring that everything works fine. 88 | We use [pytest](https://pytest.org/). 89 | Also, when a bug fix is proposed, tests should be added to avoid regression. 90 | 91 | To run tests with `pytest` and type checking with `pytype`: 92 | 93 | ``` 94 | ./scripts/run_tests.sh 95 | ``` 96 | 97 | ## Changelog and Documentation 98 | 99 | Please do not forget to update the changelog and add documentation if needed. 100 | A README is present in the `docs/` folder for instructions on how to build the documentation. 101 | 102 | 103 | Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one. 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | Copyright (c) 2018-2020 Stable-Baselines Team 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Stable Baselines TF2 [Experimental] 4 | 5 | [Stable-Baselines TF1 version](https://github.com/hill-a/stable-baselines/). 6 | 7 | 8 | ## Citing the Project 9 | 10 | To cite this repository in publications: 11 | 12 | ``` 13 | @misc{stable-baselines, 14 | author = {Hill, Ashley and Raffin, Antonin and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Traore, Rene and Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai}, 15 | title = {Stable Baselines}, 16 | year = {2018}, 17 | publisher = {GitHub}, 18 | journal = {GitHub repository}, 19 | howpublished = {\url{https://github.com/hill-a/stable-baselines}}, 20 | } 21 | ``` 22 | 23 | ## Maintainers 24 | 25 | Stable-Baselines is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/erniejunior) (aka @erniejunior), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli). 26 | 27 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. 28 | 29 | 30 | ## How To Contribute 31 | 32 | To any interested in making the baselines better, there is still some documentation that needs to be done. 33 | If you want to contribute, please read **CONTRIBUTING.md** guide first. 34 | 35 | 36 | ## Acknowledgments 37 | 38 | Stable Baselines was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en). 39 | 40 | Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/) 41 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | """Configures pytest to ignore certain unit tests unless the appropriate flag is passed. 2 | 3 | --rungpu: tests that require GPU. 4 | --expensive: tests that take a long time to run (e.g. training an RL algorithm for many timesteps).""" 5 | 6 | import pytest 7 | 8 | 9 | def pytest_addoption(parser): 10 | parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests") 11 | parser.addoption("--expensive", action="store_true", 12 | help="run expensive tests (which are otherwise skipped).") 13 | 14 | 15 | def pytest_collection_modifyitems(config, items): 16 | flags = {'gpu': '--rungpu', 'expensive': '--expensive'} 17 | skips = {keyword: pytest.mark.skip(reason="need {} option to run".format(flag)) 18 | for keyword, flag in flags.items() if not config.getoption(flag)} 19 | for item in items: 20 | for keyword, skip in skips.items(): 21 | if keyword in item.keywords: 22 | item.add_marker(skip) 23 | -------------------------------------------------------------------------------- /docs/_static/css/baselines_theme.css: -------------------------------------------------------------------------------- 1 | /* Main colors from https://color.adobe.com/fr/Copy-of-NOUEBO-Original-color-theme-11116609 */ 2 | :root{ 3 | --main-bg-color: #324D5C; 4 | --link-color: #14B278; 5 | } 6 | 7 | /* Header fonts y */ 8 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 9 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 10 | } 11 | 12 | 13 | /* Docs background */ 14 | .wy-side-nav-search{ 15 | background-color: var(--main-bg-color); 16 | } 17 | 18 | /* Mobile version */ 19 | .wy-nav-top{ 20 | background-color: var(--main-bg-color); 21 | } 22 | 23 | /* Change link colors (except for the menu) */ 24 | a { 25 | color: var(--link-color); 26 | } 27 | 28 | a:hover { 29 | color: #4F778F; 30 | } 31 | 32 | .wy-menu a { 33 | color: #b3b3b3; 34 | } 35 | 36 | .wy-menu a:hover { 37 | color: #b3b3b3; 38 | } 39 | 40 | a.icon.icon-home { 41 | color: #b3b3b3; 42 | } 43 | 44 | .version{ 45 | color: var(--link-color) !important; 46 | } 47 | 48 | 49 | /* Make code blocks have a background */ 50 | .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { 51 | background: #f8f8f8;; 52 | } 53 | -------------------------------------------------------------------------------- /docs/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/docs/_static/img/logo.png -------------------------------------------------------------------------------- /docs/misc/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changelog: 2 | 3 | Changelog 4 | ========== 5 | 6 | For download links, please look at `Github release page `_. 7 | 8 | Pre-Release 3.0.0a0 (WIP) 9 | -------------------------- 10 | 11 | **TensorFlow 2 Version** 12 | 13 | Breaking Changes: 14 | ^^^^^^^^^^^^^^^^^ 15 | - Drop support for tensorflow 1.x, TensorFlow >=2.1.0 is required 16 | - New dependency: tensorflow-probability>=0.8.0 is now required 17 | - Drop support for pretrain, in favor of https://github.com/HumanCompatibleAI/imitation 18 | - Drop support of GAIL, in favor of https://github.com/HumanCompatibleAI/imitation 19 | - Drop support of MPI 20 | 21 | New Features: 22 | ^^^^^^^^^^^^^ 23 | 24 | Bug Fixes: 25 | ^^^^^^^^^^ 26 | 27 | Deprecations: 28 | ^^^^^^^^^^^^^ 29 | 30 | Others: 31 | ^^^^^^^ 32 | 33 | Documentation: 34 | ^^^^^^^^^^^^^^ 35 | 36 | 37 | Pre-Release 2.10.0a0 (WIP) 38 | -------------------------- 39 | 40 | Breaking Changes: 41 | ^^^^^^^^^^^^^^^^^ 42 | 43 | New Features: 44 | ^^^^^^^^^^^^^ 45 | - Parallelized updating and sampling from the replay buffer in DQN. (@flodorner) 46 | - Docker build script, `scripts/build_docker.sh`, can push images automatically. 47 | - Added callback collection 48 | 49 | Bug Fixes: 50 | ^^^^^^^^^^ 51 | 52 | - Fixed Docker build script, `scripts/build_docker.sh`, to pass `USE_GPU` build argument. 53 | 54 | Deprecations: 55 | ^^^^^^^^^^^^^ 56 | 57 | Others: 58 | ^^^^^^^ 59 | - Removed redundant return value from `a2c.utils::total_episode_reward_logger`. (@shwang) 60 | 61 | Documentation: 62 | ^^^^^^^^^^^^^^ 63 | - Add dedicated page for callbacks 64 | 65 | 66 | Release 2.9.0 (2019-12-20) 67 | -------------------------- 68 | 69 | *Reproducible results, automatic `VecEnv` wrapping, env checker and more usability improvements* 70 | 71 | Breaking Changes: 72 | ^^^^^^^^^^^^^^^^^ 73 | - The `seed` argument has been moved from `learn()` method to model constructor 74 | in order to have reproducible results 75 | - `allow_early_resets` of the `Monitor` wrapper now default to `True` 76 | - `make_atari_env` now returns a `DummyVecEnv` by default (instead of a `SubprocVecEnv`) 77 | this usually improves performance. 78 | - Fix inconsistency of sample type, so that mode/sample function returns tensor of tf.int64 in CategoricalProbabilityDistribution/MultiCategoricalProbabilityDistribution (@seheevic) 79 | 80 | New Features: 81 | ^^^^^^^^^^^^^ 82 | - Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow 83 | - Environments are automatically wrapped in a `DummyVecEnv` if needed when passing them to the model constructor 84 | - Added `stable_baselines.common.make_vec_env` helper to simplify VecEnv creation 85 | - Added `stable_baselines.common.evaluation.evaluate_policy` helper to simplify model evaluation 86 | - `VecNormalize` changes: 87 | 88 | - Now supports being pickled and unpickled (@AdamGleave). 89 | - New methods `.normalize_obs(obs)` and `normalize_reward(rews)` apply normalization 90 | to arbitrary observation or rewards without updating statistics (@shwang) 91 | - `.get_original_reward()` returns the unnormalized rewards from the most recent timestep 92 | - `.reset()` now collects observation statistics (used to only apply normalization) 93 | 94 | - Add parameter `exploration_initial_eps` to DQN. (@jdossgollin) 95 | - Add type checking and PEP 561 compliance. 96 | Note: most functions are still not annotated, this will be a gradual process. 97 | - DDPG, TD3 and SAC accept non-symmetric action spaces. (@Antymon) 98 | - Add `check_env` util to check if a custom environment follows the gym interface (@araffin and @justinkterry) 99 | 100 | Bug Fixes: 101 | ^^^^^^^^^^ 102 | - Fix seeding, so it is now possible to have deterministic results on cpu 103 | - Fix a bug in DDPG where `predict` method with `deterministic=False` would fail 104 | - Fix a bug in TRPO: mean_losses was not initialized causing the logger to crash when there was no gradients (@MarvineGothic) 105 | - Fix a bug in `cmd_util` from API change in recent Gym versions 106 | - Fix a bug in DDPG, TD3 and SAC where warmup and random exploration actions would end up scaled in the replay buffer (@Antymon) 107 | 108 | Deprecations: 109 | ^^^^^^^^^^^^^ 110 | - `nprocs` (ACKTR) and `num_procs` (ACER) are deprecated in favor of `n_cpu_tf_sess` which is now common 111 | to all algorithms 112 | - `VecNormalize`: `load_running_average` and `save_running_average` are deprecated in favour of using pickle. 113 | 114 | Others: 115 | ^^^^^^^ 116 | - Add upper bound for Tensorflow version (<2.0.0). 117 | - Refactored test to remove duplicated code 118 | - Add pull request template 119 | - Replaced redundant code in load_results (@jbulow) 120 | - Minor PEP8 fixes in dqn.py (@justinkterry) 121 | - Add a message to the assert in `PPO2` 122 | - Update replay buffer doctring 123 | - Fix `VecEnv` docstrings 124 | 125 | Documentation: 126 | ^^^^^^^^^^^^^^ 127 | - Add plotting to the Monitor example (@rusu24edward) 128 | - Add Snake Game AI project (@pedrohbtp) 129 | - Add note on the support Tensorflow versions. 130 | - Remove unnecessary steps required for Windows installation. 131 | - Remove `DummyVecEnv` creation when not needed 132 | - Added `make_vec_env` to the examples to simplify VecEnv creation 133 | - Add QuaRL project (@srivatsankrishnan) 134 | - Add Pwnagotchi project (@evilsocket) 135 | - Fix multiprocessing example (@rusu24edward) 136 | - Fix `result_plotter` example 137 | - Add JNRR19 tutorial (by @edbeeching, @hill-a and @araffin) 138 | - Updated notebooks link 139 | - Fix typo in algos.rst, "containes" to "contains" (@SyllogismRXS) 140 | - Fix outdated source documentation for load_results 141 | - Add PPO_CPP project (@Antymon) 142 | - Add section on C++ portability of Tensorflow models (@Antymon) 143 | - Update custom env documentation to reflect new gym API for the `close()` method (@justinkterry) 144 | - Update custom env documentation to clarify what step and reset return (@justinkterry) 145 | - Add RL tips and tricks for doing RL experiments 146 | - Corrected lots of typos 147 | - Add spell check to documentation if available 148 | 149 | 150 | Release 2.8.0 (2019-09-29) 151 | -------------------------- 152 | 153 | **MPI dependency optional, new save format, ACKTR with continuous actions** 154 | 155 | Breaking Changes: 156 | ^^^^^^^^^^^^^^^^^ 157 | - OpenMPI-dependent algorithms (PPO1, TRPO, GAIL, DDPG) are disabled in the 158 | default installation of stable_baselines. `mpi4py` is now installed as an 159 | extra. When `mpi4py` is not available, stable-baselines skips imports of 160 | OpenMPI-dependent algorithms. 161 | See :ref:`installation notes ` and 162 | `Issue #430 `_. 163 | - SubprocVecEnv now defaults to a thread-safe start method, `forkserver` when 164 | available and otherwise `spawn`. This may require application code be 165 | wrapped in `if __name__ == '__main__'`. You can restore previous behavior 166 | by explicitly setting `start_method = 'fork'`. See 167 | `PR #428 `_. 168 | - Updated dependencies: tensorflow v1.8.0 is now required 169 | - Removed `checkpoint_path` and `checkpoint_freq` argument from `DQN` that were not used 170 | - Removed `bench/benchmark.py` that was not used 171 | - Removed several functions from `common/tf_util.py` that were not used 172 | - Removed `ppo1/run_humanoid.py` 173 | 174 | New Features: 175 | ^^^^^^^^^^^^^ 176 | - **important change** Switch to using zip-archived JSON and Numpy `savez` for 177 | storing models for better support across library/Python versions. (@Miffyli) 178 | - ACKTR now supports continuous actions 179 | - Add `double_q` argument to `DQN` constructor 180 | 181 | Bug Fixes: 182 | ^^^^^^^^^^ 183 | - Skip automatic imports of OpenMPI-dependent algorithms to avoid an issue 184 | where OpenMPI would cause stable-baselines to hang on Ubuntu installs. 185 | See :ref:`installation notes ` and 186 | `Issue #430 `_. 187 | - Fix a bug when calling `logger.configure()` with MPI enabled (@keshaviyengar) 188 | - set `allow_pickle=True` for numpy>=1.17.0 when loading expert dataset 189 | - Fix a bug when using VecCheckNan with numpy ndarray as state. `Issue #489 `_. (@ruifeng96150) 190 | 191 | Deprecations: 192 | ^^^^^^^^^^^^^ 193 | - Models saved with cloudpickle format (stable-baselines<=2.7.0) are now 194 | deprecated in favor of zip-archive format for better support across 195 | Python/Tensorflow versions. (@Miffyli) 196 | 197 | Others: 198 | ^^^^^^^ 199 | - Implementations of noise classes (`AdaptiveParamNoiseSpec`, `NormalActionNoise`, 200 | `OrnsteinUhlenbeckActionNoise`) were moved from `stable_baselines.ddpg.noise` 201 | to `stable_baselines.common.noise`. The API remains backward-compatible; 202 | for example `from stable_baselines.ddpg.noise import NormalActionNoise` is still 203 | okay. (@shwang) 204 | - Docker images were updated 205 | - Cleaned up files in `common/` folder and in `acktr/` folder that were only used by old ACKTR version 206 | (e.g. `filter.py`) 207 | - Renamed `acktr_disc.py` to `acktr.py` 208 | 209 | Documentation: 210 | ^^^^^^^^^^^^^^ 211 | - Add WaveRL project (@jaberkow) 212 | - Add Fenics-DRL project (@DonsetPG) 213 | - Fix and rename custom policy names (@eavelardev) 214 | - Add documentation on exporting models. 215 | - Update maintainers list (Welcome to @Miffyli) 216 | 217 | 218 | Release 2.7.0 (2019-07-31) 219 | -------------------------- 220 | 221 | **Twin Delayed DDPG (TD3) and GAE bug fix (TRPO, PPO1, GAIL)** 222 | 223 | Breaking Changes: 224 | ^^^^^^^^^^^^^^^^^ 225 | 226 | New Features: 227 | ^^^^^^^^^^^^^ 228 | - added Twin Delayed DDPG (TD3) algorithm, with HER support 229 | - added support for continuous action spaces to `action_probability`, computing the PDF of a Gaussian 230 | policy in addition to the existing support for categorical stochastic policies. 231 | - added flag to `action_probability` to return log-probabilities. 232 | - added support for python lists and numpy arrays in ``logger.writekvs``. (@dwiel) 233 | - the info dict returned by VecEnvs now include a ``terminal_observation`` key providing access to the last observation in a trajectory. (@qxcv) 234 | 235 | Bug Fixes: 236 | ^^^^^^^^^^ 237 | - fixed a bug in ``traj_segment_generator`` where the ``episode_starts`` was wrongly recorded, 238 | resulting in wrong calculation of Generalized Advantage Estimation (GAE), this affects TRPO, PPO1 and GAIL (thanks to @miguelrass for spotting the bug) 239 | - added missing property `n_batch` in `BasePolicy`. 240 | 241 | Deprecations: 242 | ^^^^^^^^^^^^^ 243 | 244 | Others: 245 | ^^^^^^^ 246 | - renamed some keys in ``traj_segment_generator`` to be more meaningful 247 | - retrieve unnormalized reward when using Monitor wrapper with TRPO, PPO1 and GAIL 248 | to display them in the logs (mean episode reward) 249 | - clean up DDPG code (renamed variables) 250 | 251 | Documentation: 252 | ^^^^^^^^^^^^^^ 253 | 254 | - doc fix for the hyperparameter tuning command in the rl zoo 255 | - added an example on how to log additional variable with tensorboard and a callback 256 | 257 | 258 | 259 | Release 2.6.0 (2019-06-12) 260 | -------------------------- 261 | 262 | **Hindsight Experience Replay (HER) - Reloaded | get/load parameters** 263 | 264 | Breaking Changes: 265 | ^^^^^^^^^^^^^^^^^ 266 | 267 | - **breaking change** removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer`` (see fix below) 268 | 269 | **Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result, 270 | when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error. 271 | You can fix that using: 272 | 273 | .. code-block:: python 274 | 275 | import sys 276 | import pkg_resources 277 | 278 | import stable_baselines 279 | 280 | # Fix for breaking change for DDPG buffer in v2.6.0 281 | if pkg_resources.get_distribution("stable_baselines").version >= "2.6.0": 282 | sys.modules['stable_baselines.ddpg.memory'] = stable_baselines.deepq.replay_buffer 283 | stable_baselines.deepq.replay_buffer.Memory = stable_baselines.deepq.replay_buffer.ReplayBuffer 284 | 285 | 286 | We recommend you to save again the model afterward, so the fix won't be needed the next time the trained agent is loaded. 287 | 288 | 289 | New Features: 290 | ^^^^^^^^^^^^^ 291 | 292 | - **revamped HER implementation**: clean re-implementation from scratch, now supports DQN, SAC and DDPG 293 | - add ``action_noise`` param for SAC, it helps exploration for problem with deceptive reward 294 | - The parameter ``filter_size`` of the function ``conv`` in A2C utils now supports passing a list/tuple of two integers (height and width), in order to have non-squared kernel matrix. (@yutingsz) 295 | - add ``random_exploration`` parameter for DDPG and SAC, it may be useful when using HER + DDPG/SAC. This hack was present in the original OpenAI Baselines DDPG + HER implementation. 296 | - added ``load_parameters`` and ``get_parameters`` to base RL class. With these methods, users are able to load and get parameters to/from existing model, without touching tensorflow. (@Miffyli) 297 | - added specific hyperparameter for PPO2 to clip the value function (``cliprange_vf``) 298 | - added ``VecCheckNan`` wrapper 299 | 300 | Bug Fixes: 301 | ^^^^^^^^^^ 302 | 303 | - bugfix for ``VecEnvWrapper.__getattr__`` which enables access to class attributes inherited from parent classes. 304 | - fixed path splitting in ``TensorboardWriter._get_latest_run_id()`` on Windows machines (@PatrickWalter214) 305 | - fixed a bug where initial learning rate is logged instead of its placeholder in ``A2C.setup_model`` (@sc420) 306 | - fixed a bug where number of timesteps is incorrectly updated and logged in ``A2C.learn`` and ``A2C._train_step`` (@sc420) 307 | - fixed ``num_timesteps`` (total_timesteps) variable in PPO2 that was wrongly computed. 308 | - fixed a bug in DDPG/DQN/SAC, when there were the number of samples in the replay buffer was lesser than the batch size 309 | (thanks to @dwiel for spotting the bug) 310 | - **removed** ``a2c.utils.find_trainable_params`` please use ``common.tf_util.get_trainable_vars`` instead. 311 | ``find_trainable_params`` was returning all trainable variables, discarding the scope argument. 312 | This bug was causing the model to save duplicated parameters (for DDPG and SAC) 313 | but did not affect the performance. 314 | 315 | Deprecations: 316 | ^^^^^^^^^^^^^ 317 | 318 | - **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x) 319 | 320 | Others: 321 | ^^^^^^^ 322 | 323 | - **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli) 324 | - removed unused dependencies (tdqm, dill, progressbar2, seaborn, glob2, click) 325 | - removed ``get_available_gpus`` function which hadn't been used anywhere (@Pastafarianist) 326 | 327 | Documentation: 328 | ^^^^^^^^^^^^^^ 329 | 330 | - added guide for managing ``NaN`` and ``inf`` 331 | - updated ven_env doc 332 | - misc doc updates 333 | 334 | Release 2.5.1 (2019-05-04) 335 | -------------------------- 336 | 337 | **Bug fixes + improvements in the VecEnv** 338 | 339 | **Warning: breaking changes when using custom policies** 340 | 341 | - doc update (fix example of result plotter + improve doc) 342 | - fixed logger issues when stdout lacks ``read`` function 343 | - fixed a bug in ``common.dataset.Dataset`` where shuffling was not disabled properly (it affects only PPO1 with recurrent policies) 344 | - fixed output layer name for DDPG q function, used in pop-art normalization and l2 regularization of the critic 345 | - added support for multi env recording to ``generate_expert_traj`` (@XMaster96) 346 | - added support for LSTM model recording to ``generate_expert_traj`` (@XMaster96) 347 | - ``GAIL``: remove mandatory matplotlib dependency and refactor as subclass of ``TRPO`` (@kantneel and @AdamGleave) 348 | - added ``get_attr()``, ``env_method()`` and ``set_attr()`` methods for all VecEnv. 349 | Those methods now all accept ``indices`` keyword to select a subset of envs. 350 | ``set_attr`` now returns ``None`` rather than a list of ``None``. (@kantneel) 351 | - ``GAIL``: ``gail.dataset.ExpertDataset`` supports loading from memory rather than file, and 352 | ``gail.dataset.record_expert`` supports returning in-memory rather than saving to file. 353 | - added support in ``VecEnvWrapper`` for accessing attributes of arbitrarily deeply nested 354 | instances of ``VecEnvWrapper`` and ``VecEnv``. This is allowed as long as the attribute belongs 355 | to exactly one of the nested instances i.e. it must be unambiguous. (@kantneel) 356 | - fixed bug where result plotter would crash on very short runs (@Pastafarianist) 357 | - added option to not trim output of result plotter by number of timesteps (@Pastafarianist) 358 | - clarified the public interface of ``BasePolicy`` and ``ActorCriticPolicy``. **Breaking change** when using custom policies: ``masks_ph`` is now called ``dones_ph``, 359 | and most placeholders were made private: e.g. ``self.value_fn`` is now ``self._value_fn`` 360 | - support for custom stateful policies. 361 | - fixed episode length recording in ``trpo_mpi.utils.traj_segment_generator`` (@GerardMaggiolino) 362 | 363 | 364 | Release 2.5.0 (2019-03-28) 365 | -------------------------- 366 | 367 | **Working GAIL, pretrain RL models and hotfix for A2C with continuous actions** 368 | 369 | - fixed various bugs in GAIL 370 | - added scripts to generate dataset for gail 371 | - added tests for GAIL + data for Pendulum-v0 372 | - removed unused ``utils`` file in DQN folder 373 | - fixed a bug in A2C where actions were cast to ``int32`` even in the continuous case 374 | - added addional logging to A2C when Monitor wrapper is used 375 | - changed logging for PPO2: do not display NaN when reward info is not present 376 | - change default value of A2C lr schedule 377 | - removed behavior cloning script 378 | - added ``pretrain`` method to base class, in order to use behavior cloning on all models 379 | - fixed ``close()`` method for DummyVecEnv. 380 | - added support for Dict spaces in DummyVecEnv and SubprocVecEnv. (@AdamGleave) 381 | - added support for arbitrary multiprocessing start methods and added a warning about SubprocVecEnv that are not thread-safe by default. (@AdamGleave) 382 | - added support for Discrete actions for GAIL 383 | - fixed deprecation warning for tf: replaces ``tf.to_float()`` by ``tf.cast()`` 384 | - fixed bug in saving and loading ddpg model when using normalization of obs or returns (@tperol) 385 | - changed DDPG default buffer size from 100 to 50000. 386 | - fixed a bug in ``ddpg.py`` in ``combined_stats`` for eval. Computed mean on ``eval_episode_rewards`` and ``eval_qs`` (@keshaviyengar) 387 | - fixed a bug in ``setup.py`` that would error on non-GPU systems without TensorFlow installed 388 | 389 | 390 | Release 2.4.1 (2019-02-11) 391 | -------------------------- 392 | 393 | **Bug fixes and improvements** 394 | 395 | - fixed computation of training metrics in TRPO and PPO1 396 | - added ``reset_num_timesteps`` keyword when calling train() to continue tensorboard learning curves 397 | - reduced the size taken by tensorboard logs (added a ``full_tensorboard_log`` to enable full logging, which was the previous behavior) 398 | - fixed image detection for tensorboard logging 399 | - fixed ACKTR for recurrent policies 400 | - fixed gym breaking changes 401 | - fixed custom policy examples in the doc for DQN and DDPG 402 | - remove gym spaces patch for equality functions 403 | - fixed tensorflow dependency: cpu version was installed overwritting tensorflow-gpu when present. 404 | - fixed a bug in ``traj_segment_generator`` (used in ppo1 and trpo) where ``new`` was not updated. (spotted by @junhyeokahn) 405 | 406 | 407 | Release 2.4.0 (2019-01-17) 408 | -------------------------- 409 | 410 | **Soft Actor-Critic (SAC) and policy kwargs** 411 | 412 | - added Soft Actor-Critic (SAC) model 413 | - fixed a bug in DQN where prioritized_replay_beta_iters param was not used 414 | - fixed DDPG that did not save target network parameters 415 | - fixed bug related to shape of true_reward (@abhiskk) 416 | - fixed example code in documentation of tf_util:Function (@JohannesAck) 417 | - added learning rate schedule for SAC 418 | - fixed action probability for continuous actions with actor-critic models 419 | - added optional parameter to action_probability for likelihood calculation of given action being taken. 420 | - added more flexible custom LSTM policies 421 | - added auto entropy coefficient optimization for SAC 422 | - clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed) 423 | - added a mean to pass kwargs to policy when creating a model (+ save those kwargs) 424 | - fixed DQN examples in DQN folder 425 | - added possibility to pass activation function for DDPG, DQN and SAC 426 | 427 | 428 | Release 2.3.0 (2018-12-05) 429 | -------------------------- 430 | 431 | - added support for storing model in file like object. (thanks to @erniejunior) 432 | - fixed wrong image detection when using tensorboard logging with DQN 433 | - fixed bug in ppo2 when passing non callable lr after loading 434 | - fixed tensorboard logging in ppo2 when nminibatches=1 435 | - added early stoppping via callback return value (@erniejunior) 436 | - added more flexible custom mlp policies (@erniejunior) 437 | 438 | 439 | Release 2.2.1 (2018-11-18) 440 | -------------------------- 441 | 442 | - added VecVideoRecorder to record mp4 videos from environment. 443 | 444 | 445 | Release 2.2.0 (2018-11-07) 446 | -------------------------- 447 | 448 | - Hotfix for ppo2, the wrong placeholder was used for the value function 449 | 450 | 451 | Release 2.1.2 (2018-11-06) 452 | -------------------------- 453 | 454 | - added ``async_eigen_decomp`` parameter for ACKTR and set it to ``False`` by default (remove deprecation warnings) 455 | - added methods for calling env methods/setting attributes inside a VecEnv (thanks to @bjmuld) 456 | - updated gym minimum version 457 | 458 | 459 | Release 2.1.1 (2018-10-20) 460 | -------------------------- 461 | 462 | - fixed MpiAdam synchronization issue in PPO1 (thanks to @brendenpetersen) issue #50 463 | - fixed dependency issues (new mujoco-py requires a mujoco license + gym broke MultiDiscrete space shape) 464 | 465 | 466 | Release 2.1.0 (2018-10-2) 467 | ------------------------- 468 | 469 | .. warning:: 470 | 471 | This version contains breaking changes for DQN policies, please read the full details 472 | 473 | **Bug fixes + doc update** 474 | 475 | 476 | - added patch fix for equal function using `gym.spaces.MultiDiscrete` and `gym.spaces.MultiBinary` 477 | - fixes for DQN action_probability 478 | - re-added double DQN + refactored DQN policies **breaking changes** 479 | - replaced `async` with `async_eigen_decomp` in ACKTR/KFAC for python 3.7 compatibility 480 | - removed action clipping for prediction of continuous actions (see issue #36) 481 | - fixed NaN issue due to clipping the continuous action in the wrong place (issue #36) 482 | - documentation was updated (policy + DDPG example hyperparameters) 483 | 484 | Release 2.0.0 (2018-09-18) 485 | -------------------------- 486 | 487 | .. warning:: 488 | 489 | This version contains breaking changes, please read the full details 490 | 491 | **Tensorboard, refactoring and bug fixes** 492 | 493 | 494 | - Renamed DeepQ to DQN **breaking changes** 495 | - Renamed DeepQPolicy to DQNPolicy **breaking changes** 496 | - fixed DDPG behavior **breaking changes** 497 | - changed default policies for DDPG, so that DDPG now works correctly **breaking changes** 498 | - added more documentation (some modules from common). 499 | - added doc about using custom env 500 | - added Tensorboard support for A2C, ACER, ACKTR, DDPG, DeepQ, PPO1, PPO2 and TRPO 501 | - added episode reward to Tensorboard 502 | - added documentation for Tensorboard usage 503 | - added Identity for Box action space 504 | - fixed render function ignoring parameters when using wrapped environments 505 | - fixed PPO1 and TRPO done values for recurrent policies 506 | - fixed image normalization not occurring when using images 507 | - updated VecEnv objects for the new Gym version 508 | - added test for DDPG 509 | - refactored DQN policies 510 | - added registry for policies, can be passed as string to the agent 511 | - added documentation for custom policies + policy registration 512 | - fixed numpy warning when using DDPG Memory 513 | - fixed DummyVecEnv not copying the observation array when stepping and resetting 514 | - added pre-built docker images + installation instructions 515 | - added ``deterministic`` argument in the predict function 516 | - added assert in PPO2 for recurrent policies 517 | - fixed predict function to handle both vectorized and unwrapped environment 518 | - added input check to the predict function 519 | - refactored ActorCritic models to reduce code duplication 520 | - refactored Off Policy models (to begin HER and replay_buffer refactoring) 521 | - added tests for auto vectorization detection 522 | - fixed render function, to handle positional arguments 523 | 524 | 525 | Release 1.0.7 (2018-08-29) 526 | -------------------------- 527 | 528 | **Bug fixes and documentation** 529 | 530 | - added html documentation using sphinx + integration with read the docs 531 | - cleaned up README + typos 532 | - fixed normalization for DQN with images 533 | - fixed DQN identity test 534 | 535 | 536 | Release 1.0.1 (2018-08-20) 537 | -------------------------- 538 | 539 | **Refactored Stable Baselines** 540 | 541 | - refactored A2C, ACER, ACTKR, DDPG, DeepQ, GAIL, TRPO, PPO1 and PPO2 under a single constant class 542 | - added callback to refactored algorithm training 543 | - added saving and loading to refactored algorithms 544 | - refactored ACER, DDPG, GAIL, PPO1 and TRPO to fit with A2C, PPO2 and ACKTR policies 545 | - added new policies for most algorithms (Mlp, MlpLstm, MlpLnLstm, Cnn, CnnLstm and CnnLnLstm) 546 | - added dynamic environment switching (so continual RL learning is now feasible) 547 | - added prediction from observation and action probability from observation for all the algorithms 548 | - fixed graphs issues, so models wont collide in names 549 | - fixed behavior_clone weight loading for GAIL 550 | - fixed Tensorflow using all the GPU VRAM 551 | - fixed models so that they are all compatible with vectorized environments 552 | - fixed ```set_global_seed``` to update ```gym.spaces```'s random seed 553 | - fixed PPO1 and TRPO performance issues when learning identity function 554 | - added new tests for loading, saving, continuous actions and learning the identity function 555 | - fixed DQN wrapping for atari 556 | - added saving and loading for Vecnormalize wrapper 557 | - added automatic detection of action space (for the policy network) 558 | - fixed ACER buffer with constant values assuming n_stack=4 559 | - fixed some RL algorithms not clipping the action to be in the action_space, when using ```gym.spaces.Box``` 560 | - refactored algorithms can take either a ```gym.Environment``` or a ```str``` ([if the environment name is registered](https://github.com/openai/gym/wiki/Environments)) 561 | - Hoftix in ACER (compared to v1.0.0) 562 | 563 | Future Work : 564 | 565 | - Finish refactoring HER 566 | - Refactor ACKTR and ACER for continuous implementation 567 | 568 | 569 | 570 | Release 0.1.6 (2018-07-27) 571 | -------------------------- 572 | 573 | **Deobfuscation of the code base + pep8 and fixes** 574 | 575 | - Fixed ``tf.session().__enter__()`` being used, rather than 576 | ``sess = tf.session()`` and passing the session to the objects 577 | - Fixed uneven scoping of TensorFlow Sessions throughout the code 578 | - Fixed rolling vecwrapper to handle observations that are not only 579 | grayscale images 580 | - Fixed deepq saving the environment when trying to save itself 581 | - Fixed 582 | ``ValueError: Cannot take the length of Shape with unknown rank.`` in 583 | ``acktr``, when running ``run_atari.py`` script. 584 | - Fixed calling baselines sequentially no longer creates graph 585 | conflicts 586 | - Fixed mean on empty array warning with deepq 587 | - Fixed kfac eigen decomposition not cast to float64, when the 588 | parameter use_float64 is set to True 589 | - Fixed Dataset data loader, not correctly resetting id position if 590 | shuffling is disabled 591 | - Fixed ``EOFError`` when reading from connection in the ``worker`` in 592 | ``subproc_vec_env.py`` 593 | - Fixed ``behavior_clone`` weight loading and saving for GAIL 594 | - Avoid taking root square of negative number in ``trpo_mpi.py`` 595 | - Removed some duplicated code (a2cpolicy, trpo_mpi) 596 | - Removed unused, undocumented and crashing function ``reset_task`` in 597 | ``subproc_vec_env.py`` 598 | - Reformated code to PEP8 style 599 | - Documented all the codebase 600 | - Added atari tests 601 | - Added logger tests 602 | 603 | Missing: tests for acktr continuous (+ HER, rely on mujoco...) 604 | 605 | Maintainers 606 | ----------- 607 | 608 | Stable-Baselines is currently maintained by `Ashley Hill`_ (aka @hill-a), `Antonin Raffin`_ (aka `@araffin`_), 609 | `Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_). 610 | 611 | .. _Ashley Hill: https://github.com/hill-a 612 | .. _Antonin Raffin: https://araffin.github.io/ 613 | .. _Maximilian Ernestus: https://github.com/erniejunior 614 | .. _Adam Gleave: https://gleave.me/ 615 | .. _@araffin: https://github.com/araffin 616 | .. _@AdamGleave: https://github.com/adamgleave 617 | .. _Anssi Kanervisto: https://github.com/Miffyli 618 | .. _@Miffyli: https://github.com/Miffyli 619 | 620 | 621 | Contributors (since v2.0.0): 622 | ---------------------------- 623 | In random order... 624 | 625 | Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck 626 | @EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol 627 | @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs 628 | @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket 629 | @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching 630 | @flodorner 631 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file in the wheel. 3 | license_file = LICENSE 4 | 5 | [tool:pytest] 6 | # Deterministic ordering for tests; useful for pytest-xdist. 7 | env = 8 | PYTHONHASHSEED=0 9 | filterwarnings = 10 | 11 | 12 | [pytype] 13 | inputs = stable_baselines 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup, find_packages 3 | from distutils.version import LooseVersion 4 | 5 | if sys.version_info.major != 3: 6 | print('This Python is only compatible with Python 3, but you are running ' 7 | 'Python {}. The installation will likely fail.'.format(sys.version_info.major)) 8 | 9 | 10 | long_description = """ 11 | 12 | # Stable Baselines TF2 [Experimental] 13 | 14 | """ 15 | 16 | setup(name='stable_baselines', 17 | packages=[package for package in find_packages() 18 | if package.startswith('stable_baselines')], 19 | package_data={ 20 | 'stable_baselines': ['py.typed'], 21 | }, 22 | install_requires=[ 23 | 'gym[atari,classic_control]>=0.10.9', 24 | 'scipy', 25 | 'joblib', 26 | 'cloudpickle>=0.5.5', 27 | 'opencv-python', 28 | 'numpy', 29 | 'pandas', 30 | 'matplotlib', 31 | 'tensorflow-probability>=0.8.0', 32 | 'tensorflow>=2.1.0' 33 | ], 34 | extras_require={ 35 | 'tests': [ 36 | 'pytest', 37 | 'pytest-cov', 38 | 'pytest-env', 39 | 'pytest-xdist', 40 | 'pytype', 41 | ], 42 | 'docs': [ 43 | 'sphinx', 44 | 'sphinx-autobuild', 45 | 'sphinx-rtd-theme' 46 | ] 47 | }, 48 | description='A fork of OpenAI Baselines, implementations of reinforcement learning algorithms.', 49 | author='Ashley Hill', 50 | url='https://github.com/hill-a/stable-baselines', 51 | author_email='github@hill-a.me', 52 | keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " 53 | "gym openai baselines toolbox python data-science", 54 | license="MIT", 55 | long_description=long_description, 56 | long_description_content_type='text/markdown', 57 | version="3.0.0a0", 58 | ) 59 | 60 | # python setup.py sdist 61 | # python setup.py bdist_wheel 62 | # twine upload --repository-url https://test.pypi.org/legacy/ dist/* 63 | # twine upload dist/* 64 | -------------------------------------------------------------------------------- /stable_baselines/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.td3 import TD3 2 | from stable_baselines.ppo import PPO 3 | 4 | __version__ = "3.0.0a0" 5 | -------------------------------------------------------------------------------- /stable_baselines/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/stable_baselines/common/__init__.py -------------------------------------------------------------------------------- /stable_baselines/common/base_class.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABC, abstractmethod 3 | from collections import deque 4 | import os 5 | import io 6 | import zipfile 7 | 8 | import gym 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from stable_baselines.common import logger 13 | from stable_baselines.common.policies import get_policy_from_name 14 | from stable_baselines.common.utils import set_random_seed, get_schedule_fn 15 | from stable_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, sync_envs_normalization 16 | from stable_baselines.common.monitor import Monitor 17 | from stable_baselines.common.evaluation import evaluate_policy 18 | from stable_baselines.common.save_util import data_to_json, json_to_data 19 | 20 | 21 | class BaseRLModel(ABC): 22 | """ 23 | The base RL model 24 | 25 | :param policy: (BasePolicy) Policy object 26 | :param env: (Gym environment) The environment to learn from 27 | (if registered in Gym, can be str. Can be None for loading trained models) 28 | :param policy_base: (BasePolicy) the base policy used by this method 29 | :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation 30 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug 31 | :param support_multi_env: (bool) Whether the algorithm supports training 32 | with multiple environments (as in A2C) 33 | :param create_eval_env: (bool) Whether to create a second environment that will be 34 | used for evaluating the agent periodically. (Only available when passing string for the environment) 35 | :param monitor_wrapper: (bool) When creating an environment, whether to wrap it 36 | or not in a Monitor wrapper. 37 | :param seed: (int) Seed for the pseudo random generators 38 | """ 39 | def __init__(self, policy, env, policy_base, policy_kwargs=None, 40 | verbose=0, device='auto', support_multi_env=False, 41 | create_eval_env=False, monitor_wrapper=True, seed=None): 42 | if isinstance(policy, str) and policy_base is not None: 43 | self.policy_class = get_policy_from_name(policy_base, policy) 44 | else: 45 | self.policy_class = policy 46 | 47 | self.env = env 48 | # get VecNormalize object if needed 49 | self._vec_normalize_env = unwrap_vec_normalize(env) 50 | self.verbose = verbose 51 | self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs 52 | self.observation_space = None 53 | self.action_space = None 54 | self.n_envs = None 55 | self.num_timesteps = 0 56 | self.eval_env = None 57 | self.replay_buffer = None 58 | self.seed = seed 59 | self.action_noise = None 60 | 61 | # Track the training progress (from 1 to 0) 62 | # this is used to update the learning rate 63 | self._current_progress = 1 64 | 65 | # Create and wrap the env if needed 66 | if env is not None: 67 | if isinstance(env, str): 68 | if create_eval_env: 69 | eval_env = gym.make(env) 70 | if monitor_wrapper: 71 | eval_env = Monitor(eval_env, filename=None) 72 | self.eval_env = DummyVecEnv([lambda: eval_env]) 73 | if self.verbose >= 1: 74 | print("Creating environment from the given name, wrapped in a DummyVecEnv.") 75 | 76 | env = gym.make(env) 77 | if monitor_wrapper: 78 | env = Monitor(env, filename=None) 79 | env = DummyVecEnv([lambda: env]) 80 | 81 | self.observation_space = env.observation_space 82 | self.action_space = env.action_space 83 | if not isinstance(env, VecEnv): 84 | if self.verbose >= 1: 85 | print("Wrapping the env in a DummyVecEnv.") 86 | env = DummyVecEnv([lambda: env]) 87 | self.n_envs = env.num_envs 88 | self.env = env 89 | 90 | if not support_multi_env and self.n_envs > 1: 91 | raise ValueError("Error: the model does not support multiple envs requires a single vectorized" 92 | " environment.") 93 | 94 | def _get_eval_env(self, eval_env): 95 | """ 96 | Return the environment that will be used for evaluation. 97 | 98 | :param eval_env: (gym.Env or VecEnv) 99 | :return: (VecEnv) 100 | """ 101 | if eval_env is None: 102 | eval_env = self.eval_env 103 | 104 | if eval_env is not None: 105 | if not isinstance(eval_env, VecEnv): 106 | eval_env = DummyVecEnv([lambda: eval_env]) 107 | assert eval_env.num_envs == 1 108 | return eval_env 109 | 110 | def scale_action(self, action): 111 | """ 112 | Rescale the action from [low, high] to [-1, 1] 113 | (no need for symmetric action space) 114 | 115 | :param action: (np.ndarray) 116 | :return: (np.ndarray) 117 | """ 118 | low, high = self.action_space.low, self.action_space.high 119 | return 2.0 * ((action - low) / (high - low)) - 1.0 120 | 121 | def unscale_action(self, scaled_action): 122 | """ 123 | Rescale the action from [-1, 1] to [low, high] 124 | (no need for symmetric action space) 125 | 126 | :param scaled_action: (np.ndarray) 127 | :return: (np.ndarray) 128 | """ 129 | low, high = self.action_space.low, self.action_space.high 130 | return low + (0.5 * (scaled_action + 1.0) * (high - low)) 131 | 132 | def _setup_learning_rate(self): 133 | """Transform to callable if needed.""" 134 | self.learning_rate = get_schedule_fn(self.learning_rate) 135 | 136 | def _update_current_progress(self, num_timesteps, total_timesteps): 137 | """ 138 | Compute current progress (from 1 to 0) 139 | 140 | :param num_timesteps: (int) current number of timesteps 141 | :param total_timesteps: (int) 142 | """ 143 | self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps) 144 | 145 | def _update_learning_rate(self, optimizers): 146 | """ 147 | Update the optimizers learning rate using the current learning rate schedule 148 | and the current progress (from 1 to 0). 149 | 150 | :param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer 151 | or a list of optimizer. 152 | """ 153 | # Log the current learning rate 154 | logger.logkv("learning_rate", self.learning_rate(self._current_progress)) 155 | 156 | # if not isinstance(optimizers, list): 157 | # optimizers = [optimizers] 158 | # for optimizer in optimizers: 159 | # update_learning_rate(optimizer, self.learning_rate(self._current_progress)) 160 | 161 | @staticmethod 162 | def safe_mean(arr): 163 | """ 164 | Compute the mean of an array if there is at least one element. 165 | For empty array, return nan. It is used for logging only. 166 | 167 | :param arr: (np.ndarray) 168 | :return: (float) 169 | """ 170 | return np.nan if len(arr) == 0 else np.mean(arr) 171 | 172 | def get_env(self): 173 | """ 174 | returns the current environment (can be None if not defined) 175 | 176 | :return: (gym.Env) The current environment 177 | """ 178 | return self.env 179 | 180 | def set_env(self, env): 181 | """ 182 | :param env: (gym.Env) The environment for learning a policy 183 | """ 184 | raise NotImplementedError() 185 | 186 | @abstractmethod 187 | def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run", 188 | eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True): 189 | """ 190 | Return a trained model. 191 | 192 | :param total_timesteps: (int) The total number of samples to train on 193 | :param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. 194 | It takes the local and global variables. If it returns False, training is aborted. 195 | :param log_interval: (int) The number of timesteps before logging. 196 | :param tb_log_name: (str) the name of the run for tensorboard log 197 | :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) 198 | :param eval_env: (gym.Env) Environment that will be used to evaluate the agent 199 | :param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little) 200 | :param n_eval_episodes: (int) Number of episode to evaluate the agent 201 | :return: (BaseRLModel) the trained model 202 | """ 203 | pass 204 | 205 | @abstractmethod 206 | def predict(self, observation, state=None, mask=None, deterministic=False): 207 | """ 208 | Get the model's action from an observation 209 | 210 | :param observation: (np.ndarray) the input observation 211 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies) 212 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) 213 | :param deterministic: (bool) Whether or not to return deterministic actions. 214 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) 215 | """ 216 | pass 217 | 218 | def set_random_seed(self, seed=None): 219 | """ 220 | Set the seed of the pseudo-random generators 221 | (python, numpy, pytorch, gym, action_space) 222 | 223 | :param seed: (int) 224 | """ 225 | if seed is None: 226 | return 227 | set_random_seed(seed) 228 | self.action_space.seed(seed) 229 | if self.env is not None: 230 | self.env.seed(seed) 231 | if self.eval_env is not None: 232 | self.eval_env.seed(seed) 233 | 234 | def _setup_learn(self, eval_env): 235 | """ 236 | Initialize different variables needed for training. 237 | 238 | :param eval_env: (gym.Env or VecEnv) 239 | :return: (int, int, [float], np.ndarray, VecEnv) 240 | """ 241 | self.start_time = time.time() 242 | self.ep_info_buffer = deque(maxlen=100) 243 | 244 | if self.action_noise is not None: 245 | self.action_noise.reset() 246 | 247 | timesteps_since_eval, episode_num = 0, 0 248 | evaluations = [] 249 | 250 | if eval_env is not None and self.seed is not None: 251 | eval_env.seed(self.seed) 252 | 253 | eval_env = self._get_eval_env(eval_env) 254 | obs = self.env.reset() 255 | return timesteps_since_eval, episode_num, evaluations, obs, eval_env 256 | 257 | def _update_info_buffer(self, infos): 258 | """ 259 | Retrieve reward and episode length and update the buffer 260 | if using Monitor wrapper. 261 | 262 | :param infos: ([dict]) 263 | """ 264 | for info in infos: 265 | maybe_ep_info = info.get('episode') 266 | if maybe_ep_info is not None: 267 | self.ep_info_buffer.extend([maybe_ep_info]) 268 | 269 | def _eval_policy(self, eval_freq, eval_env, n_eval_episodes, 270 | timesteps_since_eval, deterministic=True): 271 | """ 272 | Evaluate the current policy on a test environment. 273 | 274 | :param eval_env: (gym.Env) Environment that will be used to evaluate the agent 275 | :param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little) 276 | :param n_eval_episodes: (int) Number of episode to evaluate the agent 277 | :parma timesteps_since_eval: (int) Number of timesteps since last evaluation 278 | :param deterministic: (bool) Whether to use deterministic or stochastic actions 279 | :return: (int) Number of timesteps since last evaluation 280 | """ 281 | if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: 282 | timesteps_since_eval %= eval_freq 283 | # Synchronise the normalization stats if needed 284 | sync_envs_normalization(self.env, eval_env) 285 | mean_reward, std_reward = evaluate_policy(self, eval_env, n_eval_episodes, deterministic=deterministic) 286 | if self.verbose > 0: 287 | print("Eval num_timesteps={}, " 288 | "episode_reward={:.2f} +/- {:.2f}".format(self.num_timesteps, mean_reward, std_reward)) 289 | print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - self.start_time))) 290 | return timesteps_since_eval 291 | -------------------------------------------------------------------------------- /stable_baselines/common/buffers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BaseBuffer(object): 5 | """ 6 | Base class that represent a buffer (rollout or replay) 7 | 8 | :param buffer_size: (int) Max number of element in the buffer 9 | :param obs_dim: (int) Dimension of the observation 10 | :param action_dim: (int) Dimension of the action space 11 | :param n_envs: (int) Number of parallel environments 12 | """ 13 | def __init__(self, buffer_size, obs_dim, action_dim, n_envs=1): 14 | super(BaseBuffer, self).__init__() 15 | self.buffer_size = buffer_size 16 | self.obs_dim = obs_dim 17 | self.action_dim = action_dim 18 | self.pos = 0 19 | self.full = False 20 | self.n_envs = n_envs 21 | 22 | def size(self): 23 | """ 24 | :return: (int) The current size of the buffer 25 | """ 26 | if self.full: 27 | return self.buffer_size 28 | return self.pos 29 | 30 | def add(self, *args, **kwargs): 31 | """ 32 | Add elements to the buffer. 33 | """ 34 | raise NotImplementedError() 35 | 36 | def reset(self): 37 | """ 38 | Reset the buffer. 39 | """ 40 | self.pos = 0 41 | self.full = False 42 | 43 | def sample(self, batch_size): 44 | """ 45 | :param batch_size: (int) Number of element to sample 46 | """ 47 | upper_bound = self.buffer_size if self.full else self.pos 48 | batch_inds = np.random.randint(0, upper_bound, size=batch_size) 49 | return self._get_samples(batch_inds) 50 | 51 | def _get_samples(self, batch_inds): 52 | """ 53 | :param batch_inds: (np.ndarray) 54 | :return: ([np.ndarray]) 55 | """ 56 | raise NotImplementedError() 57 | 58 | 59 | class ReplayBuffer(BaseBuffer): 60 | """ 61 | Replay buffer used in off-policy algorithms like SAC/TD3. 62 | 63 | :param buffer_size: (int) Max number of element in the buffer 64 | :param obs_dim: (int) Dimension of the observation 65 | :param action_dim: (int) Dimension of the action space 66 | :param n_envs: (int) Number of parallel environments 67 | """ 68 | def __init__(self, buffer_size, obs_dim, action_dim, n_envs=1): 69 | super(ReplayBuffer, self).__init__(buffer_size, obs_dim, action_dim, n_envs=n_envs) 70 | 71 | assert n_envs == 1 72 | self.observations = np.zeros((self.buffer_size, self.n_envs, self.obs_dim), dtype=np.float32) 73 | self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) 74 | self.next_observations = np.zeros((self.buffer_size, self.n_envs, self.obs_dim), dtype=np.float32) 75 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 76 | self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 77 | 78 | def add(self, obs, next_obs, action, reward, done): 79 | # Copy to avoid modification by reference 80 | self.observations[self.pos] = np.array(obs).copy() 81 | self.next_observations[self.pos] = np.array(next_obs).copy() 82 | self.actions[self.pos] = np.array(action).copy() 83 | self.rewards[self.pos] = np.array(reward).copy() 84 | self.dones[self.pos] = np.array(done).copy() 85 | 86 | self.pos += 1 87 | if self.pos == self.buffer_size: 88 | self.full = True 89 | self.pos = 0 90 | 91 | def _get_samples(self, batch_inds): 92 | return (self.observations[batch_inds, 0, :], 93 | self.actions[batch_inds, 0, :], 94 | self.next_observations[batch_inds, 0, :], 95 | self.dones[batch_inds], 96 | self.rewards[batch_inds]) 97 | 98 | 99 | class RolloutBuffer(BaseBuffer): 100 | """ 101 | Rollout buffer used in on-policy algorithms like A2C/PPO. 102 | 103 | :param buffer_size: (int) Max number of element in the buffer 104 | :param obs_dim: (int) Dimension of the observation 105 | :param action_dim: (int) Dimension of the action space 106 | :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator 107 | Equivalent to classic advantage when set to 1. 108 | :param gamma: (float) Discount factor 109 | :param n_envs: (int) Number of parallel environments 110 | """ 111 | def __init__(self, buffer_size, obs_dim, action_dim, 112 | gae_lambda=1, gamma=0.99, n_envs=1): 113 | super(RolloutBuffer, self).__init__(buffer_size, obs_dim, action_dim, n_envs=n_envs) 114 | self.gae_lambda = gae_lambda 115 | self.gamma = gamma 116 | self.observations, self.actions, self.rewards, self.advantages = None, None, None, None 117 | self.returns, self.dones, self.values, self.log_probs = None, None, None, None 118 | self.generator_ready = False 119 | self.reset() 120 | 121 | def reset(self): 122 | self.observations = np.zeros((self.buffer_size, self.n_envs, self.obs_dim), dtype=np.float32) 123 | self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) 124 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 125 | self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 126 | self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 127 | self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 128 | self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 129 | self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) 130 | self.generator_ready = False 131 | super(RolloutBuffer, self).reset() 132 | 133 | def compute_returns_and_advantage(self, last_value, dones=False, use_gae=True): 134 | """ 135 | Post-processing step: compute the returns (sum of discounted rewards) 136 | and advantage (A(s) = R - V(S)). 137 | 138 | :param last_value: (tf.Tensor) 139 | :param dones: ([bool]) 140 | :param use_gae: (bool) Whether to use Generalized Advantage Estimation 141 | or normal advantage for advantage computation. 142 | """ 143 | if use_gae: 144 | last_gae_lam = 0 145 | for step in reversed(range(self.buffer_size)): 146 | if step == self.buffer_size - 1: 147 | next_non_terminal = np.array(1.0 - dones) 148 | next_value = last_value.numpy().flatten() 149 | else: 150 | next_non_terminal = 1.0 - self.dones[step + 1] 151 | next_value = self.values[step + 1] 152 | delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step] 153 | last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam 154 | self.advantages[step] = last_gae_lam 155 | self.returns = self.advantages + self.values 156 | else: 157 | # Discounted return with value bootstrap 158 | # Note: this is equivalent to GAE computation 159 | # with gae_lambda = 1.0 160 | last_return = 0.0 161 | for step in reversed(range(self.buffer_size)): 162 | if step == self.buffer_size - 1: 163 | next_non_terminal = np.array(1.0 - dones) 164 | next_value = last_value.numpy().flatten() 165 | last_return = self.rewards[step] + next_non_terminal * next_value 166 | else: 167 | next_non_terminal = 1.0 - self.dones[step + 1] 168 | last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal 169 | self.returns[step] = last_return 170 | self.advantages = self.returns - self.values 171 | 172 | def add(self, obs, action, reward, done, value, log_prob): 173 | """ 174 | :param obs: (np.ndarray) Observation 175 | :param action: (np.ndarray) Action 176 | :param reward: (np.ndarray) 177 | :param done: (np.ndarray) End of episode signal. 178 | :param value: (np.Tensor) estimated value of the current state 179 | following the current policy. 180 | :param log_prob: (np.Tensor) log probability of the action 181 | following the current policy. 182 | """ 183 | if len(log_prob.shape) == 0: 184 | # Reshape 0-d tensor to avoid error 185 | log_prob = log_prob.reshape(-1, 1) 186 | 187 | self.observations[self.pos] = np.array(obs).copy() 188 | self.actions[self.pos] = np.array(action).copy() 189 | self.rewards[self.pos] = np.array(reward).copy() 190 | self.dones[self.pos] = np.array(done).copy() 191 | self.values[self.pos] = value.numpy().flatten().copy() 192 | self.log_probs[self.pos] = log_prob.numpy().copy() 193 | self.pos += 1 194 | if self.pos == self.buffer_size: 195 | self.full = True 196 | 197 | def get(self, batch_size=None): 198 | assert self.full 199 | indices = np.random.permutation(self.buffer_size * self.n_envs) 200 | # Prepare the data 201 | if not self.generator_ready: 202 | for tensor in ['observations', 'actions', 'values', 203 | 'log_probs', 'advantages', 'returns']: 204 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) 205 | self.generator_ready = True 206 | 207 | # Return everything, don't create minibatches 208 | if batch_size is None: 209 | batch_size = self.buffer_size * self.n_envs 210 | 211 | start_idx = 0 212 | while start_idx < self.buffer_size * self.n_envs: 213 | yield self._get_samples(indices[start_idx:start_idx + batch_size]) 214 | start_idx += batch_size 215 | 216 | def _get_samples(self, batch_inds): 217 | return (self.observations[batch_inds], 218 | self.actions[batch_inds], 219 | self.values[batch_inds].flatten(), 220 | self.log_probs[batch_inds].flatten(), 221 | self.advantages[batch_inds].flatten(), 222 | self.returns[batch_inds].flatten()) 223 | 224 | @staticmethod 225 | def swap_and_flatten(tensor): 226 | """ 227 | Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) 228 | to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) 229 | to [n_steps * n_envs, ...] (which maintain the order) 230 | 231 | :param tensor: (np.ndarray) 232 | :return: (np.ndarray) 233 | """ 234 | shape = tensor.shape 235 | if len(shape) < 3: 236 | shape = shape + (1,) 237 | return tensor.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) 238 | -------------------------------------------------------------------------------- /stable_baselines/common/distributions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.layers as layers 3 | from tensorflow.keras.models import Sequential 4 | import tensorflow_probability as tfp 5 | from gym import spaces 6 | 7 | 8 | class Distribution(object): 9 | def __init__(self): 10 | super(Distribution, self).__init__() 11 | 12 | def log_prob(self, x): 13 | """ 14 | returns the log likelihood 15 | 16 | :param x: (object) the taken action 17 | :return: (tf.Tensor) The log likelihood of the distribution 18 | """ 19 | raise NotImplementedError 20 | 21 | def entropy(self): 22 | """ 23 | Returns shannon's entropy of the probability 24 | 25 | :return: (tf.Tensor) the entropy 26 | """ 27 | raise NotImplementedError 28 | 29 | def sample(self): 30 | """ 31 | returns a sample from the probabilty distribution 32 | 33 | :return: (tf.Tensor) the stochastic action 34 | """ 35 | raise NotImplementedError 36 | 37 | 38 | class DiagGaussianDistribution(Distribution): 39 | """ 40 | Gaussian distribution with diagonal covariance matrix, 41 | for continuous actions. 42 | 43 | :param action_dim: (int) Number of continuous actions 44 | """ 45 | 46 | def __init__(self, action_dim): 47 | super(DiagGaussianDistribution, self).__init__() 48 | self.distribution = None 49 | self.action_dim = action_dim 50 | self.mean_actions = None 51 | self.log_std = None 52 | 53 | def proba_distribution_net(self, latent_dim, log_std_init=0.0): 54 | """ 55 | Create the layers and parameter that represent the distribution: 56 | one output will be the mean of the gaussian, the other parameter will be the 57 | standard deviation (log std in fact to allow negative values) 58 | 59 | :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer) 60 | :param log_std_init: (float) Initial value for the log standard deviation 61 | :return: (tf.keras.models.Sequential, tf.Variable) 62 | """ 63 | mean_actions = Sequential(layers.Dense(self.action_dim, input_shape=(latent_dim,), activation=None)) 64 | log_std = tf.Variable(tf.ones(self.action_dim) * log_std_init) 65 | return mean_actions, log_std 66 | 67 | def proba_distribution(self, mean_actions, log_std, deterministic=False): 68 | """ 69 | Create and sample for the distribution given its parameters (mean, std) 70 | 71 | :param mean_actions: (tf.Tensor) 72 | :param log_std: (tf.Tensor) 73 | :param deterministic: (bool) 74 | :return: (tf.Tensor) 75 | """ 76 | action_std = tf.ones_like(mean_actions) * tf.exp(log_std) 77 | self.distribution = tfp.distributions.Normal(mean_actions, action_std) 78 | if deterministic: 79 | action = self.mode() 80 | else: 81 | action = self.sample() 82 | return action, self 83 | 84 | def mode(self): 85 | return self.distribution.mode() 86 | 87 | def sample(self): 88 | return self.distribution.sample() 89 | 90 | def entropy(self): 91 | return self.distribution.entropy() 92 | 93 | def log_prob_from_params(self, mean_actions, log_std): 94 | """ 95 | Compute the log probabilty of taking an action 96 | given the distribution parameters. 97 | 98 | :param mean_actions: (tf.Tensor) 99 | :param log_std: (tf.Tensor) 100 | :return: (tf.Tensor, tf.Tensor) 101 | """ 102 | action, _ = self.proba_distribution(mean_actions, log_std) 103 | log_prob = self.log_prob(action) 104 | return action, log_prob 105 | 106 | def log_prob(self, action): 107 | """ 108 | Get the log probabilty of an action given a distribution. 109 | Note that you must call `proba_distribution()` method 110 | before. 111 | 112 | :param action: (tf.Tensor) 113 | :return: (tf.Tensor) 114 | """ 115 | log_prob = self.distribution.log_prob(action) 116 | if len(log_prob.shape) > 1: 117 | log_prob = tf.reduce_sum(log_prob, axis=1) 118 | else: 119 | log_prob = tf.reduce_sum(log_prob) 120 | return log_prob 121 | 122 | 123 | class CategoricalDistribution(Distribution): 124 | """ 125 | Categorical distribution for discrete actions. 126 | 127 | :param action_dim: (int) Number of discrete actions 128 | """ 129 | def __init__(self, action_dim): 130 | super(CategoricalDistribution, self).__init__() 131 | self.distribution = None 132 | self.action_dim = action_dim 133 | 134 | def proba_distribution_net(self, latent_dim): 135 | """ 136 | Create the layer that represents the distribution: 137 | it will be the logits of the Categorical distribution. 138 | You can then get probabilties using a softmax. 139 | 140 | :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer) 141 | :return: (tf.keras.models.Sequential) 142 | """ 143 | action_logits = layers.Dense(self.action_dim, input_shape=(latent_dim,), activation=None) 144 | return Sequential([action_logits]) 145 | 146 | def proba_distribution(self, action_logits, deterministic=False): 147 | self.distribution = tfp.distributions.Categorical(logits=action_logits) 148 | if deterministic: 149 | action = self.mode() 150 | else: 151 | action = self.sample() 152 | return action, self 153 | 154 | def mode(self): 155 | return self.distribution.mode() 156 | 157 | def sample(self): 158 | return self.distribution.sample() 159 | 160 | def entropy(self): 161 | return self.distribution.entropy() 162 | 163 | def log_prob_from_params(self, action_logits): 164 | action, _ = self.proba_distribution(action_logits) 165 | log_prob = self.log_prob(action) 166 | return action, log_prob 167 | 168 | def log_prob(self, action): 169 | log_prob = self.distribution.log_prob(action) 170 | return log_prob 171 | 172 | 173 | def make_proba_distribution(action_space, dist_kwargs=None): 174 | """ 175 | Return an instance of Distribution for the correct type of action space 176 | 177 | :param action_space: (Gym Space) the input action space 178 | :param dist_kwargs: (dict) Keyword arguments to pass to the probabilty distribution 179 | :return: (Distribution) the approriate Distribution object 180 | """ 181 | if dist_kwargs is None: 182 | dist_kwargs = {} 183 | 184 | if isinstance(action_space, spaces.Box): 185 | assert len(action_space.shape) == 1, "Error: the action space must be a vector" 186 | return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs) 187 | elif isinstance(action_space, spaces.Discrete): 188 | return CategoricalDistribution(action_space.n, **dist_kwargs) 189 | # elif isinstance(action_space, spaces.MultiDiscrete): 190 | # return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) 191 | # elif isinstance(action_space, spaces.MultiBinary): 192 | # return BernoulliDistribution(action_space.n, **dist_kwargs) 193 | else: 194 | raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}." 195 | .format(type(action_space)) + 196 | " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary.") 197 | -------------------------------------------------------------------------------- /stable_baselines/common/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copied from stable_baselines 2 | import numpy as np 3 | 4 | from stable_baselines.common.vec_env import VecEnv 5 | 6 | 7 | def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, 8 | render=False, callback=None, reward_threshold=None, 9 | return_episode_rewards=False): 10 | """ 11 | Runs policy for `n_eval_episodes` episodes and returns average reward. 12 | This is made to work only with one env. 13 | 14 | :param model: (BaseRLModel) The RL agent you want to evaluate. 15 | :param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv` 16 | this must contain only one environment. 17 | :param n_eval_episodes: (int) Number of episode to evaluate the agent 18 | :param deterministic: (bool) Whether to use deterministic or stochastic actions 19 | :param render: (bool) Whether to render the environment or not 20 | :param callback: (callable) callback function to do additional checks, 21 | called after each step. 22 | :param reward_threshold: (float) Minimum expected reward per episode, 23 | this will raise an error if the performance is not met 24 | :param return_episode_rewards: (bool) If True, a list of reward per episode 25 | will be returned instead of the mean. 26 | :return: (float, float) Mean reward per episode, std of reward per episode 27 | returns ([float], int) when `return_episode_rewards` is True 28 | """ 29 | if isinstance(env, VecEnv): 30 | assert env.num_envs == 1, "You must pass only one environment when using this function" 31 | 32 | episode_rewards, n_steps = [], 0 33 | for _ in range(n_eval_episodes): 34 | obs = env.reset() 35 | done = False 36 | episode_reward = 0.0 37 | while not done: 38 | action = model.predict(obs, deterministic=deterministic) 39 | obs, reward, done, _info = env.step(action) 40 | episode_reward += reward 41 | if callback is not None: 42 | callback(locals(), globals()) 43 | n_steps += 1 44 | if render: 45 | env.render() 46 | episode_rewards.append(episode_reward) 47 | mean_reward = np.mean(episode_rewards) 48 | std_reward = np.std(episode_rewards) 49 | if reward_threshold is not None: 50 | assert mean_reward > reward_threshold, 'Mean reward below threshold: '\ 51 | '{:.2f} < {:.2f}'.format(mean_reward, reward_threshold) 52 | if return_episode_rewards: 53 | return episode_rewards, n_steps 54 | return mean_reward, std_reward 55 | -------------------------------------------------------------------------------- /stable_baselines/common/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import datetime 6 | import tempfile 7 | import warnings 8 | from collections import defaultdict 9 | 10 | DEBUG = 10 11 | INFO = 20 12 | WARN = 30 13 | ERROR = 40 14 | DISABLED = 50 15 | 16 | 17 | class KVWriter(object): 18 | """ 19 | Key Value writer 20 | """ 21 | def writekvs(self, kvs): 22 | """ 23 | write a dictionary to file 24 | 25 | :param kvs: (dict) 26 | """ 27 | raise NotImplementedError 28 | 29 | 30 | class SeqWriter(object): 31 | """ 32 | sequence writer 33 | """ 34 | def writeseq(self, seq): 35 | """ 36 | write an array to file 37 | 38 | :param seq: (list) 39 | """ 40 | raise NotImplementedError 41 | 42 | 43 | class HumanOutputFormat(KVWriter, SeqWriter): 44 | def __init__(self, filename_or_file): 45 | """ 46 | log to a file, in a human readable format 47 | 48 | :param filename_or_file: (str or File) the file to write the log to 49 | """ 50 | if isinstance(filename_or_file, str): 51 | self.file = open(filename_or_file, 'wt') 52 | self.own_file = True 53 | else: 54 | assert hasattr(filename_or_file, 'write'), 'Expected file or str, got {}'.format(filename_or_file) 55 | self.file = filename_or_file 56 | self.own_file = False 57 | 58 | def writekvs(self, kvs): 59 | # Create strings for printing 60 | key2str = {} 61 | for (key, val) in sorted(kvs.items()): 62 | if isinstance(val, float): 63 | valstr = '%-8.3g' % (val,) 64 | else: 65 | valstr = str(val) 66 | key2str[self._truncate(key)] = self._truncate(valstr) 67 | 68 | # Find max widths 69 | if len(key2str) == 0: 70 | warnings.warn('Tried to write empty key-value dict') 71 | return 72 | else: 73 | keywidth = max(map(len, key2str.keys())) 74 | valwidth = max(map(len, key2str.values())) 75 | 76 | # Write out the data 77 | dashes = '-' * (keywidth + valwidth + 7) 78 | lines = [dashes] 79 | for (key, val) in sorted(key2str.items()): 80 | lines.append('| %s%s | %s%s |' % ( 81 | key, 82 | ' ' * (keywidth - len(key)), 83 | val, 84 | ' ' * (valwidth - len(val)), 85 | )) 86 | lines.append(dashes) 87 | self.file.write('\n'.join(lines) + '\n') 88 | 89 | # Flush the output to the file 90 | self.file.flush() 91 | 92 | @classmethod 93 | def _truncate(cls, string): 94 | return string[:20] + '...' if len(string) > 23 else string 95 | 96 | def writeseq(self, seq): 97 | seq = list(seq) 98 | for (i, elem) in enumerate(seq): 99 | self.file.write(elem) 100 | if i < len(seq) - 1: # add space unless this is the last one 101 | self.file.write(' ') 102 | self.file.write('\n') 103 | self.file.flush() 104 | 105 | def close(self): 106 | """ 107 | closes the file 108 | """ 109 | if self.own_file: 110 | self.file.close() 111 | 112 | 113 | class JSONOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | """ 116 | log to a file, in the JSON format 117 | 118 | :param filename: (str) the file to write the log to 119 | """ 120 | self.file = open(filename, 'wt') 121 | 122 | def writekvs(self, kvs): 123 | for key, value in sorted(kvs.items()): 124 | if hasattr(value, 'dtype'): 125 | if value.shape == () or len(value) == 1: 126 | # if value is a dimensionless numpy array or of length 1, serialize as a float 127 | kvs[key] = float(value) 128 | else: 129 | # otherwise, a value is a numpy array, serialize as a list or nested lists 130 | kvs[key] = value.tolist() 131 | self.file.write(json.dumps(kvs) + '\n') 132 | self.file.flush() 133 | 134 | def close(self): 135 | """ 136 | closes the file 137 | """ 138 | self.file.close() 139 | 140 | 141 | class CSVOutputFormat(KVWriter): 142 | def __init__(self, filename): 143 | """ 144 | log to a file, in a CSV format 145 | 146 | :param filename: (str) the file to write the log to 147 | """ 148 | self.file = open(filename, 'w+t') 149 | self.keys = [] 150 | self.sep = ',' 151 | 152 | def writekvs(self, kvs): 153 | # Add our current row to the history 154 | extra_keys = kvs.keys() - self.keys 155 | if extra_keys: 156 | self.keys.extend(extra_keys) 157 | self.file.seek(0) 158 | lines = self.file.readlines() 159 | self.file.seek(0) 160 | for (i, key) in enumerate(self.keys): 161 | if i > 0: 162 | self.file.write(',') 163 | self.file.write(key) 164 | self.file.write('\n') 165 | for line in lines[1:]: 166 | self.file.write(line[:-1]) 167 | self.file.write(self.sep * len(extra_keys)) 168 | self.file.write('\n') 169 | for i, key in enumerate(self.keys): 170 | if i > 0: 171 | self.file.write(',') 172 | value = kvs.get(key) 173 | if value is not None: 174 | self.file.write(str(value)) 175 | self.file.write('\n') 176 | self.file.flush() 177 | 178 | def close(self): 179 | """ 180 | closes the file 181 | """ 182 | self.file.close() 183 | 184 | 185 | def summary_val(key, value): 186 | """ 187 | :param key: (str) 188 | :param value: (float) 189 | """ 190 | kwargs = {'tag': key, 'simple_value': float(value)} 191 | return tf.Summary.Value(**kwargs) 192 | 193 | 194 | def valid_float_value(value): 195 | """ 196 | Returns True if the value can be successfully cast into a float 197 | 198 | :param value: (Any) the value to check 199 | :return: (bool) 200 | """ 201 | try: 202 | float(value) 203 | return True 204 | except TypeError: 205 | return False 206 | 207 | 208 | def make_output_format(_format, ev_dir, log_suffix=''): 209 | """ 210 | return a logger for the requested format 211 | 212 | :param _format: (str) the requested format to log to ('stdout', 'log', 'json' or 'csv') 213 | :param ev_dir: (str) the logging directory 214 | :param log_suffix: (str) the suffix for the log file 215 | :return: (KVWrite) the logger 216 | """ 217 | os.makedirs(ev_dir, exist_ok=True) 218 | if _format == 'stdout': 219 | return HumanOutputFormat(sys.stdout) 220 | elif _format == 'log': 221 | return HumanOutputFormat(os.path.join(ev_dir, 'log%s.txt' % log_suffix)) 222 | elif _format == 'json': 223 | return JSONOutputFormat(os.path.join(ev_dir, 'progress%s.json' % log_suffix)) 224 | elif _format == 'csv': 225 | return CSVOutputFormat(os.path.join(ev_dir, 'progress%s.csv' % log_suffix)) 226 | else: 227 | raise ValueError('Unknown format specified: %s' % (_format,)) 228 | 229 | 230 | # ================================================================ 231 | # API 232 | # ================================================================ 233 | 234 | def logkv(key, val): 235 | """ 236 | Log a value of some diagnostic 237 | Call this once for each diagnostic quantity, each iteration 238 | If called many times, last value will be used. 239 | 240 | :param key: (Any) save to log this key 241 | :param val: (Any) save to log this value 242 | """ 243 | Logger.CURRENT.logkv(key, val) 244 | 245 | 246 | def logkv_mean(key, val): 247 | """ 248 | The same as logkv(), but if called many times, values averaged. 249 | 250 | :param key: (Any) save to log this key 251 | :param val: (Number) save to log this value 252 | """ 253 | Logger.CURRENT.logkv_mean(key, val) 254 | 255 | 256 | def logkvs(key_values): 257 | """ 258 | Log a dictionary of key-value pairs 259 | 260 | :param key_values: (dict) the list of keys and values to save to log 261 | """ 262 | for key, value in key_values.items(): 263 | logkv(key, value) 264 | 265 | 266 | def dumpkvs(): 267 | """ 268 | Write all of the diagnostics from the current iteration 269 | """ 270 | Logger.CURRENT.dumpkvs() 271 | 272 | 273 | def getkvs(): 274 | """ 275 | get the key values logs 276 | 277 | :return: (dict) the logged values 278 | """ 279 | return Logger.CURRENT.name2val 280 | 281 | 282 | def log(*args, **kwargs): 283 | """ 284 | Write the sequence of args, with no separators, 285 | to the console and output files (if you've configured an output file). 286 | 287 | level: int. (see logger.py docs) If the global logger level is higher than 288 | the level argument here, don't print to stdout. 289 | 290 | :param args: (list) log the arguments 291 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) 292 | """ 293 | level = kwargs.get('level', INFO) 294 | Logger.CURRENT.log(*args, level=level) 295 | 296 | 297 | def debug(*args): 298 | """ 299 | Write the sequence of args, with no separators, 300 | to the console and output files (if you've configured an output file). 301 | Using the DEBUG level. 302 | 303 | :param args: (list) log the arguments 304 | """ 305 | log(*args, level=DEBUG) 306 | 307 | 308 | def info(*args): 309 | """ 310 | Write the sequence of args, with no separators, 311 | to the console and output files (if you've configured an output file). 312 | Using the INFO level. 313 | 314 | :param args: (list) log the arguments 315 | """ 316 | log(*args, level=INFO) 317 | 318 | 319 | def warn(*args): 320 | """ 321 | Write the sequence of args, with no separators, 322 | to the console and output files (if you've configured an output file). 323 | Using the WARN level. 324 | 325 | :param args: (list) log the arguments 326 | """ 327 | log(*args, level=WARN) 328 | 329 | 330 | def error(*args): 331 | """ 332 | Write the sequence of args, with no separators, 333 | to the console and output files (if you've configured an output file). 334 | Using the ERROR level. 335 | 336 | :param args: (list) log the arguments 337 | """ 338 | log(*args, level=ERROR) 339 | 340 | 341 | def set_level(level): 342 | """ 343 | Set logging threshold on current logger. 344 | 345 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) 346 | """ 347 | Logger.CURRENT.set_level(level) 348 | 349 | 350 | def get_level(): 351 | """ 352 | Get logging threshold on current logger. 353 | :return: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) 354 | """ 355 | return Logger.CURRENT.level 356 | 357 | 358 | def get_dir(): 359 | """ 360 | Get directory that log files are being written to. 361 | will be None if there is no output directory (i.e., if you didn't call start) 362 | 363 | :return: (str) the logging directory 364 | """ 365 | return Logger.CURRENT.get_dir() 366 | 367 | 368 | record_tabular = logkv 369 | dump_tabular = dumpkvs 370 | 371 | 372 | # ================================================================ 373 | # Backend 374 | # ================================================================ 375 | 376 | class Logger(object): 377 | # A logger with no output files. (See right below class definition) 378 | # So that you can still log to the terminal without setting up any output files 379 | DEFAULT = None 380 | CURRENT = None # Current logger being used by the free functions above 381 | 382 | def __init__(self, folder, output_formats): 383 | """ 384 | the logger class 385 | 386 | :param folder: (str) the logging location 387 | :param output_formats: ([str]) the list of output format 388 | """ 389 | self.name2val = defaultdict(float) # values this iteration 390 | self.name2cnt = defaultdict(int) 391 | self.level = INFO 392 | self.dir = folder 393 | self.output_formats = output_formats 394 | 395 | # Logging API, forwarded 396 | # ---------------------------------------- 397 | def logkv(self, key, val): 398 | """ 399 | Log a value of some diagnostic 400 | Call this once for each diagnostic quantity, each iteration 401 | If called many times, last value will be used. 402 | 403 | :param key: (Any) save to log this key 404 | :param val: (Any) save to log this value 405 | """ 406 | self.name2val[key] = val 407 | 408 | def logkv_mean(self, key, val): 409 | """ 410 | The same as logkv(), but if called many times, values averaged. 411 | 412 | :param key: (Any) save to log this key 413 | :param val: (Number) save to log this value 414 | """ 415 | if val is None: 416 | self.name2val[key] = None 417 | return 418 | oldval, cnt = self.name2val[key], self.name2cnt[key] 419 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 420 | self.name2cnt[key] = cnt + 1 421 | 422 | def dumpkvs(self): 423 | """ 424 | Write all of the diagnostics from the current iteration 425 | """ 426 | if self.level == DISABLED: 427 | return 428 | for fmt in self.output_formats: 429 | if isinstance(fmt, KVWriter): 430 | fmt.writekvs(self.name2val) 431 | self.name2val.clear() 432 | self.name2cnt.clear() 433 | 434 | def log(self, *args, **kwargs): 435 | """ 436 | Write the sequence of args, with no separators, 437 | to the console and output files (if you've configured an output file). 438 | 439 | level: int. (see logger.py docs) If the global logger level is higher than 440 | the level argument here, don't print to stdout. 441 | 442 | :param args: (list) log the arguments 443 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) 444 | """ 445 | level = kwargs.get('level', INFO) 446 | if self.level <= level: 447 | self._do_log(args) 448 | 449 | # Configuration 450 | # ---------------------------------------- 451 | def set_level(self, level): 452 | """ 453 | Set logging threshold on current logger. 454 | 455 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) 456 | """ 457 | self.level = level 458 | 459 | def get_dir(self): 460 | """ 461 | Get directory that log files are being written to. 462 | will be None if there is no output directory (i.e., if you didn't call start) 463 | 464 | :return: (str) the logging directory 465 | """ 466 | return self.dir 467 | 468 | def close(self): 469 | """ 470 | closes the file 471 | """ 472 | for fmt in self.output_formats: 473 | fmt.close() 474 | 475 | # Misc 476 | # ---------------------------------------- 477 | def _do_log(self, args): 478 | """ 479 | log to the requested format outputs 480 | 481 | :param args: (list) the arguments to log 482 | """ 483 | for fmt in self.output_formats: 484 | if isinstance(fmt, SeqWriter): 485 | fmt.writeseq(map(str, args)) 486 | 487 | 488 | Logger.DEFAULT = Logger.CURRENT = Logger(folder=None, output_formats=[HumanOutputFormat(sys.stdout)]) 489 | 490 | 491 | def configure(folder=None, format_strs=None): 492 | """ 493 | configure the current logger 494 | 495 | :param folder: (str) the save location (if None, $BASELINES_LOGDIR, if still None, tempdir/baselines-[date & time]) 496 | :param format_strs: (list) the output logging format 497 | (if None, $BASELINES_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) 498 | """ 499 | if folder is None: 500 | folder = os.getenv('BASELINES_LOGDIR') 501 | if folder is None: 502 | folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("baselines-%Y-%m-%d-%H-%M-%S-%f")) 503 | assert isinstance(folder, str) 504 | os.makedirs(folder, exist_ok=True) 505 | 506 | log_suffix = '' 507 | if format_strs is None: 508 | format_strs = os.getenv('BASELINES_LOG_FORMAT', 'stdout,log,csv').split(',') 509 | 510 | format_strs = filter(None, format_strs) 511 | output_formats = [make_output_format(f, folder, log_suffix) for f in format_strs] 512 | 513 | Logger.CURRENT = Logger(folder=folder, output_formats=output_formats) 514 | log('Logging to %s' % folder) 515 | -------------------------------------------------------------------------------- /stable_baselines/common/monitor.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import time 5 | 6 | from gym.core import Wrapper 7 | 8 | 9 | class Monitor(Wrapper): 10 | EXT = "monitor.csv" 11 | file_handler = None 12 | 13 | def __init__(self, env, filename=None, allow_early_resets=True, reset_keywords=(), info_keywords=()): 14 | """ 15 | A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. 16 | 17 | :param env: (Gym environment) The environment 18 | :param filename: (str) the location to save a log file, can be None for no log 19 | :param allow_early_resets: (bool) allows the reset of the environment before it is done 20 | :param reset_keywords: (tuple) extra keywords for the reset call, if extra parameters are needed at reset 21 | :param info_keywords: (tuple) extra information to log, from the information return of environment.step 22 | """ 23 | Wrapper.__init__(self, env=env) 24 | self.t_start = time.time() 25 | if filename is None: 26 | self.file_handler = None 27 | self.logger = None 28 | else: 29 | if not filename.endswith(Monitor.EXT): 30 | if os.path.isdir(filename): 31 | filename = os.path.join(filename, Monitor.EXT) 32 | else: 33 | filename = filename + "." + Monitor.EXT 34 | self.file_handler = open(filename, "wt") 35 | self.file_handler.write('#%s\n' % json.dumps({"t_start": self.t_start, 'env_id': env.spec and env.spec.id})) 36 | self.logger = csv.DictWriter(self.file_handler, 37 | fieldnames=('r', 'l', 't') + reset_keywords + info_keywords) 38 | self.logger.writeheader() 39 | self.file_handler.flush() 40 | 41 | self.reset_keywords = reset_keywords 42 | self.info_keywords = info_keywords 43 | self.allow_early_resets = allow_early_resets 44 | self.rewards = None 45 | self.needs_reset = True 46 | self.episode_rewards = [] 47 | self.episode_lengths = [] 48 | self.episode_times = [] 49 | self.total_steps = 0 50 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() 51 | 52 | def reset(self, **kwargs): 53 | """ 54 | Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True 55 | 56 | :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords 57 | :return: ([int] or [float]) the first observation of the environment 58 | """ 59 | if not self.allow_early_resets and not self.needs_reset: 60 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, " 61 | "wrap your env with Monitor(env, path, allow_early_resets=True)") 62 | self.rewards = [] 63 | self.needs_reset = False 64 | for key in self.reset_keywords: 65 | value = kwargs.get(key) 66 | if value is None: 67 | raise ValueError('Expected you to pass kwarg %s into reset' % key) 68 | self.current_reset_info[key] = value 69 | return self.env.reset(**kwargs) 70 | 71 | def step(self, action): 72 | """ 73 | Step the environment with the given action 74 | 75 | :param action: ([int] or [float]) the action 76 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 77 | """ 78 | if self.needs_reset: 79 | raise RuntimeError("Tried to step environment that needs reset") 80 | observation, reward, done, info = self.env.step(action) 81 | self.rewards.append(reward) 82 | if done: 83 | self.needs_reset = True 84 | ep_rew = sum(self.rewards) 85 | eplen = len(self.rewards) 86 | ep_info = {"r": round(ep_rew, 6), "l": eplen, "t": round(time.time() - self.t_start, 6)} 87 | for key in self.info_keywords: 88 | ep_info[key] = info[key] 89 | self.episode_rewards.append(ep_rew) 90 | self.episode_lengths.append(eplen) 91 | self.episode_times.append(time.time() - self.t_start) 92 | ep_info.update(self.current_reset_info) 93 | if self.logger: 94 | self.logger.writerow(ep_info) 95 | self.file_handler.flush() 96 | info['episode'] = ep_info 97 | self.total_steps += 1 98 | return observation, reward, done, info 99 | 100 | def close(self): 101 | """ 102 | Closes the environment 103 | """ 104 | if self.file_handler is not None: 105 | self.file_handler.close() 106 | 107 | def get_total_steps(self): 108 | """ 109 | Returns the total number of timesteps 110 | 111 | :return: (int) 112 | """ 113 | return self.total_steps 114 | 115 | def get_episode_rewards(self): 116 | """ 117 | Returns the rewards of all the episodes 118 | 119 | :return: ([float]) 120 | """ 121 | return self.episode_rewards 122 | 123 | def get_episode_lengths(self): 124 | """ 125 | Returns the number of timesteps of all the episodes 126 | 127 | :return: ([int]) 128 | """ 129 | return self.episode_lengths 130 | 131 | def get_episode_times(self): 132 | """ 133 | Returns the runtime in seconds of all the episodes 134 | 135 | :return: ([float]) 136 | """ 137 | return self.episode_times 138 | -------------------------------------------------------------------------------- /stable_baselines/common/noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AdaptiveParamNoiseSpec(object): 5 | """ 6 | Implements adaptive parameter noise 7 | 8 | :param initial_stddev: (float) the initial value for the standard deviation of the noise 9 | :param desired_action_stddev: (float) the desired value for the standard deviation of the noise 10 | :param adoption_coefficient: (float) the update coefficient for the standard deviation of the noise 11 | """ 12 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01): 13 | self.initial_stddev = initial_stddev 14 | self.desired_action_stddev = desired_action_stddev 15 | self.adoption_coefficient = adoption_coefficient 16 | 17 | self.current_stddev = initial_stddev 18 | 19 | def adapt(self, distance): 20 | """ 21 | update the standard deviation for the parameter noise 22 | 23 | :param distance: (float) the noise distance applied to the parameters 24 | """ 25 | if distance > self.desired_action_stddev: 26 | # Decrease stddev. 27 | self.current_stddev /= self.adoption_coefficient 28 | else: 29 | # Increase stddev. 30 | self.current_stddev *= self.adoption_coefficient 31 | 32 | def get_stats(self): 33 | """ 34 | return the standard deviation for the parameter noise 35 | 36 | :return: (dict) the stats of the noise 37 | """ 38 | return {'param_noise_stddev': self.current_stddev} 39 | 40 | def __repr__(self): 41 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})' 42 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient) 43 | 44 | 45 | class ActionNoise(object): 46 | """ 47 | The action noise base class 48 | """ 49 | def reset(self): 50 | """ 51 | call end of episode reset for the noise 52 | """ 53 | pass 54 | 55 | 56 | class NormalActionNoise(ActionNoise): 57 | """ 58 | A Gaussian action noise 59 | 60 | :param mean: (float) the mean value of the noise 61 | :param sigma: (float) the scale of the noise (std here) 62 | """ 63 | def __init__(self, mean, sigma): 64 | self._mu = mean 65 | self._sigma = sigma 66 | 67 | def __call__(self): 68 | return np.random.normal(self._mu, self._sigma) 69 | 70 | def __repr__(self): 71 | return 'NormalActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma) 72 | 73 | 74 | class OrnsteinUhlenbeckActionNoise(ActionNoise): 75 | """ 76 | A Ornstein Uhlenbeck action noise, this is designed to approximate brownian motion with friction. 77 | 78 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 79 | 80 | :param mean: (float) the mean of the noise 81 | :param sigma: (float) the scale of the noise 82 | :param theta: (float) the rate of mean reversion 83 | :param dt: (float) the timestep for the noise 84 | :param initial_noise: ([float]) the initial value for the noise output, (if None: 0) 85 | """ 86 | 87 | def __init__(self, mean, sigma, theta=.15, dt=1e-2, initial_noise=None): 88 | self._theta = theta 89 | self._mu = mean 90 | self._sigma = sigma 91 | self._dt = dt 92 | self.initial_noise = initial_noise 93 | self.noise_prev = None 94 | self.reset() 95 | 96 | def __call__(self): 97 | noise = self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + \ 98 | self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) 99 | self.noise_prev = noise 100 | return noise 101 | 102 | def reset(self): 103 | """ 104 | reset the Ornstein Uhlenbeck noise, to the initial position 105 | """ 106 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) 107 | 108 | def __repr__(self): 109 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma) 110 | -------------------------------------------------------------------------------- /stable_baselines/common/policies.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import tensorflow as tf 4 | import tensorflow.keras.layers as layers 5 | from tensorflow.keras import Model 6 | from tensorflow.keras.models import Sequential 7 | 8 | 9 | class BasePolicy(Model): 10 | """ 11 | The base policy object 12 | 13 | :param observation_space: (Gym Space) The observation space of the environment 14 | :param action_space: (Gym Space) The action space of the environment 15 | """ 16 | 17 | def __init__(self, observation_space, action_space): 18 | super(BasePolicy, self).__init__() 19 | self.observation_space = observation_space 20 | self.action_space = action_space 21 | 22 | def save(self, path): 23 | """ 24 | Save model to a given location. 25 | 26 | :param path: (str) 27 | """ 28 | raise NotImplementedError() 29 | 30 | def load(self, path): 31 | """ 32 | Load saved model from path. 33 | 34 | :param path: (str) 35 | """ 36 | raise NotImplementedError() 37 | 38 | @tf.function 39 | def soft_update(self, other_network, tau): 40 | other_variables = other_network.trainable_variables 41 | current_variables = self.trainable_variables 42 | 43 | for (current_var, other_var) in zip(current_variables, other_variables): 44 | current_var.assign((1. - tau) * current_var + tau * other_var) 45 | 46 | def hard_update(self, other_network): 47 | self.soft_update(other_network, tau=1.) 48 | 49 | def call(self, x): 50 | raise NotImplementedError() 51 | 52 | 53 | def create_mlp(input_dim, output_dim, net_arch, 54 | activation_fn=tf.nn.relu, squash_out=False): 55 | """ 56 | Create a multi layer perceptron (MLP), which is 57 | a collection of fully-connected layers each followed by an activation function. 58 | 59 | :param input_dim: (int) Dimension of the input vector 60 | :param output_dim: (int) 61 | :param net_arch: ([int]) Architecture of the neural net 62 | It represents the number of units per layer. 63 | The length of this list is the number of layers. 64 | :param activation_fn: (tf.activations or str) The activation function 65 | to use after each layer. 66 | :param squash_out: (bool) Whether to squash the output using a Tanh 67 | activation function 68 | """ 69 | modules = [layers.Flatten(input_shape=(input_dim,), dtype=tf.float32)] 70 | 71 | if len(net_arch) > 0: 72 | modules.append(layers.Dense(net_arch[0], activation=activation_fn)) 73 | 74 | for idx in range(len(net_arch) - 1): 75 | modules.append(layers.Dense(net_arch[idx + 1], activation=activation_fn)) 76 | 77 | if output_dim > 0: 78 | modules.append(layers.Dense(output_dim, activation=None)) 79 | if squash_out: 80 | modules.append(layers.Activation(activation='tanh')) 81 | return modules 82 | 83 | 84 | _policy_registry = dict() 85 | 86 | 87 | def get_policy_from_name(base_policy_type, name): 88 | """ 89 | returns the registed policy from the base type and name 90 | 91 | :param base_policy_type: (BasePolicy) the base policy object 92 | :param name: (str) the policy name 93 | :return: (base_policy_type) the policy 94 | """ 95 | if base_policy_type not in _policy_registry: 96 | raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type)) 97 | if name not in _policy_registry[base_policy_type]: 98 | raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!" 99 | .format(name, list(_policy_registry[base_policy_type].keys()))) 100 | return _policy_registry[base_policy_type][name] 101 | 102 | 103 | 104 | def get_policy_from_name(base_policy_type, name): 105 | """ 106 | returns the registed policy from the base type and name 107 | 108 | :param base_policy_type: (BasePolicy) the base policy object 109 | :param name: (str) the policy name 110 | :return: (base_policy_type) the policy 111 | """ 112 | if base_policy_type not in _policy_registry: 113 | raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type)) 114 | if name not in _policy_registry[base_policy_type]: 115 | raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!" 116 | .format(name, list(_policy_registry[base_policy_type].keys()))) 117 | return _policy_registry[base_policy_type][name] 118 | 119 | 120 | def register_policy(name, policy): 121 | """ 122 | returns the registed policy from the base type and name 123 | 124 | :param name: (str) the policy name 125 | :param policy: (subclass of BasePolicy) the policy 126 | """ 127 | sub_class = None 128 | for cls in BasePolicy.__subclasses__(): 129 | if issubclass(policy, cls): 130 | sub_class = cls 131 | break 132 | if sub_class is None: 133 | raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy)) 134 | 135 | if sub_class not in _policy_registry: 136 | _policy_registry[sub_class] = {} 137 | if name in _policy_registry[sub_class]: 138 | raise ValueError("Error: the name {} is alreay registered for a different policy, will not override." 139 | .format(name)) 140 | _policy_registry[sub_class][name] = policy 141 | 142 | 143 | class MlpExtractor(Model): 144 | """ 145 | Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and 146 | a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many 147 | of them are shared between the policy network and the value network. It is assumed to be a list with the following 148 | structure: 149 | 150 | 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer. 151 | If the number of ints is zero, there will be no shared layers. 152 | 2. An optional dict, to specify the following non-shared layers for the value network and the policy network. 153 | It is formatted like ``dict(vf=[], pi=[])``. 154 | If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. 155 | 156 | For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value 157 | network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec 158 | would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128 159 | would be specified as [128, 128]. 160 | 161 | 162 | :param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN) 163 | :param net_arch: ([int or dict]) The specification of the policy and value networks. 164 | See above for details on its formatting. 165 | :param activation_fn: (tf.nn.activation) The activation function to use for the networks. 166 | """ 167 | def __init__(self, feature_dim, net_arch, activation_fn): 168 | super(MlpExtractor, self).__init__() 169 | 170 | shared_net, policy_net, value_net = [], [], [] 171 | policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network 172 | value_only_layers = [] # Layer sizes of the network that only belongs to the value network 173 | last_layer_dim_shared = feature_dim 174 | 175 | # Iterate through the shared layers and build the shared parts of the network 176 | for idx, layer in enumerate(net_arch): 177 | if isinstance(layer, int): # Check that this is a shared layer 178 | layer_size = layer 179 | # TODO: give layer a meaningful name 180 | # shared_net.append(layers.Dense(layer_size, input_shape=(last_layer_dim_shared,), activation=activation_fn)) 181 | shared_net.append(layers.Dense(layer_size, activation=activation_fn)) 182 | last_layer_dim_shared = layer_size 183 | else: 184 | assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" 185 | if 'pi' in layer: 186 | assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers." 187 | policy_only_layers = layer['pi'] 188 | 189 | if 'vf' in layer: 190 | assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers." 191 | value_only_layers = layer['vf'] 192 | break # From here on the network splits up in policy and value network 193 | 194 | last_layer_dim_pi = last_layer_dim_shared 195 | last_layer_dim_vf = last_layer_dim_shared 196 | 197 | # Build the non-shared part of the network 198 | for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)): 199 | if pi_layer_size is not None: 200 | assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." 201 | policy_net.append(layers.Dense(pi_layer_size, input_shape=(last_layer_dim_pi,), activation=activation_fn)) 202 | last_layer_dim_pi = pi_layer_size 203 | 204 | if vf_layer_size is not None: 205 | assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." 206 | value_net.append(layers.Dense(vf_layer_size, input_shape=(last_layer_dim_vf,), activation=activation_fn)) 207 | last_layer_dim_vf = vf_layer_size 208 | 209 | # Save dim, used to create the distributions 210 | self.latent_dim_pi = last_layer_dim_pi 211 | self.latent_dim_vf = last_layer_dim_vf 212 | 213 | # Create networks 214 | # If the list of layers is empty, the network will just act as an Identity module 215 | self.shared_net = Sequential(shared_net) 216 | self.policy_net = Sequential(policy_net) 217 | self.value_net = Sequential(value_net) 218 | 219 | def call(self, features): 220 | """ 221 | :return: (tf.Tensor, tf.Tensor) latent_policy, latent_value of the specified network. 222 | If all layers are shared, then ``latent_policy == latent_value`` 223 | """ 224 | shared_latent = self.shared_net(features) 225 | return self.policy_net(shared_latent), self.value_net(shared_latent) 226 | -------------------------------------------------------------------------------- /stable_baselines/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RunningMeanStd(object): 5 | def __init__(self, epsilon=1e-4, shape=()): 6 | """ 7 | calulates the running mean and std of a data stream 8 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 9 | 10 | :param epsilon: (float) helps with arithmetic issues 11 | :param shape: (tuple) the shape of the data stream's output 12 | """ 13 | self.mean = np.zeros(shape, 'float64') 14 | self.var = np.ones(shape, 'float64') 15 | self.count = epsilon 16 | 17 | def update(self, arr): 18 | batch_mean = np.mean(arr, axis=0) 19 | batch_var = np.var(arr, axis=0) 20 | batch_count = arr.shape[0] 21 | self.update_from_moments(batch_mean, batch_var, batch_count) 22 | 23 | def update_from_moments(self, batch_mean, batch_var, batch_count): 24 | delta = batch_mean - self.mean 25 | tot_count = self.count + batch_count 26 | 27 | new_mean = self.mean + delta * batch_count / tot_count 28 | m_a = self.var * self.count 29 | m_b = batch_var * batch_count 30 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 31 | new_var = m_2 / (self.count + batch_count) 32 | 33 | new_count = batch_count + self.count 34 | 35 | self.mean = new_mean 36 | self.var = new_var 37 | self.count = new_count 38 | -------------------------------------------------------------------------------- /stable_baselines/common/save_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import base64 3 | import pickle 4 | import cloudpickle 5 | 6 | 7 | def is_json_serializable(item): 8 | """ 9 | Test if an object is serializable into JSON 10 | 11 | :param item: (object) The object to be tested for JSON serialization. 12 | :return: (bool) True if object is JSON serializable, false otherwise. 13 | """ 14 | # Try with try-except struct. 15 | json_serializable = True 16 | try: 17 | _ = json.dumps(item) 18 | except TypeError: 19 | json_serializable = False 20 | return json_serializable 21 | 22 | 23 | def data_to_json(data): 24 | """ 25 | Turn data (class parameters) into a JSON string for storing 26 | 27 | :param data: (Dict) Dictionary of class parameters to be 28 | stored. Items that are not JSON serializable will be 29 | pickled with Cloudpickle and stored as bytearray in 30 | the JSON file 31 | :return: (str) JSON string of the data serialized. 32 | """ 33 | # First, check what elements can not be JSONfied, 34 | # and turn them into byte-strings 35 | serializable_data = {} 36 | for data_key, data_item in data.items(): 37 | # See if object is JSON serializable 38 | if is_json_serializable(data_item): 39 | # All good, store as it is 40 | serializable_data[data_key] = data_item 41 | else: 42 | # Not serializable, cloudpickle it into 43 | # bytes and convert to base64 string for storing. 44 | # Also store type of the class for consumption 45 | # from other languages/humans, so we have an 46 | # idea what was being stored. 47 | base64_encoded = base64.b64encode( 48 | cloudpickle.dumps(data_item) 49 | ).decode() 50 | 51 | # Use ":" to make sure we do 52 | # not override these keys 53 | # when we include variables of the object later 54 | cloudpickle_serialization = { 55 | ":type:": str(type(data_item)), 56 | ":serialized:": base64_encoded 57 | } 58 | 59 | # Add first-level JSON-serializable items of the 60 | # object for further details (but not deeper than this to 61 | # avoid deep nesting). 62 | # First we check that object has attributes (not all do, 63 | # e.g. numpy scalars) 64 | if hasattr(data_item, "__dict__") or isinstance(data_item, dict): 65 | # Take elements from __dict__ for custom classes 66 | item_generator = ( 67 | data_item.items if isinstance(data_item, dict) else data_item.__dict__.items 68 | ) 69 | for variable_name, variable_item in item_generator(): 70 | # Check if serializable. If not, just include the 71 | # string-representation of the object. 72 | if is_json_serializable(variable_item): 73 | cloudpickle_serialization[variable_name] = variable_item 74 | else: 75 | cloudpickle_serialization[variable_name] = str(variable_item) 76 | 77 | serializable_data[data_key] = cloudpickle_serialization 78 | json_string = json.dumps(serializable_data, indent=4) 79 | return json_string 80 | 81 | 82 | def json_to_data(json_string, custom_objects=None): 83 | """ 84 | Turn JSON serialization of class-parameters back into dictionary. 85 | 86 | :param json_string: (str) JSON serialization of the class-parameters 87 | that should be loaded. 88 | :param custom_objects: (dict) Dictionary of objects to replace 89 | upon loading. If a variable is present in this dictionary as a 90 | key, it will not be deserialized and the corresponding item 91 | will be used instead. Similar to custom_objects in 92 | `keras.models.load_model`. Useful when you have an object in 93 | file that can not be deserialized. 94 | :return: (dict) Loaded class parameters. 95 | """ 96 | if custom_objects is not None and not isinstance(custom_objects, dict): 97 | raise ValueError("custom_objects argument must be a dict or None") 98 | 99 | json_dict = json.loads(json_string) 100 | # This will be filled with deserialized data 101 | return_data = {} 102 | for data_key, data_item in json_dict.items(): 103 | if custom_objects is not None and data_key in custom_objects.keys(): 104 | # If item is provided in custom_objects, replace 105 | # the one from JSON with the one in custom_objects 106 | return_data[data_key] = custom_objects[data_key] 107 | elif isinstance(data_item, dict) and ":serialized:" in data_item.keys(): 108 | # If item is dictionary with ":serialized:" 109 | # key, this means it is serialized with cloudpickle. 110 | serialization = data_item[":serialized:"] 111 | # Try-except deserialization in case we run into 112 | # errors. If so, we can tell bit more information to 113 | # user. 114 | try: 115 | deserialized_object = cloudpickle.loads( 116 | base64.b64decode(serialization.encode()) 117 | ) 118 | except pickle.UnpicklingError: 119 | raise RuntimeError( 120 | "Could not deserialize object {}. ".format(data_key) + 121 | "Consider using `custom_objects` argument to replace " + 122 | "this object." 123 | ) 124 | return_data[data_key] = deserialized_object 125 | else: 126 | # Read as it is 127 | return_data[data_key] = data_item 128 | return return_data 129 | -------------------------------------------------------------------------------- /stable_baselines/common/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def tile_images(img_nhwc): 5 | """ 6 | Tile N images into one big PxQ image 7 | (P,Q) are chosen to be as close as possible, and if N 8 | is square, then P=Q. 9 | 10 | :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc 11 | n = batch index, h = height, w = width, c = channel 12 | :return: (numpy float) img_HWc, ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | n_images, height, width, n_channels = img_nhwc.shape 16 | # new_height was named H before 17 | new_height = int(np.ceil(np.sqrt(n_images))) 18 | # new_width was named W before 19 | new_width = int(np.ceil(float(n_images) / new_height)) 20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) 21 | # img_HWhwc 22 | out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels) 23 | # img_HhWwc 24 | out_image = out_image.transpose(0, 2, 1, 3, 4) 25 | # img_Hh_Ww_c 26 | out_image = out_image.reshape(new_height * height, new_width * width, n_channels) 27 | return out_image 28 | 29 | -------------------------------------------------------------------------------- /stable_baselines/common/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def set_random_seed(seed): 8 | """ 9 | Seed the different random generators 10 | :param seed: (int) 11 | :param using_cuda: (bool) 12 | """ 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | tf.random.set_seed(seed) 16 | 17 | 18 | def explained_variance(y_pred, y_true): 19 | """ 20 | Computes fraction of variance that ypred explains about y. 21 | Returns 1 - Var[y-ypred] / Var[y] 22 | 23 | interpretation: 24 | ev=0 => might as well have predicted zero 25 | ev=1 => perfect prediction 26 | ev<0 => worse than just predicting zero 27 | 28 | :param y_pred: (np.ndarray) the prediction 29 | :param y_true: (np.ndarray) the expected value 30 | :return: (float) explained variance of ypred and y 31 | """ 32 | assert y_true.ndim == 1 and y_pred.ndim == 1 33 | var_y = np.var(y_true) 34 | return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 35 | 36 | 37 | def get_schedule_fn(value_schedule): 38 | """ 39 | Transform (if needed) learning rate and clip range (for PPO) 40 | to callable. 41 | 42 | :param value_schedule: (callable or float) 43 | :return: (function) 44 | """ 45 | # If the passed schedule is a float 46 | # create a constant function 47 | if isinstance(value_schedule, (float, int)): 48 | # Cast to float to avoid errors 49 | value_schedule = constant_fn(float(value_schedule)) 50 | else: 51 | assert callable(value_schedule) 52 | return value_schedule 53 | 54 | 55 | def constant_fn(val): 56 | """ 57 | Create a function that returns a constant 58 | It is useful for learning rate schedule (to avoid code duplication) 59 | 60 | :param val: (float) 61 | :return: (function) 62 | """ 63 | 64 | def func(_): 65 | return val 66 | 67 | return func 68 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | # flake8: noqa F401 4 | from stable_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \ 5 | CloudpickleWrapper 6 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv 7 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 8 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack 9 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize 10 | from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder 11 | from stable_baselines.common.vec_env.vec_check_nan import VecCheckNan 12 | 13 | 14 | 15 | def unwrap_vec_normalize(env): 16 | """ 17 | :param env: (gym.Env) 18 | :return: (VecNormalize) 19 | """ 20 | env_tmp = env 21 | while isinstance(env_tmp, VecEnvWrapper): 22 | if isinstance(env_tmp, VecNormalize): 23 | return env_tmp 24 | env_tmp = env_tmp.venv 25 | return None 26 | 27 | 28 | # Define here to avoid circular import 29 | def sync_envs_normalization(env, eval_env): 30 | """ 31 | Sync eval env and train env when using VecNormalize 32 | 33 | :param env: (gym.Env) 34 | :param eval_env: (gym.Env) 35 | """ 36 | env_tmp, eval_env_tmp = env, eval_env 37 | while isinstance(env_tmp, VecEnvWrapper): 38 | if isinstance(env_tmp, VecNormalize): 39 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 40 | env_tmp = env_tmp.venv 41 | eval_env_tmp = eval_env_tmp.venv 42 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/base_vec_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import inspect 3 | import pickle 4 | 5 | import cloudpickle 6 | from stable_baselines.common import logger 7 | 8 | 9 | class AlreadySteppingError(Exception): 10 | """ 11 | Raised when an asynchronous step is running while 12 | step_async() is called again. 13 | """ 14 | 15 | def __init__(self): 16 | msg = 'already running an async step' 17 | Exception.__init__(self, msg) 18 | 19 | 20 | class NotSteppingError(Exception): 21 | """ 22 | Raised when an asynchronous step is not running but 23 | step_wait() is called. 24 | """ 25 | 26 | def __init__(self): 27 | msg = 'not running an async step' 28 | Exception.__init__(self, msg) 29 | 30 | 31 | class VecEnv(ABC): 32 | """ 33 | An abstract asynchronous, vectorized environment. 34 | 35 | :param num_envs: (int) the number of environments 36 | :param observation_space: (Gym Space) the observation space 37 | :param action_space: (Gym Space) the action space 38 | """ 39 | metadata = { 40 | 'render.modes': ['human', 'rgb_array'] 41 | } 42 | 43 | def __init__(self, num_envs, observation_space, action_space): 44 | self.num_envs = num_envs 45 | self.observation_space = observation_space 46 | self.action_space = action_space 47 | 48 | @abstractmethod 49 | def reset(self): 50 | """ 51 | Reset all the environments and return an array of 52 | observations, or a tuple of observation arrays. 53 | 54 | If step_async is still doing work, that work will 55 | be cancelled and step_wait() should not be called 56 | until step_async() is invoked again. 57 | 58 | :return: ([int] or [float]) observation 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def step_async(self, actions): 64 | """ 65 | Tell all the environments to start taking a step 66 | with the given actions. 67 | Call step_wait() to get the results of the step. 68 | 69 | You should not call this if a step_async run is 70 | already pending. 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def step_wait(self): 76 | """ 77 | Wait for the step taken with step_async(). 78 | 79 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 80 | """ 81 | pass 82 | 83 | @abstractmethod 84 | def close(self): 85 | """ 86 | Clean up the environment's resources. 87 | """ 88 | pass 89 | 90 | @abstractmethod 91 | def get_attr(self, attr_name, indices=None): 92 | """ 93 | Return attribute from vectorized environment. 94 | 95 | :param attr_name: (str) The name of the attribute whose value to return 96 | :param indices: (list,int) Indices of envs to get attribute from 97 | :return: (list) List of values of 'attr_name' in all environments 98 | """ 99 | pass 100 | 101 | @abstractmethod 102 | def set_attr(self, attr_name, value, indices=None): 103 | """ 104 | Set attribute inside vectorized environments. 105 | 106 | :param attr_name: (str) The name of attribute to assign new value 107 | :param value: (obj) Value to assign to `attr_name` 108 | :param indices: (list,int) Indices of envs to assign value 109 | :return: (NoneType) 110 | """ 111 | pass 112 | 113 | @abstractmethod 114 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 115 | """ 116 | Call instance methods of vectorized environments. 117 | 118 | :param method_name: (str) The name of the environment method to invoke. 119 | :param indices: (list,int) Indices of envs whose method to call 120 | :param method_args: (tuple) Any positional arguments to provide in the call 121 | :param method_kwargs: (dict) Any keyword arguments to provide in the call 122 | :return: (list) List of items returned by the environment's method call 123 | """ 124 | pass 125 | 126 | def step(self, actions): 127 | """ 128 | Step the environments with the given action 129 | 130 | :param actions: ([int] or [float]) the action 131 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 132 | """ 133 | self.step_async(actions) 134 | return self.step_wait() 135 | 136 | def get_images(self): 137 | """ 138 | Return RGB images from each environment 139 | """ 140 | raise NotImplementedError 141 | 142 | def render(self, *args, **kwargs): 143 | """ 144 | Gym environment rendering 145 | 146 | :param mode: (str) the rendering type 147 | """ 148 | logger.warn('Render not defined for %s' % self) 149 | 150 | @property 151 | def unwrapped(self): 152 | if isinstance(self, VecEnvWrapper): 153 | return self.venv.unwrapped 154 | else: 155 | return self 156 | 157 | def getattr_depth_check(self, name, already_found): 158 | """Check if an attribute reference is being hidden in a recursive call to __getattr__ 159 | 160 | :param name: (str) name of attribute to check for 161 | :param already_found: (bool) whether this attribute has already been found in a wrapper 162 | :return: (str or None) name of module whose attribute is being shadowed, if any. 163 | """ 164 | if hasattr(self, name) and already_found: 165 | return "{0}.{1}".format(type(self).__module__, type(self).__name__) 166 | else: 167 | return None 168 | 169 | def _get_indices(self, indices): 170 | """ 171 | Convert a flexibly-typed reference to environment indices to an implied list of indices. 172 | 173 | :param indices: (None,int,Iterable) refers to indices of envs. 174 | :return: (list) the implied list of indices. 175 | """ 176 | if indices is None: 177 | indices = range(self.num_envs) 178 | elif isinstance(indices, int): 179 | indices = [indices] 180 | return indices 181 | 182 | 183 | class VecEnvWrapper(VecEnv): 184 | """ 185 | Vectorized environment base class 186 | 187 | :param venv: (VecEnv) the vectorized environment to wrap 188 | :param observation_space: (Gym Space) the observation space (can be None to load from venv) 189 | :param action_space: (Gym Space) the action space (can be None to load from venv) 190 | """ 191 | 192 | def __init__(self, venv, observation_space=None, action_space=None): 193 | self.venv = venv 194 | VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, 195 | action_space=action_space or venv.action_space) 196 | self.class_attributes = dict(inspect.getmembers(self.__class__)) 197 | 198 | def step_async(self, actions): 199 | self.venv.step_async(actions) 200 | 201 | @abstractmethod 202 | def reset(self): 203 | pass 204 | 205 | @abstractmethod 206 | def step_wait(self): 207 | pass 208 | 209 | def close(self): 210 | return self.venv.close() 211 | 212 | def render(self, *args, **kwargs): 213 | return self.venv.render(*args, **kwargs) 214 | 215 | def get_images(self): 216 | return self.venv.get_images() 217 | 218 | def get_attr(self, attr_name, indices=None): 219 | return self.venv.get_attr(attr_name, indices) 220 | 221 | def set_attr(self, attr_name, value, indices=None): 222 | return self.venv.set_attr(attr_name, value, indices) 223 | 224 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 225 | return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) 226 | 227 | def __getattr__(self, name): 228 | """Find attribute from wrapped venv(s) if this wrapper does not have it. 229 | Useful for accessing attributes from venvs which are wrapped with multiple wrappers 230 | which have unique attributes of interest. 231 | """ 232 | blocked_class = self.getattr_depth_check(name, already_found=False) 233 | if blocked_class is not None: 234 | own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) 235 | format_str = ("Error: Recursive attribute lookup for {0} from {1} is " 236 | "ambiguous and hides attribute from {2}") 237 | raise AttributeError(format_str.format(name, own_class, blocked_class)) 238 | 239 | return self.getattr_recursive(name) 240 | 241 | def _get_all_attributes(self): 242 | """Get all (inherited) instance and class attributes 243 | 244 | :return: (dict) all_attributes 245 | """ 246 | all_attributes = self.__dict__.copy() 247 | all_attributes.update(self.class_attributes) 248 | return all_attributes 249 | 250 | def getattr_recursive(self, name): 251 | """Recursively check wrappers to find attribute. 252 | 253 | :param name (str) name of attribute to look for 254 | :return: (object) attribute 255 | """ 256 | all_attributes = self._get_all_attributes() 257 | if name in all_attributes: # attribute is present in this wrapper 258 | attr = getattr(self, name) 259 | elif hasattr(self.venv, 'getattr_recursive'): 260 | # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr 261 | # to avoid a duplicate call to getattr_depth_check. 262 | attr = self.venv.getattr_recursive(name) 263 | else: # attribute not present, child is an unwrapped VecEnv 264 | attr = getattr(self.venv, name) 265 | 266 | return attr 267 | 268 | def getattr_depth_check(self, name, already_found): 269 | """See base class. 270 | 271 | :return: (str or None) name of module whose attribute is being shadowed, if any. 272 | """ 273 | all_attributes = self._get_all_attributes() 274 | if name in all_attributes and already_found: 275 | # this venv's attribute is being hidden because of a higher venv. 276 | shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) 277 | elif name in all_attributes and not already_found: 278 | # we have found the first reference to the attribute. Now check for duplicates. 279 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) 280 | else: 281 | # this wrapper does not have the attribute. Keep searching. 282 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) 283 | 284 | return shadowed_wrapper_class 285 | 286 | 287 | class CloudpickleWrapper(object): 288 | def __init__(self, var): 289 | """ 290 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 291 | 292 | :param var: (Any) the variable you wish to wrap for pickling with cloudpickle 293 | """ 294 | self.var = var 295 | 296 | def __getstate__(self): 297 | return cloudpickle.dumps(self.var) 298 | 299 | def __setstate__(self, obs): 300 | self.var = pickle.loads(obs) 301 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | 4 | from stable_baselines.common.vec_env.base_vec_env import VecEnv 5 | from stable_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info 6 | 7 | 8 | class DummyVecEnv(VecEnv): 9 | """ 10 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 11 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of 12 | multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that 13 | require a vectorized environment, but that you want a single environments to train with. 14 | 15 | :param env_fns: ([callable]) A list of functions that will create the environments 16 | (each callable returns a `Gym.Env` instance when called). 17 | """ 18 | 19 | def __init__(self, env_fns): 20 | self.envs = [fn() for fn in env_fns] 21 | env = self.envs[0] 22 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 23 | obs_space = env.observation_space 24 | self.keys, shapes, dtypes = obs_space_info(obs_space) 25 | 26 | self.buf_obs = OrderedDict([ 27 | (k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) 28 | for k in self.keys]) 29 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) 30 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 31 | self.buf_infos = [{} for _ in range(self.num_envs)] 32 | self.actions = None 33 | self.metadata = env.metadata 34 | 35 | def step_async(self, actions): 36 | self.actions = actions 37 | 38 | def step_wait(self): 39 | for env_idx in range(self.num_envs): 40 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\ 41 | self.envs[env_idx].step(self.actions[env_idx]) 42 | if self.buf_dones[env_idx]: 43 | # save final observation where user can get it, then reset 44 | self.buf_infos[env_idx]['terminal_observation'] = obs 45 | obs = self.envs[env_idx].reset() 46 | self._save_obs(env_idx, obs) 47 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), 48 | self.buf_infos.copy()) 49 | 50 | def reset(self): 51 | for env_idx in range(self.num_envs): 52 | obs = self.envs[env_idx].reset() 53 | self._save_obs(env_idx, obs) 54 | return self._obs_from_buf() 55 | 56 | def close(self): 57 | for env in self.envs: 58 | env.close() 59 | 60 | def seed(self, seed, indices=None): 61 | """ 62 | :param seed: (int or [int]) 63 | :param indices: ([int]) 64 | """ 65 | indices = self._get_indices(indices) 66 | if not hasattr(seed, 'len'): 67 | seed = [seed] * len(indices) 68 | assert len(seed) == len(indices) 69 | return [self.envs[i].seed(seed[i]) for i in indices] 70 | 71 | def get_images(self): 72 | return [env.render(mode='rgb_array') for env in self.envs] 73 | 74 | def render(self, *args, **kwargs): 75 | if self.num_envs == 1: 76 | return self.envs[0].render(*args, **kwargs) 77 | else: 78 | return super().render(*args, **kwargs) 79 | 80 | def _save_obs(self, env_idx, obs): 81 | for key in self.keys: 82 | if key is None: 83 | self.buf_obs[key][env_idx] = obs 84 | else: 85 | self.buf_obs[key][env_idx] = obs[key] 86 | 87 | def _obs_from_buf(self): 88 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 89 | 90 | def get_attr(self, attr_name, indices=None): 91 | """Return attribute from vectorized environment (see base class).""" 92 | target_envs = self._get_target_envs(indices) 93 | return [getattr(env_i, attr_name) for env_i in target_envs] 94 | 95 | def set_attr(self, attr_name, value, indices=None): 96 | """Set attribute inside vectorized environments (see base class).""" 97 | target_envs = self._get_target_envs(indices) 98 | for env_i in target_envs: 99 | setattr(env_i, attr_name, value) 100 | 101 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 102 | """Call instance methods of vectorized environments.""" 103 | target_envs = self._get_target_envs(indices) 104 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 105 | 106 | def _get_target_envs(self, indices): 107 | indices = self._get_indices(indices) 108 | return [self.envs[i] for i in indices] 109 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from collections import OrderedDict 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from stable_baselines.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper 8 | from stable_baselines.common.tile_images import tile_images 9 | 10 | 11 | def _worker(remote, parent_remote, env_fn_wrapper): 12 | parent_remote.close() 13 | env = env_fn_wrapper.var() 14 | while True: 15 | try: 16 | cmd, data = remote.recv() 17 | if cmd == 'step': 18 | observation, reward, done, info = env.step(data) 19 | if done: 20 | # save final observation where user can get it, then reset 21 | info['terminal_observation'] = observation 22 | observation = env.reset() 23 | remote.send((observation, reward, done, info)) 24 | elif cmd == 'reset': 25 | observation = env.reset() 26 | remote.send(observation) 27 | elif cmd == 'render': 28 | remote.send(env.render(*data[0], **data[1])) 29 | elif cmd == 'close': 30 | remote.close() 31 | break 32 | elif cmd == 'get_spaces': 33 | remote.send((env.observation_space, env.action_space)) 34 | elif cmd == 'env_method': 35 | method = getattr(env, data[0]) 36 | remote.send(method(*data[1], **data[2])) 37 | elif cmd == 'get_attr': 38 | remote.send(getattr(env, data)) 39 | elif cmd == 'set_attr': 40 | remote.send(setattr(env, data[0], data[1])) 41 | else: 42 | raise NotImplementedError 43 | except EOFError: 44 | break 45 | 46 | 47 | class SubprocVecEnv(VecEnv): 48 | """ 49 | Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own 50 | process, allowing significant speed up when the environment is computationally complex. 51 | 52 | For performance reasons, if your environment is not IO bound, the number of environments should not exceed the 53 | number of logical cores on your CPU. 54 | 55 | .. warning:: 56 | 57 | Only 'forkserver' and 'spawn' start methods are thread-safe, 58 | which is important when TensorFlow sessions or other non thread-safe 59 | libraries are used in the parent (see issue #217). However, compared to 60 | 'fork' they incur a small start-up cost and have restrictions on 61 | global variables. With those methods, users must wrap the code in an 62 | ``if __name__ == "__main__":`` block. 63 | For more information, see the multiprocessing documentation. 64 | 65 | :param env_fns: ([callable]) A list of functions that will create the environments 66 | (each callable returns a `Gym.Env` instance when called). 67 | :param start_method: (str) method used to start the subprocesses. 68 | Must be one of the methods returned by multiprocessing.get_all_start_methods(). 69 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. 70 | """ 71 | 72 | def __init__(self, env_fns, start_method=None): 73 | self.waiting = False 74 | self.closed = False 75 | n_envs = len(env_fns) 76 | 77 | if start_method is None: 78 | # Fork is not a thread safe method (see issue #217) 79 | # but is more user friendly (does not require to wrap the code in 80 | # a `if __name__ == "__main__":`) 81 | forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods() 82 | start_method = 'forkserver' if forkserver_available else 'spawn' 83 | ctx = multiprocessing.get_context(start_method) 84 | 85 | self.remotes, self.work_remotes = zip(*[ctx.Pipe(duplex=True) for _ in range(n_envs)]) 86 | self.processes = [] 87 | for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): 88 | args = (work_remote, remote, CloudpickleWrapper(env_fn)) 89 | # daemon=True: if the main process crashes, we should not cause things to hang 90 | process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error 91 | process.start() 92 | self.processes.append(process) 93 | work_remote.close() 94 | 95 | self.remotes[0].send(('get_spaces', None)) 96 | observation_space, action_space = self.remotes[0].recv() 97 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 98 | 99 | def step_async(self, actions): 100 | for remote, action in zip(self.remotes, actions): 101 | remote.send(('step', action)) 102 | self.waiting = True 103 | 104 | def step_wait(self): 105 | results = [remote.recv() for remote in self.remotes] 106 | self.waiting = False 107 | obs, rews, dones, infos = zip(*results) 108 | return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos 109 | 110 | def reset(self): 111 | for remote in self.remotes: 112 | remote.send(('reset', None)) 113 | obs = [remote.recv() for remote in self.remotes] 114 | return _flatten_obs(obs, self.observation_space) 115 | 116 | def close(self): 117 | if self.closed: 118 | return 119 | if self.waiting: 120 | for remote in self.remotes: 121 | remote.recv() 122 | for remote in self.remotes: 123 | remote.send(('close', None)) 124 | for process in self.processes: 125 | process.join() 126 | self.closed = True 127 | 128 | def render(self, mode='human', *args, **kwargs): 129 | for pipe in self.remotes: 130 | # gather images from subprocesses 131 | # `mode` will be taken into account later 132 | pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs}))) 133 | imgs = [pipe.recv() for pipe in self.remotes] 134 | # Create a big image by tiling images from subprocesses 135 | bigimg = tile_images(imgs) 136 | if mode == 'human': 137 | import cv2 # pytype:disable=import-error 138 | cv2.imshow('vecenv', bigimg[:, :, ::-1]) 139 | cv2.waitKey(1) 140 | elif mode == 'rgb_array': 141 | return bigimg 142 | else: 143 | raise NotImplementedError 144 | 145 | def get_images(self): 146 | for pipe in self.remotes: 147 | pipe.send(('render', {"mode": 'rgb_array'})) 148 | imgs = [pipe.recv() for pipe in self.remotes] 149 | return imgs 150 | 151 | def get_attr(self, attr_name, indices=None): 152 | """Return attribute from vectorized environment (see base class).""" 153 | target_remotes = self._get_target_remotes(indices) 154 | for remote in target_remotes: 155 | remote.send(('get_attr', attr_name)) 156 | return [remote.recv() for remote in target_remotes] 157 | 158 | def set_attr(self, attr_name, value, indices=None): 159 | """Set attribute inside vectorized environments (see base class).""" 160 | target_remotes = self._get_target_remotes(indices) 161 | for remote in target_remotes: 162 | remote.send(('set_attr', (attr_name, value))) 163 | for remote in target_remotes: 164 | remote.recv() 165 | 166 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 167 | """Call instance methods of vectorized environments.""" 168 | target_remotes = self._get_target_remotes(indices) 169 | for remote in target_remotes: 170 | remote.send(('env_method', (method_name, method_args, method_kwargs))) 171 | return [remote.recv() for remote in target_remotes] 172 | 173 | def _get_target_remotes(self, indices): 174 | """ 175 | Get the connection object needed to communicate with the wanted 176 | envs that are in subprocesses. 177 | 178 | :param indices: (None,int,Iterable) refers to indices of envs. 179 | :return: ([multiprocessing.Connection]) Connection object to communicate between processes. 180 | """ 181 | indices = self._get_indices(indices) 182 | return [self.remotes[i] for i in indices] 183 | 184 | 185 | def _flatten_obs(obs, space): 186 | """ 187 | Flatten observations, depending on the observation space. 188 | 189 | :param obs: (list or tuple where X is dict, tuple or ndarray) observations. 190 | A list or tuple of observations, one per environment. 191 | Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. 192 | :return (OrderedDict, tuple or ndarray) flattened observations. 193 | A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. 194 | Each NumPy array has the environment index as its first axis. 195 | """ 196 | assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" 197 | assert len(obs) > 0, "need observations from at least one environment" 198 | 199 | if isinstance(space, gym.spaces.Dict): 200 | assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" 201 | assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" 202 | return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) 203 | elif isinstance(space, gym.spaces.Tuple): 204 | assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" 205 | obs_len = len(space.spaces) 206 | return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) 207 | else: 208 | return np.stack(obs) 209 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy a dict of numpy arrays. 14 | 15 | :param obs: (OrderedDict): a dict of numpy arrays. 16 | :return (OrderedDict) a dict of copied numpy arrays. 17 | """ 18 | assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs)) 19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 20 | 21 | 22 | def dict_to_obs(space, obs_dict): 23 | """ 24 | Convert an internal representation raw_obs into the appropriate type 25 | specified by space. 26 | 27 | :param space: (gym.spaces.Space) an observation space. 28 | :param obs_dict: (OrderedDict) a dict of numpy arrays. 29 | :return (ndarray, tuple or dict): returns an observation 30 | of the same type as space. If space is Dict, function is identity; 31 | if space is Tuple, converts dict to Tuple; otherwise, space is 32 | unstructured and returns the value raw_obs[None]. 33 | """ 34 | if isinstance(space, gym.spaces.Dict): 35 | return obs_dict 36 | elif isinstance(space, gym.spaces.Tuple): 37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space" 38 | return tuple((obs_dict[i] for i in range(len(space.spaces)))) 39 | else: 40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 41 | return obs_dict[None] 42 | 43 | 44 | def obs_space_info(obs_space): 45 | """ 46 | Get dict-structured information about a gym.Space. 47 | 48 | Dict spaces are represented directly by their dict of subspaces. 49 | Tuple spaces are converted into a dict with keys indexing into the tuple. 50 | Unstructured spaces are represented by {None: obs_space}. 51 | 52 | :param obs_space: (gym.spaces.Space) an observation space 53 | :return (tuple) A tuple (keys, shapes, dtypes): 54 | keys: a list of dict keys. 55 | shapes: a dict mapping keys to shapes. 56 | dtypes: a dict mapping keys to dtypes. 57 | """ 58 | if isinstance(obs_space, gym.spaces.Dict): 59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 60 | subspaces = obs_space.spaces 61 | elif isinstance(obs_space, gym.spaces.Tuple): 62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 63 | else: 64 | assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space)) 65 | subspaces = {None: obs_space} 66 | keys = [] 67 | shapes = {} 68 | dtypes = {} 69 | for key, box in subspaces.items(): 70 | keys.append(key) 71 | shapes[key] = box.shape 72 | dtypes[key] = box.dtype 73 | return keys, shapes, dtypes 74 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: (bool) Whether or not to only warn once. 16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions): 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self): 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self): 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step, **kwargs): 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += "found {} in {}".format(type_val, name) 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions) 80 | else: 81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations) 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 7 | 8 | 9 | class VecFrameStack(VecEnvWrapper): 10 | """ 11 | Frame stacking wrapper for vectorized environment 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param n_stack: (int) Number of frames to stack 15 | """ 16 | 17 | def __init__(self, venv, n_stack): 18 | self.venv = venv 19 | self.n_stack = n_stack 20 | wrapped_obs_space = venv.observation_space 21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) 22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) 23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 26 | 27 | def step_wait(self): 28 | observations, rewards, dones, infos = self.venv.step_wait() 29 | last_ax_size = observations.shape[-1] 30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) 31 | for i, done in enumerate(dones): 32 | if done: 33 | if 'terminal_observation' in infos[i]: 34 | old_terminal = infos[i]['terminal_observation'] 35 | new_terminal = np.concatenate( 36 | (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) 37 | infos[i]['terminal_observation'] = new_terminal 38 | else: 39 | warnings.warn( 40 | "VecFrameStack wrapping a VecEnv without terminal_observation info") 41 | self.stackedobs[i] = 0 42 | self.stackedobs[..., -observations.shape[-1]:] = observations 43 | return self.stackedobs, rewards, dones, infos 44 | 45 | def reset(self): 46 | """ 47 | Reset all environments 48 | """ 49 | obs = self.venv.reset() 50 | self.stackedobs[...] = 0 51 | self.stackedobs[..., -obs.shape[-1]:] = obs 52 | return self.stackedobs 53 | 54 | def close(self): 55 | self.venv.close() 56 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_normalize.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | 4 | import numpy as np 5 | 6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 7 | from stable_baselines.common.running_mean_std import RunningMeanStd 8 | 9 | 10 | class VecNormalize(VecEnvWrapper): 11 | """ 12 | A moving average, normalizing wrapper for vectorized environment. 13 | 14 | It is pickleable which will save moving averages and configuration parameters. 15 | The wrapped environment `venv` is not saved, and must be restored manually with 16 | `set_venv` after being unpickled. 17 | 18 | :param venv: (VecEnv) the vectorized environment to wrap 19 | :param training: (bool) Whether to update or not the moving average 20 | :param norm_obs: (bool) Whether to normalize observation or not (default: True) 21 | :param norm_reward: (bool) Whether to normalize rewards or not (default: True) 22 | :param clip_obs: (float) Max absolute value for observation 23 | :param clip_reward: (float) Max value absolute for discounted reward 24 | :param gamma: (float) discount factor 25 | :param epsilon: (float) To avoid division by zero 26 | """ 27 | 28 | def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, 29 | clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): 30 | VecEnvWrapper.__init__(self, venv) 31 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) 32 | self.ret_rms = RunningMeanStd(shape=()) 33 | self.clip_obs = clip_obs 34 | self.clip_reward = clip_reward 35 | # Returns: discounted rewards 36 | self.ret = np.zeros(self.num_envs) 37 | self.gamma = gamma 38 | self.epsilon = epsilon 39 | self.training = training 40 | self.norm_obs = norm_obs 41 | self.norm_reward = norm_reward 42 | self.old_obs = None 43 | self.old_rews = None 44 | 45 | def __getstate__(self): 46 | """ 47 | Gets state for pickling. 48 | 49 | Excludes self.venv, as in general VecEnv's may not be pickleable.""" 50 | state = self.__dict__.copy() 51 | # these attributes are not pickleable 52 | del state['venv'] 53 | del state['class_attributes'] 54 | # these attributes depend on the above and so we would prefer not to pickle 55 | del state['ret'] 56 | return state 57 | 58 | def __setstate__(self, state): 59 | """ 60 | Restores pickled state. 61 | 62 | User must call set_venv() after unpickling before using. 63 | 64 | :param state: (dict)""" 65 | self.__dict__.update(state) 66 | assert 'venv' not in state 67 | self.venv = None 68 | 69 | def set_venv(self, venv): 70 | """ 71 | Sets the vector environment to wrap to venv. 72 | 73 | Also sets attributes derived from this such as `num_env`. 74 | 75 | :param venv: (VecEnv) 76 | """ 77 | if self.venv is not None: 78 | raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") 79 | VecEnvWrapper.__init__(self, venv) 80 | if self.obs_rms.mean.shape != self.observation_space.shape: 81 | raise ValueError("venv is incompatible with current statistics.") 82 | self.ret = np.zeros(self.num_envs) 83 | 84 | def step_wait(self): 85 | """ 86 | Apply sequence of actions to sequence of environments 87 | actions -> (observations, rewards, news) 88 | 89 | where 'news' is a boolean vector indicating whether each element is new. 90 | """ 91 | obs, rews, news, infos = self.venv.step_wait() 92 | self.old_obs = obs 93 | self.old_rews = rews 94 | 95 | if self.training: 96 | self.obs_rms.update(obs) 97 | obs = self.normalize_obs(obs) 98 | 99 | if self.training: 100 | self._update_reward(rews) 101 | rews = self.normalize_reward(rews) 102 | 103 | self.ret[news] = 0 104 | return obs, rews, news, infos 105 | 106 | def _update_reward(self, reward: np.ndarray) -> None: 107 | """Update reward normalization statistics.""" 108 | self.ret = self.ret * self.gamma + reward 109 | self.ret_rms.update(self.ret) 110 | 111 | def normalize_obs(self, obs: np.ndarray) -> np.ndarray: 112 | """ 113 | Normalize observations using this VecNormalize's observations statistics. 114 | Calling this method does not update statistics. 115 | """ 116 | if self.norm_obs: 117 | obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), 118 | -self.clip_obs, 119 | self.clip_obs) 120 | return obs 121 | 122 | def normalize_reward(self, reward: np.ndarray) -> np.ndarray: 123 | """ 124 | Normalize rewards using this VecNormalize's rewards statistics. 125 | Calling this method does not update statistics. 126 | """ 127 | if self.norm_reward: 128 | reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), 129 | -self.clip_reward, self.clip_reward) 130 | return reward 131 | 132 | def get_original_obs(self) -> np.ndarray: 133 | """ 134 | Returns an unnormalized version of the observations from the most recent 135 | step or reset. 136 | """ 137 | return self.old_obs.copy() 138 | 139 | def get_original_reward(self) -> np.ndarray: 140 | """ 141 | Returns an unnormalized version of the rewards from the most recent step. 142 | """ 143 | return self.old_rews.copy() 144 | 145 | def reset(self): 146 | """ 147 | Reset all environments 148 | """ 149 | obs = self.venv.reset() 150 | self.old_obs = obs 151 | self.ret = np.zeros(self.num_envs) 152 | if self.training: 153 | self._update_reward(self.ret) 154 | return self.normalize_obs(obs) 155 | 156 | @staticmethod 157 | def load(load_path, venv): 158 | """ 159 | Loads a saved VecNormalize object. 160 | 161 | :param load_path: the path to load from. 162 | :param venv: the VecEnv to wrap. 163 | :return: (VecNormalize) 164 | """ 165 | with open(load_path, "rb") as file_handler: 166 | vec_normalize = pickle.load(file_handler) 167 | vec_normalize.set_venv(venv) 168 | return vec_normalize 169 | 170 | def save(self, save_path): 171 | with open(save_path, "wb") as file_handler: 172 | pickle.dump(self, file_handler) 173 | 174 | def save_running_average(self, path): 175 | """ 176 | :param path: (str) path to log dir 177 | 178 | .. deprecated:: 2.9.0 179 | This function will be removed in a future version 180 | """ 181 | warnings.warn("Usage of `save_running_average` is deprecated. Please " 182 | "use `save` or pickle instead.", DeprecationWarning) 183 | for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']): 184 | with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: 185 | pickle.dump(rms, file_handler) 186 | 187 | def load_running_average(self, path): 188 | """ 189 | :param path: (str) path to log dir 190 | 191 | .. deprecated:: 2.9.0 192 | This function will be removed in a future version 193 | """ 194 | warnings.warn("Usage of `load_running_average` is deprecated. Please " 195 | "use `load` or pickle instead.", DeprecationWarning) 196 | for name in ['obs_rms', 'ret_rms']: 197 | with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: 198 | setattr(self, name, pickle.load(file_handler)) 199 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.wrappers.monitoring import video_recorder 4 | 5 | from stable_baselines.common import logger 6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 7 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack 10 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize 11 | 12 | 13 | class VecVideoRecorder(VecEnvWrapper): 14 | """ 15 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 16 | It requires ffmpeg or avconv to be installed on the machine. 17 | 18 | :param venv: (VecEnv or VecEnvWrapper) 19 | :param video_folder: (str) Where to save videos 20 | :param record_video_trigger: (func) Function that defines when to start recording. 21 | The function takes the current number of step, 22 | and returns whether we should start recording or not. 23 | :param video_length: (int) Length of recorded videos 24 | :param name_prefix: (str) Prefix to the video name 25 | """ 26 | 27 | def __init__(self, venv, video_folder, record_video_trigger, 28 | video_length=200, name_prefix='rl-video'): 29 | 30 | VecEnvWrapper.__init__(self, venv) 31 | 32 | self.env = venv 33 | # Temp variable to retrieve metadata 34 | temp_env = venv 35 | 36 | # Unwrap to retrieve metadata dict 37 | # that will be used by gym recorder 38 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack): 39 | temp_env = temp_env.venv 40 | 41 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 42 | metadata = temp_env.get_attr('metadata')[0] 43 | else: 44 | metadata = temp_env.metadata 45 | 46 | self.env.metadata = metadata 47 | 48 | self.record_video_trigger = record_video_trigger 49 | self.video_recorder = None 50 | 51 | self.video_folder = os.path.abspath(video_folder) 52 | # Create output folder if needed 53 | os.makedirs(self.video_folder, exist_ok=True) 54 | 55 | self.name_prefix = name_prefix 56 | self.step_id = 0 57 | self.video_length = video_length 58 | 59 | self.recording = False 60 | self.recorded_frames = 0 61 | 62 | def reset(self): 63 | obs = self.venv.reset() 64 | self.start_video_recorder() 65 | return obs 66 | 67 | def start_video_recorder(self): 68 | self.close_video_recorder() 69 | 70 | video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id, 71 | self.step_id + self.video_length) 72 | base_path = os.path.join(self.video_folder, video_name) 73 | self.video_recorder = video_recorder.VideoRecorder( 74 | env=self.env, 75 | base_path=base_path, 76 | metadata={'step_id': self.step_id} 77 | ) 78 | 79 | self.video_recorder.capture_frame() 80 | self.recorded_frames = 1 81 | self.recording = True 82 | 83 | def _video_enabled(self): 84 | return self.record_video_trigger(self.step_id) 85 | 86 | def step_wait(self): 87 | obs, rews, dones, infos = self.venv.step_wait() 88 | 89 | self.step_id += 1 90 | if self.recording: 91 | self.video_recorder.capture_frame() 92 | self.recorded_frames += 1 93 | if self.recorded_frames > self.video_length: 94 | logger.info("Saving video to ", self.video_recorder.path) 95 | self.close_video_recorder() 96 | elif self._video_enabled(): 97 | self.start_video_recorder() 98 | 99 | return obs, rews, dones, infos 100 | 101 | def close_video_recorder(self): 102 | if self.recording: 103 | self.video_recorder.close() 104 | self.recording = False 105 | self.recorded_frames = 1 106 | 107 | def close(self): 108 | VecEnvWrapper.close(self) 109 | self.close_video_recorder() 110 | 111 | def __del__(self): 112 | self.close() 113 | -------------------------------------------------------------------------------- /stable_baselines/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.ppo.ppo import PPO 2 | from stable_baselines.ppo.policies import MlpPolicy 3 | -------------------------------------------------------------------------------- /stable_baselines/ppo/policies.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import tensorflow as tf 4 | import tensorflow.keras.layers as layers 5 | from tensorflow.keras.models import Sequential 6 | import numpy as np 7 | 8 | from stable_baselines.common.policies import BasePolicy, register_policy, MlpExtractor 9 | from stable_baselines.common.distributions import make_proba_distribution,\ 10 | DiagGaussianDistribution, CategoricalDistribution 11 | 12 | 13 | class PPOPolicy(BasePolicy): 14 | """ 15 | Policy class (with both actor and critic) for A2C and derivates (PPO). 16 | 17 | :param observation_space: (gym.spaces.Space) Observation space 18 | :param action_space: (gym.spaces.Space) Action space 19 | :param learning_rate: (callable) Learning rate schedule (could be constant) 20 | :param net_arch: ([int or dict]) The specification of the policy and value networks. 21 | :param activation_fn: (nn.Module) Activation function 22 | :param adam_epsilon: (float) Small values to avoid NaN in ADAM optimizer 23 | :param ortho_init: (bool) Whether to use or not orthogonal initialization 24 | :param log_std_init: (float) Initial value for the log standard deviation 25 | """ 26 | def __init__(self, observation_space, action_space, 27 | learning_rate, net_arch=None, 28 | activation_fn=tf.nn.tanh, adam_epsilon=1e-5, 29 | ortho_init=True, log_std_init=0.0): 30 | super(PPOPolicy, self).__init__(observation_space, action_space) 31 | self.obs_dim = self.observation_space.shape[0] 32 | 33 | # Default network architecture, from stable-baselines 34 | if net_arch is None: 35 | net_arch = [dict(pi=[64, 64], vf=[64, 64])] 36 | 37 | self.net_arch = net_arch 38 | self.activation_fn = activation_fn 39 | self.adam_epsilon = adam_epsilon 40 | self.ortho_init = ortho_init 41 | self.net_args = { 42 | 'input_dim': self.obs_dim, 43 | 'output_dim': -1, 44 | 'net_arch': self.net_arch, 45 | 'activation_fn': self.activation_fn 46 | } 47 | self.shared_net = None 48 | self.pi_net, self.vf_net = None, None 49 | # In the future, feature_extractor will be replaced with a CNN 50 | self.features_extractor = Sequential(layers.Flatten(input_shape=(self.obs_dim,), dtype=tf.float32)) 51 | self.features_dim = self.obs_dim 52 | self.log_std_init = log_std_init 53 | dist_kwargs = None 54 | 55 | # Action distribution 56 | self.action_dist = make_proba_distribution(action_space, dist_kwargs=dist_kwargs) 57 | 58 | self._build(learning_rate) 59 | 60 | def _build(self, learning_rate): 61 | self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, 62 | activation_fn=self.activation_fn) 63 | 64 | latent_dim_pi = self.mlp_extractor.latent_dim_pi 65 | 66 | if isinstance(self.action_dist, DiagGaussianDistribution): 67 | self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi, 68 | log_std_init=self.log_std_init) 69 | elif isinstance(self.action_dist, CategoricalDistribution): 70 | self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) 71 | 72 | self.value_net = Sequential(layers.Dense(1, input_shape=(self.mlp_extractor.latent_dim_vf,))) 73 | 74 | self.features_extractor.build() 75 | self.action_net.build() 76 | self.value_net.build() 77 | # Init weights: use orthogonal initialization 78 | # with small initial weight for the output 79 | if self.ortho_init: 80 | pass 81 | # for module in [self.mlp_extractor, self.action_net, self.value_net]: 82 | # # Values from stable-baselines, TODO: check why 83 | # gain = { 84 | # self.mlp_extractor: np.sqrt(2), 85 | # self.action_net: 0.01, 86 | # self.value_net: 1 87 | # }[module] 88 | # module.apply(partial(self.init_weights, gain=gain)) 89 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate(1), epsilon=self.adam_epsilon) 90 | 91 | @tf.function 92 | def call(self, obs, deterministic=False): 93 | latent_pi, latent_vf = self._get_latent(obs) 94 | value = self.value_net(latent_vf) 95 | action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) 96 | log_prob = action_distribution.log_prob(action) 97 | return action, value, log_prob 98 | 99 | def _get_latent(self, obs): 100 | features = self.features_extractor(obs) 101 | latent_pi, latent_vf = self.mlp_extractor(features) 102 | return latent_pi, latent_vf 103 | 104 | def _get_action_dist_from_latent(self, latent_pi, deterministic=False): 105 | mean_actions = self.action_net(latent_pi) 106 | 107 | if isinstance(self.action_dist, DiagGaussianDistribution): 108 | return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic) 109 | 110 | elif isinstance(self.action_dist, CategoricalDistribution): 111 | # Here mean_actions are the logits before the softmax 112 | return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic) 113 | 114 | def actor_forward(self, obs, deterministic=False): 115 | latent_pi, _ = self._get_latent(obs) 116 | action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) 117 | return tf.stop_gradient(action).numpy() 118 | 119 | @tf.function 120 | def evaluate_actions(self, obs, action, deterministic=False): 121 | """ 122 | Evaluate actions according to the current policy, 123 | given the observations. 124 | 125 | :param obs: (th.Tensor) 126 | :param action: (th.Tensor) 127 | :param deterministic: (bool) 128 | :return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions 129 | and entropy of the action distribution. 130 | """ 131 | latent_pi, latent_vf = self._get_latent(obs) 132 | _, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic) 133 | log_prob = action_distribution.log_prob(action) 134 | value = self.value_net(latent_vf) 135 | return value, log_prob, action_distribution.entropy() 136 | 137 | def value_forward(self, obs): 138 | _, latent_vf, _ = self._get_latent(obs) 139 | return self.value_net(latent_vf) 140 | 141 | 142 | MlpPolicy = PPOPolicy 143 | 144 | register_policy("MlpPolicy", MlpPolicy) 145 | -------------------------------------------------------------------------------- /stable_baselines/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import gym 5 | from gym import spaces 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from stable_baselines.common.base_class import BaseRLModel 10 | from stable_baselines.common.buffers import RolloutBuffer 11 | from stable_baselines.common.utils import explained_variance, get_schedule_fn 12 | from stable_baselines.common import logger 13 | from stable_baselines.ppo.policies import PPOPolicy 14 | 15 | 16 | class PPO(BaseRLModel): 17 | """ 18 | Proximal Policy Optimization algorithm (PPO) (clip version) 19 | 20 | Paper: https://arxiv.org/abs/1707.06347 21 | Code: This implementation borrows code from OpenAI spinningup (https://github.com/openai/spinningup/) 22 | https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and 23 | and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) 24 | 25 | Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html 26 | 27 | :param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) 28 | :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) 29 | :param learning_rate: (float or callable) The learning rate, it can be a function 30 | of the current progress (from 1 to 0) 31 | :param n_steps: (int) The number of steps to run for each environment per update 32 | (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) 33 | :param batch_size: (int) Minibatch size 34 | :param n_epochs: (int) Number of epoch when optimizing the surrogate loss 35 | :param gamma: (float) Discount factor 36 | :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator 37 | :param clip_range: (float or callable) Clipping parameter, it can be a function of the current progress 38 | (from 1 to 0). 39 | :param clip_range_vf: (float or callable) Clipping parameter for the value function, 40 | it can be a function of the current progress (from 1 to 0). 41 | This is a parameter specific to the OpenAI implementation. If None is passed (default), 42 | no clipping will be done on the value function. 43 | IMPORTANT: this clipping depends on the reward scaling. 44 | :param ent_coef: (float) Entropy coefficient for the loss calculation 45 | :param vf_coef: (float) Value function coefficient for the loss calculation 46 | :param max_grad_norm: (float) The maximum value for the gradient clipping 47 | :param target_kl: (float) Limit the KL divergence between updates, 48 | because the clipping is not enough to prevent large update 49 | see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) 50 | By default, there is no limit on the kl div. 51 | :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) 52 | :param create_eval_env: (bool) Whether to create a second environment that will be 53 | used for evaluating the agent periodically. (Only available when passing string for the environment) 54 | :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation 55 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug 56 | :param seed: (int) Seed for the pseudo random generators 57 | :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance 58 | """ 59 | def __init__(self, policy, env, learning_rate=3e-4, 60 | n_steps=2048, batch_size=64, n_epochs=10, 61 | gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None, 62 | ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, 63 | target_kl=None, tensorboard_log=None, create_eval_env=False, 64 | policy_kwargs=None, verbose=0, seed=0, 65 | _init_setup_model=True): 66 | 67 | super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs, 68 | verbose=verbose, create_eval_env=create_eval_env, support_multi_env=True, seed=seed) 69 | 70 | self.learning_rate = learning_rate 71 | self.batch_size = batch_size 72 | self.n_epochs = n_epochs 73 | self.n_steps = n_steps 74 | self.gamma = gamma 75 | self.gae_lambda = gae_lambda 76 | self.clip_range = clip_range 77 | self.clip_range_vf = clip_range_vf 78 | self.ent_coef = ent_coef 79 | self.vf_coef = vf_coef 80 | self.max_grad_norm = max_grad_norm 81 | self.rollout_buffer = None 82 | self.target_kl = target_kl 83 | self.tensorboard_log = tensorboard_log 84 | self.tb_writer = None 85 | 86 | if _init_setup_model: 87 | self._setup_model() 88 | 89 | def _setup_model(self): 90 | self._setup_learning_rate() 91 | # TODO: preprocessing: one hot vector for obs discrete 92 | state_dim = self.observation_space.shape[0] 93 | if isinstance(self.action_space, spaces.Box): 94 | # Action is a 1D vector 95 | action_dim = self.action_space.shape[0] 96 | elif isinstance(self.action_space, spaces.Discrete): 97 | # Action is a scalar 98 | action_dim = 1 99 | 100 | # TODO: different seed for each env when n_envs > 1 101 | if self.n_envs == 1: 102 | self.set_random_seed(self.seed) 103 | 104 | self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, 105 | gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) 106 | self.policy = self.policy_class(self.observation_space, self.action_space, 107 | self.learning_rate, **self.policy_kwargs) 108 | 109 | self.clip_range = get_schedule_fn(self.clip_range) 110 | if self.clip_range_vf is not None: 111 | self.clip_range_vf = get_schedule_fn(self.clip_range_vf) 112 | 113 | def predict(self, observation, state=None, mask=None, deterministic=False): 114 | """ 115 | Get the model's action from an observation 116 | 117 | :param observation: (np.ndarray) the input observation 118 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies) 119 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) 120 | :param deterministic: (bool) Whether or not to return deterministic actions. 121 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) 122 | """ 123 | clipped_actions = self.policy.actor_forward(np.array(observation).reshape(1, -1), deterministic=deterministic) 124 | if isinstance(self.action_space, gym.spaces.Box): 125 | clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high) 126 | return clipped_actions 127 | 128 | def collect_rollouts(self, env, rollout_buffer, n_rollout_steps=256, callback=None, 129 | obs=None): 130 | 131 | n_steps = 0 132 | rollout_buffer.reset() 133 | 134 | while n_steps < n_rollout_steps: 135 | actions, values, log_probs = self.policy.call(obs) 136 | actions = actions.numpy() 137 | 138 | # Rescale and perform action 139 | clipped_actions = actions 140 | # Clip the actions to avoid out of bound error 141 | if isinstance(self.action_space, gym.spaces.Box): 142 | clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) 143 | new_obs, rewards, dones, infos = env.step(clipped_actions) 144 | 145 | self._update_info_buffer(infos) 146 | n_steps += 1 147 | if isinstance(self.action_space, gym.spaces.Discrete): 148 | # Reshape in case of discrete action 149 | actions = actions.reshape(-1, 1) 150 | rollout_buffer.add(obs, actions, rewards, dones, values, log_probs) 151 | obs = new_obs 152 | 153 | rollout_buffer.compute_returns_and_advantage(values, dones=dones) 154 | 155 | return obs 156 | 157 | @tf.function 158 | def policy_loss(self, advantage, log_prob, old_log_prob, clip_range): 159 | # Normalize advantage 160 | advantage = (advantage - tf.reduce_mean(advantage)) / (tf.math.reduce_std(advantage) + 1e-8) 161 | 162 | # ratio between old and new policy, should be one at the first iteration 163 | ratio = tf.exp(log_prob - old_log_prob) 164 | # clipped surrogate loss 165 | policy_loss_1 = advantage * ratio 166 | policy_loss_2 = advantage * tf.clip_by_value(ratio, 1 - clip_range, 1 + clip_range) 167 | return - tf.reduce_mean(tf.minimum(policy_loss_1, policy_loss_2)) 168 | 169 | @tf.function 170 | def value_loss(self, values, old_values, return_batch, clip_range_vf): 171 | if clip_range_vf is None: 172 | # No clipping 173 | values_pred = values 174 | else: 175 | # Clip the different between old and new value 176 | # NOTE: this depends on the reward scaling 177 | values_pred = old_values + tf.clip_by_value(values - old_values, -clip_range_vf, clip_range_vf) 178 | # Value loss using the TD(gae_lambda) target 179 | return tf.keras.losses.MSE(return_batch, values_pred) 180 | 181 | def train(self, gradient_steps, batch_size=64): 182 | # Update optimizer learning rate 183 | # self._update_learning_rate(self.policy.optimizer) 184 | 185 | # Compute current clip range 186 | clip_range = self.clip_range(self._current_progress) 187 | if self.clip_range_vf is not None: 188 | clip_range_vf = self.clip_range_vf(self._current_progress) 189 | else: 190 | clip_range_vf = None 191 | 192 | for gradient_step in range(gradient_steps): 193 | approx_kl_divs = [] 194 | # Sample replay buffer 195 | for replay_data in self.rollout_buffer.get(batch_size): 196 | # Unpack 197 | obs, action, old_values, old_log_prob, advantage, return_batch = replay_data 198 | 199 | if isinstance(self.action_space, spaces.Discrete): 200 | # Convert discrete action for float to long 201 | action = action.astype(np.int64).flatten() 202 | 203 | with tf.GradientTape() as tape: 204 | tape.watch(self.policy.trainable_variables) 205 | values, log_prob, entropy = self.policy.evaluate_actions(obs, action) 206 | # Flatten 207 | values = tf.reshape(values, [-1]) 208 | 209 | policy_loss = self.policy_loss(advantage, log_prob, old_log_prob, clip_range) 210 | value_loss = self.value_loss(values, old_values, return_batch, clip_range_vf) 211 | 212 | # Entropy loss favor exploration 213 | entropy_loss = -tf.reduce_mean(entropy) 214 | 215 | loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss 216 | 217 | # Optimization step 218 | gradients = tape.gradient(loss, self.policy.trainable_variables) 219 | # Clip grad norm 220 | # gradients = tf.clip_by_norm(gradients, self.max_grad_norm) 221 | self.policy.optimizer.apply_gradients(zip(gradients, self.policy.trainable_variables)) 222 | approx_kl_divs.append(tf.reduce_mean(old_log_prob - log_prob).numpy()) 223 | 224 | if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl: 225 | print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step, 226 | np.mean(approx_kl_divs))) 227 | break 228 | 229 | explained_var = explained_variance(self.rollout_buffer.returns.flatten(), 230 | self.rollout_buffer.values.flatten()) 231 | 232 | logger.logkv("clip_range", clip_range) 233 | if self.clip_range_vf is not None: 234 | logger.logkv("clip_range_vf", clip_range_vf) 235 | 236 | logger.logkv("explained_variance", explained_var) 237 | # TODO: gather stats for the entropy and other losses? 238 | logger.logkv("entropy", entropy.numpy().mean()) 239 | logger.logkv("policy_loss", policy_loss.numpy()) 240 | logger.logkv("value_loss", value_loss.numpy()) 241 | if hasattr(self.policy, 'log_std'): 242 | logger.logkv("std", tf.exp(self.policy.log_std).numpy().mean()) 243 | 244 | def learn(self, total_timesteps, callback=None, log_interval=1, 245 | eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True): 246 | 247 | timesteps_since_eval, iteration, evaluations, obs, eval_env = self._setup_learn(eval_env) 248 | 249 | if self.tensorboard_log is not None: 250 | self.tb_writer = tf.summary.create_file_writer(os.path.join(self.tensorboard_log, tb_log_name)) 251 | 252 | while self.num_timesteps < total_timesteps: 253 | 254 | if callback is not None: 255 | # Only stop training if return value is False, not when it is None. 256 | if callback(locals(), globals()) is False: 257 | break 258 | 259 | obs = self.collect_rollouts(self.env, self.rollout_buffer, n_rollout_steps=self.n_steps, 260 | obs=obs) 261 | iteration += 1 262 | self.num_timesteps += self.n_steps * self.n_envs 263 | timesteps_since_eval += self.n_steps * self.n_envs 264 | self._update_current_progress(self.num_timesteps, total_timesteps) 265 | 266 | # Display training infos 267 | if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0: 268 | fps = int(self.num_timesteps / (time.time() - self.start_time)) 269 | logger.logkv("iterations", iteration) 270 | if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: 271 | logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer])) 272 | logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer])) 273 | logger.logkv("fps", fps) 274 | logger.logkv('time_elapsed', int(time.time() - self.start_time)) 275 | logger.logkv("total timesteps", self.num_timesteps) 276 | logger.dumpkvs() 277 | 278 | self.train(self.n_epochs, batch_size=self.batch_size) 279 | 280 | # Evaluate the agent 281 | timesteps_since_eval = self._eval_policy(eval_freq, eval_env, n_eval_episodes, 282 | timesteps_since_eval, deterministic=True) 283 | # For tensorboard integration 284 | # if self.tb_writer is not None: 285 | # with self.tb_writer.as_default(): 286 | # tf.summary.scalar('Eval/reward', mean_reward, self.num_timesteps) 287 | 288 | 289 | return self 290 | -------------------------------------------------------------------------------- /stable_baselines/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/stable_baselines/py.typed -------------------------------------------------------------------------------- /stable_baselines/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.td3.td3 import TD3 2 | from stable_baselines.td3.policies import MlpPolicy 3 | -------------------------------------------------------------------------------- /stable_baselines/td3/policies.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Sequential 3 | 4 | from stable_baselines.common.policies import BasePolicy, register_policy, create_mlp 5 | 6 | 7 | class Actor(BasePolicy): 8 | """ 9 | Actor network (policy) for TD3. 10 | 11 | :param obs_dim: (int) Dimension of the observation 12 | :param action_dim: (int) Dimension of the action space 13 | :param net_arch: ([int]) Network architecture 14 | :param activation_fn: (str or tf.activation) Activation function 15 | """ 16 | def __init__(self, obs_dim, action_dim, net_arch, activation_fn=tf.nn.relu): 17 | super(Actor, self).__init__(None, None) 18 | 19 | actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True) 20 | self.mu = Sequential(actor_net) 21 | self.mu.build() 22 | 23 | @tf.function 24 | def call(self, obs): 25 | return self.mu(obs) 26 | 27 | 28 | class Critic(BasePolicy): 29 | """ 30 | Critic network for TD3, 31 | in fact it represents the action-state value function (Q-value function) 32 | 33 | :param obs_dim: (int) Dimension of the observation 34 | :param action_dim: (int) Dimension of the action space 35 | :param net_arch: ([int]) Network architecture 36 | :param activation_fn: (nn.Module) Activation function 37 | """ 38 | def __init__(self, obs_dim, action_dim, 39 | net_arch, activation_fn=tf.nn.relu): 40 | super(Critic, self).__init__(None, None) 41 | 42 | q1_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn) 43 | self.q1_net = Sequential(q1_net) 44 | 45 | q2_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn) 46 | self.q2_net = Sequential(q2_net) 47 | 48 | self.q_networks = [self.q1_net, self.q2_net] 49 | 50 | for q_net in self.q_networks: 51 | q_net.build() 52 | 53 | @tf.function 54 | def call(self, obs, action): 55 | qvalue_input = tf.concat([obs, action], axis=1) 56 | return [q_net(qvalue_input) for q_net in self.q_networks] 57 | 58 | @tf.function 59 | def q1_forward(self, obs, action): 60 | return self.q_networks[0](tf.concat([obs, action], axis=1)) 61 | 62 | 63 | class TD3Policy(BasePolicy): 64 | """ 65 | Policy class (with both actor and critic) for TD3. 66 | 67 | :param observation_space: (gym.spaces.Space) Observation space 68 | :param action_space: (gym.spaces.Space) Action space 69 | :param learning_rate: (callable) Learning rate schedule (could be constant) 70 | :param net_arch: ([int or dict]) The specification of the policy and value networks. 71 | :param activation_fn: (str or tf.nn.activation) Activation function 72 | """ 73 | def __init__(self, observation_space, action_space, 74 | learning_rate, net_arch=None, 75 | activation_fn=tf.nn.relu): 76 | super(TD3Policy, self).__init__(observation_space, action_space) 77 | 78 | # Default network architecture, from the original paper 79 | if net_arch is None: 80 | net_arch = [400, 300] 81 | 82 | self.obs_dim = self.observation_space.shape[0] 83 | self.action_dim = self.action_space.shape[0] 84 | self.net_arch = net_arch 85 | self.activation_fn = activation_fn 86 | self.net_args = { 87 | 'obs_dim': self.obs_dim, 88 | 'action_dim': self.action_dim, 89 | 'net_arch': self.net_arch, 90 | 'activation_fn': self.activation_fn 91 | } 92 | self.actor_kwargs = self.net_args.copy() 93 | 94 | self.actor, self.actor_target = None, None 95 | self.critic, self.critic_target = None, None 96 | self._build(learning_rate) 97 | 98 | def _build(self, learning_rate): 99 | self.actor = self.make_actor() 100 | self.actor_target = self.make_actor() 101 | self.actor_target.hard_update(self.actor) 102 | self.actor.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate(1)) 103 | 104 | self.critic = self.make_critic() 105 | self.critic_target = self.make_critic() 106 | self.critic_target.hard_update(self.critic) 107 | self.critic.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate(1)) 108 | 109 | def make_actor(self): 110 | return Actor(**self.actor_kwargs) 111 | 112 | def make_critic(self): 113 | return Critic(**self.net_args) 114 | 115 | @tf.function 116 | def call(self, obs): 117 | return self.actor(obs) 118 | 119 | 120 | MlpPolicy = TD3Policy 121 | 122 | register_policy("MlpPolicy", MlpPolicy) 123 | -------------------------------------------------------------------------------- /stable_baselines/td3/td3.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | from stable_baselines.common.base_class import BaseRLModel 7 | from stable_baselines.common.buffers import ReplayBuffer 8 | from stable_baselines.common.vec_env import VecEnv 9 | from stable_baselines.common import logger 10 | from stable_baselines.td3.policies import TD3Policy 11 | 12 | 13 | class TD3(BaseRLModel): 14 | """ 15 | Twin Delayed DDPG (TD3) 16 | Addressing Function Approximation Error in Actor-Critic Methods. 17 | 18 | Original implementation: https://github.com/sfujim/TD3 19 | Paper: https://arxiv.org/abs/1802.09477 20 | Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html 21 | 22 | :param policy: (TD3Policy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) 23 | :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) 24 | :param buffer_size: (int) size of the replay buffer 25 | :param learning_rate: (float or callable) learning rate for adam optimizer, 26 | the same learning rate will be used for all networks (Q-Values and Actor networks) 27 | it can be a function of the current progress (from 1 to 0) 28 | :param policy_delay: (int) Policy and target networks will only be updated once every policy_delay steps 29 | per training steps. The Q values will be updated policy_delay more often (update every training step). 30 | :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts 31 | :param gamma: (float) the discount factor 32 | :param batch_size: (int) Minibatch size for each gradient update 33 | :param train_freq: (int) Update the model every `train_freq` steps. 34 | :param gradient_steps: (int) How many gradient update after each step 35 | :param tau: (float) the soft update coefficient ("polyak update" of the target networks, between 0 and 1) 36 | :param action_noise: (ActionNoise) the action noise type. Cf common.noise for the different action noise type. 37 | :param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy 38 | (smoothing noise) 39 | :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise. 40 | :param create_eval_env: (bool) Whether to create a second environment that will be 41 | used for evaluating the agent periodically. (Only available when passing string for the environment) 42 | :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation 43 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug 44 | :param seed: (int) Seed for the pseudo random generators 45 | :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance 46 | """ 47 | 48 | def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3, 49 | policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100, 50 | train_freq=-1, gradient_steps=-1, n_episodes_rollout=1, 51 | tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5, 52 | tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, 53 | seed=None, _init_setup_model=True): 54 | 55 | super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, 56 | create_eval_env=create_eval_env, seed=seed) 57 | 58 | self.buffer_size = buffer_size 59 | self.learning_rate = learning_rate 60 | self.learning_starts = learning_starts 61 | self.train_freq = train_freq 62 | self.gradient_steps = gradient_steps 63 | self.n_episodes_rollout = n_episodes_rollout 64 | self.batch_size = batch_size 65 | self.tau = tau 66 | self.gamma = gamma 67 | self.action_noise = action_noise 68 | self.policy_delay = policy_delay 69 | self.target_noise_clip = target_noise_clip 70 | self.target_policy_noise = target_policy_noise 71 | 72 | if _init_setup_model: 73 | self._setup_model() 74 | 75 | def _setup_model(self): 76 | self._setup_learning_rate() 77 | obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] 78 | self.set_random_seed(self.seed) 79 | self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim) 80 | self.policy = self.policy_class(self.observation_space, self.action_space, 81 | self.learning_rate, **self.policy_kwargs) 82 | self._create_aliases() 83 | 84 | def _create_aliases(self): 85 | self.actor = self.policy.actor 86 | self.actor_target = self.policy.actor_target 87 | self.critic = self.policy.critic 88 | self.critic_target = self.policy.critic_target 89 | 90 | def predict(self, observation, state=None, mask=None, deterministic=True): 91 | """ 92 | Get the model's action from an observation 93 | 94 | :param observation: (np.ndarray) the input observation 95 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies) 96 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) 97 | :param deterministic: (bool) Whether or not to return deterministic actions. 98 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) 99 | """ 100 | return self.unscale_action(self.actor(np.array(observation).reshape(1, -1)).numpy()) 101 | 102 | @tf.function 103 | def critic_loss(self, obs, action, next_obs, done, reward): 104 | # Select action according to policy and add clipped noise 105 | noise = tf.random.normal(shape=action.shape) * self.target_policy_noise 106 | noise = tf.clip_by_value(noise, -self.target_noise_clip, self.target_noise_clip) 107 | next_action = tf.clip_by_value(self.actor_target(next_obs) + noise, -1., 1.) 108 | 109 | # Compute the target Q value 110 | target_q1, target_q2 = self.critic_target(next_obs, next_action) 111 | target_q = tf.minimum(target_q1, target_q2) 112 | target_q = reward + tf.stop_gradient((1 - done) * self.gamma * target_q) 113 | 114 | # Get current Q estimates 115 | current_q1, current_q2 = self.critic(obs, action) 116 | 117 | # Compute critic loss 118 | return tf.keras.losses.MSE(current_q1, target_q) + tf.keras.losses.MSE(current_q2, target_q) 119 | 120 | @tf.function 121 | def actor_loss(self, obs): 122 | return - tf.reduce_mean(self.critic.q1_forward(obs, self.actor(obs))) 123 | 124 | @tf.function 125 | def update_targets(self): 126 | self.critic_target.soft_update(self.critic, self.tau) 127 | self.actor_target.soft_update(self.actor, self.tau) 128 | 129 | @tf.function 130 | def _train_critic(self, obs, action, next_obs, done, reward): 131 | with tf.GradientTape() as critic_tape: 132 | critic_tape.watch(self.critic.trainable_variables) 133 | critic_loss = self.critic_loss(obs, action, next_obs, done, reward) 134 | 135 | # Optimize the critic 136 | grads_critic = critic_tape.gradient(critic_loss, self.critic.trainable_variables) 137 | self.critic.optimizer.apply_gradients(zip(grads_critic, self.critic.trainable_variables)) 138 | 139 | @tf.function 140 | def _train_actor(self, obs): 141 | with tf.GradientTape() as actor_tape: 142 | actor_tape.watch(self.actor.trainable_variables) 143 | # Compute actor loss 144 | actor_loss = self.actor_loss(obs) 145 | 146 | # Optimize the actor 147 | grads_actor = actor_tape.gradient(actor_loss, self.actor.trainable_variables) 148 | self.actor.optimizer.apply_gradients(zip(grads_actor, self.actor.trainable_variables)) 149 | 150 | def train(self, gradient_steps, batch_size=100, policy_delay=2): 151 | # self._update_learning_rate() 152 | 153 | for gradient_step in range(gradient_steps): 154 | 155 | # Sample replay buffer 156 | obs, action, next_obs, done, reward = self.replay_buffer.sample(batch_size) 157 | 158 | self._train_critic(obs, action, next_obs, done, reward) 159 | 160 | # Delayed policy updates 161 | if gradient_step % policy_delay == 0: 162 | 163 | self._train_actor(obs) 164 | 165 | # Update the frozen target models 166 | self.update_targets() 167 | 168 | 169 | def learn(self, total_timesteps, callback=None, log_interval=4, 170 | eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True): 171 | 172 | timesteps_since_eval, episode_num, evaluations, obs, eval_env = self._setup_learn(eval_env) 173 | 174 | while self.num_timesteps < total_timesteps: 175 | 176 | if callback is not None: 177 | # Only stop training if return value is False, not when it is None. 178 | if callback(locals(), globals()) is False: 179 | break 180 | 181 | rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout, 182 | n_steps=self.train_freq, action_noise=self.action_noise, 183 | deterministic=False, callback=None, 184 | learning_starts=self.learning_starts, 185 | num_timesteps=self.num_timesteps, 186 | replay_buffer=self.replay_buffer, 187 | obs=obs, episode_num=episode_num, 188 | log_interval=log_interval) 189 | # Unpack 190 | episode_reward, episode_timesteps, n_episodes, obs = rollout 191 | 192 | episode_num += n_episodes 193 | self.num_timesteps += episode_timesteps 194 | timesteps_since_eval += episode_timesteps 195 | self._update_current_progress(self.num_timesteps, total_timesteps) 196 | 197 | if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: 198 | 199 | gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps 200 | self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay) 201 | 202 | # Evaluate the agent 203 | timesteps_since_eval = self._eval_policy(eval_freq, eval_env, n_eval_episodes, 204 | timesteps_since_eval, deterministic=True) 205 | 206 | return self 207 | 208 | def collect_rollouts(self, env, n_episodes=1, n_steps=-1, action_noise=None, 209 | deterministic=False, callback=None, 210 | learning_starts=0, num_timesteps=0, 211 | replay_buffer=None, obs=None, 212 | episode_num=0, log_interval=None): 213 | """ 214 | Collect rollout using the current policy (and possibly fill the replay buffer) 215 | TODO: move this method to off-policy base class. 216 | 217 | :param env: (VecEnv) 218 | :param n_episodes: (int) 219 | :param n_steps: (int) 220 | :param action_noise: (ActionNoise) 221 | :param deterministic: (bool) 222 | :param callback: (callable) 223 | :param learning_starts: (int) 224 | :param num_timesteps: (int) 225 | :param replay_buffer: (ReplayBuffer) 226 | :param obs: (np.ndarray) 227 | :param episode_num: (int) 228 | :param log_interval: (int) 229 | """ 230 | episode_rewards = [] 231 | total_timesteps = [] 232 | total_steps, total_episodes = 0, 0 233 | assert isinstance(env, VecEnv) 234 | assert env.num_envs == 1 235 | 236 | while total_steps < n_steps or total_episodes < n_episodes: 237 | done = False 238 | # Reset environment: not needed for VecEnv 239 | # obs = env.reset() 240 | episode_reward, episode_timesteps = 0.0, 0 241 | 242 | while not done: 243 | 244 | # Select action randomly or according to policy 245 | if num_timesteps < learning_starts: 246 | # Warmup phase 247 | unscaled_action = np.array([self.action_space.sample()]) 248 | # Rescale the action from [low, high] to [-1, 1] 249 | scaled_action = self.scale_action(unscaled_action) 250 | else: 251 | scaled_action = self.policy.call(obs) 252 | 253 | # Add noise to the action (improve exploration) 254 | if action_noise is not None: 255 | scaled_action = np.clip(scaled_action + action_noise(), -1, 1) 256 | 257 | # Rescale and perform action 258 | new_obs, reward, done, infos = env.step(self.unscale_action(scaled_action)) 259 | 260 | done_bool = [float(done[0])] 261 | episode_reward += reward 262 | 263 | # Retrieve reward and episode length if using Monitor wrapper 264 | self._update_info_buffer(infos) 265 | 266 | # Store data in replay buffer 267 | if replay_buffer is not None: 268 | replay_buffer.add(obs, new_obs, scaled_action, reward, done_bool) 269 | 270 | obs = new_obs 271 | 272 | num_timesteps += 1 273 | episode_timesteps += 1 274 | total_steps += 1 275 | if 0 < n_steps <= total_steps: 276 | break 277 | 278 | if done: 279 | total_episodes += 1 280 | episode_rewards.append(episode_reward) 281 | total_timesteps.append(episode_timesteps) 282 | if action_noise is not None: 283 | action_noise.reset() 284 | 285 | # Display training infos 286 | if self.verbose >= 1 and log_interval is not None and ( 287 | episode_num + total_episodes) % log_interval == 0: 288 | fps = int(num_timesteps / (time.time() - self.start_time)) 289 | logger.logkv("episodes", episode_num + total_episodes) 290 | if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: 291 | logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer])) 292 | logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer])) 293 | # logger.logkv("n_updates", n_updates) 294 | logger.logkv("fps", fps) 295 | logger.logkv('time_elapsed', int(time.time() - self.start_time)) 296 | logger.logkv("total timesteps", num_timesteps) 297 | logger.dumpkvs() 298 | 299 | mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0 300 | 301 | return mean_reward, total_steps, total_episodes, obs 302 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_run.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import TD3, PPO 5 | from stable_baselines.common.noise import NormalActionNoise 6 | 7 | action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) 8 | 9 | 10 | def test_td3(): 11 | model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), seed=0, 12 | learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise) 13 | model.learn(total_timesteps=10000, eval_freq=5000) 14 | # model.save("test_save") 15 | # model.load("test_save") 16 | # os.remove("test_save.zip") 17 | 18 | @pytest.mark.parametrize("model_class", [PPO]) 19 | @pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0']) 20 | def test_onpolicy(model_class, env_id): 21 | model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) 22 | model.learn(total_timesteps=1000, eval_freq=500) 23 | # model.save("test_save") 24 | # model.load("test_save") 25 | # os.remove("test_save.zip") 26 | --------------------------------------------------------------------------------