├── .github
├── ISSUE_TEMPLATE
│ └── issue-template.md
└── PULL_REQUEST_TEMPLATE.md
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── conftest.py
├── docs
├── _static
│ ├── css
│ │ └── baselines_theme.css
│ └── img
│ │ └── logo.png
└── misc
│ └── changelog.rst
├── setup.cfg
├── setup.py
├── stable_baselines
├── __init__.py
├── common
│ ├── __init__.py
│ ├── base_class.py
│ ├── buffers.py
│ ├── distributions.py
│ ├── evaluation.py
│ ├── logger.py
│ ├── monitor.py
│ ├── noise.py
│ ├── policies.py
│ ├── running_mean_std.py
│ ├── save_util.py
│ ├── tile_images.py
│ ├── utils.py
│ └── vec_env
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── subproc_vec_env.py
│ │ ├── util.py
│ │ ├── vec_check_nan.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_normalize.py
│ │ └── vec_video_recorder.py
├── ppo
│ ├── __init__.py
│ ├── policies.py
│ └── ppo.py
├── py.typed
└── td3
│ ├── __init__.py
│ ├── policies.py
│ └── td3.py
└── tests
├── __init__.py
└── test_run.py
/.github/ISSUE_TEMPLATE/issue-template.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Issue Template
3 | about: How to create an issue for this repository
4 |
5 | ---
6 |
7 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
8 |
9 | If you have any questions, feel free to create an issue with the tag [question].
10 | If you wish to suggest an enhancement or feature request, add the tag [feature request].
11 | If you are submitting a bug report, please fill in the following details.
12 |
13 | If your issue is related to a custom gym environment, please check it first using:
14 |
15 | ```python
16 | from stable_baselines.common.env_checker import check_env
17 |
18 | env = CustomEnv(arg1, ...)
19 | # It will check your custom environment and output additional warnings if needed
20 | check_env(env)
21 | ```
22 |
23 | **Describe the bug**
24 | A clear and concise description of what the bug is.
25 |
26 | **Code example**
27 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
28 |
29 | Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
30 | for both code and stack traces.
31 |
32 | ```python
33 | from stable_baselines import ...
34 |
35 | ```
36 |
37 | ```bash
38 | Traceback (most recent call last): File ...
39 |
40 | ```
41 |
42 | **System Info**
43 | Describe the characteristic of your environment:
44 | * Describe how the library was installed (pip, docker, source, ...)
45 | * GPU models and configuration
46 | * Python version
47 | * Tensorflow version
48 | * Versions of any other relevant libraries
49 |
50 | **Additional context**
51 | Add any other context about the problem here.
52 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Description
4 |
5 |
6 | ## Motivation and Context
7 |
8 |
9 |
10 | - [ ] I have raised an issue to propose this change ([required](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) for new features and bug fixes)
11 |
12 | ## Types of changes
13 |
14 | - [ ] Bug fix (non-breaking change which fixes an issue)
15 | - [ ] New feature (non-breaking change which adds functionality)
16 | - [ ] Breaking change (fix or feature that would cause existing functionality to change)
17 | - [ ] Documentation (update in the documentation)
18 |
19 | ## Checklist:
20 |
21 |
22 | - [ ] I've read the [CONTRIBUTION](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) guide (**required**)
23 | - [ ] I have updated the [changelog](https://github.com/hill-a/stable-baselines/blob/master/docs/misc/changelog.rst) accordingly (**required**).
24 | - [ ] My change requires a change to the documentation.
25 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
26 | - [ ] I have updated the documentation accordingly.
27 | - [ ] I have ensured `pytest` and `pytype` both pass.
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | *.pkl
4 | *.py~
5 | *.bak
6 | .pytest_cache
7 | .pytype
8 | .DS_Store
9 | .idea
10 | .coverage
11 | .coverage.*
12 | __pycache__/
13 | _build/
14 | *.npz
15 | *.zip
16 | *.lprof
17 |
18 | # Setuptools distribution and build folders.
19 | /dist/
20 | /build
21 | keys/
22 |
23 | # Virtualenv
24 | /env
25 | /venv
26 |
27 | *.sublime-project
28 | *.sublime-workspace
29 |
30 | logs/
31 |
32 | .ipynb_checkpoints
33 | ghostdriver.log
34 |
35 | htmlcov
36 |
37 | junk
38 | src
39 |
40 | *.egg-info
41 | .cache
42 |
43 | MUJOCO_LOG.TXT
44 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing to Stable-Baselines
2 |
3 | If you are interested in contributing to Stable-Baselines, your contributions will fall
4 | into two categories:
5 | 1. You want to propose a new Feature and implement it
6 | - Create an issue about your intended feature, and we shall discuss the design and
7 | implementation. Once we agree that the plan looks good, go ahead and implement it.
8 | 2. You want to implement a feature or bug-fix for an outstanding issue
9 | - Look at the outstanding issues here: https://github.com/hill-a/stable-baselines/issues
10 | - Look at the roadmap here: https://github.com/hill-a/stable-baselines/projects/1
11 | - Pick an issue or feature and comment on the task that you want to work on this feature.
12 | - If you need more context on a particular issue, please ask and we shall provide.
13 |
14 | Once you finish implementing a feature or bug-fix, please send a Pull Request to
15 | https://github.com/hill-a/stable-baselines/
16 |
17 |
18 | If you are not familiar with creating a Pull Request, here are some guides:
19 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
20 | - https://help.github.com/articles/creating-a-pull-request/
21 |
22 |
23 | ## Developing Stable-Baselines
24 |
25 | To develop Stable-Baselines on your machine, here are some tips:
26 |
27 | 1. Clone a copy of Stable-Baselines from source:
28 |
29 | ```bash
30 | git clone https://github.com/hill-a/stable-baselines/
31 | cd stable-baselines
32 | ```
33 |
34 | 2. Install Stable-Baselines in develop mode, with support for building the docs and running tests:
35 |
36 | ```bash
37 | pip install -e .[docs,tests]
38 | ```
39 |
40 | ## Codestyle
41 |
42 | We follow the [PEP8 codestyle](https://www.python.org/dev/peps/pep-0008/). Please order the imports as follows:
43 |
44 | 1. built-in
45 | 2. packages
46 | 3. current module
47 |
48 | with one space between each, that gives for instance:
49 | ```python
50 | import os
51 | import warnings
52 |
53 | import numpy as np
54 |
55 | from stable_baselines import PPO2
56 | ```
57 |
58 | In general, we recommend using pycharm to format everything in an efficient way.
59 |
60 | Please documentation each function/method using the following template:
61 |
62 | ```python
63 |
64 | def my_function(arg1, arg2):
65 | """
66 | Short description of the function.
67 |
68 | :param arg1: (arg1 type) describe what is arg1
69 | :param arg2: (arg2 type) describe what is arg2
70 | :return: (return type) describe what is returned
71 | """
72 | ...
73 | return my_variable
74 | ```
75 |
76 | ## Pull Request (PR)
77 |
78 | Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
79 |
80 | Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a , @araffin or @erniejunior ).
81 | A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch.
82 |
83 | Note: in rare cases, we can create exception for codacy failure.
84 |
85 | ## Test
86 |
87 | All new features must add tests in the `tests/` folder ensuring that everything works fine.
88 | We use [pytest](https://pytest.org/).
89 | Also, when a bug fix is proposed, tests should be added to avoid regression.
90 |
91 | To run tests with `pytest` and type checking with `pytype`:
92 |
93 | ```
94 | ./scripts/run_tests.sh
95 | ```
96 |
97 | ## Changelog and Documentation
98 |
99 | Please do not forget to update the changelog and add documentation if needed.
100 | A README is present in the `docs/` folder for instructions on how to build the documentation.
101 |
102 |
103 | Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2017 OpenAI (http://openai.com)
4 | Copyright (c) 2018-2020 Stable-Baselines Team
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Stable Baselines TF2 [Experimental]
4 |
5 | [Stable-Baselines TF1 version](https://github.com/hill-a/stable-baselines/).
6 |
7 |
8 | ## Citing the Project
9 |
10 | To cite this repository in publications:
11 |
12 | ```
13 | @misc{stable-baselines,
14 | author = {Hill, Ashley and Raffin, Antonin and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Traore, Rene and Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai},
15 | title = {Stable Baselines},
16 | year = {2018},
17 | publisher = {GitHub},
18 | journal = {GitHub repository},
19 | howpublished = {\url{https://github.com/hill-a/stable-baselines}},
20 | }
21 | ```
22 |
23 | ## Maintainers
24 |
25 | Stable-Baselines is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/erniejunior) (aka @erniejunior), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli).
26 |
27 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
28 |
29 |
30 | ## How To Contribute
31 |
32 | To any interested in making the baselines better, there is still some documentation that needs to be done.
33 | If you want to contribute, please read **CONTRIBUTING.md** guide first.
34 |
35 |
36 | ## Acknowledgments
37 |
38 | Stable Baselines was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).
39 |
40 | Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/)
41 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | """Configures pytest to ignore certain unit tests unless the appropriate flag is passed.
2 |
3 | --rungpu: tests that require GPU.
4 | --expensive: tests that take a long time to run (e.g. training an RL algorithm for many timesteps)."""
5 |
6 | import pytest
7 |
8 |
9 | def pytest_addoption(parser):
10 | parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests")
11 | parser.addoption("--expensive", action="store_true",
12 | help="run expensive tests (which are otherwise skipped).")
13 |
14 |
15 | def pytest_collection_modifyitems(config, items):
16 | flags = {'gpu': '--rungpu', 'expensive': '--expensive'}
17 | skips = {keyword: pytest.mark.skip(reason="need {} option to run".format(flag))
18 | for keyword, flag in flags.items() if not config.getoption(flag)}
19 | for item in items:
20 | for keyword, skip in skips.items():
21 | if keyword in item.keywords:
22 | item.add_marker(skip)
23 |
--------------------------------------------------------------------------------
/docs/_static/css/baselines_theme.css:
--------------------------------------------------------------------------------
1 | /* Main colors from https://color.adobe.com/fr/Copy-of-NOUEBO-Original-color-theme-11116609 */
2 | :root{
3 | --main-bg-color: #324D5C;
4 | --link-color: #14B278;
5 | }
6 |
7 | /* Header fonts y */
8 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
9 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
10 | }
11 |
12 |
13 | /* Docs background */
14 | .wy-side-nav-search{
15 | background-color: var(--main-bg-color);
16 | }
17 |
18 | /* Mobile version */
19 | .wy-nav-top{
20 | background-color: var(--main-bg-color);
21 | }
22 |
23 | /* Change link colors (except for the menu) */
24 | a {
25 | color: var(--link-color);
26 | }
27 |
28 | a:hover {
29 | color: #4F778F;
30 | }
31 |
32 | .wy-menu a {
33 | color: #b3b3b3;
34 | }
35 |
36 | .wy-menu a:hover {
37 | color: #b3b3b3;
38 | }
39 |
40 | a.icon.icon-home {
41 | color: #b3b3b3;
42 | }
43 |
44 | .version{
45 | color: var(--link-color) !important;
46 | }
47 |
48 |
49 | /* Make code blocks have a background */
50 | .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
51 | background: #f8f8f8;;
52 | }
53 |
--------------------------------------------------------------------------------
/docs/_static/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/docs/_static/img/logo.png
--------------------------------------------------------------------------------
/docs/misc/changelog.rst:
--------------------------------------------------------------------------------
1 | .. _changelog:
2 |
3 | Changelog
4 | ==========
5 |
6 | For download links, please look at `Github release page `_.
7 |
8 | Pre-Release 3.0.0a0 (WIP)
9 | --------------------------
10 |
11 | **TensorFlow 2 Version**
12 |
13 | Breaking Changes:
14 | ^^^^^^^^^^^^^^^^^
15 | - Drop support for tensorflow 1.x, TensorFlow >=2.1.0 is required
16 | - New dependency: tensorflow-probability>=0.8.0 is now required
17 | - Drop support for pretrain, in favor of https://github.com/HumanCompatibleAI/imitation
18 | - Drop support of GAIL, in favor of https://github.com/HumanCompatibleAI/imitation
19 | - Drop support of MPI
20 |
21 | New Features:
22 | ^^^^^^^^^^^^^
23 |
24 | Bug Fixes:
25 | ^^^^^^^^^^
26 |
27 | Deprecations:
28 | ^^^^^^^^^^^^^
29 |
30 | Others:
31 | ^^^^^^^
32 |
33 | Documentation:
34 | ^^^^^^^^^^^^^^
35 |
36 |
37 | Pre-Release 2.10.0a0 (WIP)
38 | --------------------------
39 |
40 | Breaking Changes:
41 | ^^^^^^^^^^^^^^^^^
42 |
43 | New Features:
44 | ^^^^^^^^^^^^^
45 | - Parallelized updating and sampling from the replay buffer in DQN. (@flodorner)
46 | - Docker build script, `scripts/build_docker.sh`, can push images automatically.
47 | - Added callback collection
48 |
49 | Bug Fixes:
50 | ^^^^^^^^^^
51 |
52 | - Fixed Docker build script, `scripts/build_docker.sh`, to pass `USE_GPU` build argument.
53 |
54 | Deprecations:
55 | ^^^^^^^^^^^^^
56 |
57 | Others:
58 | ^^^^^^^
59 | - Removed redundant return value from `a2c.utils::total_episode_reward_logger`. (@shwang)
60 |
61 | Documentation:
62 | ^^^^^^^^^^^^^^
63 | - Add dedicated page for callbacks
64 |
65 |
66 | Release 2.9.0 (2019-12-20)
67 | --------------------------
68 |
69 | *Reproducible results, automatic `VecEnv` wrapping, env checker and more usability improvements*
70 |
71 | Breaking Changes:
72 | ^^^^^^^^^^^^^^^^^
73 | - The `seed` argument has been moved from `learn()` method to model constructor
74 | in order to have reproducible results
75 | - `allow_early_resets` of the `Monitor` wrapper now default to `True`
76 | - `make_atari_env` now returns a `DummyVecEnv` by default (instead of a `SubprocVecEnv`)
77 | this usually improves performance.
78 | - Fix inconsistency of sample type, so that mode/sample function returns tensor of tf.int64 in CategoricalProbabilityDistribution/MultiCategoricalProbabilityDistribution (@seheevic)
79 |
80 | New Features:
81 | ^^^^^^^^^^^^^
82 | - Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow
83 | - Environments are automatically wrapped in a `DummyVecEnv` if needed when passing them to the model constructor
84 | - Added `stable_baselines.common.make_vec_env` helper to simplify VecEnv creation
85 | - Added `stable_baselines.common.evaluation.evaluate_policy` helper to simplify model evaluation
86 | - `VecNormalize` changes:
87 |
88 | - Now supports being pickled and unpickled (@AdamGleave).
89 | - New methods `.normalize_obs(obs)` and `normalize_reward(rews)` apply normalization
90 | to arbitrary observation or rewards without updating statistics (@shwang)
91 | - `.get_original_reward()` returns the unnormalized rewards from the most recent timestep
92 | - `.reset()` now collects observation statistics (used to only apply normalization)
93 |
94 | - Add parameter `exploration_initial_eps` to DQN. (@jdossgollin)
95 | - Add type checking and PEP 561 compliance.
96 | Note: most functions are still not annotated, this will be a gradual process.
97 | - DDPG, TD3 and SAC accept non-symmetric action spaces. (@Antymon)
98 | - Add `check_env` util to check if a custom environment follows the gym interface (@araffin and @justinkterry)
99 |
100 | Bug Fixes:
101 | ^^^^^^^^^^
102 | - Fix seeding, so it is now possible to have deterministic results on cpu
103 | - Fix a bug in DDPG where `predict` method with `deterministic=False` would fail
104 | - Fix a bug in TRPO: mean_losses was not initialized causing the logger to crash when there was no gradients (@MarvineGothic)
105 | - Fix a bug in `cmd_util` from API change in recent Gym versions
106 | - Fix a bug in DDPG, TD3 and SAC where warmup and random exploration actions would end up scaled in the replay buffer (@Antymon)
107 |
108 | Deprecations:
109 | ^^^^^^^^^^^^^
110 | - `nprocs` (ACKTR) and `num_procs` (ACER) are deprecated in favor of `n_cpu_tf_sess` which is now common
111 | to all algorithms
112 | - `VecNormalize`: `load_running_average` and `save_running_average` are deprecated in favour of using pickle.
113 |
114 | Others:
115 | ^^^^^^^
116 | - Add upper bound for Tensorflow version (<2.0.0).
117 | - Refactored test to remove duplicated code
118 | - Add pull request template
119 | - Replaced redundant code in load_results (@jbulow)
120 | - Minor PEP8 fixes in dqn.py (@justinkterry)
121 | - Add a message to the assert in `PPO2`
122 | - Update replay buffer doctring
123 | - Fix `VecEnv` docstrings
124 |
125 | Documentation:
126 | ^^^^^^^^^^^^^^
127 | - Add plotting to the Monitor example (@rusu24edward)
128 | - Add Snake Game AI project (@pedrohbtp)
129 | - Add note on the support Tensorflow versions.
130 | - Remove unnecessary steps required for Windows installation.
131 | - Remove `DummyVecEnv` creation when not needed
132 | - Added `make_vec_env` to the examples to simplify VecEnv creation
133 | - Add QuaRL project (@srivatsankrishnan)
134 | - Add Pwnagotchi project (@evilsocket)
135 | - Fix multiprocessing example (@rusu24edward)
136 | - Fix `result_plotter` example
137 | - Add JNRR19 tutorial (by @edbeeching, @hill-a and @araffin)
138 | - Updated notebooks link
139 | - Fix typo in algos.rst, "containes" to "contains" (@SyllogismRXS)
140 | - Fix outdated source documentation for load_results
141 | - Add PPO_CPP project (@Antymon)
142 | - Add section on C++ portability of Tensorflow models (@Antymon)
143 | - Update custom env documentation to reflect new gym API for the `close()` method (@justinkterry)
144 | - Update custom env documentation to clarify what step and reset return (@justinkterry)
145 | - Add RL tips and tricks for doing RL experiments
146 | - Corrected lots of typos
147 | - Add spell check to documentation if available
148 |
149 |
150 | Release 2.8.0 (2019-09-29)
151 | --------------------------
152 |
153 | **MPI dependency optional, new save format, ACKTR with continuous actions**
154 |
155 | Breaking Changes:
156 | ^^^^^^^^^^^^^^^^^
157 | - OpenMPI-dependent algorithms (PPO1, TRPO, GAIL, DDPG) are disabled in the
158 | default installation of stable_baselines. `mpi4py` is now installed as an
159 | extra. When `mpi4py` is not available, stable-baselines skips imports of
160 | OpenMPI-dependent algorithms.
161 | See :ref:`installation notes ` and
162 | `Issue #430 `_.
163 | - SubprocVecEnv now defaults to a thread-safe start method, `forkserver` when
164 | available and otherwise `spawn`. This may require application code be
165 | wrapped in `if __name__ == '__main__'`. You can restore previous behavior
166 | by explicitly setting `start_method = 'fork'`. See
167 | `PR #428 `_.
168 | - Updated dependencies: tensorflow v1.8.0 is now required
169 | - Removed `checkpoint_path` and `checkpoint_freq` argument from `DQN` that were not used
170 | - Removed `bench/benchmark.py` that was not used
171 | - Removed several functions from `common/tf_util.py` that were not used
172 | - Removed `ppo1/run_humanoid.py`
173 |
174 | New Features:
175 | ^^^^^^^^^^^^^
176 | - **important change** Switch to using zip-archived JSON and Numpy `savez` for
177 | storing models for better support across library/Python versions. (@Miffyli)
178 | - ACKTR now supports continuous actions
179 | - Add `double_q` argument to `DQN` constructor
180 |
181 | Bug Fixes:
182 | ^^^^^^^^^^
183 | - Skip automatic imports of OpenMPI-dependent algorithms to avoid an issue
184 | where OpenMPI would cause stable-baselines to hang on Ubuntu installs.
185 | See :ref:`installation notes ` and
186 | `Issue #430 `_.
187 | - Fix a bug when calling `logger.configure()` with MPI enabled (@keshaviyengar)
188 | - set `allow_pickle=True` for numpy>=1.17.0 when loading expert dataset
189 | - Fix a bug when using VecCheckNan with numpy ndarray as state. `Issue #489 `_. (@ruifeng96150)
190 |
191 | Deprecations:
192 | ^^^^^^^^^^^^^
193 | - Models saved with cloudpickle format (stable-baselines<=2.7.0) are now
194 | deprecated in favor of zip-archive format for better support across
195 | Python/Tensorflow versions. (@Miffyli)
196 |
197 | Others:
198 | ^^^^^^^
199 | - Implementations of noise classes (`AdaptiveParamNoiseSpec`, `NormalActionNoise`,
200 | `OrnsteinUhlenbeckActionNoise`) were moved from `stable_baselines.ddpg.noise`
201 | to `stable_baselines.common.noise`. The API remains backward-compatible;
202 | for example `from stable_baselines.ddpg.noise import NormalActionNoise` is still
203 | okay. (@shwang)
204 | - Docker images were updated
205 | - Cleaned up files in `common/` folder and in `acktr/` folder that were only used by old ACKTR version
206 | (e.g. `filter.py`)
207 | - Renamed `acktr_disc.py` to `acktr.py`
208 |
209 | Documentation:
210 | ^^^^^^^^^^^^^^
211 | - Add WaveRL project (@jaberkow)
212 | - Add Fenics-DRL project (@DonsetPG)
213 | - Fix and rename custom policy names (@eavelardev)
214 | - Add documentation on exporting models.
215 | - Update maintainers list (Welcome to @Miffyli)
216 |
217 |
218 | Release 2.7.0 (2019-07-31)
219 | --------------------------
220 |
221 | **Twin Delayed DDPG (TD3) and GAE bug fix (TRPO, PPO1, GAIL)**
222 |
223 | Breaking Changes:
224 | ^^^^^^^^^^^^^^^^^
225 |
226 | New Features:
227 | ^^^^^^^^^^^^^
228 | - added Twin Delayed DDPG (TD3) algorithm, with HER support
229 | - added support for continuous action spaces to `action_probability`, computing the PDF of a Gaussian
230 | policy in addition to the existing support for categorical stochastic policies.
231 | - added flag to `action_probability` to return log-probabilities.
232 | - added support for python lists and numpy arrays in ``logger.writekvs``. (@dwiel)
233 | - the info dict returned by VecEnvs now include a ``terminal_observation`` key providing access to the last observation in a trajectory. (@qxcv)
234 |
235 | Bug Fixes:
236 | ^^^^^^^^^^
237 | - fixed a bug in ``traj_segment_generator`` where the ``episode_starts`` was wrongly recorded,
238 | resulting in wrong calculation of Generalized Advantage Estimation (GAE), this affects TRPO, PPO1 and GAIL (thanks to @miguelrass for spotting the bug)
239 | - added missing property `n_batch` in `BasePolicy`.
240 |
241 | Deprecations:
242 | ^^^^^^^^^^^^^
243 |
244 | Others:
245 | ^^^^^^^
246 | - renamed some keys in ``traj_segment_generator`` to be more meaningful
247 | - retrieve unnormalized reward when using Monitor wrapper with TRPO, PPO1 and GAIL
248 | to display them in the logs (mean episode reward)
249 | - clean up DDPG code (renamed variables)
250 |
251 | Documentation:
252 | ^^^^^^^^^^^^^^
253 |
254 | - doc fix for the hyperparameter tuning command in the rl zoo
255 | - added an example on how to log additional variable with tensorboard and a callback
256 |
257 |
258 |
259 | Release 2.6.0 (2019-06-12)
260 | --------------------------
261 |
262 | **Hindsight Experience Replay (HER) - Reloaded | get/load parameters**
263 |
264 | Breaking Changes:
265 | ^^^^^^^^^^^^^^^^^
266 |
267 | - **breaking change** removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer`` (see fix below)
268 |
269 | **Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result,
270 | when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error.
271 | You can fix that using:
272 |
273 | .. code-block:: python
274 |
275 | import sys
276 | import pkg_resources
277 |
278 | import stable_baselines
279 |
280 | # Fix for breaking change for DDPG buffer in v2.6.0
281 | if pkg_resources.get_distribution("stable_baselines").version >= "2.6.0":
282 | sys.modules['stable_baselines.ddpg.memory'] = stable_baselines.deepq.replay_buffer
283 | stable_baselines.deepq.replay_buffer.Memory = stable_baselines.deepq.replay_buffer.ReplayBuffer
284 |
285 |
286 | We recommend you to save again the model afterward, so the fix won't be needed the next time the trained agent is loaded.
287 |
288 |
289 | New Features:
290 | ^^^^^^^^^^^^^
291 |
292 | - **revamped HER implementation**: clean re-implementation from scratch, now supports DQN, SAC and DDPG
293 | - add ``action_noise`` param for SAC, it helps exploration for problem with deceptive reward
294 | - The parameter ``filter_size`` of the function ``conv`` in A2C utils now supports passing a list/tuple of two integers (height and width), in order to have non-squared kernel matrix. (@yutingsz)
295 | - add ``random_exploration`` parameter for DDPG and SAC, it may be useful when using HER + DDPG/SAC. This hack was present in the original OpenAI Baselines DDPG + HER implementation.
296 | - added ``load_parameters`` and ``get_parameters`` to base RL class. With these methods, users are able to load and get parameters to/from existing model, without touching tensorflow. (@Miffyli)
297 | - added specific hyperparameter for PPO2 to clip the value function (``cliprange_vf``)
298 | - added ``VecCheckNan`` wrapper
299 |
300 | Bug Fixes:
301 | ^^^^^^^^^^
302 |
303 | - bugfix for ``VecEnvWrapper.__getattr__`` which enables access to class attributes inherited from parent classes.
304 | - fixed path splitting in ``TensorboardWriter._get_latest_run_id()`` on Windows machines (@PatrickWalter214)
305 | - fixed a bug where initial learning rate is logged instead of its placeholder in ``A2C.setup_model`` (@sc420)
306 | - fixed a bug where number of timesteps is incorrectly updated and logged in ``A2C.learn`` and ``A2C._train_step`` (@sc420)
307 | - fixed ``num_timesteps`` (total_timesteps) variable in PPO2 that was wrongly computed.
308 | - fixed a bug in DDPG/DQN/SAC, when there were the number of samples in the replay buffer was lesser than the batch size
309 | (thanks to @dwiel for spotting the bug)
310 | - **removed** ``a2c.utils.find_trainable_params`` please use ``common.tf_util.get_trainable_vars`` instead.
311 | ``find_trainable_params`` was returning all trainable variables, discarding the scope argument.
312 | This bug was causing the model to save duplicated parameters (for DDPG and SAC)
313 | but did not affect the performance.
314 |
315 | Deprecations:
316 | ^^^^^^^^^^^^^
317 |
318 | - **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x)
319 |
320 | Others:
321 | ^^^^^^^
322 |
323 | - **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli)
324 | - removed unused dependencies (tdqm, dill, progressbar2, seaborn, glob2, click)
325 | - removed ``get_available_gpus`` function which hadn't been used anywhere (@Pastafarianist)
326 |
327 | Documentation:
328 | ^^^^^^^^^^^^^^
329 |
330 | - added guide for managing ``NaN`` and ``inf``
331 | - updated ven_env doc
332 | - misc doc updates
333 |
334 | Release 2.5.1 (2019-05-04)
335 | --------------------------
336 |
337 | **Bug fixes + improvements in the VecEnv**
338 |
339 | **Warning: breaking changes when using custom policies**
340 |
341 | - doc update (fix example of result plotter + improve doc)
342 | - fixed logger issues when stdout lacks ``read`` function
343 | - fixed a bug in ``common.dataset.Dataset`` where shuffling was not disabled properly (it affects only PPO1 with recurrent policies)
344 | - fixed output layer name for DDPG q function, used in pop-art normalization and l2 regularization of the critic
345 | - added support for multi env recording to ``generate_expert_traj`` (@XMaster96)
346 | - added support for LSTM model recording to ``generate_expert_traj`` (@XMaster96)
347 | - ``GAIL``: remove mandatory matplotlib dependency and refactor as subclass of ``TRPO`` (@kantneel and @AdamGleave)
348 | - added ``get_attr()``, ``env_method()`` and ``set_attr()`` methods for all VecEnv.
349 | Those methods now all accept ``indices`` keyword to select a subset of envs.
350 | ``set_attr`` now returns ``None`` rather than a list of ``None``. (@kantneel)
351 | - ``GAIL``: ``gail.dataset.ExpertDataset`` supports loading from memory rather than file, and
352 | ``gail.dataset.record_expert`` supports returning in-memory rather than saving to file.
353 | - added support in ``VecEnvWrapper`` for accessing attributes of arbitrarily deeply nested
354 | instances of ``VecEnvWrapper`` and ``VecEnv``. This is allowed as long as the attribute belongs
355 | to exactly one of the nested instances i.e. it must be unambiguous. (@kantneel)
356 | - fixed bug where result plotter would crash on very short runs (@Pastafarianist)
357 | - added option to not trim output of result plotter by number of timesteps (@Pastafarianist)
358 | - clarified the public interface of ``BasePolicy`` and ``ActorCriticPolicy``. **Breaking change** when using custom policies: ``masks_ph`` is now called ``dones_ph``,
359 | and most placeholders were made private: e.g. ``self.value_fn`` is now ``self._value_fn``
360 | - support for custom stateful policies.
361 | - fixed episode length recording in ``trpo_mpi.utils.traj_segment_generator`` (@GerardMaggiolino)
362 |
363 |
364 | Release 2.5.0 (2019-03-28)
365 | --------------------------
366 |
367 | **Working GAIL, pretrain RL models and hotfix for A2C with continuous actions**
368 |
369 | - fixed various bugs in GAIL
370 | - added scripts to generate dataset for gail
371 | - added tests for GAIL + data for Pendulum-v0
372 | - removed unused ``utils`` file in DQN folder
373 | - fixed a bug in A2C where actions were cast to ``int32`` even in the continuous case
374 | - added addional logging to A2C when Monitor wrapper is used
375 | - changed logging for PPO2: do not display NaN when reward info is not present
376 | - change default value of A2C lr schedule
377 | - removed behavior cloning script
378 | - added ``pretrain`` method to base class, in order to use behavior cloning on all models
379 | - fixed ``close()`` method for DummyVecEnv.
380 | - added support for Dict spaces in DummyVecEnv and SubprocVecEnv. (@AdamGleave)
381 | - added support for arbitrary multiprocessing start methods and added a warning about SubprocVecEnv that are not thread-safe by default. (@AdamGleave)
382 | - added support for Discrete actions for GAIL
383 | - fixed deprecation warning for tf: replaces ``tf.to_float()`` by ``tf.cast()``
384 | - fixed bug in saving and loading ddpg model when using normalization of obs or returns (@tperol)
385 | - changed DDPG default buffer size from 100 to 50000.
386 | - fixed a bug in ``ddpg.py`` in ``combined_stats`` for eval. Computed mean on ``eval_episode_rewards`` and ``eval_qs`` (@keshaviyengar)
387 | - fixed a bug in ``setup.py`` that would error on non-GPU systems without TensorFlow installed
388 |
389 |
390 | Release 2.4.1 (2019-02-11)
391 | --------------------------
392 |
393 | **Bug fixes and improvements**
394 |
395 | - fixed computation of training metrics in TRPO and PPO1
396 | - added ``reset_num_timesteps`` keyword when calling train() to continue tensorboard learning curves
397 | - reduced the size taken by tensorboard logs (added a ``full_tensorboard_log`` to enable full logging, which was the previous behavior)
398 | - fixed image detection for tensorboard logging
399 | - fixed ACKTR for recurrent policies
400 | - fixed gym breaking changes
401 | - fixed custom policy examples in the doc for DQN and DDPG
402 | - remove gym spaces patch for equality functions
403 | - fixed tensorflow dependency: cpu version was installed overwritting tensorflow-gpu when present.
404 | - fixed a bug in ``traj_segment_generator`` (used in ppo1 and trpo) where ``new`` was not updated. (spotted by @junhyeokahn)
405 |
406 |
407 | Release 2.4.0 (2019-01-17)
408 | --------------------------
409 |
410 | **Soft Actor-Critic (SAC) and policy kwargs**
411 |
412 | - added Soft Actor-Critic (SAC) model
413 | - fixed a bug in DQN where prioritized_replay_beta_iters param was not used
414 | - fixed DDPG that did not save target network parameters
415 | - fixed bug related to shape of true_reward (@abhiskk)
416 | - fixed example code in documentation of tf_util:Function (@JohannesAck)
417 | - added learning rate schedule for SAC
418 | - fixed action probability for continuous actions with actor-critic models
419 | - added optional parameter to action_probability for likelihood calculation of given action being taken.
420 | - added more flexible custom LSTM policies
421 | - added auto entropy coefficient optimization for SAC
422 | - clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed)
423 | - added a mean to pass kwargs to policy when creating a model (+ save those kwargs)
424 | - fixed DQN examples in DQN folder
425 | - added possibility to pass activation function for DDPG, DQN and SAC
426 |
427 |
428 | Release 2.3.0 (2018-12-05)
429 | --------------------------
430 |
431 | - added support for storing model in file like object. (thanks to @erniejunior)
432 | - fixed wrong image detection when using tensorboard logging with DQN
433 | - fixed bug in ppo2 when passing non callable lr after loading
434 | - fixed tensorboard logging in ppo2 when nminibatches=1
435 | - added early stoppping via callback return value (@erniejunior)
436 | - added more flexible custom mlp policies (@erniejunior)
437 |
438 |
439 | Release 2.2.1 (2018-11-18)
440 | --------------------------
441 |
442 | - added VecVideoRecorder to record mp4 videos from environment.
443 |
444 |
445 | Release 2.2.0 (2018-11-07)
446 | --------------------------
447 |
448 | - Hotfix for ppo2, the wrong placeholder was used for the value function
449 |
450 |
451 | Release 2.1.2 (2018-11-06)
452 | --------------------------
453 |
454 | - added ``async_eigen_decomp`` parameter for ACKTR and set it to ``False`` by default (remove deprecation warnings)
455 | - added methods for calling env methods/setting attributes inside a VecEnv (thanks to @bjmuld)
456 | - updated gym minimum version
457 |
458 |
459 | Release 2.1.1 (2018-10-20)
460 | --------------------------
461 |
462 | - fixed MpiAdam synchronization issue in PPO1 (thanks to @brendenpetersen) issue #50
463 | - fixed dependency issues (new mujoco-py requires a mujoco license + gym broke MultiDiscrete space shape)
464 |
465 |
466 | Release 2.1.0 (2018-10-2)
467 | -------------------------
468 |
469 | .. warning::
470 |
471 | This version contains breaking changes for DQN policies, please read the full details
472 |
473 | **Bug fixes + doc update**
474 |
475 |
476 | - added patch fix for equal function using `gym.spaces.MultiDiscrete` and `gym.spaces.MultiBinary`
477 | - fixes for DQN action_probability
478 | - re-added double DQN + refactored DQN policies **breaking changes**
479 | - replaced `async` with `async_eigen_decomp` in ACKTR/KFAC for python 3.7 compatibility
480 | - removed action clipping for prediction of continuous actions (see issue #36)
481 | - fixed NaN issue due to clipping the continuous action in the wrong place (issue #36)
482 | - documentation was updated (policy + DDPG example hyperparameters)
483 |
484 | Release 2.0.0 (2018-09-18)
485 | --------------------------
486 |
487 | .. warning::
488 |
489 | This version contains breaking changes, please read the full details
490 |
491 | **Tensorboard, refactoring and bug fixes**
492 |
493 |
494 | - Renamed DeepQ to DQN **breaking changes**
495 | - Renamed DeepQPolicy to DQNPolicy **breaking changes**
496 | - fixed DDPG behavior **breaking changes**
497 | - changed default policies for DDPG, so that DDPG now works correctly **breaking changes**
498 | - added more documentation (some modules from common).
499 | - added doc about using custom env
500 | - added Tensorboard support for A2C, ACER, ACKTR, DDPG, DeepQ, PPO1, PPO2 and TRPO
501 | - added episode reward to Tensorboard
502 | - added documentation for Tensorboard usage
503 | - added Identity for Box action space
504 | - fixed render function ignoring parameters when using wrapped environments
505 | - fixed PPO1 and TRPO done values for recurrent policies
506 | - fixed image normalization not occurring when using images
507 | - updated VecEnv objects for the new Gym version
508 | - added test for DDPG
509 | - refactored DQN policies
510 | - added registry for policies, can be passed as string to the agent
511 | - added documentation for custom policies + policy registration
512 | - fixed numpy warning when using DDPG Memory
513 | - fixed DummyVecEnv not copying the observation array when stepping and resetting
514 | - added pre-built docker images + installation instructions
515 | - added ``deterministic`` argument in the predict function
516 | - added assert in PPO2 for recurrent policies
517 | - fixed predict function to handle both vectorized and unwrapped environment
518 | - added input check to the predict function
519 | - refactored ActorCritic models to reduce code duplication
520 | - refactored Off Policy models (to begin HER and replay_buffer refactoring)
521 | - added tests for auto vectorization detection
522 | - fixed render function, to handle positional arguments
523 |
524 |
525 | Release 1.0.7 (2018-08-29)
526 | --------------------------
527 |
528 | **Bug fixes and documentation**
529 |
530 | - added html documentation using sphinx + integration with read the docs
531 | - cleaned up README + typos
532 | - fixed normalization for DQN with images
533 | - fixed DQN identity test
534 |
535 |
536 | Release 1.0.1 (2018-08-20)
537 | --------------------------
538 |
539 | **Refactored Stable Baselines**
540 |
541 | - refactored A2C, ACER, ACTKR, DDPG, DeepQ, GAIL, TRPO, PPO1 and PPO2 under a single constant class
542 | - added callback to refactored algorithm training
543 | - added saving and loading to refactored algorithms
544 | - refactored ACER, DDPG, GAIL, PPO1 and TRPO to fit with A2C, PPO2 and ACKTR policies
545 | - added new policies for most algorithms (Mlp, MlpLstm, MlpLnLstm, Cnn, CnnLstm and CnnLnLstm)
546 | - added dynamic environment switching (so continual RL learning is now feasible)
547 | - added prediction from observation and action probability from observation for all the algorithms
548 | - fixed graphs issues, so models wont collide in names
549 | - fixed behavior_clone weight loading for GAIL
550 | - fixed Tensorflow using all the GPU VRAM
551 | - fixed models so that they are all compatible with vectorized environments
552 | - fixed ```set_global_seed``` to update ```gym.spaces```'s random seed
553 | - fixed PPO1 and TRPO performance issues when learning identity function
554 | - added new tests for loading, saving, continuous actions and learning the identity function
555 | - fixed DQN wrapping for atari
556 | - added saving and loading for Vecnormalize wrapper
557 | - added automatic detection of action space (for the policy network)
558 | - fixed ACER buffer with constant values assuming n_stack=4
559 | - fixed some RL algorithms not clipping the action to be in the action_space, when using ```gym.spaces.Box```
560 | - refactored algorithms can take either a ```gym.Environment``` or a ```str``` ([if the environment name is registered](https://github.com/openai/gym/wiki/Environments))
561 | - Hoftix in ACER (compared to v1.0.0)
562 |
563 | Future Work :
564 |
565 | - Finish refactoring HER
566 | - Refactor ACKTR and ACER for continuous implementation
567 |
568 |
569 |
570 | Release 0.1.6 (2018-07-27)
571 | --------------------------
572 |
573 | **Deobfuscation of the code base + pep8 and fixes**
574 |
575 | - Fixed ``tf.session().__enter__()`` being used, rather than
576 | ``sess = tf.session()`` and passing the session to the objects
577 | - Fixed uneven scoping of TensorFlow Sessions throughout the code
578 | - Fixed rolling vecwrapper to handle observations that are not only
579 | grayscale images
580 | - Fixed deepq saving the environment when trying to save itself
581 | - Fixed
582 | ``ValueError: Cannot take the length of Shape with unknown rank.`` in
583 | ``acktr``, when running ``run_atari.py`` script.
584 | - Fixed calling baselines sequentially no longer creates graph
585 | conflicts
586 | - Fixed mean on empty array warning with deepq
587 | - Fixed kfac eigen decomposition not cast to float64, when the
588 | parameter use_float64 is set to True
589 | - Fixed Dataset data loader, not correctly resetting id position if
590 | shuffling is disabled
591 | - Fixed ``EOFError`` when reading from connection in the ``worker`` in
592 | ``subproc_vec_env.py``
593 | - Fixed ``behavior_clone`` weight loading and saving for GAIL
594 | - Avoid taking root square of negative number in ``trpo_mpi.py``
595 | - Removed some duplicated code (a2cpolicy, trpo_mpi)
596 | - Removed unused, undocumented and crashing function ``reset_task`` in
597 | ``subproc_vec_env.py``
598 | - Reformated code to PEP8 style
599 | - Documented all the codebase
600 | - Added atari tests
601 | - Added logger tests
602 |
603 | Missing: tests for acktr continuous (+ HER, rely on mujoco...)
604 |
605 | Maintainers
606 | -----------
607 |
608 | Stable-Baselines is currently maintained by `Ashley Hill`_ (aka @hill-a), `Antonin Raffin`_ (aka `@araffin`_),
609 | `Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
610 |
611 | .. _Ashley Hill: https://github.com/hill-a
612 | .. _Antonin Raffin: https://araffin.github.io/
613 | .. _Maximilian Ernestus: https://github.com/erniejunior
614 | .. _Adam Gleave: https://gleave.me/
615 | .. _@araffin: https://github.com/araffin
616 | .. _@AdamGleave: https://github.com/adamgleave
617 | .. _Anssi Kanervisto: https://github.com/Miffyli
618 | .. _@Miffyli: https://github.com/Miffyli
619 |
620 |
621 | Contributors (since v2.0.0):
622 | ----------------------------
623 | In random order...
624 |
625 | Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
626 | @EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
627 | @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
628 | @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
629 | @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
630 | @flodorner
631 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | # This includes the license file in the wheel.
3 | license_file = LICENSE
4 |
5 | [tool:pytest]
6 | # Deterministic ordering for tests; useful for pytest-xdist.
7 | env =
8 | PYTHONHASHSEED=0
9 | filterwarnings =
10 |
11 |
12 | [pytype]
13 | inputs = stable_baselines
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from setuptools import setup, find_packages
3 | from distutils.version import LooseVersion
4 |
5 | if sys.version_info.major != 3:
6 | print('This Python is only compatible with Python 3, but you are running '
7 | 'Python {}. The installation will likely fail.'.format(sys.version_info.major))
8 |
9 |
10 | long_description = """
11 |
12 | # Stable Baselines TF2 [Experimental]
13 |
14 | """
15 |
16 | setup(name='stable_baselines',
17 | packages=[package for package in find_packages()
18 | if package.startswith('stable_baselines')],
19 | package_data={
20 | 'stable_baselines': ['py.typed'],
21 | },
22 | install_requires=[
23 | 'gym[atari,classic_control]>=0.10.9',
24 | 'scipy',
25 | 'joblib',
26 | 'cloudpickle>=0.5.5',
27 | 'opencv-python',
28 | 'numpy',
29 | 'pandas',
30 | 'matplotlib',
31 | 'tensorflow-probability>=0.8.0',
32 | 'tensorflow>=2.1.0'
33 | ],
34 | extras_require={
35 | 'tests': [
36 | 'pytest',
37 | 'pytest-cov',
38 | 'pytest-env',
39 | 'pytest-xdist',
40 | 'pytype',
41 | ],
42 | 'docs': [
43 | 'sphinx',
44 | 'sphinx-autobuild',
45 | 'sphinx-rtd-theme'
46 | ]
47 | },
48 | description='A fork of OpenAI Baselines, implementations of reinforcement learning algorithms.',
49 | author='Ashley Hill',
50 | url='https://github.com/hill-a/stable-baselines',
51 | author_email='github@hill-a.me',
52 | keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
53 | "gym openai baselines toolbox python data-science",
54 | license="MIT",
55 | long_description=long_description,
56 | long_description_content_type='text/markdown',
57 | version="3.0.0a0",
58 | )
59 |
60 | # python setup.py sdist
61 | # python setup.py bdist_wheel
62 | # twine upload --repository-url https://test.pypi.org/legacy/ dist/*
63 | # twine upload dist/*
64 |
--------------------------------------------------------------------------------
/stable_baselines/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.td3 import TD3
2 | from stable_baselines.ppo import PPO
3 |
4 | __version__ = "3.0.0a0"
5 |
--------------------------------------------------------------------------------
/stable_baselines/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/stable_baselines/common/__init__.py
--------------------------------------------------------------------------------
/stable_baselines/common/base_class.py:
--------------------------------------------------------------------------------
1 | import time
2 | from abc import ABC, abstractmethod
3 | from collections import deque
4 | import os
5 | import io
6 | import zipfile
7 |
8 | import gym
9 | import tensorflow as tf
10 | import numpy as np
11 |
12 | from stable_baselines.common import logger
13 | from stable_baselines.common.policies import get_policy_from_name
14 | from stable_baselines.common.utils import set_random_seed, get_schedule_fn
15 | from stable_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, sync_envs_normalization
16 | from stable_baselines.common.monitor import Monitor
17 | from stable_baselines.common.evaluation import evaluate_policy
18 | from stable_baselines.common.save_util import data_to_json, json_to_data
19 |
20 |
21 | class BaseRLModel(ABC):
22 | """
23 | The base RL model
24 |
25 | :param policy: (BasePolicy) Policy object
26 | :param env: (Gym environment) The environment to learn from
27 | (if registered in Gym, can be str. Can be None for loading trained models)
28 | :param policy_base: (BasePolicy) the base policy used by this method
29 | :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
30 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug
31 | :param support_multi_env: (bool) Whether the algorithm supports training
32 | with multiple environments (as in A2C)
33 | :param create_eval_env: (bool) Whether to create a second environment that will be
34 | used for evaluating the agent periodically. (Only available when passing string for the environment)
35 | :param monitor_wrapper: (bool) When creating an environment, whether to wrap it
36 | or not in a Monitor wrapper.
37 | :param seed: (int) Seed for the pseudo random generators
38 | """
39 | def __init__(self, policy, env, policy_base, policy_kwargs=None,
40 | verbose=0, device='auto', support_multi_env=False,
41 | create_eval_env=False, monitor_wrapper=True, seed=None):
42 | if isinstance(policy, str) and policy_base is not None:
43 | self.policy_class = get_policy_from_name(policy_base, policy)
44 | else:
45 | self.policy_class = policy
46 |
47 | self.env = env
48 | # get VecNormalize object if needed
49 | self._vec_normalize_env = unwrap_vec_normalize(env)
50 | self.verbose = verbose
51 | self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
52 | self.observation_space = None
53 | self.action_space = None
54 | self.n_envs = None
55 | self.num_timesteps = 0
56 | self.eval_env = None
57 | self.replay_buffer = None
58 | self.seed = seed
59 | self.action_noise = None
60 |
61 | # Track the training progress (from 1 to 0)
62 | # this is used to update the learning rate
63 | self._current_progress = 1
64 |
65 | # Create and wrap the env if needed
66 | if env is not None:
67 | if isinstance(env, str):
68 | if create_eval_env:
69 | eval_env = gym.make(env)
70 | if monitor_wrapper:
71 | eval_env = Monitor(eval_env, filename=None)
72 | self.eval_env = DummyVecEnv([lambda: eval_env])
73 | if self.verbose >= 1:
74 | print("Creating environment from the given name, wrapped in a DummyVecEnv.")
75 |
76 | env = gym.make(env)
77 | if monitor_wrapper:
78 | env = Monitor(env, filename=None)
79 | env = DummyVecEnv([lambda: env])
80 |
81 | self.observation_space = env.observation_space
82 | self.action_space = env.action_space
83 | if not isinstance(env, VecEnv):
84 | if self.verbose >= 1:
85 | print("Wrapping the env in a DummyVecEnv.")
86 | env = DummyVecEnv([lambda: env])
87 | self.n_envs = env.num_envs
88 | self.env = env
89 |
90 | if not support_multi_env and self.n_envs > 1:
91 | raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
92 | " environment.")
93 |
94 | def _get_eval_env(self, eval_env):
95 | """
96 | Return the environment that will be used for evaluation.
97 |
98 | :param eval_env: (gym.Env or VecEnv)
99 | :return: (VecEnv)
100 | """
101 | if eval_env is None:
102 | eval_env = self.eval_env
103 |
104 | if eval_env is not None:
105 | if not isinstance(eval_env, VecEnv):
106 | eval_env = DummyVecEnv([lambda: eval_env])
107 | assert eval_env.num_envs == 1
108 | return eval_env
109 |
110 | def scale_action(self, action):
111 | """
112 | Rescale the action from [low, high] to [-1, 1]
113 | (no need for symmetric action space)
114 |
115 | :param action: (np.ndarray)
116 | :return: (np.ndarray)
117 | """
118 | low, high = self.action_space.low, self.action_space.high
119 | return 2.0 * ((action - low) / (high - low)) - 1.0
120 |
121 | def unscale_action(self, scaled_action):
122 | """
123 | Rescale the action from [-1, 1] to [low, high]
124 | (no need for symmetric action space)
125 |
126 | :param scaled_action: (np.ndarray)
127 | :return: (np.ndarray)
128 | """
129 | low, high = self.action_space.low, self.action_space.high
130 | return low + (0.5 * (scaled_action + 1.0) * (high - low))
131 |
132 | def _setup_learning_rate(self):
133 | """Transform to callable if needed."""
134 | self.learning_rate = get_schedule_fn(self.learning_rate)
135 |
136 | def _update_current_progress(self, num_timesteps, total_timesteps):
137 | """
138 | Compute current progress (from 1 to 0)
139 |
140 | :param num_timesteps: (int) current number of timesteps
141 | :param total_timesteps: (int)
142 | """
143 | self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps)
144 |
145 | def _update_learning_rate(self, optimizers):
146 | """
147 | Update the optimizers learning rate using the current learning rate schedule
148 | and the current progress (from 1 to 0).
149 |
150 | :param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer
151 | or a list of optimizer.
152 | """
153 | # Log the current learning rate
154 | logger.logkv("learning_rate", self.learning_rate(self._current_progress))
155 |
156 | # if not isinstance(optimizers, list):
157 | # optimizers = [optimizers]
158 | # for optimizer in optimizers:
159 | # update_learning_rate(optimizer, self.learning_rate(self._current_progress))
160 |
161 | @staticmethod
162 | def safe_mean(arr):
163 | """
164 | Compute the mean of an array if there is at least one element.
165 | For empty array, return nan. It is used for logging only.
166 |
167 | :param arr: (np.ndarray)
168 | :return: (float)
169 | """
170 | return np.nan if len(arr) == 0 else np.mean(arr)
171 |
172 | def get_env(self):
173 | """
174 | returns the current environment (can be None if not defined)
175 |
176 | :return: (gym.Env) The current environment
177 | """
178 | return self.env
179 |
180 | def set_env(self, env):
181 | """
182 | :param env: (gym.Env) The environment for learning a policy
183 | """
184 | raise NotImplementedError()
185 |
186 | @abstractmethod
187 | def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
188 | eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True):
189 | """
190 | Return a trained model.
191 |
192 | :param total_timesteps: (int) The total number of samples to train on
193 | :param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
194 | It takes the local and global variables. If it returns False, training is aborted.
195 | :param log_interval: (int) The number of timesteps before logging.
196 | :param tb_log_name: (str) the name of the run for tensorboard log
197 | :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
198 | :param eval_env: (gym.Env) Environment that will be used to evaluate the agent
199 | :param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little)
200 | :param n_eval_episodes: (int) Number of episode to evaluate the agent
201 | :return: (BaseRLModel) the trained model
202 | """
203 | pass
204 |
205 | @abstractmethod
206 | def predict(self, observation, state=None, mask=None, deterministic=False):
207 | """
208 | Get the model's action from an observation
209 |
210 | :param observation: (np.ndarray) the input observation
211 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies)
212 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
213 | :param deterministic: (bool) Whether or not to return deterministic actions.
214 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
215 | """
216 | pass
217 |
218 | def set_random_seed(self, seed=None):
219 | """
220 | Set the seed of the pseudo-random generators
221 | (python, numpy, pytorch, gym, action_space)
222 |
223 | :param seed: (int)
224 | """
225 | if seed is None:
226 | return
227 | set_random_seed(seed)
228 | self.action_space.seed(seed)
229 | if self.env is not None:
230 | self.env.seed(seed)
231 | if self.eval_env is not None:
232 | self.eval_env.seed(seed)
233 |
234 | def _setup_learn(self, eval_env):
235 | """
236 | Initialize different variables needed for training.
237 |
238 | :param eval_env: (gym.Env or VecEnv)
239 | :return: (int, int, [float], np.ndarray, VecEnv)
240 | """
241 | self.start_time = time.time()
242 | self.ep_info_buffer = deque(maxlen=100)
243 |
244 | if self.action_noise is not None:
245 | self.action_noise.reset()
246 |
247 | timesteps_since_eval, episode_num = 0, 0
248 | evaluations = []
249 |
250 | if eval_env is not None and self.seed is not None:
251 | eval_env.seed(self.seed)
252 |
253 | eval_env = self._get_eval_env(eval_env)
254 | obs = self.env.reset()
255 | return timesteps_since_eval, episode_num, evaluations, obs, eval_env
256 |
257 | def _update_info_buffer(self, infos):
258 | """
259 | Retrieve reward and episode length and update the buffer
260 | if using Monitor wrapper.
261 |
262 | :param infos: ([dict])
263 | """
264 | for info in infos:
265 | maybe_ep_info = info.get('episode')
266 | if maybe_ep_info is not None:
267 | self.ep_info_buffer.extend([maybe_ep_info])
268 |
269 | def _eval_policy(self, eval_freq, eval_env, n_eval_episodes,
270 | timesteps_since_eval, deterministic=True):
271 | """
272 | Evaluate the current policy on a test environment.
273 |
274 | :param eval_env: (gym.Env) Environment that will be used to evaluate the agent
275 | :param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little)
276 | :param n_eval_episodes: (int) Number of episode to evaluate the agent
277 | :parma timesteps_since_eval: (int) Number of timesteps since last evaluation
278 | :param deterministic: (bool) Whether to use deterministic or stochastic actions
279 | :return: (int) Number of timesteps since last evaluation
280 | """
281 | if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
282 | timesteps_since_eval %= eval_freq
283 | # Synchronise the normalization stats if needed
284 | sync_envs_normalization(self.env, eval_env)
285 | mean_reward, std_reward = evaluate_policy(self, eval_env, n_eval_episodes, deterministic=deterministic)
286 | if self.verbose > 0:
287 | print("Eval num_timesteps={}, "
288 | "episode_reward={:.2f} +/- {:.2f}".format(self.num_timesteps, mean_reward, std_reward))
289 | print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - self.start_time)))
290 | return timesteps_since_eval
291 |
--------------------------------------------------------------------------------
/stable_baselines/common/buffers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class BaseBuffer(object):
5 | """
6 | Base class that represent a buffer (rollout or replay)
7 |
8 | :param buffer_size: (int) Max number of element in the buffer
9 | :param obs_dim: (int) Dimension of the observation
10 | :param action_dim: (int) Dimension of the action space
11 | :param n_envs: (int) Number of parallel environments
12 | """
13 | def __init__(self, buffer_size, obs_dim, action_dim, n_envs=1):
14 | super(BaseBuffer, self).__init__()
15 | self.buffer_size = buffer_size
16 | self.obs_dim = obs_dim
17 | self.action_dim = action_dim
18 | self.pos = 0
19 | self.full = False
20 | self.n_envs = n_envs
21 |
22 | def size(self):
23 | """
24 | :return: (int) The current size of the buffer
25 | """
26 | if self.full:
27 | return self.buffer_size
28 | return self.pos
29 |
30 | def add(self, *args, **kwargs):
31 | """
32 | Add elements to the buffer.
33 | """
34 | raise NotImplementedError()
35 |
36 | def reset(self):
37 | """
38 | Reset the buffer.
39 | """
40 | self.pos = 0
41 | self.full = False
42 |
43 | def sample(self, batch_size):
44 | """
45 | :param batch_size: (int) Number of element to sample
46 | """
47 | upper_bound = self.buffer_size if self.full else self.pos
48 | batch_inds = np.random.randint(0, upper_bound, size=batch_size)
49 | return self._get_samples(batch_inds)
50 |
51 | def _get_samples(self, batch_inds):
52 | """
53 | :param batch_inds: (np.ndarray)
54 | :return: ([np.ndarray])
55 | """
56 | raise NotImplementedError()
57 |
58 |
59 | class ReplayBuffer(BaseBuffer):
60 | """
61 | Replay buffer used in off-policy algorithms like SAC/TD3.
62 |
63 | :param buffer_size: (int) Max number of element in the buffer
64 | :param obs_dim: (int) Dimension of the observation
65 | :param action_dim: (int) Dimension of the action space
66 | :param n_envs: (int) Number of parallel environments
67 | """
68 | def __init__(self, buffer_size, obs_dim, action_dim, n_envs=1):
69 | super(ReplayBuffer, self).__init__(buffer_size, obs_dim, action_dim, n_envs=n_envs)
70 |
71 | assert n_envs == 1
72 | self.observations = np.zeros((self.buffer_size, self.n_envs, self.obs_dim), dtype=np.float32)
73 | self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
74 | self.next_observations = np.zeros((self.buffer_size, self.n_envs, self.obs_dim), dtype=np.float32)
75 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
76 | self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
77 |
78 | def add(self, obs, next_obs, action, reward, done):
79 | # Copy to avoid modification by reference
80 | self.observations[self.pos] = np.array(obs).copy()
81 | self.next_observations[self.pos] = np.array(next_obs).copy()
82 | self.actions[self.pos] = np.array(action).copy()
83 | self.rewards[self.pos] = np.array(reward).copy()
84 | self.dones[self.pos] = np.array(done).copy()
85 |
86 | self.pos += 1
87 | if self.pos == self.buffer_size:
88 | self.full = True
89 | self.pos = 0
90 |
91 | def _get_samples(self, batch_inds):
92 | return (self.observations[batch_inds, 0, :],
93 | self.actions[batch_inds, 0, :],
94 | self.next_observations[batch_inds, 0, :],
95 | self.dones[batch_inds],
96 | self.rewards[batch_inds])
97 |
98 |
99 | class RolloutBuffer(BaseBuffer):
100 | """
101 | Rollout buffer used in on-policy algorithms like A2C/PPO.
102 |
103 | :param buffer_size: (int) Max number of element in the buffer
104 | :param obs_dim: (int) Dimension of the observation
105 | :param action_dim: (int) Dimension of the action space
106 | :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
107 | Equivalent to classic advantage when set to 1.
108 | :param gamma: (float) Discount factor
109 | :param n_envs: (int) Number of parallel environments
110 | """
111 | def __init__(self, buffer_size, obs_dim, action_dim,
112 | gae_lambda=1, gamma=0.99, n_envs=1):
113 | super(RolloutBuffer, self).__init__(buffer_size, obs_dim, action_dim, n_envs=n_envs)
114 | self.gae_lambda = gae_lambda
115 | self.gamma = gamma
116 | self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
117 | self.returns, self.dones, self.values, self.log_probs = None, None, None, None
118 | self.generator_ready = False
119 | self.reset()
120 |
121 | def reset(self):
122 | self.observations = np.zeros((self.buffer_size, self.n_envs, self.obs_dim), dtype=np.float32)
123 | self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
124 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
125 | self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
126 | self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
127 | self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
128 | self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
129 | self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
130 | self.generator_ready = False
131 | super(RolloutBuffer, self).reset()
132 |
133 | def compute_returns_and_advantage(self, last_value, dones=False, use_gae=True):
134 | """
135 | Post-processing step: compute the returns (sum of discounted rewards)
136 | and advantage (A(s) = R - V(S)).
137 |
138 | :param last_value: (tf.Tensor)
139 | :param dones: ([bool])
140 | :param use_gae: (bool) Whether to use Generalized Advantage Estimation
141 | or normal advantage for advantage computation.
142 | """
143 | if use_gae:
144 | last_gae_lam = 0
145 | for step in reversed(range(self.buffer_size)):
146 | if step == self.buffer_size - 1:
147 | next_non_terminal = np.array(1.0 - dones)
148 | next_value = last_value.numpy().flatten()
149 | else:
150 | next_non_terminal = 1.0 - self.dones[step + 1]
151 | next_value = self.values[step + 1]
152 | delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
153 | last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
154 | self.advantages[step] = last_gae_lam
155 | self.returns = self.advantages + self.values
156 | else:
157 | # Discounted return with value bootstrap
158 | # Note: this is equivalent to GAE computation
159 | # with gae_lambda = 1.0
160 | last_return = 0.0
161 | for step in reversed(range(self.buffer_size)):
162 | if step == self.buffer_size - 1:
163 | next_non_terminal = np.array(1.0 - dones)
164 | next_value = last_value.numpy().flatten()
165 | last_return = self.rewards[step] + next_non_terminal * next_value
166 | else:
167 | next_non_terminal = 1.0 - self.dones[step + 1]
168 | last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal
169 | self.returns[step] = last_return
170 | self.advantages = self.returns - self.values
171 |
172 | def add(self, obs, action, reward, done, value, log_prob):
173 | """
174 | :param obs: (np.ndarray) Observation
175 | :param action: (np.ndarray) Action
176 | :param reward: (np.ndarray)
177 | :param done: (np.ndarray) End of episode signal.
178 | :param value: (np.Tensor) estimated value of the current state
179 | following the current policy.
180 | :param log_prob: (np.Tensor) log probability of the action
181 | following the current policy.
182 | """
183 | if len(log_prob.shape) == 0:
184 | # Reshape 0-d tensor to avoid error
185 | log_prob = log_prob.reshape(-1, 1)
186 |
187 | self.observations[self.pos] = np.array(obs).copy()
188 | self.actions[self.pos] = np.array(action).copy()
189 | self.rewards[self.pos] = np.array(reward).copy()
190 | self.dones[self.pos] = np.array(done).copy()
191 | self.values[self.pos] = value.numpy().flatten().copy()
192 | self.log_probs[self.pos] = log_prob.numpy().copy()
193 | self.pos += 1
194 | if self.pos == self.buffer_size:
195 | self.full = True
196 |
197 | def get(self, batch_size=None):
198 | assert self.full
199 | indices = np.random.permutation(self.buffer_size * self.n_envs)
200 | # Prepare the data
201 | if not self.generator_ready:
202 | for tensor in ['observations', 'actions', 'values',
203 | 'log_probs', 'advantages', 'returns']:
204 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
205 | self.generator_ready = True
206 |
207 | # Return everything, don't create minibatches
208 | if batch_size is None:
209 | batch_size = self.buffer_size * self.n_envs
210 |
211 | start_idx = 0
212 | while start_idx < self.buffer_size * self.n_envs:
213 | yield self._get_samples(indices[start_idx:start_idx + batch_size])
214 | start_idx += batch_size
215 |
216 | def _get_samples(self, batch_inds):
217 | return (self.observations[batch_inds],
218 | self.actions[batch_inds],
219 | self.values[batch_inds].flatten(),
220 | self.log_probs[batch_inds].flatten(),
221 | self.advantages[batch_inds].flatten(),
222 | self.returns[batch_inds].flatten())
223 |
224 | @staticmethod
225 | def swap_and_flatten(tensor):
226 | """
227 | Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
228 | to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
229 | to [n_steps * n_envs, ...] (which maintain the order)
230 |
231 | :param tensor: (np.ndarray)
232 | :return: (np.ndarray)
233 | """
234 | shape = tensor.shape
235 | if len(shape) < 3:
236 | shape = shape + (1,)
237 | return tensor.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
238 |
--------------------------------------------------------------------------------
/stable_baselines/common/distributions.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.keras.layers as layers
3 | from tensorflow.keras.models import Sequential
4 | import tensorflow_probability as tfp
5 | from gym import spaces
6 |
7 |
8 | class Distribution(object):
9 | def __init__(self):
10 | super(Distribution, self).__init__()
11 |
12 | def log_prob(self, x):
13 | """
14 | returns the log likelihood
15 |
16 | :param x: (object) the taken action
17 | :return: (tf.Tensor) The log likelihood of the distribution
18 | """
19 | raise NotImplementedError
20 |
21 | def entropy(self):
22 | """
23 | Returns shannon's entropy of the probability
24 |
25 | :return: (tf.Tensor) the entropy
26 | """
27 | raise NotImplementedError
28 |
29 | def sample(self):
30 | """
31 | returns a sample from the probabilty distribution
32 |
33 | :return: (tf.Tensor) the stochastic action
34 | """
35 | raise NotImplementedError
36 |
37 |
38 | class DiagGaussianDistribution(Distribution):
39 | """
40 | Gaussian distribution with diagonal covariance matrix,
41 | for continuous actions.
42 |
43 | :param action_dim: (int) Number of continuous actions
44 | """
45 |
46 | def __init__(self, action_dim):
47 | super(DiagGaussianDistribution, self).__init__()
48 | self.distribution = None
49 | self.action_dim = action_dim
50 | self.mean_actions = None
51 | self.log_std = None
52 |
53 | def proba_distribution_net(self, latent_dim, log_std_init=0.0):
54 | """
55 | Create the layers and parameter that represent the distribution:
56 | one output will be the mean of the gaussian, the other parameter will be the
57 | standard deviation (log std in fact to allow negative values)
58 |
59 | :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
60 | :param log_std_init: (float) Initial value for the log standard deviation
61 | :return: (tf.keras.models.Sequential, tf.Variable)
62 | """
63 | mean_actions = Sequential(layers.Dense(self.action_dim, input_shape=(latent_dim,), activation=None))
64 | log_std = tf.Variable(tf.ones(self.action_dim) * log_std_init)
65 | return mean_actions, log_std
66 |
67 | def proba_distribution(self, mean_actions, log_std, deterministic=False):
68 | """
69 | Create and sample for the distribution given its parameters (mean, std)
70 |
71 | :param mean_actions: (tf.Tensor)
72 | :param log_std: (tf.Tensor)
73 | :param deterministic: (bool)
74 | :return: (tf.Tensor)
75 | """
76 | action_std = tf.ones_like(mean_actions) * tf.exp(log_std)
77 | self.distribution = tfp.distributions.Normal(mean_actions, action_std)
78 | if deterministic:
79 | action = self.mode()
80 | else:
81 | action = self.sample()
82 | return action, self
83 |
84 | def mode(self):
85 | return self.distribution.mode()
86 |
87 | def sample(self):
88 | return self.distribution.sample()
89 |
90 | def entropy(self):
91 | return self.distribution.entropy()
92 |
93 | def log_prob_from_params(self, mean_actions, log_std):
94 | """
95 | Compute the log probabilty of taking an action
96 | given the distribution parameters.
97 |
98 | :param mean_actions: (tf.Tensor)
99 | :param log_std: (tf.Tensor)
100 | :return: (tf.Tensor, tf.Tensor)
101 | """
102 | action, _ = self.proba_distribution(mean_actions, log_std)
103 | log_prob = self.log_prob(action)
104 | return action, log_prob
105 |
106 | def log_prob(self, action):
107 | """
108 | Get the log probabilty of an action given a distribution.
109 | Note that you must call `proba_distribution()` method
110 | before.
111 |
112 | :param action: (tf.Tensor)
113 | :return: (tf.Tensor)
114 | """
115 | log_prob = self.distribution.log_prob(action)
116 | if len(log_prob.shape) > 1:
117 | log_prob = tf.reduce_sum(log_prob, axis=1)
118 | else:
119 | log_prob = tf.reduce_sum(log_prob)
120 | return log_prob
121 |
122 |
123 | class CategoricalDistribution(Distribution):
124 | """
125 | Categorical distribution for discrete actions.
126 |
127 | :param action_dim: (int) Number of discrete actions
128 | """
129 | def __init__(self, action_dim):
130 | super(CategoricalDistribution, self).__init__()
131 | self.distribution = None
132 | self.action_dim = action_dim
133 |
134 | def proba_distribution_net(self, latent_dim):
135 | """
136 | Create the layer that represents the distribution:
137 | it will be the logits of the Categorical distribution.
138 | You can then get probabilties using a softmax.
139 |
140 | :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
141 | :return: (tf.keras.models.Sequential)
142 | """
143 | action_logits = layers.Dense(self.action_dim, input_shape=(latent_dim,), activation=None)
144 | return Sequential([action_logits])
145 |
146 | def proba_distribution(self, action_logits, deterministic=False):
147 | self.distribution = tfp.distributions.Categorical(logits=action_logits)
148 | if deterministic:
149 | action = self.mode()
150 | else:
151 | action = self.sample()
152 | return action, self
153 |
154 | def mode(self):
155 | return self.distribution.mode()
156 |
157 | def sample(self):
158 | return self.distribution.sample()
159 |
160 | def entropy(self):
161 | return self.distribution.entropy()
162 |
163 | def log_prob_from_params(self, action_logits):
164 | action, _ = self.proba_distribution(action_logits)
165 | log_prob = self.log_prob(action)
166 | return action, log_prob
167 |
168 | def log_prob(self, action):
169 | log_prob = self.distribution.log_prob(action)
170 | return log_prob
171 |
172 |
173 | def make_proba_distribution(action_space, dist_kwargs=None):
174 | """
175 | Return an instance of Distribution for the correct type of action space
176 |
177 | :param action_space: (Gym Space) the input action space
178 | :param dist_kwargs: (dict) Keyword arguments to pass to the probabilty distribution
179 | :return: (Distribution) the approriate Distribution object
180 | """
181 | if dist_kwargs is None:
182 | dist_kwargs = {}
183 |
184 | if isinstance(action_space, spaces.Box):
185 | assert len(action_space.shape) == 1, "Error: the action space must be a vector"
186 | return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs)
187 | elif isinstance(action_space, spaces.Discrete):
188 | return CategoricalDistribution(action_space.n, **dist_kwargs)
189 | # elif isinstance(action_space, spaces.MultiDiscrete):
190 | # return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
191 | # elif isinstance(action_space, spaces.MultiBinary):
192 | # return BernoulliDistribution(action_space.n, **dist_kwargs)
193 | else:
194 | raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}."
195 | .format(type(action_space)) +
196 | " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary.")
197 |
--------------------------------------------------------------------------------
/stable_baselines/common/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copied from stable_baselines
2 | import numpy as np
3 |
4 | from stable_baselines.common.vec_env import VecEnv
5 |
6 |
7 | def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
8 | render=False, callback=None, reward_threshold=None,
9 | return_episode_rewards=False):
10 | """
11 | Runs policy for `n_eval_episodes` episodes and returns average reward.
12 | This is made to work only with one env.
13 |
14 | :param model: (BaseRLModel) The RL agent you want to evaluate.
15 | :param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv`
16 | this must contain only one environment.
17 | :param n_eval_episodes: (int) Number of episode to evaluate the agent
18 | :param deterministic: (bool) Whether to use deterministic or stochastic actions
19 | :param render: (bool) Whether to render the environment or not
20 | :param callback: (callable) callback function to do additional checks,
21 | called after each step.
22 | :param reward_threshold: (float) Minimum expected reward per episode,
23 | this will raise an error if the performance is not met
24 | :param return_episode_rewards: (bool) If True, a list of reward per episode
25 | will be returned instead of the mean.
26 | :return: (float, float) Mean reward per episode, std of reward per episode
27 | returns ([float], int) when `return_episode_rewards` is True
28 | """
29 | if isinstance(env, VecEnv):
30 | assert env.num_envs == 1, "You must pass only one environment when using this function"
31 |
32 | episode_rewards, n_steps = [], 0
33 | for _ in range(n_eval_episodes):
34 | obs = env.reset()
35 | done = False
36 | episode_reward = 0.0
37 | while not done:
38 | action = model.predict(obs, deterministic=deterministic)
39 | obs, reward, done, _info = env.step(action)
40 | episode_reward += reward
41 | if callback is not None:
42 | callback(locals(), globals())
43 | n_steps += 1
44 | if render:
45 | env.render()
46 | episode_rewards.append(episode_reward)
47 | mean_reward = np.mean(episode_rewards)
48 | std_reward = np.std(episode_rewards)
49 | if reward_threshold is not None:
50 | assert mean_reward > reward_threshold, 'Mean reward below threshold: '\
51 | '{:.2f} < {:.2f}'.format(mean_reward, reward_threshold)
52 | if return_episode_rewards:
53 | return episode_rewards, n_steps
54 | return mean_reward, std_reward
55 |
--------------------------------------------------------------------------------
/stable_baselines/common/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | import datetime
6 | import tempfile
7 | import warnings
8 | from collections import defaultdict
9 |
10 | DEBUG = 10
11 | INFO = 20
12 | WARN = 30
13 | ERROR = 40
14 | DISABLED = 50
15 |
16 |
17 | class KVWriter(object):
18 | """
19 | Key Value writer
20 | """
21 | def writekvs(self, kvs):
22 | """
23 | write a dictionary to file
24 |
25 | :param kvs: (dict)
26 | """
27 | raise NotImplementedError
28 |
29 |
30 | class SeqWriter(object):
31 | """
32 | sequence writer
33 | """
34 | def writeseq(self, seq):
35 | """
36 | write an array to file
37 |
38 | :param seq: (list)
39 | """
40 | raise NotImplementedError
41 |
42 |
43 | class HumanOutputFormat(KVWriter, SeqWriter):
44 | def __init__(self, filename_or_file):
45 | """
46 | log to a file, in a human readable format
47 |
48 | :param filename_or_file: (str or File) the file to write the log to
49 | """
50 | if isinstance(filename_or_file, str):
51 | self.file = open(filename_or_file, 'wt')
52 | self.own_file = True
53 | else:
54 | assert hasattr(filename_or_file, 'write'), 'Expected file or str, got {}'.format(filename_or_file)
55 | self.file = filename_or_file
56 | self.own_file = False
57 |
58 | def writekvs(self, kvs):
59 | # Create strings for printing
60 | key2str = {}
61 | for (key, val) in sorted(kvs.items()):
62 | if isinstance(val, float):
63 | valstr = '%-8.3g' % (val,)
64 | else:
65 | valstr = str(val)
66 | key2str[self._truncate(key)] = self._truncate(valstr)
67 |
68 | # Find max widths
69 | if len(key2str) == 0:
70 | warnings.warn('Tried to write empty key-value dict')
71 | return
72 | else:
73 | keywidth = max(map(len, key2str.keys()))
74 | valwidth = max(map(len, key2str.values()))
75 |
76 | # Write out the data
77 | dashes = '-' * (keywidth + valwidth + 7)
78 | lines = [dashes]
79 | for (key, val) in sorted(key2str.items()):
80 | lines.append('| %s%s | %s%s |' % (
81 | key,
82 | ' ' * (keywidth - len(key)),
83 | val,
84 | ' ' * (valwidth - len(val)),
85 | ))
86 | lines.append(dashes)
87 | self.file.write('\n'.join(lines) + '\n')
88 |
89 | # Flush the output to the file
90 | self.file.flush()
91 |
92 | @classmethod
93 | def _truncate(cls, string):
94 | return string[:20] + '...' if len(string) > 23 else string
95 |
96 | def writeseq(self, seq):
97 | seq = list(seq)
98 | for (i, elem) in enumerate(seq):
99 | self.file.write(elem)
100 | if i < len(seq) - 1: # add space unless this is the last one
101 | self.file.write(' ')
102 | self.file.write('\n')
103 | self.file.flush()
104 |
105 | def close(self):
106 | """
107 | closes the file
108 | """
109 | if self.own_file:
110 | self.file.close()
111 |
112 |
113 | class JSONOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | """
116 | log to a file, in the JSON format
117 |
118 | :param filename: (str) the file to write the log to
119 | """
120 | self.file = open(filename, 'wt')
121 |
122 | def writekvs(self, kvs):
123 | for key, value in sorted(kvs.items()):
124 | if hasattr(value, 'dtype'):
125 | if value.shape == () or len(value) == 1:
126 | # if value is a dimensionless numpy array or of length 1, serialize as a float
127 | kvs[key] = float(value)
128 | else:
129 | # otherwise, a value is a numpy array, serialize as a list or nested lists
130 | kvs[key] = value.tolist()
131 | self.file.write(json.dumps(kvs) + '\n')
132 | self.file.flush()
133 |
134 | def close(self):
135 | """
136 | closes the file
137 | """
138 | self.file.close()
139 |
140 |
141 | class CSVOutputFormat(KVWriter):
142 | def __init__(self, filename):
143 | """
144 | log to a file, in a CSV format
145 |
146 | :param filename: (str) the file to write the log to
147 | """
148 | self.file = open(filename, 'w+t')
149 | self.keys = []
150 | self.sep = ','
151 |
152 | def writekvs(self, kvs):
153 | # Add our current row to the history
154 | extra_keys = kvs.keys() - self.keys
155 | if extra_keys:
156 | self.keys.extend(extra_keys)
157 | self.file.seek(0)
158 | lines = self.file.readlines()
159 | self.file.seek(0)
160 | for (i, key) in enumerate(self.keys):
161 | if i > 0:
162 | self.file.write(',')
163 | self.file.write(key)
164 | self.file.write('\n')
165 | for line in lines[1:]:
166 | self.file.write(line[:-1])
167 | self.file.write(self.sep * len(extra_keys))
168 | self.file.write('\n')
169 | for i, key in enumerate(self.keys):
170 | if i > 0:
171 | self.file.write(',')
172 | value = kvs.get(key)
173 | if value is not None:
174 | self.file.write(str(value))
175 | self.file.write('\n')
176 | self.file.flush()
177 |
178 | def close(self):
179 | """
180 | closes the file
181 | """
182 | self.file.close()
183 |
184 |
185 | def summary_val(key, value):
186 | """
187 | :param key: (str)
188 | :param value: (float)
189 | """
190 | kwargs = {'tag': key, 'simple_value': float(value)}
191 | return tf.Summary.Value(**kwargs)
192 |
193 |
194 | def valid_float_value(value):
195 | """
196 | Returns True if the value can be successfully cast into a float
197 |
198 | :param value: (Any) the value to check
199 | :return: (bool)
200 | """
201 | try:
202 | float(value)
203 | return True
204 | except TypeError:
205 | return False
206 |
207 |
208 | def make_output_format(_format, ev_dir, log_suffix=''):
209 | """
210 | return a logger for the requested format
211 |
212 | :param _format: (str) the requested format to log to ('stdout', 'log', 'json' or 'csv')
213 | :param ev_dir: (str) the logging directory
214 | :param log_suffix: (str) the suffix for the log file
215 | :return: (KVWrite) the logger
216 | """
217 | os.makedirs(ev_dir, exist_ok=True)
218 | if _format == 'stdout':
219 | return HumanOutputFormat(sys.stdout)
220 | elif _format == 'log':
221 | return HumanOutputFormat(os.path.join(ev_dir, 'log%s.txt' % log_suffix))
222 | elif _format == 'json':
223 | return JSONOutputFormat(os.path.join(ev_dir, 'progress%s.json' % log_suffix))
224 | elif _format == 'csv':
225 | return CSVOutputFormat(os.path.join(ev_dir, 'progress%s.csv' % log_suffix))
226 | else:
227 | raise ValueError('Unknown format specified: %s' % (_format,))
228 |
229 |
230 | # ================================================================
231 | # API
232 | # ================================================================
233 |
234 | def logkv(key, val):
235 | """
236 | Log a value of some diagnostic
237 | Call this once for each diagnostic quantity, each iteration
238 | If called many times, last value will be used.
239 |
240 | :param key: (Any) save to log this key
241 | :param val: (Any) save to log this value
242 | """
243 | Logger.CURRENT.logkv(key, val)
244 |
245 |
246 | def logkv_mean(key, val):
247 | """
248 | The same as logkv(), but if called many times, values averaged.
249 |
250 | :param key: (Any) save to log this key
251 | :param val: (Number) save to log this value
252 | """
253 | Logger.CURRENT.logkv_mean(key, val)
254 |
255 |
256 | def logkvs(key_values):
257 | """
258 | Log a dictionary of key-value pairs
259 |
260 | :param key_values: (dict) the list of keys and values to save to log
261 | """
262 | for key, value in key_values.items():
263 | logkv(key, value)
264 |
265 |
266 | def dumpkvs():
267 | """
268 | Write all of the diagnostics from the current iteration
269 | """
270 | Logger.CURRENT.dumpkvs()
271 |
272 |
273 | def getkvs():
274 | """
275 | get the key values logs
276 |
277 | :return: (dict) the logged values
278 | """
279 | return Logger.CURRENT.name2val
280 |
281 |
282 | def log(*args, **kwargs):
283 | """
284 | Write the sequence of args, with no separators,
285 | to the console and output files (if you've configured an output file).
286 |
287 | level: int. (see logger.py docs) If the global logger level is higher than
288 | the level argument here, don't print to stdout.
289 |
290 | :param args: (list) log the arguments
291 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
292 | """
293 | level = kwargs.get('level', INFO)
294 | Logger.CURRENT.log(*args, level=level)
295 |
296 |
297 | def debug(*args):
298 | """
299 | Write the sequence of args, with no separators,
300 | to the console and output files (if you've configured an output file).
301 | Using the DEBUG level.
302 |
303 | :param args: (list) log the arguments
304 | """
305 | log(*args, level=DEBUG)
306 |
307 |
308 | def info(*args):
309 | """
310 | Write the sequence of args, with no separators,
311 | to the console and output files (if you've configured an output file).
312 | Using the INFO level.
313 |
314 | :param args: (list) log the arguments
315 | """
316 | log(*args, level=INFO)
317 |
318 |
319 | def warn(*args):
320 | """
321 | Write the sequence of args, with no separators,
322 | to the console and output files (if you've configured an output file).
323 | Using the WARN level.
324 |
325 | :param args: (list) log the arguments
326 | """
327 | log(*args, level=WARN)
328 |
329 |
330 | def error(*args):
331 | """
332 | Write the sequence of args, with no separators,
333 | to the console and output files (if you've configured an output file).
334 | Using the ERROR level.
335 |
336 | :param args: (list) log the arguments
337 | """
338 | log(*args, level=ERROR)
339 |
340 |
341 | def set_level(level):
342 | """
343 | Set logging threshold on current logger.
344 |
345 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
346 | """
347 | Logger.CURRENT.set_level(level)
348 |
349 |
350 | def get_level():
351 | """
352 | Get logging threshold on current logger.
353 | :return: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
354 | """
355 | return Logger.CURRENT.level
356 |
357 |
358 | def get_dir():
359 | """
360 | Get directory that log files are being written to.
361 | will be None if there is no output directory (i.e., if you didn't call start)
362 |
363 | :return: (str) the logging directory
364 | """
365 | return Logger.CURRENT.get_dir()
366 |
367 |
368 | record_tabular = logkv
369 | dump_tabular = dumpkvs
370 |
371 |
372 | # ================================================================
373 | # Backend
374 | # ================================================================
375 |
376 | class Logger(object):
377 | # A logger with no output files. (See right below class definition)
378 | # So that you can still log to the terminal without setting up any output files
379 | DEFAULT = None
380 | CURRENT = None # Current logger being used by the free functions above
381 |
382 | def __init__(self, folder, output_formats):
383 | """
384 | the logger class
385 |
386 | :param folder: (str) the logging location
387 | :param output_formats: ([str]) the list of output format
388 | """
389 | self.name2val = defaultdict(float) # values this iteration
390 | self.name2cnt = defaultdict(int)
391 | self.level = INFO
392 | self.dir = folder
393 | self.output_formats = output_formats
394 |
395 | # Logging API, forwarded
396 | # ----------------------------------------
397 | def logkv(self, key, val):
398 | """
399 | Log a value of some diagnostic
400 | Call this once for each diagnostic quantity, each iteration
401 | If called many times, last value will be used.
402 |
403 | :param key: (Any) save to log this key
404 | :param val: (Any) save to log this value
405 | """
406 | self.name2val[key] = val
407 |
408 | def logkv_mean(self, key, val):
409 | """
410 | The same as logkv(), but if called many times, values averaged.
411 |
412 | :param key: (Any) save to log this key
413 | :param val: (Number) save to log this value
414 | """
415 | if val is None:
416 | self.name2val[key] = None
417 | return
418 | oldval, cnt = self.name2val[key], self.name2cnt[key]
419 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
420 | self.name2cnt[key] = cnt + 1
421 |
422 | def dumpkvs(self):
423 | """
424 | Write all of the diagnostics from the current iteration
425 | """
426 | if self.level == DISABLED:
427 | return
428 | for fmt in self.output_formats:
429 | if isinstance(fmt, KVWriter):
430 | fmt.writekvs(self.name2val)
431 | self.name2val.clear()
432 | self.name2cnt.clear()
433 |
434 | def log(self, *args, **kwargs):
435 | """
436 | Write the sequence of args, with no separators,
437 | to the console and output files (if you've configured an output file).
438 |
439 | level: int. (see logger.py docs) If the global logger level is higher than
440 | the level argument here, don't print to stdout.
441 |
442 | :param args: (list) log the arguments
443 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
444 | """
445 | level = kwargs.get('level', INFO)
446 | if self.level <= level:
447 | self._do_log(args)
448 |
449 | # Configuration
450 | # ----------------------------------------
451 | def set_level(self, level):
452 | """
453 | Set logging threshold on current logger.
454 |
455 | :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
456 | """
457 | self.level = level
458 |
459 | def get_dir(self):
460 | """
461 | Get directory that log files are being written to.
462 | will be None if there is no output directory (i.e., if you didn't call start)
463 |
464 | :return: (str) the logging directory
465 | """
466 | return self.dir
467 |
468 | def close(self):
469 | """
470 | closes the file
471 | """
472 | for fmt in self.output_formats:
473 | fmt.close()
474 |
475 | # Misc
476 | # ----------------------------------------
477 | def _do_log(self, args):
478 | """
479 | log to the requested format outputs
480 |
481 | :param args: (list) the arguments to log
482 | """
483 | for fmt in self.output_formats:
484 | if isinstance(fmt, SeqWriter):
485 | fmt.writeseq(map(str, args))
486 |
487 |
488 | Logger.DEFAULT = Logger.CURRENT = Logger(folder=None, output_formats=[HumanOutputFormat(sys.stdout)])
489 |
490 |
491 | def configure(folder=None, format_strs=None):
492 | """
493 | configure the current logger
494 |
495 | :param folder: (str) the save location (if None, $BASELINES_LOGDIR, if still None, tempdir/baselines-[date & time])
496 | :param format_strs: (list) the output logging format
497 | (if None, $BASELINES_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
498 | """
499 | if folder is None:
500 | folder = os.getenv('BASELINES_LOGDIR')
501 | if folder is None:
502 | folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("baselines-%Y-%m-%d-%H-%M-%S-%f"))
503 | assert isinstance(folder, str)
504 | os.makedirs(folder, exist_ok=True)
505 |
506 | log_suffix = ''
507 | if format_strs is None:
508 | format_strs = os.getenv('BASELINES_LOG_FORMAT', 'stdout,log,csv').split(',')
509 |
510 | format_strs = filter(None, format_strs)
511 | output_formats = [make_output_format(f, folder, log_suffix) for f in format_strs]
512 |
513 | Logger.CURRENT = Logger(folder=folder, output_formats=output_formats)
514 | log('Logging to %s' % folder)
515 |
--------------------------------------------------------------------------------
/stable_baselines/common/monitor.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 | import time
5 |
6 | from gym.core import Wrapper
7 |
8 |
9 | class Monitor(Wrapper):
10 | EXT = "monitor.csv"
11 | file_handler = None
12 |
13 | def __init__(self, env, filename=None, allow_early_resets=True, reset_keywords=(), info_keywords=()):
14 | """
15 | A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
16 |
17 | :param env: (Gym environment) The environment
18 | :param filename: (str) the location to save a log file, can be None for no log
19 | :param allow_early_resets: (bool) allows the reset of the environment before it is done
20 | :param reset_keywords: (tuple) extra keywords for the reset call, if extra parameters are needed at reset
21 | :param info_keywords: (tuple) extra information to log, from the information return of environment.step
22 | """
23 | Wrapper.__init__(self, env=env)
24 | self.t_start = time.time()
25 | if filename is None:
26 | self.file_handler = None
27 | self.logger = None
28 | else:
29 | if not filename.endswith(Monitor.EXT):
30 | if os.path.isdir(filename):
31 | filename = os.path.join(filename, Monitor.EXT)
32 | else:
33 | filename = filename + "." + Monitor.EXT
34 | self.file_handler = open(filename, "wt")
35 | self.file_handler.write('#%s\n' % json.dumps({"t_start": self.t_start, 'env_id': env.spec and env.spec.id}))
36 | self.logger = csv.DictWriter(self.file_handler,
37 | fieldnames=('r', 'l', 't') + reset_keywords + info_keywords)
38 | self.logger.writeheader()
39 | self.file_handler.flush()
40 |
41 | self.reset_keywords = reset_keywords
42 | self.info_keywords = info_keywords
43 | self.allow_early_resets = allow_early_resets
44 | self.rewards = None
45 | self.needs_reset = True
46 | self.episode_rewards = []
47 | self.episode_lengths = []
48 | self.episode_times = []
49 | self.total_steps = 0
50 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
51 |
52 | def reset(self, **kwargs):
53 | """
54 | Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
55 |
56 | :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
57 | :return: ([int] or [float]) the first observation of the environment
58 | """
59 | if not self.allow_early_resets and not self.needs_reset:
60 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, "
61 | "wrap your env with Monitor(env, path, allow_early_resets=True)")
62 | self.rewards = []
63 | self.needs_reset = False
64 | for key in self.reset_keywords:
65 | value = kwargs.get(key)
66 | if value is None:
67 | raise ValueError('Expected you to pass kwarg %s into reset' % key)
68 | self.current_reset_info[key] = value
69 | return self.env.reset(**kwargs)
70 |
71 | def step(self, action):
72 | """
73 | Step the environment with the given action
74 |
75 | :param action: ([int] or [float]) the action
76 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
77 | """
78 | if self.needs_reset:
79 | raise RuntimeError("Tried to step environment that needs reset")
80 | observation, reward, done, info = self.env.step(action)
81 | self.rewards.append(reward)
82 | if done:
83 | self.needs_reset = True
84 | ep_rew = sum(self.rewards)
85 | eplen = len(self.rewards)
86 | ep_info = {"r": round(ep_rew, 6), "l": eplen, "t": round(time.time() - self.t_start, 6)}
87 | for key in self.info_keywords:
88 | ep_info[key] = info[key]
89 | self.episode_rewards.append(ep_rew)
90 | self.episode_lengths.append(eplen)
91 | self.episode_times.append(time.time() - self.t_start)
92 | ep_info.update(self.current_reset_info)
93 | if self.logger:
94 | self.logger.writerow(ep_info)
95 | self.file_handler.flush()
96 | info['episode'] = ep_info
97 | self.total_steps += 1
98 | return observation, reward, done, info
99 |
100 | def close(self):
101 | """
102 | Closes the environment
103 | """
104 | if self.file_handler is not None:
105 | self.file_handler.close()
106 |
107 | def get_total_steps(self):
108 | """
109 | Returns the total number of timesteps
110 |
111 | :return: (int)
112 | """
113 | return self.total_steps
114 |
115 | def get_episode_rewards(self):
116 | """
117 | Returns the rewards of all the episodes
118 |
119 | :return: ([float])
120 | """
121 | return self.episode_rewards
122 |
123 | def get_episode_lengths(self):
124 | """
125 | Returns the number of timesteps of all the episodes
126 |
127 | :return: ([int])
128 | """
129 | return self.episode_lengths
130 |
131 | def get_episode_times(self):
132 | """
133 | Returns the runtime in seconds of all the episodes
134 |
135 | :return: ([float])
136 | """
137 | return self.episode_times
138 |
--------------------------------------------------------------------------------
/stable_baselines/common/noise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class AdaptiveParamNoiseSpec(object):
5 | """
6 | Implements adaptive parameter noise
7 |
8 | :param initial_stddev: (float) the initial value for the standard deviation of the noise
9 | :param desired_action_stddev: (float) the desired value for the standard deviation of the noise
10 | :param adoption_coefficient: (float) the update coefficient for the standard deviation of the noise
11 | """
12 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01):
13 | self.initial_stddev = initial_stddev
14 | self.desired_action_stddev = desired_action_stddev
15 | self.adoption_coefficient = adoption_coefficient
16 |
17 | self.current_stddev = initial_stddev
18 |
19 | def adapt(self, distance):
20 | """
21 | update the standard deviation for the parameter noise
22 |
23 | :param distance: (float) the noise distance applied to the parameters
24 | """
25 | if distance > self.desired_action_stddev:
26 | # Decrease stddev.
27 | self.current_stddev /= self.adoption_coefficient
28 | else:
29 | # Increase stddev.
30 | self.current_stddev *= self.adoption_coefficient
31 |
32 | def get_stats(self):
33 | """
34 | return the standard deviation for the parameter noise
35 |
36 | :return: (dict) the stats of the noise
37 | """
38 | return {'param_noise_stddev': self.current_stddev}
39 |
40 | def __repr__(self):
41 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})'
42 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient)
43 |
44 |
45 | class ActionNoise(object):
46 | """
47 | The action noise base class
48 | """
49 | def reset(self):
50 | """
51 | call end of episode reset for the noise
52 | """
53 | pass
54 |
55 |
56 | class NormalActionNoise(ActionNoise):
57 | """
58 | A Gaussian action noise
59 |
60 | :param mean: (float) the mean value of the noise
61 | :param sigma: (float) the scale of the noise (std here)
62 | """
63 | def __init__(self, mean, sigma):
64 | self._mu = mean
65 | self._sigma = sigma
66 |
67 | def __call__(self):
68 | return np.random.normal(self._mu, self._sigma)
69 |
70 | def __repr__(self):
71 | return 'NormalActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma)
72 |
73 |
74 | class OrnsteinUhlenbeckActionNoise(ActionNoise):
75 | """
76 | A Ornstein Uhlenbeck action noise, this is designed to approximate brownian motion with friction.
77 |
78 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
79 |
80 | :param mean: (float) the mean of the noise
81 | :param sigma: (float) the scale of the noise
82 | :param theta: (float) the rate of mean reversion
83 | :param dt: (float) the timestep for the noise
84 | :param initial_noise: ([float]) the initial value for the noise output, (if None: 0)
85 | """
86 |
87 | def __init__(self, mean, sigma, theta=.15, dt=1e-2, initial_noise=None):
88 | self._theta = theta
89 | self._mu = mean
90 | self._sigma = sigma
91 | self._dt = dt
92 | self.initial_noise = initial_noise
93 | self.noise_prev = None
94 | self.reset()
95 |
96 | def __call__(self):
97 | noise = self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + \
98 | self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
99 | self.noise_prev = noise
100 | return noise
101 |
102 | def reset(self):
103 | """
104 | reset the Ornstein Uhlenbeck noise, to the initial position
105 | """
106 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
107 |
108 | def __repr__(self):
109 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma)
110 |
--------------------------------------------------------------------------------
/stable_baselines/common/policies.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 |
3 | import tensorflow as tf
4 | import tensorflow.keras.layers as layers
5 | from tensorflow.keras import Model
6 | from tensorflow.keras.models import Sequential
7 |
8 |
9 | class BasePolicy(Model):
10 | """
11 | The base policy object
12 |
13 | :param observation_space: (Gym Space) The observation space of the environment
14 | :param action_space: (Gym Space) The action space of the environment
15 | """
16 |
17 | def __init__(self, observation_space, action_space):
18 | super(BasePolicy, self).__init__()
19 | self.observation_space = observation_space
20 | self.action_space = action_space
21 |
22 | def save(self, path):
23 | """
24 | Save model to a given location.
25 |
26 | :param path: (str)
27 | """
28 | raise NotImplementedError()
29 |
30 | def load(self, path):
31 | """
32 | Load saved model from path.
33 |
34 | :param path: (str)
35 | """
36 | raise NotImplementedError()
37 |
38 | @tf.function
39 | def soft_update(self, other_network, tau):
40 | other_variables = other_network.trainable_variables
41 | current_variables = self.trainable_variables
42 |
43 | for (current_var, other_var) in zip(current_variables, other_variables):
44 | current_var.assign((1. - tau) * current_var + tau * other_var)
45 |
46 | def hard_update(self, other_network):
47 | self.soft_update(other_network, tau=1.)
48 |
49 | def call(self, x):
50 | raise NotImplementedError()
51 |
52 |
53 | def create_mlp(input_dim, output_dim, net_arch,
54 | activation_fn=tf.nn.relu, squash_out=False):
55 | """
56 | Create a multi layer perceptron (MLP), which is
57 | a collection of fully-connected layers each followed by an activation function.
58 |
59 | :param input_dim: (int) Dimension of the input vector
60 | :param output_dim: (int)
61 | :param net_arch: ([int]) Architecture of the neural net
62 | It represents the number of units per layer.
63 | The length of this list is the number of layers.
64 | :param activation_fn: (tf.activations or str) The activation function
65 | to use after each layer.
66 | :param squash_out: (bool) Whether to squash the output using a Tanh
67 | activation function
68 | """
69 | modules = [layers.Flatten(input_shape=(input_dim,), dtype=tf.float32)]
70 |
71 | if len(net_arch) > 0:
72 | modules.append(layers.Dense(net_arch[0], activation=activation_fn))
73 |
74 | for idx in range(len(net_arch) - 1):
75 | modules.append(layers.Dense(net_arch[idx + 1], activation=activation_fn))
76 |
77 | if output_dim > 0:
78 | modules.append(layers.Dense(output_dim, activation=None))
79 | if squash_out:
80 | modules.append(layers.Activation(activation='tanh'))
81 | return modules
82 |
83 |
84 | _policy_registry = dict()
85 |
86 |
87 | def get_policy_from_name(base_policy_type, name):
88 | """
89 | returns the registed policy from the base type and name
90 |
91 | :param base_policy_type: (BasePolicy) the base policy object
92 | :param name: (str) the policy name
93 | :return: (base_policy_type) the policy
94 | """
95 | if base_policy_type not in _policy_registry:
96 | raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type))
97 | if name not in _policy_registry[base_policy_type]:
98 | raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!"
99 | .format(name, list(_policy_registry[base_policy_type].keys())))
100 | return _policy_registry[base_policy_type][name]
101 |
102 |
103 |
104 | def get_policy_from_name(base_policy_type, name):
105 | """
106 | returns the registed policy from the base type and name
107 |
108 | :param base_policy_type: (BasePolicy) the base policy object
109 | :param name: (str) the policy name
110 | :return: (base_policy_type) the policy
111 | """
112 | if base_policy_type not in _policy_registry:
113 | raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type))
114 | if name not in _policy_registry[base_policy_type]:
115 | raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!"
116 | .format(name, list(_policy_registry[base_policy_type].keys())))
117 | return _policy_registry[base_policy_type][name]
118 |
119 |
120 | def register_policy(name, policy):
121 | """
122 | returns the registed policy from the base type and name
123 |
124 | :param name: (str) the policy name
125 | :param policy: (subclass of BasePolicy) the policy
126 | """
127 | sub_class = None
128 | for cls in BasePolicy.__subclasses__():
129 | if issubclass(policy, cls):
130 | sub_class = cls
131 | break
132 | if sub_class is None:
133 | raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy))
134 |
135 | if sub_class not in _policy_registry:
136 | _policy_registry[sub_class] = {}
137 | if name in _policy_registry[sub_class]:
138 | raise ValueError("Error: the name {} is alreay registered for a different policy, will not override."
139 | .format(name))
140 | _policy_registry[sub_class][name] = policy
141 |
142 |
143 | class MlpExtractor(Model):
144 | """
145 | Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
146 | a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
147 | of them are shared between the policy network and the value network. It is assumed to be a list with the following
148 | structure:
149 |
150 | 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
151 | If the number of ints is zero, there will be no shared layers.
152 | 2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
153 | It is formatted like ``dict(vf=[], pi=[])``.
154 | If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
155 |
156 | For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
157 | network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
158 | would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
159 | would be specified as [128, 128].
160 |
161 |
162 | :param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN)
163 | :param net_arch: ([int or dict]) The specification of the policy and value networks.
164 | See above for details on its formatting.
165 | :param activation_fn: (tf.nn.activation) The activation function to use for the networks.
166 | """
167 | def __init__(self, feature_dim, net_arch, activation_fn):
168 | super(MlpExtractor, self).__init__()
169 |
170 | shared_net, policy_net, value_net = [], [], []
171 | policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
172 | value_only_layers = [] # Layer sizes of the network that only belongs to the value network
173 | last_layer_dim_shared = feature_dim
174 |
175 | # Iterate through the shared layers and build the shared parts of the network
176 | for idx, layer in enumerate(net_arch):
177 | if isinstance(layer, int): # Check that this is a shared layer
178 | layer_size = layer
179 | # TODO: give layer a meaningful name
180 | # shared_net.append(layers.Dense(layer_size, input_shape=(last_layer_dim_shared,), activation=activation_fn))
181 | shared_net.append(layers.Dense(layer_size, activation=activation_fn))
182 | last_layer_dim_shared = layer_size
183 | else:
184 | assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
185 | if 'pi' in layer:
186 | assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
187 | policy_only_layers = layer['pi']
188 |
189 | if 'vf' in layer:
190 | assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
191 | value_only_layers = layer['vf']
192 | break # From here on the network splits up in policy and value network
193 |
194 | last_layer_dim_pi = last_layer_dim_shared
195 | last_layer_dim_vf = last_layer_dim_shared
196 |
197 | # Build the non-shared part of the network
198 | for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
199 | if pi_layer_size is not None:
200 | assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
201 | policy_net.append(layers.Dense(pi_layer_size, input_shape=(last_layer_dim_pi,), activation=activation_fn))
202 | last_layer_dim_pi = pi_layer_size
203 |
204 | if vf_layer_size is not None:
205 | assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
206 | value_net.append(layers.Dense(vf_layer_size, input_shape=(last_layer_dim_vf,), activation=activation_fn))
207 | last_layer_dim_vf = vf_layer_size
208 |
209 | # Save dim, used to create the distributions
210 | self.latent_dim_pi = last_layer_dim_pi
211 | self.latent_dim_vf = last_layer_dim_vf
212 |
213 | # Create networks
214 | # If the list of layers is empty, the network will just act as an Identity module
215 | self.shared_net = Sequential(shared_net)
216 | self.policy_net = Sequential(policy_net)
217 | self.value_net = Sequential(value_net)
218 |
219 | def call(self, features):
220 | """
221 | :return: (tf.Tensor, tf.Tensor) latent_policy, latent_value of the specified network.
222 | If all layers are shared, then ``latent_policy == latent_value``
223 | """
224 | shared_latent = self.shared_net(features)
225 | return self.policy_net(shared_latent), self.value_net(shared_latent)
226 |
--------------------------------------------------------------------------------
/stable_baselines/common/running_mean_std.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class RunningMeanStd(object):
5 | def __init__(self, epsilon=1e-4, shape=()):
6 | """
7 | calulates the running mean and std of a data stream
8 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
9 |
10 | :param epsilon: (float) helps with arithmetic issues
11 | :param shape: (tuple) the shape of the data stream's output
12 | """
13 | self.mean = np.zeros(shape, 'float64')
14 | self.var = np.ones(shape, 'float64')
15 | self.count = epsilon
16 |
17 | def update(self, arr):
18 | batch_mean = np.mean(arr, axis=0)
19 | batch_var = np.var(arr, axis=0)
20 | batch_count = arr.shape[0]
21 | self.update_from_moments(batch_mean, batch_var, batch_count)
22 |
23 | def update_from_moments(self, batch_mean, batch_var, batch_count):
24 | delta = batch_mean - self.mean
25 | tot_count = self.count + batch_count
26 |
27 | new_mean = self.mean + delta * batch_count / tot_count
28 | m_a = self.var * self.count
29 | m_b = batch_var * batch_count
30 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
31 | new_var = m_2 / (self.count + batch_count)
32 |
33 | new_count = batch_count + self.count
34 |
35 | self.mean = new_mean
36 | self.var = new_var
37 | self.count = new_count
38 |
--------------------------------------------------------------------------------
/stable_baselines/common/save_util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import base64
3 | import pickle
4 | import cloudpickle
5 |
6 |
7 | def is_json_serializable(item):
8 | """
9 | Test if an object is serializable into JSON
10 |
11 | :param item: (object) The object to be tested for JSON serialization.
12 | :return: (bool) True if object is JSON serializable, false otherwise.
13 | """
14 | # Try with try-except struct.
15 | json_serializable = True
16 | try:
17 | _ = json.dumps(item)
18 | except TypeError:
19 | json_serializable = False
20 | return json_serializable
21 |
22 |
23 | def data_to_json(data):
24 | """
25 | Turn data (class parameters) into a JSON string for storing
26 |
27 | :param data: (Dict) Dictionary of class parameters to be
28 | stored. Items that are not JSON serializable will be
29 | pickled with Cloudpickle and stored as bytearray in
30 | the JSON file
31 | :return: (str) JSON string of the data serialized.
32 | """
33 | # First, check what elements can not be JSONfied,
34 | # and turn them into byte-strings
35 | serializable_data = {}
36 | for data_key, data_item in data.items():
37 | # See if object is JSON serializable
38 | if is_json_serializable(data_item):
39 | # All good, store as it is
40 | serializable_data[data_key] = data_item
41 | else:
42 | # Not serializable, cloudpickle it into
43 | # bytes and convert to base64 string for storing.
44 | # Also store type of the class for consumption
45 | # from other languages/humans, so we have an
46 | # idea what was being stored.
47 | base64_encoded = base64.b64encode(
48 | cloudpickle.dumps(data_item)
49 | ).decode()
50 |
51 | # Use ":" to make sure we do
52 | # not override these keys
53 | # when we include variables of the object later
54 | cloudpickle_serialization = {
55 | ":type:": str(type(data_item)),
56 | ":serialized:": base64_encoded
57 | }
58 |
59 | # Add first-level JSON-serializable items of the
60 | # object for further details (but not deeper than this to
61 | # avoid deep nesting).
62 | # First we check that object has attributes (not all do,
63 | # e.g. numpy scalars)
64 | if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
65 | # Take elements from __dict__ for custom classes
66 | item_generator = (
67 | data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
68 | )
69 | for variable_name, variable_item in item_generator():
70 | # Check if serializable. If not, just include the
71 | # string-representation of the object.
72 | if is_json_serializable(variable_item):
73 | cloudpickle_serialization[variable_name] = variable_item
74 | else:
75 | cloudpickle_serialization[variable_name] = str(variable_item)
76 |
77 | serializable_data[data_key] = cloudpickle_serialization
78 | json_string = json.dumps(serializable_data, indent=4)
79 | return json_string
80 |
81 |
82 | def json_to_data(json_string, custom_objects=None):
83 | """
84 | Turn JSON serialization of class-parameters back into dictionary.
85 |
86 | :param json_string: (str) JSON serialization of the class-parameters
87 | that should be loaded.
88 | :param custom_objects: (dict) Dictionary of objects to replace
89 | upon loading. If a variable is present in this dictionary as a
90 | key, it will not be deserialized and the corresponding item
91 | will be used instead. Similar to custom_objects in
92 | `keras.models.load_model`. Useful when you have an object in
93 | file that can not be deserialized.
94 | :return: (dict) Loaded class parameters.
95 | """
96 | if custom_objects is not None and not isinstance(custom_objects, dict):
97 | raise ValueError("custom_objects argument must be a dict or None")
98 |
99 | json_dict = json.loads(json_string)
100 | # This will be filled with deserialized data
101 | return_data = {}
102 | for data_key, data_item in json_dict.items():
103 | if custom_objects is not None and data_key in custom_objects.keys():
104 | # If item is provided in custom_objects, replace
105 | # the one from JSON with the one in custom_objects
106 | return_data[data_key] = custom_objects[data_key]
107 | elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
108 | # If item is dictionary with ":serialized:"
109 | # key, this means it is serialized with cloudpickle.
110 | serialization = data_item[":serialized:"]
111 | # Try-except deserialization in case we run into
112 | # errors. If so, we can tell bit more information to
113 | # user.
114 | try:
115 | deserialized_object = cloudpickle.loads(
116 | base64.b64decode(serialization.encode())
117 | )
118 | except pickle.UnpicklingError:
119 | raise RuntimeError(
120 | "Could not deserialize object {}. ".format(data_key) +
121 | "Consider using `custom_objects` argument to replace " +
122 | "this object."
123 | )
124 | return_data[data_key] = deserialized_object
125 | else:
126 | # Read as it is
127 | return_data[data_key] = data_item
128 | return return_data
129 |
--------------------------------------------------------------------------------
/stable_baselines/common/tile_images.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def tile_images(img_nhwc):
5 | """
6 | Tile N images into one big PxQ image
7 | (P,Q) are chosen to be as close as possible, and if N
8 | is square, then P=Q.
9 |
10 | :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc
11 | n = batch index, h = height, w = width, c = channel
12 | :return: (numpy float) img_HWc, ndim=3
13 | """
14 | img_nhwc = np.asarray(img_nhwc)
15 | n_images, height, width, n_channels = img_nhwc.shape
16 | # new_height was named H before
17 | new_height = int(np.ceil(np.sqrt(n_images)))
18 | # new_width was named W before
19 | new_width = int(np.ceil(float(n_images) / new_height))
20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
21 | # img_HWhwc
22 | out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels)
23 | # img_HhWwc
24 | out_image = out_image.transpose(0, 2, 1, 3, 4)
25 | # img_Hh_Ww_c
26 | out_image = out_image.reshape(new_height * height, new_width * width, n_channels)
27 | return out_image
28 |
29 |
--------------------------------------------------------------------------------
/stable_baselines/common/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 |
7 | def set_random_seed(seed):
8 | """
9 | Seed the different random generators
10 | :param seed: (int)
11 | :param using_cuda: (bool)
12 | """
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | tf.random.set_seed(seed)
16 |
17 |
18 | def explained_variance(y_pred, y_true):
19 | """
20 | Computes fraction of variance that ypred explains about y.
21 | Returns 1 - Var[y-ypred] / Var[y]
22 |
23 | interpretation:
24 | ev=0 => might as well have predicted zero
25 | ev=1 => perfect prediction
26 | ev<0 => worse than just predicting zero
27 |
28 | :param y_pred: (np.ndarray) the prediction
29 | :param y_true: (np.ndarray) the expected value
30 | :return: (float) explained variance of ypred and y
31 | """
32 | assert y_true.ndim == 1 and y_pred.ndim == 1
33 | var_y = np.var(y_true)
34 | return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
35 |
36 |
37 | def get_schedule_fn(value_schedule):
38 | """
39 | Transform (if needed) learning rate and clip range (for PPO)
40 | to callable.
41 |
42 | :param value_schedule: (callable or float)
43 | :return: (function)
44 | """
45 | # If the passed schedule is a float
46 | # create a constant function
47 | if isinstance(value_schedule, (float, int)):
48 | # Cast to float to avoid errors
49 | value_schedule = constant_fn(float(value_schedule))
50 | else:
51 | assert callable(value_schedule)
52 | return value_schedule
53 |
54 |
55 | def constant_fn(val):
56 | """
57 | Create a function that returns a constant
58 | It is useful for learning rate schedule (to avoid code duplication)
59 |
60 | :param val: (float)
61 | :return: (function)
62 | """
63 |
64 | def func(_):
65 | return val
66 |
67 | return func
68 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | # flake8: noqa F401
4 | from stable_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \
5 | CloudpickleWrapper
6 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
7 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
8 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
9 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize
10 | from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
11 | from stable_baselines.common.vec_env.vec_check_nan import VecCheckNan
12 |
13 |
14 |
15 | def unwrap_vec_normalize(env):
16 | """
17 | :param env: (gym.Env)
18 | :return: (VecNormalize)
19 | """
20 | env_tmp = env
21 | while isinstance(env_tmp, VecEnvWrapper):
22 | if isinstance(env_tmp, VecNormalize):
23 | return env_tmp
24 | env_tmp = env_tmp.venv
25 | return None
26 |
27 |
28 | # Define here to avoid circular import
29 | def sync_envs_normalization(env, eval_env):
30 | """
31 | Sync eval env and train env when using VecNormalize
32 |
33 | :param env: (gym.Env)
34 | :param eval_env: (gym.Env)
35 | """
36 | env_tmp, eval_env_tmp = env, eval_env
37 | while isinstance(env_tmp, VecEnvWrapper):
38 | if isinstance(env_tmp, VecNormalize):
39 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
40 | env_tmp = env_tmp.venv
41 | eval_env_tmp = eval_env_tmp.venv
42 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/base_vec_env.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import inspect
3 | import pickle
4 |
5 | import cloudpickle
6 | from stable_baselines.common import logger
7 |
8 |
9 | class AlreadySteppingError(Exception):
10 | """
11 | Raised when an asynchronous step is running while
12 | step_async() is called again.
13 | """
14 |
15 | def __init__(self):
16 | msg = 'already running an async step'
17 | Exception.__init__(self, msg)
18 |
19 |
20 | class NotSteppingError(Exception):
21 | """
22 | Raised when an asynchronous step is not running but
23 | step_wait() is called.
24 | """
25 |
26 | def __init__(self):
27 | msg = 'not running an async step'
28 | Exception.__init__(self, msg)
29 |
30 |
31 | class VecEnv(ABC):
32 | """
33 | An abstract asynchronous, vectorized environment.
34 |
35 | :param num_envs: (int) the number of environments
36 | :param observation_space: (Gym Space) the observation space
37 | :param action_space: (Gym Space) the action space
38 | """
39 | metadata = {
40 | 'render.modes': ['human', 'rgb_array']
41 | }
42 |
43 | def __init__(self, num_envs, observation_space, action_space):
44 | self.num_envs = num_envs
45 | self.observation_space = observation_space
46 | self.action_space = action_space
47 |
48 | @abstractmethod
49 | def reset(self):
50 | """
51 | Reset all the environments and return an array of
52 | observations, or a tuple of observation arrays.
53 |
54 | If step_async is still doing work, that work will
55 | be cancelled and step_wait() should not be called
56 | until step_async() is invoked again.
57 |
58 | :return: ([int] or [float]) observation
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def step_async(self, actions):
64 | """
65 | Tell all the environments to start taking a step
66 | with the given actions.
67 | Call step_wait() to get the results of the step.
68 |
69 | You should not call this if a step_async run is
70 | already pending.
71 | """
72 | pass
73 |
74 | @abstractmethod
75 | def step_wait(self):
76 | """
77 | Wait for the step taken with step_async().
78 |
79 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
80 | """
81 | pass
82 |
83 | @abstractmethod
84 | def close(self):
85 | """
86 | Clean up the environment's resources.
87 | """
88 | pass
89 |
90 | @abstractmethod
91 | def get_attr(self, attr_name, indices=None):
92 | """
93 | Return attribute from vectorized environment.
94 |
95 | :param attr_name: (str) The name of the attribute whose value to return
96 | :param indices: (list,int) Indices of envs to get attribute from
97 | :return: (list) List of values of 'attr_name' in all environments
98 | """
99 | pass
100 |
101 | @abstractmethod
102 | def set_attr(self, attr_name, value, indices=None):
103 | """
104 | Set attribute inside vectorized environments.
105 |
106 | :param attr_name: (str) The name of attribute to assign new value
107 | :param value: (obj) Value to assign to `attr_name`
108 | :param indices: (list,int) Indices of envs to assign value
109 | :return: (NoneType)
110 | """
111 | pass
112 |
113 | @abstractmethod
114 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
115 | """
116 | Call instance methods of vectorized environments.
117 |
118 | :param method_name: (str) The name of the environment method to invoke.
119 | :param indices: (list,int) Indices of envs whose method to call
120 | :param method_args: (tuple) Any positional arguments to provide in the call
121 | :param method_kwargs: (dict) Any keyword arguments to provide in the call
122 | :return: (list) List of items returned by the environment's method call
123 | """
124 | pass
125 |
126 | def step(self, actions):
127 | """
128 | Step the environments with the given action
129 |
130 | :param actions: ([int] or [float]) the action
131 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
132 | """
133 | self.step_async(actions)
134 | return self.step_wait()
135 |
136 | def get_images(self):
137 | """
138 | Return RGB images from each environment
139 | """
140 | raise NotImplementedError
141 |
142 | def render(self, *args, **kwargs):
143 | """
144 | Gym environment rendering
145 |
146 | :param mode: (str) the rendering type
147 | """
148 | logger.warn('Render not defined for %s' % self)
149 |
150 | @property
151 | def unwrapped(self):
152 | if isinstance(self, VecEnvWrapper):
153 | return self.venv.unwrapped
154 | else:
155 | return self
156 |
157 | def getattr_depth_check(self, name, already_found):
158 | """Check if an attribute reference is being hidden in a recursive call to __getattr__
159 |
160 | :param name: (str) name of attribute to check for
161 | :param already_found: (bool) whether this attribute has already been found in a wrapper
162 | :return: (str or None) name of module whose attribute is being shadowed, if any.
163 | """
164 | if hasattr(self, name) and already_found:
165 | return "{0}.{1}".format(type(self).__module__, type(self).__name__)
166 | else:
167 | return None
168 |
169 | def _get_indices(self, indices):
170 | """
171 | Convert a flexibly-typed reference to environment indices to an implied list of indices.
172 |
173 | :param indices: (None,int,Iterable) refers to indices of envs.
174 | :return: (list) the implied list of indices.
175 | """
176 | if indices is None:
177 | indices = range(self.num_envs)
178 | elif isinstance(indices, int):
179 | indices = [indices]
180 | return indices
181 |
182 |
183 | class VecEnvWrapper(VecEnv):
184 | """
185 | Vectorized environment base class
186 |
187 | :param venv: (VecEnv) the vectorized environment to wrap
188 | :param observation_space: (Gym Space) the observation space (can be None to load from venv)
189 | :param action_space: (Gym Space) the action space (can be None to load from venv)
190 | """
191 |
192 | def __init__(self, venv, observation_space=None, action_space=None):
193 | self.venv = venv
194 | VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
195 | action_space=action_space or venv.action_space)
196 | self.class_attributes = dict(inspect.getmembers(self.__class__))
197 |
198 | def step_async(self, actions):
199 | self.venv.step_async(actions)
200 |
201 | @abstractmethod
202 | def reset(self):
203 | pass
204 |
205 | @abstractmethod
206 | def step_wait(self):
207 | pass
208 |
209 | def close(self):
210 | return self.venv.close()
211 |
212 | def render(self, *args, **kwargs):
213 | return self.venv.render(*args, **kwargs)
214 |
215 | def get_images(self):
216 | return self.venv.get_images()
217 |
218 | def get_attr(self, attr_name, indices=None):
219 | return self.venv.get_attr(attr_name, indices)
220 |
221 | def set_attr(self, attr_name, value, indices=None):
222 | return self.venv.set_attr(attr_name, value, indices)
223 |
224 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
225 | return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
226 |
227 | def __getattr__(self, name):
228 | """Find attribute from wrapped venv(s) if this wrapper does not have it.
229 | Useful for accessing attributes from venvs which are wrapped with multiple wrappers
230 | which have unique attributes of interest.
231 | """
232 | blocked_class = self.getattr_depth_check(name, already_found=False)
233 | if blocked_class is not None:
234 | own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
235 | format_str = ("Error: Recursive attribute lookup for {0} from {1} is "
236 | "ambiguous and hides attribute from {2}")
237 | raise AttributeError(format_str.format(name, own_class, blocked_class))
238 |
239 | return self.getattr_recursive(name)
240 |
241 | def _get_all_attributes(self):
242 | """Get all (inherited) instance and class attributes
243 |
244 | :return: (dict) all_attributes
245 | """
246 | all_attributes = self.__dict__.copy()
247 | all_attributes.update(self.class_attributes)
248 | return all_attributes
249 |
250 | def getattr_recursive(self, name):
251 | """Recursively check wrappers to find attribute.
252 |
253 | :param name (str) name of attribute to look for
254 | :return: (object) attribute
255 | """
256 | all_attributes = self._get_all_attributes()
257 | if name in all_attributes: # attribute is present in this wrapper
258 | attr = getattr(self, name)
259 | elif hasattr(self.venv, 'getattr_recursive'):
260 | # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
261 | # to avoid a duplicate call to getattr_depth_check.
262 | attr = self.venv.getattr_recursive(name)
263 | else: # attribute not present, child is an unwrapped VecEnv
264 | attr = getattr(self.venv, name)
265 |
266 | return attr
267 |
268 | def getattr_depth_check(self, name, already_found):
269 | """See base class.
270 |
271 | :return: (str or None) name of module whose attribute is being shadowed, if any.
272 | """
273 | all_attributes = self._get_all_attributes()
274 | if name in all_attributes and already_found:
275 | # this venv's attribute is being hidden because of a higher venv.
276 | shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
277 | elif name in all_attributes and not already_found:
278 | # we have found the first reference to the attribute. Now check for duplicates.
279 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
280 | else:
281 | # this wrapper does not have the attribute. Keep searching.
282 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
283 |
284 | return shadowed_wrapper_class
285 |
286 |
287 | class CloudpickleWrapper(object):
288 | def __init__(self, var):
289 | """
290 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
291 |
292 | :param var: (Any) the variable you wish to wrap for pickling with cloudpickle
293 | """
294 | self.var = var
295 |
296 | def __getstate__(self):
297 | return cloudpickle.dumps(self.var)
298 |
299 | def __setstate__(self, obs):
300 | self.var = pickle.loads(obs)
301 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/dummy_vec_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import numpy as np
3 |
4 | from stable_baselines.common.vec_env.base_vec_env import VecEnv
5 | from stable_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
6 |
7 |
8 | class DummyVecEnv(VecEnv):
9 | """
10 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
11 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of
12 | multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
13 | require a vectorized environment, but that you want a single environments to train with.
14 |
15 | :param env_fns: ([callable]) A list of functions that will create the environments
16 | (each callable returns a `Gym.Env` instance when called).
17 | """
18 |
19 | def __init__(self, env_fns):
20 | self.envs = [fn() for fn in env_fns]
21 | env = self.envs[0]
22 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
23 | obs_space = env.observation_space
24 | self.keys, shapes, dtypes = obs_space_info(obs_space)
25 |
26 | self.buf_obs = OrderedDict([
27 | (k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]))
28 | for k in self.keys])
29 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
30 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
31 | self.buf_infos = [{} for _ in range(self.num_envs)]
32 | self.actions = None
33 | self.metadata = env.metadata
34 |
35 | def step_async(self, actions):
36 | self.actions = actions
37 |
38 | def step_wait(self):
39 | for env_idx in range(self.num_envs):
40 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
41 | self.envs[env_idx].step(self.actions[env_idx])
42 | if self.buf_dones[env_idx]:
43 | # save final observation where user can get it, then reset
44 | self.buf_infos[env_idx]['terminal_observation'] = obs
45 | obs = self.envs[env_idx].reset()
46 | self._save_obs(env_idx, obs)
47 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
48 | self.buf_infos.copy())
49 |
50 | def reset(self):
51 | for env_idx in range(self.num_envs):
52 | obs = self.envs[env_idx].reset()
53 | self._save_obs(env_idx, obs)
54 | return self._obs_from_buf()
55 |
56 | def close(self):
57 | for env in self.envs:
58 | env.close()
59 |
60 | def seed(self, seed, indices=None):
61 | """
62 | :param seed: (int or [int])
63 | :param indices: ([int])
64 | """
65 | indices = self._get_indices(indices)
66 | if not hasattr(seed, 'len'):
67 | seed = [seed] * len(indices)
68 | assert len(seed) == len(indices)
69 | return [self.envs[i].seed(seed[i]) for i in indices]
70 |
71 | def get_images(self):
72 | return [env.render(mode='rgb_array') for env in self.envs]
73 |
74 | def render(self, *args, **kwargs):
75 | if self.num_envs == 1:
76 | return self.envs[0].render(*args, **kwargs)
77 | else:
78 | return super().render(*args, **kwargs)
79 |
80 | def _save_obs(self, env_idx, obs):
81 | for key in self.keys:
82 | if key is None:
83 | self.buf_obs[key][env_idx] = obs
84 | else:
85 | self.buf_obs[key][env_idx] = obs[key]
86 |
87 | def _obs_from_buf(self):
88 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
89 |
90 | def get_attr(self, attr_name, indices=None):
91 | """Return attribute from vectorized environment (see base class)."""
92 | target_envs = self._get_target_envs(indices)
93 | return [getattr(env_i, attr_name) for env_i in target_envs]
94 |
95 | def set_attr(self, attr_name, value, indices=None):
96 | """Set attribute inside vectorized environments (see base class)."""
97 | target_envs = self._get_target_envs(indices)
98 | for env_i in target_envs:
99 | setattr(env_i, attr_name, value)
100 |
101 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
102 | """Call instance methods of vectorized environments."""
103 | target_envs = self._get_target_envs(indices)
104 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
105 |
106 | def _get_target_envs(self, indices):
107 | indices = self._get_indices(indices)
108 | return [self.envs[i] for i in indices]
109 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/subproc_vec_env.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from collections import OrderedDict
3 |
4 | import gym
5 | import numpy as np
6 |
7 | from stable_baselines.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper
8 | from stable_baselines.common.tile_images import tile_images
9 |
10 |
11 | def _worker(remote, parent_remote, env_fn_wrapper):
12 | parent_remote.close()
13 | env = env_fn_wrapper.var()
14 | while True:
15 | try:
16 | cmd, data = remote.recv()
17 | if cmd == 'step':
18 | observation, reward, done, info = env.step(data)
19 | if done:
20 | # save final observation where user can get it, then reset
21 | info['terminal_observation'] = observation
22 | observation = env.reset()
23 | remote.send((observation, reward, done, info))
24 | elif cmd == 'reset':
25 | observation = env.reset()
26 | remote.send(observation)
27 | elif cmd == 'render':
28 | remote.send(env.render(*data[0], **data[1]))
29 | elif cmd == 'close':
30 | remote.close()
31 | break
32 | elif cmd == 'get_spaces':
33 | remote.send((env.observation_space, env.action_space))
34 | elif cmd == 'env_method':
35 | method = getattr(env, data[0])
36 | remote.send(method(*data[1], **data[2]))
37 | elif cmd == 'get_attr':
38 | remote.send(getattr(env, data))
39 | elif cmd == 'set_attr':
40 | remote.send(setattr(env, data[0], data[1]))
41 | else:
42 | raise NotImplementedError
43 | except EOFError:
44 | break
45 |
46 |
47 | class SubprocVecEnv(VecEnv):
48 | """
49 | Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
50 | process, allowing significant speed up when the environment is computationally complex.
51 |
52 | For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
53 | number of logical cores on your CPU.
54 |
55 | .. warning::
56 |
57 | Only 'forkserver' and 'spawn' start methods are thread-safe,
58 | which is important when TensorFlow sessions or other non thread-safe
59 | libraries are used in the parent (see issue #217). However, compared to
60 | 'fork' they incur a small start-up cost and have restrictions on
61 | global variables. With those methods, users must wrap the code in an
62 | ``if __name__ == "__main__":`` block.
63 | For more information, see the multiprocessing documentation.
64 |
65 | :param env_fns: ([callable]) A list of functions that will create the environments
66 | (each callable returns a `Gym.Env` instance when called).
67 | :param start_method: (str) method used to start the subprocesses.
68 | Must be one of the methods returned by multiprocessing.get_all_start_methods().
69 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
70 | """
71 |
72 | def __init__(self, env_fns, start_method=None):
73 | self.waiting = False
74 | self.closed = False
75 | n_envs = len(env_fns)
76 |
77 | if start_method is None:
78 | # Fork is not a thread safe method (see issue #217)
79 | # but is more user friendly (does not require to wrap the code in
80 | # a `if __name__ == "__main__":`)
81 | forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods()
82 | start_method = 'forkserver' if forkserver_available else 'spawn'
83 | ctx = multiprocessing.get_context(start_method)
84 |
85 | self.remotes, self.work_remotes = zip(*[ctx.Pipe(duplex=True) for _ in range(n_envs)])
86 | self.processes = []
87 | for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
88 | args = (work_remote, remote, CloudpickleWrapper(env_fn))
89 | # daemon=True: if the main process crashes, we should not cause things to hang
90 | process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
91 | process.start()
92 | self.processes.append(process)
93 | work_remote.close()
94 |
95 | self.remotes[0].send(('get_spaces', None))
96 | observation_space, action_space = self.remotes[0].recv()
97 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
98 |
99 | def step_async(self, actions):
100 | for remote, action in zip(self.remotes, actions):
101 | remote.send(('step', action))
102 | self.waiting = True
103 |
104 | def step_wait(self):
105 | results = [remote.recv() for remote in self.remotes]
106 | self.waiting = False
107 | obs, rews, dones, infos = zip(*results)
108 | return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
109 |
110 | def reset(self):
111 | for remote in self.remotes:
112 | remote.send(('reset', None))
113 | obs = [remote.recv() for remote in self.remotes]
114 | return _flatten_obs(obs, self.observation_space)
115 |
116 | def close(self):
117 | if self.closed:
118 | return
119 | if self.waiting:
120 | for remote in self.remotes:
121 | remote.recv()
122 | for remote in self.remotes:
123 | remote.send(('close', None))
124 | for process in self.processes:
125 | process.join()
126 | self.closed = True
127 |
128 | def render(self, mode='human', *args, **kwargs):
129 | for pipe in self.remotes:
130 | # gather images from subprocesses
131 | # `mode` will be taken into account later
132 | pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
133 | imgs = [pipe.recv() for pipe in self.remotes]
134 | # Create a big image by tiling images from subprocesses
135 | bigimg = tile_images(imgs)
136 | if mode == 'human':
137 | import cv2 # pytype:disable=import-error
138 | cv2.imshow('vecenv', bigimg[:, :, ::-1])
139 | cv2.waitKey(1)
140 | elif mode == 'rgb_array':
141 | return bigimg
142 | else:
143 | raise NotImplementedError
144 |
145 | def get_images(self):
146 | for pipe in self.remotes:
147 | pipe.send(('render', {"mode": 'rgb_array'}))
148 | imgs = [pipe.recv() for pipe in self.remotes]
149 | return imgs
150 |
151 | def get_attr(self, attr_name, indices=None):
152 | """Return attribute from vectorized environment (see base class)."""
153 | target_remotes = self._get_target_remotes(indices)
154 | for remote in target_remotes:
155 | remote.send(('get_attr', attr_name))
156 | return [remote.recv() for remote in target_remotes]
157 |
158 | def set_attr(self, attr_name, value, indices=None):
159 | """Set attribute inside vectorized environments (see base class)."""
160 | target_remotes = self._get_target_remotes(indices)
161 | for remote in target_remotes:
162 | remote.send(('set_attr', (attr_name, value)))
163 | for remote in target_remotes:
164 | remote.recv()
165 |
166 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
167 | """Call instance methods of vectorized environments."""
168 | target_remotes = self._get_target_remotes(indices)
169 | for remote in target_remotes:
170 | remote.send(('env_method', (method_name, method_args, method_kwargs)))
171 | return [remote.recv() for remote in target_remotes]
172 |
173 | def _get_target_remotes(self, indices):
174 | """
175 | Get the connection object needed to communicate with the wanted
176 | envs that are in subprocesses.
177 |
178 | :param indices: (None,int,Iterable) refers to indices of envs.
179 | :return: ([multiprocessing.Connection]) Connection object to communicate between processes.
180 | """
181 | indices = self._get_indices(indices)
182 | return [self.remotes[i] for i in indices]
183 |
184 |
185 | def _flatten_obs(obs, space):
186 | """
187 | Flatten observations, depending on the observation space.
188 |
189 | :param obs: (list or tuple where X is dict, tuple or ndarray) observations.
190 | A list or tuple of observations, one per environment.
191 | Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
192 | :return (OrderedDict, tuple or ndarray) flattened observations.
193 | A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
194 | Each NumPy array has the environment index as its first axis.
195 | """
196 | assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
197 | assert len(obs) > 0, "need observations from at least one environment"
198 |
199 | if isinstance(space, gym.spaces.Dict):
200 | assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
201 | assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
202 | return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
203 | elif isinstance(space, gym.spaces.Tuple):
204 | assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
205 | obs_len = len(space.spaces)
206 | return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len)))
207 | else:
208 | return np.stack(obs)
209 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for dealing with vectorized environments.
3 | """
4 |
5 | from collections import OrderedDict
6 |
7 | import gym
8 | import numpy as np
9 |
10 |
11 | def copy_obs_dict(obs):
12 | """
13 | Deep-copy a dict of numpy arrays.
14 |
15 | :param obs: (OrderedDict): a dict of numpy arrays.
16 | :return (OrderedDict) a dict of copied numpy arrays.
17 | """
18 | assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs))
19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
20 |
21 |
22 | def dict_to_obs(space, obs_dict):
23 | """
24 | Convert an internal representation raw_obs into the appropriate type
25 | specified by space.
26 |
27 | :param space: (gym.spaces.Space) an observation space.
28 | :param obs_dict: (OrderedDict) a dict of numpy arrays.
29 | :return (ndarray, tuple or dict): returns an observation
30 | of the same type as space. If space is Dict, function is identity;
31 | if space is Tuple, converts dict to Tuple; otherwise, space is
32 | unstructured and returns the value raw_obs[None].
33 | """
34 | if isinstance(space, gym.spaces.Dict):
35 | return obs_dict
36 | elif isinstance(space, gym.spaces.Tuple):
37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space"
38 | return tuple((obs_dict[i] for i in range(len(space.spaces))))
39 | else:
40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
41 | return obs_dict[None]
42 |
43 |
44 | def obs_space_info(obs_space):
45 | """
46 | Get dict-structured information about a gym.Space.
47 |
48 | Dict spaces are represented directly by their dict of subspaces.
49 | Tuple spaces are converted into a dict with keys indexing into the tuple.
50 | Unstructured spaces are represented by {None: obs_space}.
51 |
52 | :param obs_space: (gym.spaces.Space) an observation space
53 | :return (tuple) A tuple (keys, shapes, dtypes):
54 | keys: a list of dict keys.
55 | shapes: a dict mapping keys to shapes.
56 | dtypes: a dict mapping keys to dtypes.
57 | """
58 | if isinstance(obs_space, gym.spaces.Dict):
59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
60 | subspaces = obs_space.spaces
61 | elif isinstance(obs_space, gym.spaces.Tuple):
62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
63 | else:
64 | assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space))
65 | subspaces = {None: obs_space}
66 | keys = []
67 | shapes = {}
68 | dtypes = {}
69 | for key, box in subspaces.items():
70 | keys.append(key)
71 | shapes[key] = box.shape
72 | dtypes[key] = box.dtype
73 | return keys, shapes, dtypes
74 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 |
5 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
6 |
7 |
8 | class VecCheckNan(VecEnvWrapper):
9 | """
10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
11 | allowing you to know from what the NaN of inf originated from.
12 |
13 | :param venv: (VecEnv) the vectorized environment to wrap
14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning
15 | :param warn_once: (bool) Whether or not to only warn once.
16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well
17 | """
18 |
19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True):
20 | VecEnvWrapper.__init__(self, venv)
21 | self.raise_exception = raise_exception
22 | self.warn_once = warn_once
23 | self.check_inf = check_inf
24 | self._actions = None
25 | self._observations = None
26 | self._user_warned = False
27 |
28 | def step_async(self, actions):
29 | self._check_val(async_step=True, actions=actions)
30 |
31 | self._actions = actions
32 | self.venv.step_async(actions)
33 |
34 | def step_wait(self):
35 | observations, rewards, news, infos = self.venv.step_wait()
36 |
37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
38 |
39 | self._observations = observations
40 | return observations, rewards, news, infos
41 |
42 | def reset(self):
43 | observations = self.venv.reset()
44 | self._actions = None
45 |
46 | self._check_val(async_step=False, observations=observations)
47 |
48 | self._observations = observations
49 | return observations
50 |
51 | def _check_val(self, *, async_step, **kwargs):
52 | # if warn and warn once and have warned once: then stop checking
53 | if not self.raise_exception and self.warn_once and self._user_warned:
54 | return
55 |
56 | found = []
57 | for name, val in kwargs.items():
58 | has_nan = np.any(np.isnan(val))
59 | has_inf = self.check_inf and np.any(np.isinf(val))
60 | if has_inf:
61 | found.append((name, "inf"))
62 | if has_nan:
63 | found.append((name, "nan"))
64 |
65 | if found:
66 | self._user_warned = True
67 | msg = ""
68 | for i, (name, type_val) in enumerate(found):
69 | msg += "found {} in {}".format(type_val, name)
70 | if i != len(found) - 1:
71 | msg += ", "
72 |
73 | msg += ".\r\nOriginated from the "
74 |
75 | if not async_step:
76 | if self._actions is None:
77 | msg += "environment observation (at reset)"
78 | else:
79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions)
80 | else:
81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations)
82 |
83 | if self.raise_exception:
84 | raise ValueError(msg)
85 | else:
86 | warnings.warn(msg, UserWarning)
87 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
7 |
8 |
9 | class VecFrameStack(VecEnvWrapper):
10 | """
11 | Frame stacking wrapper for vectorized environment
12 |
13 | :param venv: (VecEnv) the vectorized environment to wrap
14 | :param n_stack: (int) Number of frames to stack
15 | """
16 |
17 | def __init__(self, venv, n_stack):
18 | self.venv = venv
19 | self.n_stack = n_stack
20 | wrapped_obs_space = venv.observation_space
21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
26 |
27 | def step_wait(self):
28 | observations, rewards, dones, infos = self.venv.step_wait()
29 | last_ax_size = observations.shape[-1]
30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
31 | for i, done in enumerate(dones):
32 | if done:
33 | if 'terminal_observation' in infos[i]:
34 | old_terminal = infos[i]['terminal_observation']
35 | new_terminal = np.concatenate(
36 | (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
37 | infos[i]['terminal_observation'] = new_terminal
38 | else:
39 | warnings.warn(
40 | "VecFrameStack wrapping a VecEnv without terminal_observation info")
41 | self.stackedobs[i] = 0
42 | self.stackedobs[..., -observations.shape[-1]:] = observations
43 | return self.stackedobs, rewards, dones, infos
44 |
45 | def reset(self):
46 | """
47 | Reset all environments
48 | """
49 | obs = self.venv.reset()
50 | self.stackedobs[...] = 0
51 | self.stackedobs[..., -obs.shape[-1]:] = obs
52 | return self.stackedobs
53 |
54 | def close(self):
55 | self.venv.close()
56 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_normalize.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import warnings
3 |
4 | import numpy as np
5 |
6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
7 | from stable_baselines.common.running_mean_std import RunningMeanStd
8 |
9 |
10 | class VecNormalize(VecEnvWrapper):
11 | """
12 | A moving average, normalizing wrapper for vectorized environment.
13 |
14 | It is pickleable which will save moving averages and configuration parameters.
15 | The wrapped environment `venv` is not saved, and must be restored manually with
16 | `set_venv` after being unpickled.
17 |
18 | :param venv: (VecEnv) the vectorized environment to wrap
19 | :param training: (bool) Whether to update or not the moving average
20 | :param norm_obs: (bool) Whether to normalize observation or not (default: True)
21 | :param norm_reward: (bool) Whether to normalize rewards or not (default: True)
22 | :param clip_obs: (float) Max absolute value for observation
23 | :param clip_reward: (float) Max value absolute for discounted reward
24 | :param gamma: (float) discount factor
25 | :param epsilon: (float) To avoid division by zero
26 | """
27 |
28 | def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
29 | clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8):
30 | VecEnvWrapper.__init__(self, venv)
31 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
32 | self.ret_rms = RunningMeanStd(shape=())
33 | self.clip_obs = clip_obs
34 | self.clip_reward = clip_reward
35 | # Returns: discounted rewards
36 | self.ret = np.zeros(self.num_envs)
37 | self.gamma = gamma
38 | self.epsilon = epsilon
39 | self.training = training
40 | self.norm_obs = norm_obs
41 | self.norm_reward = norm_reward
42 | self.old_obs = None
43 | self.old_rews = None
44 |
45 | def __getstate__(self):
46 | """
47 | Gets state for pickling.
48 |
49 | Excludes self.venv, as in general VecEnv's may not be pickleable."""
50 | state = self.__dict__.copy()
51 | # these attributes are not pickleable
52 | del state['venv']
53 | del state['class_attributes']
54 | # these attributes depend on the above and so we would prefer not to pickle
55 | del state['ret']
56 | return state
57 |
58 | def __setstate__(self, state):
59 | """
60 | Restores pickled state.
61 |
62 | User must call set_venv() after unpickling before using.
63 |
64 | :param state: (dict)"""
65 | self.__dict__.update(state)
66 | assert 'venv' not in state
67 | self.venv = None
68 |
69 | def set_venv(self, venv):
70 | """
71 | Sets the vector environment to wrap to venv.
72 |
73 | Also sets attributes derived from this such as `num_env`.
74 |
75 | :param venv: (VecEnv)
76 | """
77 | if self.venv is not None:
78 | raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
79 | VecEnvWrapper.__init__(self, venv)
80 | if self.obs_rms.mean.shape != self.observation_space.shape:
81 | raise ValueError("venv is incompatible with current statistics.")
82 | self.ret = np.zeros(self.num_envs)
83 |
84 | def step_wait(self):
85 | """
86 | Apply sequence of actions to sequence of environments
87 | actions -> (observations, rewards, news)
88 |
89 | where 'news' is a boolean vector indicating whether each element is new.
90 | """
91 | obs, rews, news, infos = self.venv.step_wait()
92 | self.old_obs = obs
93 | self.old_rews = rews
94 |
95 | if self.training:
96 | self.obs_rms.update(obs)
97 | obs = self.normalize_obs(obs)
98 |
99 | if self.training:
100 | self._update_reward(rews)
101 | rews = self.normalize_reward(rews)
102 |
103 | self.ret[news] = 0
104 | return obs, rews, news, infos
105 |
106 | def _update_reward(self, reward: np.ndarray) -> None:
107 | """Update reward normalization statistics."""
108 | self.ret = self.ret * self.gamma + reward
109 | self.ret_rms.update(self.ret)
110 |
111 | def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
112 | """
113 | Normalize observations using this VecNormalize's observations statistics.
114 | Calling this method does not update statistics.
115 | """
116 | if self.norm_obs:
117 | obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon),
118 | -self.clip_obs,
119 | self.clip_obs)
120 | return obs
121 |
122 | def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
123 | """
124 | Normalize rewards using this VecNormalize's rewards statistics.
125 | Calling this method does not update statistics.
126 | """
127 | if self.norm_reward:
128 | reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon),
129 | -self.clip_reward, self.clip_reward)
130 | return reward
131 |
132 | def get_original_obs(self) -> np.ndarray:
133 | """
134 | Returns an unnormalized version of the observations from the most recent
135 | step or reset.
136 | """
137 | return self.old_obs.copy()
138 |
139 | def get_original_reward(self) -> np.ndarray:
140 | """
141 | Returns an unnormalized version of the rewards from the most recent step.
142 | """
143 | return self.old_rews.copy()
144 |
145 | def reset(self):
146 | """
147 | Reset all environments
148 | """
149 | obs = self.venv.reset()
150 | self.old_obs = obs
151 | self.ret = np.zeros(self.num_envs)
152 | if self.training:
153 | self._update_reward(self.ret)
154 | return self.normalize_obs(obs)
155 |
156 | @staticmethod
157 | def load(load_path, venv):
158 | """
159 | Loads a saved VecNormalize object.
160 |
161 | :param load_path: the path to load from.
162 | :param venv: the VecEnv to wrap.
163 | :return: (VecNormalize)
164 | """
165 | with open(load_path, "rb") as file_handler:
166 | vec_normalize = pickle.load(file_handler)
167 | vec_normalize.set_venv(venv)
168 | return vec_normalize
169 |
170 | def save(self, save_path):
171 | with open(save_path, "wb") as file_handler:
172 | pickle.dump(self, file_handler)
173 |
174 | def save_running_average(self, path):
175 | """
176 | :param path: (str) path to log dir
177 |
178 | .. deprecated:: 2.9.0
179 | This function will be removed in a future version
180 | """
181 | warnings.warn("Usage of `save_running_average` is deprecated. Please "
182 | "use `save` or pickle instead.", DeprecationWarning)
183 | for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
184 | with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
185 | pickle.dump(rms, file_handler)
186 |
187 | def load_running_average(self, path):
188 | """
189 | :param path: (str) path to log dir
190 |
191 | .. deprecated:: 2.9.0
192 | This function will be removed in a future version
193 | """
194 | warnings.warn("Usage of `load_running_average` is deprecated. Please "
195 | "use `load` or pickle instead.", DeprecationWarning)
196 | for name in ['obs_rms', 'ret_rms']:
197 | with open("{}/{}.pkl".format(path, name), 'rb') as file_handler:
198 | setattr(self, name, pickle.load(file_handler))
199 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym.wrappers.monitoring import video_recorder
4 |
5 | from stable_baselines.common import logger
6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
7 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
9 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
10 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize
11 |
12 |
13 | class VecVideoRecorder(VecEnvWrapper):
14 | """
15 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
16 | It requires ffmpeg or avconv to be installed on the machine.
17 |
18 | :param venv: (VecEnv or VecEnvWrapper)
19 | :param video_folder: (str) Where to save videos
20 | :param record_video_trigger: (func) Function that defines when to start recording.
21 | The function takes the current number of step,
22 | and returns whether we should start recording or not.
23 | :param video_length: (int) Length of recorded videos
24 | :param name_prefix: (str) Prefix to the video name
25 | """
26 |
27 | def __init__(self, venv, video_folder, record_video_trigger,
28 | video_length=200, name_prefix='rl-video'):
29 |
30 | VecEnvWrapper.__init__(self, venv)
31 |
32 | self.env = venv
33 | # Temp variable to retrieve metadata
34 | temp_env = venv
35 |
36 | # Unwrap to retrieve metadata dict
37 | # that will be used by gym recorder
38 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack):
39 | temp_env = temp_env.venv
40 |
41 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
42 | metadata = temp_env.get_attr('metadata')[0]
43 | else:
44 | metadata = temp_env.metadata
45 |
46 | self.env.metadata = metadata
47 |
48 | self.record_video_trigger = record_video_trigger
49 | self.video_recorder = None
50 |
51 | self.video_folder = os.path.abspath(video_folder)
52 | # Create output folder if needed
53 | os.makedirs(self.video_folder, exist_ok=True)
54 |
55 | self.name_prefix = name_prefix
56 | self.step_id = 0
57 | self.video_length = video_length
58 |
59 | self.recording = False
60 | self.recorded_frames = 0
61 |
62 | def reset(self):
63 | obs = self.venv.reset()
64 | self.start_video_recorder()
65 | return obs
66 |
67 | def start_video_recorder(self):
68 | self.close_video_recorder()
69 |
70 | video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id,
71 | self.step_id + self.video_length)
72 | base_path = os.path.join(self.video_folder, video_name)
73 | self.video_recorder = video_recorder.VideoRecorder(
74 | env=self.env,
75 | base_path=base_path,
76 | metadata={'step_id': self.step_id}
77 | )
78 |
79 | self.video_recorder.capture_frame()
80 | self.recorded_frames = 1
81 | self.recording = True
82 |
83 | def _video_enabled(self):
84 | return self.record_video_trigger(self.step_id)
85 |
86 | def step_wait(self):
87 | obs, rews, dones, infos = self.venv.step_wait()
88 |
89 | self.step_id += 1
90 | if self.recording:
91 | self.video_recorder.capture_frame()
92 | self.recorded_frames += 1
93 | if self.recorded_frames > self.video_length:
94 | logger.info("Saving video to ", self.video_recorder.path)
95 | self.close_video_recorder()
96 | elif self._video_enabled():
97 | self.start_video_recorder()
98 |
99 | return obs, rews, dones, infos
100 |
101 | def close_video_recorder(self):
102 | if self.recording:
103 | self.video_recorder.close()
104 | self.recording = False
105 | self.recorded_frames = 1
106 |
107 | def close(self):
108 | VecEnvWrapper.close(self)
109 | self.close_video_recorder()
110 |
111 | def __del__(self):
112 | self.close()
113 |
--------------------------------------------------------------------------------
/stable_baselines/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.ppo.ppo import PPO
2 | from stable_baselines.ppo.policies import MlpPolicy
3 |
--------------------------------------------------------------------------------
/stable_baselines/ppo/policies.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import tensorflow as tf
4 | import tensorflow.keras.layers as layers
5 | from tensorflow.keras.models import Sequential
6 | import numpy as np
7 |
8 | from stable_baselines.common.policies import BasePolicy, register_policy, MlpExtractor
9 | from stable_baselines.common.distributions import make_proba_distribution,\
10 | DiagGaussianDistribution, CategoricalDistribution
11 |
12 |
13 | class PPOPolicy(BasePolicy):
14 | """
15 | Policy class (with both actor and critic) for A2C and derivates (PPO).
16 |
17 | :param observation_space: (gym.spaces.Space) Observation space
18 | :param action_space: (gym.spaces.Space) Action space
19 | :param learning_rate: (callable) Learning rate schedule (could be constant)
20 | :param net_arch: ([int or dict]) The specification of the policy and value networks.
21 | :param activation_fn: (nn.Module) Activation function
22 | :param adam_epsilon: (float) Small values to avoid NaN in ADAM optimizer
23 | :param ortho_init: (bool) Whether to use or not orthogonal initialization
24 | :param log_std_init: (float) Initial value for the log standard deviation
25 | """
26 | def __init__(self, observation_space, action_space,
27 | learning_rate, net_arch=None,
28 | activation_fn=tf.nn.tanh, adam_epsilon=1e-5,
29 | ortho_init=True, log_std_init=0.0):
30 | super(PPOPolicy, self).__init__(observation_space, action_space)
31 | self.obs_dim = self.observation_space.shape[0]
32 |
33 | # Default network architecture, from stable-baselines
34 | if net_arch is None:
35 | net_arch = [dict(pi=[64, 64], vf=[64, 64])]
36 |
37 | self.net_arch = net_arch
38 | self.activation_fn = activation_fn
39 | self.adam_epsilon = adam_epsilon
40 | self.ortho_init = ortho_init
41 | self.net_args = {
42 | 'input_dim': self.obs_dim,
43 | 'output_dim': -1,
44 | 'net_arch': self.net_arch,
45 | 'activation_fn': self.activation_fn
46 | }
47 | self.shared_net = None
48 | self.pi_net, self.vf_net = None, None
49 | # In the future, feature_extractor will be replaced with a CNN
50 | self.features_extractor = Sequential(layers.Flatten(input_shape=(self.obs_dim,), dtype=tf.float32))
51 | self.features_dim = self.obs_dim
52 | self.log_std_init = log_std_init
53 | dist_kwargs = None
54 |
55 | # Action distribution
56 | self.action_dist = make_proba_distribution(action_space, dist_kwargs=dist_kwargs)
57 |
58 | self._build(learning_rate)
59 |
60 | def _build(self, learning_rate):
61 | self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
62 | activation_fn=self.activation_fn)
63 |
64 | latent_dim_pi = self.mlp_extractor.latent_dim_pi
65 |
66 | if isinstance(self.action_dist, DiagGaussianDistribution):
67 | self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
68 | log_std_init=self.log_std_init)
69 | elif isinstance(self.action_dist, CategoricalDistribution):
70 | self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
71 |
72 | self.value_net = Sequential(layers.Dense(1, input_shape=(self.mlp_extractor.latent_dim_vf,)))
73 |
74 | self.features_extractor.build()
75 | self.action_net.build()
76 | self.value_net.build()
77 | # Init weights: use orthogonal initialization
78 | # with small initial weight for the output
79 | if self.ortho_init:
80 | pass
81 | # for module in [self.mlp_extractor, self.action_net, self.value_net]:
82 | # # Values from stable-baselines, TODO: check why
83 | # gain = {
84 | # self.mlp_extractor: np.sqrt(2),
85 | # self.action_net: 0.01,
86 | # self.value_net: 1
87 | # }[module]
88 | # module.apply(partial(self.init_weights, gain=gain))
89 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate(1), epsilon=self.adam_epsilon)
90 |
91 | @tf.function
92 | def call(self, obs, deterministic=False):
93 | latent_pi, latent_vf = self._get_latent(obs)
94 | value = self.value_net(latent_vf)
95 | action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
96 | log_prob = action_distribution.log_prob(action)
97 | return action, value, log_prob
98 |
99 | def _get_latent(self, obs):
100 | features = self.features_extractor(obs)
101 | latent_pi, latent_vf = self.mlp_extractor(features)
102 | return latent_pi, latent_vf
103 |
104 | def _get_action_dist_from_latent(self, latent_pi, deterministic=False):
105 | mean_actions = self.action_net(latent_pi)
106 |
107 | if isinstance(self.action_dist, DiagGaussianDistribution):
108 | return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
109 |
110 | elif isinstance(self.action_dist, CategoricalDistribution):
111 | # Here mean_actions are the logits before the softmax
112 | return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
113 |
114 | def actor_forward(self, obs, deterministic=False):
115 | latent_pi, _ = self._get_latent(obs)
116 | action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
117 | return tf.stop_gradient(action).numpy()
118 |
119 | @tf.function
120 | def evaluate_actions(self, obs, action, deterministic=False):
121 | """
122 | Evaluate actions according to the current policy,
123 | given the observations.
124 |
125 | :param obs: (th.Tensor)
126 | :param action: (th.Tensor)
127 | :param deterministic: (bool)
128 | :return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions
129 | and entropy of the action distribution.
130 | """
131 | latent_pi, latent_vf = self._get_latent(obs)
132 | _, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
133 | log_prob = action_distribution.log_prob(action)
134 | value = self.value_net(latent_vf)
135 | return value, log_prob, action_distribution.entropy()
136 |
137 | def value_forward(self, obs):
138 | _, latent_vf, _ = self._get_latent(obs)
139 | return self.value_net(latent_vf)
140 |
141 |
142 | MlpPolicy = PPOPolicy
143 |
144 | register_policy("MlpPolicy", MlpPolicy)
145 |
--------------------------------------------------------------------------------
/stable_baselines/ppo/ppo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import gym
5 | from gym import spaces
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 | from stable_baselines.common.base_class import BaseRLModel
10 | from stable_baselines.common.buffers import RolloutBuffer
11 | from stable_baselines.common.utils import explained_variance, get_schedule_fn
12 | from stable_baselines.common import logger
13 | from stable_baselines.ppo.policies import PPOPolicy
14 |
15 |
16 | class PPO(BaseRLModel):
17 | """
18 | Proximal Policy Optimization algorithm (PPO) (clip version)
19 |
20 | Paper: https://arxiv.org/abs/1707.06347
21 | Code: This implementation borrows code from OpenAI spinningup (https://github.com/openai/spinningup/)
22 | https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
23 | and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
24 |
25 | Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
26 |
27 | :param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
28 | :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
29 | :param learning_rate: (float or callable) The learning rate, it can be a function
30 | of the current progress (from 1 to 0)
31 | :param n_steps: (int) The number of steps to run for each environment per update
32 | (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
33 | :param batch_size: (int) Minibatch size
34 | :param n_epochs: (int) Number of epoch when optimizing the surrogate loss
35 | :param gamma: (float) Discount factor
36 | :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
37 | :param clip_range: (float or callable) Clipping parameter, it can be a function of the current progress
38 | (from 1 to 0).
39 | :param clip_range_vf: (float or callable) Clipping parameter for the value function,
40 | it can be a function of the current progress (from 1 to 0).
41 | This is a parameter specific to the OpenAI implementation. If None is passed (default),
42 | no clipping will be done on the value function.
43 | IMPORTANT: this clipping depends on the reward scaling.
44 | :param ent_coef: (float) Entropy coefficient for the loss calculation
45 | :param vf_coef: (float) Value function coefficient for the loss calculation
46 | :param max_grad_norm: (float) The maximum value for the gradient clipping
47 | :param target_kl: (float) Limit the KL divergence between updates,
48 | because the clipping is not enough to prevent large update
49 | see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
50 | By default, there is no limit on the kl div.
51 | :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
52 | :param create_eval_env: (bool) Whether to create a second environment that will be
53 | used for evaluating the agent periodically. (Only available when passing string for the environment)
54 | :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
55 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
56 | :param seed: (int) Seed for the pseudo random generators
57 | :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
58 | """
59 | def __init__(self, policy, env, learning_rate=3e-4,
60 | n_steps=2048, batch_size=64, n_epochs=10,
61 | gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None,
62 | ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5,
63 | target_kl=None, tensorboard_log=None, create_eval_env=False,
64 | policy_kwargs=None, verbose=0, seed=0,
65 | _init_setup_model=True):
66 |
67 | super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs,
68 | verbose=verbose, create_eval_env=create_eval_env, support_multi_env=True, seed=seed)
69 |
70 | self.learning_rate = learning_rate
71 | self.batch_size = batch_size
72 | self.n_epochs = n_epochs
73 | self.n_steps = n_steps
74 | self.gamma = gamma
75 | self.gae_lambda = gae_lambda
76 | self.clip_range = clip_range
77 | self.clip_range_vf = clip_range_vf
78 | self.ent_coef = ent_coef
79 | self.vf_coef = vf_coef
80 | self.max_grad_norm = max_grad_norm
81 | self.rollout_buffer = None
82 | self.target_kl = target_kl
83 | self.tensorboard_log = tensorboard_log
84 | self.tb_writer = None
85 |
86 | if _init_setup_model:
87 | self._setup_model()
88 |
89 | def _setup_model(self):
90 | self._setup_learning_rate()
91 | # TODO: preprocessing: one hot vector for obs discrete
92 | state_dim = self.observation_space.shape[0]
93 | if isinstance(self.action_space, spaces.Box):
94 | # Action is a 1D vector
95 | action_dim = self.action_space.shape[0]
96 | elif isinstance(self.action_space, spaces.Discrete):
97 | # Action is a scalar
98 | action_dim = 1
99 |
100 | # TODO: different seed for each env when n_envs > 1
101 | if self.n_envs == 1:
102 | self.set_random_seed(self.seed)
103 |
104 | self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim,
105 | gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
106 | self.policy = self.policy_class(self.observation_space, self.action_space,
107 | self.learning_rate, **self.policy_kwargs)
108 |
109 | self.clip_range = get_schedule_fn(self.clip_range)
110 | if self.clip_range_vf is not None:
111 | self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
112 |
113 | def predict(self, observation, state=None, mask=None, deterministic=False):
114 | """
115 | Get the model's action from an observation
116 |
117 | :param observation: (np.ndarray) the input observation
118 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies)
119 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
120 | :param deterministic: (bool) Whether or not to return deterministic actions.
121 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
122 | """
123 | clipped_actions = self.policy.actor_forward(np.array(observation).reshape(1, -1), deterministic=deterministic)
124 | if isinstance(self.action_space, gym.spaces.Box):
125 | clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high)
126 | return clipped_actions
127 |
128 | def collect_rollouts(self, env, rollout_buffer, n_rollout_steps=256, callback=None,
129 | obs=None):
130 |
131 | n_steps = 0
132 | rollout_buffer.reset()
133 |
134 | while n_steps < n_rollout_steps:
135 | actions, values, log_probs = self.policy.call(obs)
136 | actions = actions.numpy()
137 |
138 | # Rescale and perform action
139 | clipped_actions = actions
140 | # Clip the actions to avoid out of bound error
141 | if isinstance(self.action_space, gym.spaces.Box):
142 | clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
143 | new_obs, rewards, dones, infos = env.step(clipped_actions)
144 |
145 | self._update_info_buffer(infos)
146 | n_steps += 1
147 | if isinstance(self.action_space, gym.spaces.Discrete):
148 | # Reshape in case of discrete action
149 | actions = actions.reshape(-1, 1)
150 | rollout_buffer.add(obs, actions, rewards, dones, values, log_probs)
151 | obs = new_obs
152 |
153 | rollout_buffer.compute_returns_and_advantage(values, dones=dones)
154 |
155 | return obs
156 |
157 | @tf.function
158 | def policy_loss(self, advantage, log_prob, old_log_prob, clip_range):
159 | # Normalize advantage
160 | advantage = (advantage - tf.reduce_mean(advantage)) / (tf.math.reduce_std(advantage) + 1e-8)
161 |
162 | # ratio between old and new policy, should be one at the first iteration
163 | ratio = tf.exp(log_prob - old_log_prob)
164 | # clipped surrogate loss
165 | policy_loss_1 = advantage * ratio
166 | policy_loss_2 = advantage * tf.clip_by_value(ratio, 1 - clip_range, 1 + clip_range)
167 | return - tf.reduce_mean(tf.minimum(policy_loss_1, policy_loss_2))
168 |
169 | @tf.function
170 | def value_loss(self, values, old_values, return_batch, clip_range_vf):
171 | if clip_range_vf is None:
172 | # No clipping
173 | values_pred = values
174 | else:
175 | # Clip the different between old and new value
176 | # NOTE: this depends on the reward scaling
177 | values_pred = old_values + tf.clip_by_value(values - old_values, -clip_range_vf, clip_range_vf)
178 | # Value loss using the TD(gae_lambda) target
179 | return tf.keras.losses.MSE(return_batch, values_pred)
180 |
181 | def train(self, gradient_steps, batch_size=64):
182 | # Update optimizer learning rate
183 | # self._update_learning_rate(self.policy.optimizer)
184 |
185 | # Compute current clip range
186 | clip_range = self.clip_range(self._current_progress)
187 | if self.clip_range_vf is not None:
188 | clip_range_vf = self.clip_range_vf(self._current_progress)
189 | else:
190 | clip_range_vf = None
191 |
192 | for gradient_step in range(gradient_steps):
193 | approx_kl_divs = []
194 | # Sample replay buffer
195 | for replay_data in self.rollout_buffer.get(batch_size):
196 | # Unpack
197 | obs, action, old_values, old_log_prob, advantage, return_batch = replay_data
198 |
199 | if isinstance(self.action_space, spaces.Discrete):
200 | # Convert discrete action for float to long
201 | action = action.astype(np.int64).flatten()
202 |
203 | with tf.GradientTape() as tape:
204 | tape.watch(self.policy.trainable_variables)
205 | values, log_prob, entropy = self.policy.evaluate_actions(obs, action)
206 | # Flatten
207 | values = tf.reshape(values, [-1])
208 |
209 | policy_loss = self.policy_loss(advantage, log_prob, old_log_prob, clip_range)
210 | value_loss = self.value_loss(values, old_values, return_batch, clip_range_vf)
211 |
212 | # Entropy loss favor exploration
213 | entropy_loss = -tf.reduce_mean(entropy)
214 |
215 | loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
216 |
217 | # Optimization step
218 | gradients = tape.gradient(loss, self.policy.trainable_variables)
219 | # Clip grad norm
220 | # gradients = tf.clip_by_norm(gradients, self.max_grad_norm)
221 | self.policy.optimizer.apply_gradients(zip(gradients, self.policy.trainable_variables))
222 | approx_kl_divs.append(tf.reduce_mean(old_log_prob - log_prob).numpy())
223 |
224 | if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
225 | print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step,
226 | np.mean(approx_kl_divs)))
227 | break
228 |
229 | explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
230 | self.rollout_buffer.values.flatten())
231 |
232 | logger.logkv("clip_range", clip_range)
233 | if self.clip_range_vf is not None:
234 | logger.logkv("clip_range_vf", clip_range_vf)
235 |
236 | logger.logkv("explained_variance", explained_var)
237 | # TODO: gather stats for the entropy and other losses?
238 | logger.logkv("entropy", entropy.numpy().mean())
239 | logger.logkv("policy_loss", policy_loss.numpy())
240 | logger.logkv("value_loss", value_loss.numpy())
241 | if hasattr(self.policy, 'log_std'):
242 | logger.logkv("std", tf.exp(self.policy.log_std).numpy().mean())
243 |
244 | def learn(self, total_timesteps, callback=None, log_interval=1,
245 | eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):
246 |
247 | timesteps_since_eval, iteration, evaluations, obs, eval_env = self._setup_learn(eval_env)
248 |
249 | if self.tensorboard_log is not None:
250 | self.tb_writer = tf.summary.create_file_writer(os.path.join(self.tensorboard_log, tb_log_name))
251 |
252 | while self.num_timesteps < total_timesteps:
253 |
254 | if callback is not None:
255 | # Only stop training if return value is False, not when it is None.
256 | if callback(locals(), globals()) is False:
257 | break
258 |
259 | obs = self.collect_rollouts(self.env, self.rollout_buffer, n_rollout_steps=self.n_steps,
260 | obs=obs)
261 | iteration += 1
262 | self.num_timesteps += self.n_steps * self.n_envs
263 | timesteps_since_eval += self.n_steps * self.n_envs
264 | self._update_current_progress(self.num_timesteps, total_timesteps)
265 |
266 | # Display training infos
267 | if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0:
268 | fps = int(self.num_timesteps / (time.time() - self.start_time))
269 | logger.logkv("iterations", iteration)
270 | if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
271 | logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
272 | logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
273 | logger.logkv("fps", fps)
274 | logger.logkv('time_elapsed', int(time.time() - self.start_time))
275 | logger.logkv("total timesteps", self.num_timesteps)
276 | logger.dumpkvs()
277 |
278 | self.train(self.n_epochs, batch_size=self.batch_size)
279 |
280 | # Evaluate the agent
281 | timesteps_since_eval = self._eval_policy(eval_freq, eval_env, n_eval_episodes,
282 | timesteps_since_eval, deterministic=True)
283 | # For tensorboard integration
284 | # if self.tb_writer is not None:
285 | # with self.tb_writer.as_default():
286 | # tf.summary.scalar('Eval/reward', mean_reward, self.num_timesteps)
287 |
288 |
289 | return self
290 |
--------------------------------------------------------------------------------
/stable_baselines/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/stable_baselines/py.typed
--------------------------------------------------------------------------------
/stable_baselines/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.td3.td3 import TD3
2 | from stable_baselines.td3.policies import MlpPolicy
3 |
--------------------------------------------------------------------------------
/stable_baselines/td3/policies.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.models import Sequential
3 |
4 | from stable_baselines.common.policies import BasePolicy, register_policy, create_mlp
5 |
6 |
7 | class Actor(BasePolicy):
8 | """
9 | Actor network (policy) for TD3.
10 |
11 | :param obs_dim: (int) Dimension of the observation
12 | :param action_dim: (int) Dimension of the action space
13 | :param net_arch: ([int]) Network architecture
14 | :param activation_fn: (str or tf.activation) Activation function
15 | """
16 | def __init__(self, obs_dim, action_dim, net_arch, activation_fn=tf.nn.relu):
17 | super(Actor, self).__init__(None, None)
18 |
19 | actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
20 | self.mu = Sequential(actor_net)
21 | self.mu.build()
22 |
23 | @tf.function
24 | def call(self, obs):
25 | return self.mu(obs)
26 |
27 |
28 | class Critic(BasePolicy):
29 | """
30 | Critic network for TD3,
31 | in fact it represents the action-state value function (Q-value function)
32 |
33 | :param obs_dim: (int) Dimension of the observation
34 | :param action_dim: (int) Dimension of the action space
35 | :param net_arch: ([int]) Network architecture
36 | :param activation_fn: (nn.Module) Activation function
37 | """
38 | def __init__(self, obs_dim, action_dim,
39 | net_arch, activation_fn=tf.nn.relu):
40 | super(Critic, self).__init__(None, None)
41 |
42 | q1_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
43 | self.q1_net = Sequential(q1_net)
44 |
45 | q2_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
46 | self.q2_net = Sequential(q2_net)
47 |
48 | self.q_networks = [self.q1_net, self.q2_net]
49 |
50 | for q_net in self.q_networks:
51 | q_net.build()
52 |
53 | @tf.function
54 | def call(self, obs, action):
55 | qvalue_input = tf.concat([obs, action], axis=1)
56 | return [q_net(qvalue_input) for q_net in self.q_networks]
57 |
58 | @tf.function
59 | def q1_forward(self, obs, action):
60 | return self.q_networks[0](tf.concat([obs, action], axis=1))
61 |
62 |
63 | class TD3Policy(BasePolicy):
64 | """
65 | Policy class (with both actor and critic) for TD3.
66 |
67 | :param observation_space: (gym.spaces.Space) Observation space
68 | :param action_space: (gym.spaces.Space) Action space
69 | :param learning_rate: (callable) Learning rate schedule (could be constant)
70 | :param net_arch: ([int or dict]) The specification of the policy and value networks.
71 | :param activation_fn: (str or tf.nn.activation) Activation function
72 | """
73 | def __init__(self, observation_space, action_space,
74 | learning_rate, net_arch=None,
75 | activation_fn=tf.nn.relu):
76 | super(TD3Policy, self).__init__(observation_space, action_space)
77 |
78 | # Default network architecture, from the original paper
79 | if net_arch is None:
80 | net_arch = [400, 300]
81 |
82 | self.obs_dim = self.observation_space.shape[0]
83 | self.action_dim = self.action_space.shape[0]
84 | self.net_arch = net_arch
85 | self.activation_fn = activation_fn
86 | self.net_args = {
87 | 'obs_dim': self.obs_dim,
88 | 'action_dim': self.action_dim,
89 | 'net_arch': self.net_arch,
90 | 'activation_fn': self.activation_fn
91 | }
92 | self.actor_kwargs = self.net_args.copy()
93 |
94 | self.actor, self.actor_target = None, None
95 | self.critic, self.critic_target = None, None
96 | self._build(learning_rate)
97 |
98 | def _build(self, learning_rate):
99 | self.actor = self.make_actor()
100 | self.actor_target = self.make_actor()
101 | self.actor_target.hard_update(self.actor)
102 | self.actor.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate(1))
103 |
104 | self.critic = self.make_critic()
105 | self.critic_target = self.make_critic()
106 | self.critic_target.hard_update(self.critic)
107 | self.critic.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate(1))
108 |
109 | def make_actor(self):
110 | return Actor(**self.actor_kwargs)
111 |
112 | def make_critic(self):
113 | return Critic(**self.net_args)
114 |
115 | @tf.function
116 | def call(self, obs):
117 | return self.actor(obs)
118 |
119 |
120 | MlpPolicy = TD3Policy
121 |
122 | register_policy("MlpPolicy", MlpPolicy)
123 |
--------------------------------------------------------------------------------
/stable_baselines/td3/td3.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | from stable_baselines.common.base_class import BaseRLModel
7 | from stable_baselines.common.buffers import ReplayBuffer
8 | from stable_baselines.common.vec_env import VecEnv
9 | from stable_baselines.common import logger
10 | from stable_baselines.td3.policies import TD3Policy
11 |
12 |
13 | class TD3(BaseRLModel):
14 | """
15 | Twin Delayed DDPG (TD3)
16 | Addressing Function Approximation Error in Actor-Critic Methods.
17 |
18 | Original implementation: https://github.com/sfujim/TD3
19 | Paper: https://arxiv.org/abs/1802.09477
20 | Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
21 |
22 | :param policy: (TD3Policy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
23 | :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
24 | :param buffer_size: (int) size of the replay buffer
25 | :param learning_rate: (float or callable) learning rate for adam optimizer,
26 | the same learning rate will be used for all networks (Q-Values and Actor networks)
27 | it can be a function of the current progress (from 1 to 0)
28 | :param policy_delay: (int) Policy and target networks will only be updated once every policy_delay steps
29 | per training steps. The Q values will be updated policy_delay more often (update every training step).
30 | :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
31 | :param gamma: (float) the discount factor
32 | :param batch_size: (int) Minibatch size for each gradient update
33 | :param train_freq: (int) Update the model every `train_freq` steps.
34 | :param gradient_steps: (int) How many gradient update after each step
35 | :param tau: (float) the soft update coefficient ("polyak update" of the target networks, between 0 and 1)
36 | :param action_noise: (ActionNoise) the action noise type. Cf common.noise for the different action noise type.
37 | :param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
38 | (smoothing noise)
39 | :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
40 | :param create_eval_env: (bool) Whether to create a second environment that will be
41 | used for evaluating the agent periodically. (Only available when passing string for the environment)
42 | :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
43 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
44 | :param seed: (int) Seed for the pseudo random generators
45 | :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
46 | """
47 |
48 | def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3,
49 | policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
50 | train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
51 | tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
52 | tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
53 | seed=None, _init_setup_model=True):
54 |
55 | super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose,
56 | create_eval_env=create_eval_env, seed=seed)
57 |
58 | self.buffer_size = buffer_size
59 | self.learning_rate = learning_rate
60 | self.learning_starts = learning_starts
61 | self.train_freq = train_freq
62 | self.gradient_steps = gradient_steps
63 | self.n_episodes_rollout = n_episodes_rollout
64 | self.batch_size = batch_size
65 | self.tau = tau
66 | self.gamma = gamma
67 | self.action_noise = action_noise
68 | self.policy_delay = policy_delay
69 | self.target_noise_clip = target_noise_clip
70 | self.target_policy_noise = target_policy_noise
71 |
72 | if _init_setup_model:
73 | self._setup_model()
74 |
75 | def _setup_model(self):
76 | self._setup_learning_rate()
77 | obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
78 | self.set_random_seed(self.seed)
79 | self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim)
80 | self.policy = self.policy_class(self.observation_space, self.action_space,
81 | self.learning_rate, **self.policy_kwargs)
82 | self._create_aliases()
83 |
84 | def _create_aliases(self):
85 | self.actor = self.policy.actor
86 | self.actor_target = self.policy.actor_target
87 | self.critic = self.policy.critic
88 | self.critic_target = self.policy.critic_target
89 |
90 | def predict(self, observation, state=None, mask=None, deterministic=True):
91 | """
92 | Get the model's action from an observation
93 |
94 | :param observation: (np.ndarray) the input observation
95 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies)
96 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
97 | :param deterministic: (bool) Whether or not to return deterministic actions.
98 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
99 | """
100 | return self.unscale_action(self.actor(np.array(observation).reshape(1, -1)).numpy())
101 |
102 | @tf.function
103 | def critic_loss(self, obs, action, next_obs, done, reward):
104 | # Select action according to policy and add clipped noise
105 | noise = tf.random.normal(shape=action.shape) * self.target_policy_noise
106 | noise = tf.clip_by_value(noise, -self.target_noise_clip, self.target_noise_clip)
107 | next_action = tf.clip_by_value(self.actor_target(next_obs) + noise, -1., 1.)
108 |
109 | # Compute the target Q value
110 | target_q1, target_q2 = self.critic_target(next_obs, next_action)
111 | target_q = tf.minimum(target_q1, target_q2)
112 | target_q = reward + tf.stop_gradient((1 - done) * self.gamma * target_q)
113 |
114 | # Get current Q estimates
115 | current_q1, current_q2 = self.critic(obs, action)
116 |
117 | # Compute critic loss
118 | return tf.keras.losses.MSE(current_q1, target_q) + tf.keras.losses.MSE(current_q2, target_q)
119 |
120 | @tf.function
121 | def actor_loss(self, obs):
122 | return - tf.reduce_mean(self.critic.q1_forward(obs, self.actor(obs)))
123 |
124 | @tf.function
125 | def update_targets(self):
126 | self.critic_target.soft_update(self.critic, self.tau)
127 | self.actor_target.soft_update(self.actor, self.tau)
128 |
129 | @tf.function
130 | def _train_critic(self, obs, action, next_obs, done, reward):
131 | with tf.GradientTape() as critic_tape:
132 | critic_tape.watch(self.critic.trainable_variables)
133 | critic_loss = self.critic_loss(obs, action, next_obs, done, reward)
134 |
135 | # Optimize the critic
136 | grads_critic = critic_tape.gradient(critic_loss, self.critic.trainable_variables)
137 | self.critic.optimizer.apply_gradients(zip(grads_critic, self.critic.trainable_variables))
138 |
139 | @tf.function
140 | def _train_actor(self, obs):
141 | with tf.GradientTape() as actor_tape:
142 | actor_tape.watch(self.actor.trainable_variables)
143 | # Compute actor loss
144 | actor_loss = self.actor_loss(obs)
145 |
146 | # Optimize the actor
147 | grads_actor = actor_tape.gradient(actor_loss, self.actor.trainable_variables)
148 | self.actor.optimizer.apply_gradients(zip(grads_actor, self.actor.trainable_variables))
149 |
150 | def train(self, gradient_steps, batch_size=100, policy_delay=2):
151 | # self._update_learning_rate()
152 |
153 | for gradient_step in range(gradient_steps):
154 |
155 | # Sample replay buffer
156 | obs, action, next_obs, done, reward = self.replay_buffer.sample(batch_size)
157 |
158 | self._train_critic(obs, action, next_obs, done, reward)
159 |
160 | # Delayed policy updates
161 | if gradient_step % policy_delay == 0:
162 |
163 | self._train_actor(obs)
164 |
165 | # Update the frozen target models
166 | self.update_targets()
167 |
168 |
169 | def learn(self, total_timesteps, callback=None, log_interval=4,
170 | eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True):
171 |
172 | timesteps_since_eval, episode_num, evaluations, obs, eval_env = self._setup_learn(eval_env)
173 |
174 | while self.num_timesteps < total_timesteps:
175 |
176 | if callback is not None:
177 | # Only stop training if return value is False, not when it is None.
178 | if callback(locals(), globals()) is False:
179 | break
180 |
181 | rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout,
182 | n_steps=self.train_freq, action_noise=self.action_noise,
183 | deterministic=False, callback=None,
184 | learning_starts=self.learning_starts,
185 | num_timesteps=self.num_timesteps,
186 | replay_buffer=self.replay_buffer,
187 | obs=obs, episode_num=episode_num,
188 | log_interval=log_interval)
189 | # Unpack
190 | episode_reward, episode_timesteps, n_episodes, obs = rollout
191 |
192 | episode_num += n_episodes
193 | self.num_timesteps += episode_timesteps
194 | timesteps_since_eval += episode_timesteps
195 | self._update_current_progress(self.num_timesteps, total_timesteps)
196 |
197 | if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
198 |
199 | gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
200 | self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
201 |
202 | # Evaluate the agent
203 | timesteps_since_eval = self._eval_policy(eval_freq, eval_env, n_eval_episodes,
204 | timesteps_since_eval, deterministic=True)
205 |
206 | return self
207 |
208 | def collect_rollouts(self, env, n_episodes=1, n_steps=-1, action_noise=None,
209 | deterministic=False, callback=None,
210 | learning_starts=0, num_timesteps=0,
211 | replay_buffer=None, obs=None,
212 | episode_num=0, log_interval=None):
213 | """
214 | Collect rollout using the current policy (and possibly fill the replay buffer)
215 | TODO: move this method to off-policy base class.
216 |
217 | :param env: (VecEnv)
218 | :param n_episodes: (int)
219 | :param n_steps: (int)
220 | :param action_noise: (ActionNoise)
221 | :param deterministic: (bool)
222 | :param callback: (callable)
223 | :param learning_starts: (int)
224 | :param num_timesteps: (int)
225 | :param replay_buffer: (ReplayBuffer)
226 | :param obs: (np.ndarray)
227 | :param episode_num: (int)
228 | :param log_interval: (int)
229 | """
230 | episode_rewards = []
231 | total_timesteps = []
232 | total_steps, total_episodes = 0, 0
233 | assert isinstance(env, VecEnv)
234 | assert env.num_envs == 1
235 |
236 | while total_steps < n_steps or total_episodes < n_episodes:
237 | done = False
238 | # Reset environment: not needed for VecEnv
239 | # obs = env.reset()
240 | episode_reward, episode_timesteps = 0.0, 0
241 |
242 | while not done:
243 |
244 | # Select action randomly or according to policy
245 | if num_timesteps < learning_starts:
246 | # Warmup phase
247 | unscaled_action = np.array([self.action_space.sample()])
248 | # Rescale the action from [low, high] to [-1, 1]
249 | scaled_action = self.scale_action(unscaled_action)
250 | else:
251 | scaled_action = self.policy.call(obs)
252 |
253 | # Add noise to the action (improve exploration)
254 | if action_noise is not None:
255 | scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
256 |
257 | # Rescale and perform action
258 | new_obs, reward, done, infos = env.step(self.unscale_action(scaled_action))
259 |
260 | done_bool = [float(done[0])]
261 | episode_reward += reward
262 |
263 | # Retrieve reward and episode length if using Monitor wrapper
264 | self._update_info_buffer(infos)
265 |
266 | # Store data in replay buffer
267 | if replay_buffer is not None:
268 | replay_buffer.add(obs, new_obs, scaled_action, reward, done_bool)
269 |
270 | obs = new_obs
271 |
272 | num_timesteps += 1
273 | episode_timesteps += 1
274 | total_steps += 1
275 | if 0 < n_steps <= total_steps:
276 | break
277 |
278 | if done:
279 | total_episodes += 1
280 | episode_rewards.append(episode_reward)
281 | total_timesteps.append(episode_timesteps)
282 | if action_noise is not None:
283 | action_noise.reset()
284 |
285 | # Display training infos
286 | if self.verbose >= 1 and log_interval is not None and (
287 | episode_num + total_episodes) % log_interval == 0:
288 | fps = int(num_timesteps / (time.time() - self.start_time))
289 | logger.logkv("episodes", episode_num + total_episodes)
290 | if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
291 | logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
292 | logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
293 | # logger.logkv("n_updates", n_updates)
294 | logger.logkv("fps", fps)
295 | logger.logkv('time_elapsed', int(time.time() - self.start_time))
296 | logger.logkv("total timesteps", num_timesteps)
297 | logger.dumpkvs()
298 |
299 | mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
300 |
301 | return mean_reward, total_steps, total_episodes, obs
302 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-Baselines-Team/stable-baselines-tf2/769b03e091067108a01a6778173b6cf90ec375ce/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_run.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import TD3, PPO
5 | from stable_baselines.common.noise import NormalActionNoise
6 |
7 | action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
8 |
9 |
10 | def test_td3():
11 | model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), seed=0,
12 | learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise)
13 | model.learn(total_timesteps=10000, eval_freq=5000)
14 | # model.save("test_save")
15 | # model.load("test_save")
16 | # os.remove("test_save.zip")
17 |
18 | @pytest.mark.parametrize("model_class", [PPO])
19 | @pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
20 | def test_onpolicy(model_class, env_id):
21 | model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
22 | model.learn(total_timesteps=1000, eval_freq=500)
23 | # model.save("test_save")
24 | # model.load("test_save")
25 | # os.remove("test_save.zip")
26 |
--------------------------------------------------------------------------------