├── .coveragerc ├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ └── issue-template.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── ci.yml ├── .gitignore ├── .readthedocs.yml ├── .travis.yml ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── conftest.py ├── data └── logo.jpg ├── docs ├── Makefile ├── README.md ├── _static │ ├── css │ │ └── baselines_theme.css │ └── img │ │ ├── Tensorboard_example_1.png │ │ ├── Tensorboard_example_2.png │ │ ├── Tensorboard_example_3.png │ │ ├── breakout.gif │ │ ├── colab.svg │ │ ├── learning_curve.png │ │ ├── logo.png │ │ ├── mistake.png │ │ └── try_it.png ├── common │ ├── cmd_utils.rst │ ├── distributions.rst │ ├── env_checker.rst │ ├── evaluation.rst │ ├── monitor.rst │ ├── schedules.rst │ └── tf_utils.rst ├── conf.py ├── guide │ ├── algos.rst │ ├── callbacks.rst │ ├── checking_nan.rst │ ├── custom_env.rst │ ├── custom_policy.rst │ ├── examples.rst │ ├── export.rst │ ├── install.rst │ ├── pretrain.rst │ ├── quickstart.rst │ ├── rl.rst │ ├── rl_tips.rst │ ├── rl_zoo.rst │ ├── save_format.rst │ ├── tensorboard.rst │ └── vec_envs.rst ├── index.rst ├── make.bat ├── misc │ ├── changelog.rst │ ├── projects.rst │ └── results_plotter.rst ├── modules │ ├── a2c.rst │ ├── acer.rst │ ├── acktr.rst │ ├── base.rst │ ├── ddpg.rst │ ├── dqn.rst │ ├── gail.rst │ ├── her.rst │ ├── policies.rst │ ├── ppo1.rst │ ├── ppo2.rst │ ├── sac.rst │ ├── td3.rst │ └── trpo.rst ├── requirements.txt └── spelling_wordlist.txt ├── scripts ├── build_docker.sh ├── run_docker_cpu.sh ├── run_docker_gpu.sh ├── run_tests.sh └── run_tests_travis.sh ├── setup.cfg ├── setup.py ├── stable_baselines ├── __init__.py ├── a2c │ ├── __init__.py │ ├── a2c.py │ └── run_atari.py ├── acer │ ├── __init__.py │ ├── acer_simple.py │ ├── buffer.py │ └── run_atari.py ├── acktr │ ├── __init__.py │ ├── acktr.py │ ├── kfac.py │ ├── kfac_utils.py │ └── run_atari.py ├── bench │ ├── __init__.py │ └── monitor.py ├── common │ ├── __init__.py │ ├── atari_wrappers.py │ ├── base_class.py │ ├── bit_flipping_env.py │ ├── buffers.py │ ├── callbacks.py │ ├── cg.py │ ├── cmd_util.py │ ├── console_util.py │ ├── dataset.py │ ├── distributions.py │ ├── env_checker.py │ ├── evaluation.py │ ├── identity_env.py │ ├── input.py │ ├── math_util.py │ ├── misc_util.py │ ├── mpi_adam.py │ ├── mpi_moments.py │ ├── mpi_running_mean_std.py │ ├── noise.py │ ├── policies.py │ ├── runners.py │ ├── running_mean_std.py │ ├── save_util.py │ ├── schedules.py │ ├── segment_tree.py │ ├── tf_layers.py │ ├── tf_util.py │ ├── tile_images.py │ └── vec_env │ │ ├── __init__.py │ │ ├── base_vec_env.py │ │ ├── dummy_vec_env.py │ │ ├── subproc_vec_env.py │ │ ├── util.py │ │ ├── vec_check_nan.py │ │ ├── vec_frame_stack.py │ │ ├── vec_normalize.py │ │ └── vec_video_recorder.py ├── ddpg │ ├── __init__.py │ ├── ddpg.py │ ├── main.py │ ├── noise.py │ └── policies.py ├── deepq │ ├── __init__.py │ ├── build_graph.py │ ├── dqn.py │ ├── experiments │ │ ├── __init__.py │ │ ├── enjoy_cartpole.py │ │ ├── enjoy_mountaincar.py │ │ ├── run_atari.py │ │ ├── train_cartpole.py │ │ └── train_mountaincar.py │ └── policies.py ├── gail │ ├── __init__.py │ ├── adversary.py │ ├── dataset │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── expert_cartpole.npz │ │ ├── expert_pendulum.npz │ │ └── record_expert.py │ └── model.py ├── her │ ├── __init__.py │ ├── her.py │ ├── replay_buffer.py │ └── utils.py ├── logger.py ├── ppo1 │ ├── __init__.py │ ├── experiments │ │ └── train_cartpole.py │ ├── pposgd_simple.py │ ├── run_atari.py │ ├── run_mujoco.py │ └── run_robotics.py ├── ppo2 │ ├── __init__.py │ ├── ppo2.py │ ├── run_atari.py │ └── run_mujoco.py ├── py.typed ├── results_plotter.py ├── sac │ ├── __init__.py │ ├── policies.py │ └── sac.py ├── td3 │ ├── __init__.py │ ├── policies.py │ └── td3.py ├── trpo_mpi │ ├── __init__.py │ ├── run_atari.py │ ├── run_mujoco.py │ ├── trpo_mpi.py │ └── utils.py └── version.txt └── tests ├── __init__.py ├── test_0deterministic.py ├── test_a2c.py ├── test_a2c_conv.py ├── test_action_scaling.py ├── test_action_space.py ├── test_atari.py ├── test_auto_vec_detection.py ├── test_callbacks.py ├── test_common.py ├── test_continuous.py ├── test_custom_policy.py ├── test_deepq.py ├── test_distri.py ├── test_envs.py ├── test_gail.py ├── test_her.py ├── test_identity.py ├── test_load_parameters.py ├── test_log_prob.py ├── test_logger.py ├── test_lstm_policy.py ├── test_math_util.py ├── test_monitor.py ├── test_mpi_adam.py ├── test_multiple_learn.py ├── test_no_mpi.py ├── test_ppo2.py ├── test_replay_buffer.py ├── test_save.py ├── test_schedules.py ├── test_segment_tree.py ├── test_tensorboard.py ├── test_tf_util.py ├── test_utils.py ├── test_vec_check_nan.py ├── test_vec_envs.py └── test_vec_normalize.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = False 3 | omit = 4 | # Mujoco requires a licence 5 | stable_baselines/*/run_mujoco.py 6 | stable_baselines/ppo1/run_humanoid.py 7 | stable_baselines/ppo1/run_robotics.py 8 | # HER requires mpi and Mujoco 9 | stable_baselines/her/experiment/* 10 | tests/* 11 | setup.py 12 | 13 | [report] 14 | exclude_lines = 15 | pragma: no cover 16 | raise NotImplementedError() 17 | if KFAC_DEBUG: 18 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .gitignore -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/issue-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Issue Template 3 | about: How to create an issue for this repository 4 | 5 | --- 6 | 7 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. 8 | 9 | If you have any questions, feel free to create an issue with the tag [question]. 10 | If you wish to suggest an enhancement or feature request, add the tag [feature request]. 11 | If you are submitting a bug report, please fill in the following details. 12 | 13 | If your issue is related to a custom gym environment, please check it first using: 14 | 15 | ```python 16 | from stable_baselines.common.env_checker import check_env 17 | 18 | env = CustomEnv(arg1, ...) 19 | # It will check your custom environment and output additional warnings if needed 20 | check_env(env) 21 | ``` 22 | 23 | **Describe the bug** 24 | A clear and concise description of what the bug is. 25 | 26 | **Code example** 27 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. 28 | 29 | Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) 30 | for both code and stack traces. 31 | 32 | ```python 33 | from stable_baselines import ... 34 | 35 | ``` 36 | 37 | ```bash 38 | Traceback (most recent call last): File ... 39 | 40 | ``` 41 | 42 | **System Info** 43 | Describe the characteristic of your environment: 44 | * Describe how the library was installed (pip, docker, source, ...) 45 | * GPU models and configuration 46 | * Python version 47 | * Tensorflow version 48 | * Versions of any other relevant libraries 49 | 50 | **Additional context** 51 | Add any other context about the problem here. 52 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | ## Motivation and Context 7 | 8 | 9 | 10 | - [ ] I have raised an issue to propose this change ([required](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) for new features and bug fixes) 11 | 12 | ## Types of changes 13 | 14 | - [ ] Bug fix (non-breaking change which fixes an issue) 15 | - [ ] New feature (non-breaking change which adds functionality) 16 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 17 | - [ ] Documentation (update in the documentation) 18 | 19 | ## Checklist: 20 | 21 | 22 | - [ ] I've read the [CONTRIBUTION](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) guide (**required**) 23 | - [ ] I have updated the [changelog](https://github.com/hill-a/stable-baselines/blob/master/docs/misc/changelog.rst) accordingly (**required**). 24 | - [ ] My change requires a change to the documentation. 25 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). 26 | - [ ] I have updated the documentation accordingly. 27 | - [ ] I have ensured `pytest` and `pytype` both pass (by running `make pytest` and `make type`). 28 | 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | # Skip CI if [ci skip] in the commit message 15 | if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: [3.6] # Deactivate 3.5 build as it is not longer maintained 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | sudo apt-get install libopenmpi-dev 30 | python -m pip install --upgrade pip 31 | pip install wheel 32 | pip install .[mpi,tests,docs] 33 | # Use headless version 34 | pip install opencv-python-headless 35 | # Tmp fix: ROM missing in the newest atari-py version 36 | pip install atari-py==0.2.5 37 | - name: MPI 38 | run: | 39 | # check MPI 40 | mpirun -h 41 | python -c "import mpi4py; print(mpi4py.__version__)" 42 | mpirun --allow-run-as-root -np 2 python -m stable_baselines.common.mpi_adam 43 | mpirun --allow-run-as-root -np 2 python -m stable_baselines.ppo1.experiments.train_cartpole 44 | mpirun --allow-run-as-root -np 2 python -m stable_baselines.common.mpi_running_mean_std 45 | # MPI requires 3 processes to run the following code 46 | # but will throw an error on GitHub CI as there is only two threads 47 | # mpirun --allow-run-as-root -np 3 python -c "from stable_baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()" 48 | 49 | - name: Build the doc 50 | run: | 51 | make doc 52 | - name: Type check 53 | run: | 54 | make type 55 | - name: Test with pytest 56 | run: | 57 | # Prevent issues with multiprocessing 58 | DEFAULT_START_METHOD=fork make pytest 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.pkl 4 | *.py~ 5 | *.bak 6 | .pytest_cache 7 | .pytype 8 | .DS_Store 9 | .idea 10 | .coverage 11 | .coverage.* 12 | __pycache__/ 13 | _build/ 14 | *.npz 15 | *.zip 16 | 17 | # Setuptools distribution and build folders. 18 | /dist/ 19 | /build 20 | keys/ 21 | 22 | # Virtualenv 23 | /env 24 | /venv 25 | 26 | *.sublime-project 27 | *.sublime-workspace 28 | 29 | logs/ 30 | 31 | .ipynb_checkpoints 32 | ghostdriver.log 33 | 34 | htmlcov 35 | 36 | junk 37 | src 38 | 39 | *.egg-info 40 | .cache 41 | 42 | MUJOCO_LOG.TXT 43 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | # Optionally build your docs in additional formats such as PDF and ePub 12 | formats: all 13 | 14 | # Optionally set the version of Python and requirements required to build your docs 15 | python: 16 | version: 3.7 17 | install: 18 | - requirements: docs/requirements.txt 19 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | 5 | env: 6 | global: 7 | - DOCKER_IMAGE=stablebaselines/stable-baselines-cpu:v2.10.0 8 | 9 | notifications: 10 | email: false 11 | 12 | services: 13 | - docker 14 | 15 | install: 16 | - docker pull ${DOCKER_IMAGE} 17 | 18 | script: 19 | - ./scripts/run_tests_travis.sh "${TEST_GLOB}" 20 | 21 | jobs: 22 | include: 23 | # Big test suite. Run in parallel to decrease wall-clock time, and to avoid OOM error from leaks 24 | - stage: Test 25 | name: "Unit Tests a-h" 26 | env: TEST_GLOB="[a-h]*" 27 | 28 | - name: "Unit Tests i-l" 29 | env: TEST_GLOB="[i-l]*" 30 | 31 | - name: "Unit Tests m-sa" 32 | env: TEST_GLOB="{[m-r]*,sa*}" 33 | 34 | - name: "Unit Tests sb-z" 35 | env: TEST_GLOB="{s[b-z]*,[t-z]*}" 36 | 37 | - name: "Unit Tests determinism" 38 | env: TEST_GLOB="0deterministic.py" 39 | 40 | - name: "Sphinx Documentation" 41 | script: 42 | - 'docker run -it --rm --mount src=$(pwd),target=/root/code/stable-baselines,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines/ && pushd docs/ && make clean && make html"' 43 | 44 | - name: "Type Checking" 45 | script: 46 | - 'docker run --rm --mount src=$(pwd),target=/root/code/stable-baselines,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines/ && pytype --version && pytype"' 47 | 48 | - stage: Codacy Trigger 49 | if: type != pull_request 50 | script: 51 | # When all test coverage reports have been uploaded, instruct Codacy to start analysis. 52 | - 'docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines,type=bind --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines/ && java -jar /root/code/codacy-coverage-reporter.jar final"' 53 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to Stable-Baselines 2 | 3 | If you are interested in contributing to Stable-Baselines, your contributions will fall 4 | into two categories: 5 | 1. You want to propose a new Feature and implement it 6 | - Create an issue about your intended feature, and we shall discuss the design and 7 | implementation. Once we agree that the plan looks good, go ahead and implement it. 8 | 2. You want to implement a feature or bug-fix for an outstanding issue 9 | - Look at the outstanding issues here: https://github.com/hill-a/stable-baselines/issues 10 | - Look at the roadmap here: https://github.com/hill-a/stable-baselines/projects/1 11 | - Pick an issue or feature and comment on the task that you want to work on this feature. 12 | - If you need more context on a particular issue, please ask and we shall provide. 13 | 14 | Once you finish implementing a feature or bug-fix, please send a Pull Request to 15 | https://github.com/hill-a/stable-baselines/ 16 | 17 | 18 | If you are not familiar with creating a Pull Request, here are some guides: 19 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request 20 | - https://help.github.com/articles/creating-a-pull-request/ 21 | 22 | 23 | ## Developing Stable-Baselines 24 | 25 | To develop Stable-Baselines on your machine, here are some tips: 26 | 27 | 1. Clone a copy of Stable-Baselines from source: 28 | 29 | ```bash 30 | git clone https://github.com/hill-a/stable-baselines/ 31 | cd stable-baselines 32 | ``` 33 | 34 | 2. Install Stable-Baselines in develop mode, with support for building the docs and running tests: 35 | 36 | ```bash 37 | pip install -e .[docs,tests] 38 | ``` 39 | 40 | ## Codestyle 41 | 42 | We follow the [PEP8 codestyle](https://www.python.org/dev/peps/pep-0008/). Please order the imports as follows: 43 | 44 | 1. built-in 45 | 2. packages 46 | 3. current module 47 | 48 | with one space between each, that gives for instance: 49 | ```python 50 | import os 51 | import warnings 52 | 53 | import numpy as np 54 | 55 | from stable_baselines import PPO2 56 | ``` 57 | 58 | In general, we recommend using pycharm to format everything in an efficient way. 59 | 60 | Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template: 61 | 62 | ```python 63 | 64 | def my_function(arg1: type1, arg2: type2) -> returntype: 65 | """ 66 | Short description of the function. 67 | 68 | :param arg1: (type1) describe what is arg1 69 | :param arg2: (type2) describe what is arg2 70 | :return: (returntype) describe what is returned 71 | """ 72 | ... 73 | return my_variable 74 | ``` 75 | 76 | ## Pull Request (PR) 77 | 78 | Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process. 79 | 80 | Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli). 81 | A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch. 82 | 83 | Note: in rare cases, we can create exception for codacy failure. 84 | 85 | ## Test 86 | 87 | All new features must add tests in the `tests/` folder ensuring that everything works fine. 88 | We use [pytest](https://pytest.org/). 89 | Also, when a bug fix is proposed, tests should be added to avoid regression. 90 | 91 | To run tests with `pytest`: 92 | 93 | ``` 94 | make pytest 95 | ``` 96 | 97 | Type checking with `pytype`: 98 | 99 | ``` 100 | make type 101 | ``` 102 | 103 | Build the documentation: 104 | 105 | ``` 106 | make doc 107 | ``` 108 | 109 | Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that): 110 | 111 | ``` 112 | make spelling 113 | ``` 114 | 115 | 116 | ## Changelog and Documentation 117 | 118 | Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed. 119 | A README is present in the `docs/` folder for instructions on how to build the documentation. 120 | 121 | 122 | Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one. 123 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PARENT_IMAGE 2 | FROM $PARENT_IMAGE 3 | ARG USE_GPU 4 | 5 | RUN apt-get -y update \ 6 | && apt-get -y install \ 7 | curl \ 8 | cmake \ 9 | default-jre \ 10 | git \ 11 | jq \ 12 | python-dev \ 13 | python-pip \ 14 | python3-dev \ 15 | libfontconfig1 \ 16 | libglib2.0-0 \ 17 | libsm6 \ 18 | libxext6 \ 19 | libxrender1 \ 20 | libopenmpi-dev \ 21 | zlib1g-dev \ 22 | && apt-get clean \ 23 | && rm -rf /var/lib/apt/lists/* 24 | 25 | ENV CODE_DIR /root/code 26 | ENV VENV /root/venv 27 | 28 | COPY ./stable_baselines/version.txt ${CODE_DIR}/stable-baselines/stable_baselines/version.txt 29 | COPY ./setup.py ${CODE_DIR}/stable-baselines/setup.py 30 | 31 | RUN \ 32 | pip install pip --upgrade && \ 33 | pip install virtualenv && \ 34 | virtualenv $VENV --python=python3 && \ 35 | . $VENV/bin/activate && \ 36 | pip install --upgrade pip && \ 37 | cd ${CODE_DIR}/stable-baselines && \ 38 | pip install -e .[mpi,tests,docs] && \ 39 | rm -rf $HOME/.cache/pip 40 | 41 | ENV PATH=$VENV/bin:$PATH 42 | 43 | # Codacy code coverage report: used for partial code coverage reporting 44 | RUN cd $CODE_DIR && \ 45 | curl -Ls -o codacy-coverage-reporter.jar "$(curl -Ls https://api.github.com/repos/codacy/codacy-coverage-reporter/releases/latest | jq -r '.assets | map({name, browser_download_url} | select(.name | (startswith("codacy-coverage-reporter") and contains("assembly") and endswith(".jar")))) | .[0].browser_download_url')" 46 | 47 | CMD /bin/bash 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | Copyright (c) 2018-2019 Stable-Baselines Team 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Run pytest and coverage report 2 | pytest: 3 | ./scripts/run_tests.sh 4 | 5 | # Type check 6 | type: 7 | pytype -j auto 8 | 9 | # Build the doc 10 | doc: 11 | cd docs && make html 12 | 13 | # Check the spelling in the doc 14 | spelling: 15 | cd docs && make spelling 16 | 17 | # Clean the doc build folder 18 | clean: 19 | cd docs && make clean 20 | 21 | # Build docker images 22 | # If you do export RELEASE=True, it will also push them 23 | docker: docker-cpu docker-gpu 24 | 25 | docker-cpu: 26 | ./scripts/build_docker.sh 27 | 28 | docker-gpu: 29 | USE_GPU=True ./scripts/build_docker.sh 30 | 31 | # PyPi package release 32 | release: 33 | python setup.py sdist 34 | python setup.py bdist_wheel 35 | twine upload dist/* 36 | 37 | # Test PyPi package release 38 | test-release: 39 | python setup.py sdist 40 | python setup.py bdist_wheel 41 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 42 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | """Configures pytest to ignore certain unit tests unless the appropriate flag is passed. 2 | 3 | --rungpu: tests that require GPU. 4 | --expensive: tests that take a long time to run (e.g. training an RL algorithm for many timestesps).""" 5 | 6 | import pytest 7 | 8 | 9 | def pytest_addoption(parser): 10 | parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests") 11 | parser.addoption("--expensive", action="store_true", 12 | help="run expensive tests (which are otherwise skipped).") 13 | 14 | 15 | def pytest_collection_modifyitems(config, items): 16 | flags = {'gpu': '--rungpu', 'expensive': '--expensive'} 17 | skips = {keyword: pytest.mark.skip(reason="need {} option to run".format(flag)) 18 | for keyword, flag in flags.items() if not config.getoption(flag)} 19 | for item in items: 20 | for keyword, skip in skips.items(): 21 | if keyword in item.keywords: 22 | item.add_marker(skip) 23 | -------------------------------------------------------------------------------- /data/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/data/logo.jpg -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W # make warnings fatal 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = StableBaselines 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Stable Baselines Documentation 2 | 3 | This folder contains documentation for the RL baselines. 4 | 5 | 6 | ### Build the Documentation 7 | 8 | #### Install Sphinx and Theme 9 | 10 | ``` 11 | pip install sphinx sphinx-autobuild sphinx-rtd-theme 12 | ``` 13 | 14 | #### Building the Docs 15 | 16 | In the `docs/` folder: 17 | ``` 18 | make html 19 | ``` 20 | 21 | if you want to building each time a file is changed: 22 | 23 | ``` 24 | sphinx-autobuild . _build/html 25 | ``` 26 | -------------------------------------------------------------------------------- /docs/_static/css/baselines_theme.css: -------------------------------------------------------------------------------- 1 | /* Main colors from https://color.adobe.com/fr/Copy-of-NOUEBO-Original-color-theme-11116609 */ 2 | :root{ 3 | --main-bg-color: #324D5C; 4 | --link-color: #14B278; 5 | } 6 | 7 | /* Header fonts y */ 8 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 9 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 10 | } 11 | 12 | 13 | /* Docs background */ 14 | .wy-side-nav-search{ 15 | background-color: var(--main-bg-color); 16 | } 17 | 18 | /* Mobile version */ 19 | .wy-nav-top{ 20 | background-color: var(--main-bg-color); 21 | } 22 | 23 | /* Change link colors (except for the menu) */ 24 | a { 25 | color: var(--link-color); 26 | } 27 | 28 | a:hover { 29 | color: #4F778F; 30 | } 31 | 32 | .wy-menu a { 33 | color: #b3b3b3; 34 | } 35 | 36 | .wy-menu a:hover { 37 | color: #b3b3b3; 38 | } 39 | 40 | a.icon.icon-home { 41 | color: #b3b3b3; 42 | } 43 | 44 | .version{ 45 | color: var(--link-color) !important; 46 | } 47 | 48 | 49 | /* Make code blocks have a background */ 50 | .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { 51 | background: #f8f8f8;; 52 | } 53 | -------------------------------------------------------------------------------- /docs/_static/img/Tensorboard_example_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/Tensorboard_example_1.png -------------------------------------------------------------------------------- /docs/_static/img/Tensorboard_example_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/Tensorboard_example_2.png -------------------------------------------------------------------------------- /docs/_static/img/Tensorboard_example_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/Tensorboard_example_3.png -------------------------------------------------------------------------------- /docs/_static/img/breakout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/breakout.gif -------------------------------------------------------------------------------- /docs/_static/img/colab.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /docs/_static/img/learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/learning_curve.png -------------------------------------------------------------------------------- /docs/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/logo.png -------------------------------------------------------------------------------- /docs/_static/img/mistake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/mistake.png -------------------------------------------------------------------------------- /docs/_static/img/try_it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/try_it.png -------------------------------------------------------------------------------- /docs/common/cmd_utils.rst: -------------------------------------------------------------------------------- 1 | .. _cmd_utils: 2 | 3 | Command Utils 4 | ========================= 5 | 6 | .. automodule:: stable_baselines.common.cmd_util 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/common/distributions.rst: -------------------------------------------------------------------------------- 1 | .. _distributions: 2 | 3 | Probability Distributions 4 | ========================= 5 | 6 | Probability distributions used for the different action spaces: 7 | 8 | - ``CategoricalProbabilityDistribution`` -> Discrete 9 | - ``DiagGaussianProbabilityDistribution`` -> Box (continuous actions) 10 | - ``MultiCategoricalProbabilityDistribution`` -> MultiDiscrete 11 | - ``BernoulliProbabilityDistribution`` -> MultiBinary 12 | 13 | The policy networks output parameters for the distributions (named ``flat`` in the methods). 14 | Actions are then sampled from those distributions. 15 | 16 | For instance, in the case of discrete actions. The policy network outputs probability 17 | of taking each action. The ``CategoricalProbabilityDistribution`` allows to sample from it, 18 | computes the entropy, the negative log probability (``neglogp``) and backpropagate the gradient. 19 | 20 | In the case of continuous actions, a Gaussian distribution is used. The policy network outputs 21 | mean and (log) std of the distribution (assumed to be a ``DiagGaussianProbabilityDistribution``). 22 | 23 | .. automodule:: stable_baselines.common.distributions 24 | :members: 25 | -------------------------------------------------------------------------------- /docs/common/env_checker.rst: -------------------------------------------------------------------------------- 1 | .. _env_checker: 2 | 3 | Gym Environment Checker 4 | ======================== 5 | 6 | .. automodule:: stable_baselines.common.env_checker 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/common/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _eval: 2 | 3 | Evaluation Helper 4 | ================= 5 | 6 | .. automodule:: stable_baselines.common.evaluation 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/common/monitor.rst: -------------------------------------------------------------------------------- 1 | .. _monitor: 2 | 3 | Monitor Wrapper 4 | =============== 5 | 6 | .. automodule:: stable_baselines.bench.monitor 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/common/schedules.rst: -------------------------------------------------------------------------------- 1 | .. _schedules: 2 | 3 | Schedules 4 | ========= 5 | 6 | Schedules are used as hyperparameter for most of the algorithms, 7 | in order to change value of a parameter over time (usually the learning rate). 8 | 9 | 10 | .. automodule:: stable_baselines.common.schedules 11 | :members: 12 | -------------------------------------------------------------------------------- /docs/common/tf_utils.rst: -------------------------------------------------------------------------------- 1 | .. _tf_utils: 2 | 3 | Tensorflow Utils 4 | ========================= 5 | 6 | .. automodule:: stable_baselines.common.tf_util 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/guide/algos.rst: -------------------------------------------------------------------------------- 1 | RL Algorithms 2 | ============= 3 | 4 | This table displays the rl algorithms that are implemented in the stable baselines project, 5 | along with some useful characteristics: support for recurrent policies, discrete/continuous actions, multiprocessing. 6 | 7 | .. Table too large 8 | .. ===== ======================== ========= ======= ============ ================= =============== ================ 9 | .. Name Refactored \ :sup:`(1)`\ Recurrent ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing 10 | .. ===== ======================== ========= ======= ============ ================= =============== ================ 11 | .. A2C ✔️ 12 | .. ===== ======================== ========= ======= ============ ================= =============== ================ 13 | 14 | 15 | ============ ======================== ========= =========== ============ ================ 16 | Name Refactored [#f1]_ Recurrent ``Box`` ``Discrete`` Multi Processing 17 | ============ ======================== ========= =========== ============ ================ 18 | A2C ✔️ ✔️ ✔️ ✔️ ✔️ 19 | ACER ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️ 20 | ACKTR ✔️ ✔️ ✔️ ✔️ ✔️ 21 | DDPG ✔️ ❌ ✔️ ❌ ✔️ [#f3]_ 22 | DQN ✔️ ❌ ❌ ✔️ ❌ 23 | HER ✔️ ❌ ✔️ ✔️ ❌ 24 | GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️ ✔️ [#f3]_ 25 | PPO1 ✔️ ❌ ✔️ ✔️ ✔️ [#f3]_ 26 | PPO2 ✔️ ✔️ ✔️ ✔️ ✔️ 27 | SAC ✔️ ❌ ✔️ ❌ ❌ 28 | TD3 ✔️ ❌ ✔️ ❌ ❌ 29 | TRPO ✔️ ❌ ✔️ ✔ ✔️ [#f3]_ 30 | ============ ======================== ========= =========== ============ ================ 31 | 32 | .. [#f1] Whether or not the algorithm has be refactored to fit the ``BaseRLModel`` class. 33 | .. [#f2] Only implemented for TRPO. 34 | .. [#f3] Multi Processing with `MPI`_. 35 | .. [#f4] TODO, in project scope. 36 | 37 | .. note:: 38 | Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm, 39 | except HER for dict when working with ``gym.GoalEnv`` 40 | 41 | Actions ``gym.spaces``: 42 | 43 | - ``Box``: A N-dimensional box that contains every point in the action 44 | space. 45 | - ``Discrete``: A list of possible actions, where each timestep only 46 | one of the actions can be used. 47 | - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used. 48 | - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. 49 | 50 | .. _MPI: https://mpi4py.readthedocs.io/en/stable/ 51 | 52 | .. note:: 53 | 54 | Some logging values (like ``ep_rewmean``, ``eplenmean``) are only available when using a Monitor wrapper 55 | See `Issue #339 `_ for more info. 56 | 57 | 58 | Reproducibility 59 | --------------- 60 | 61 | Completely reproducible results are not guaranteed across Tensorflow releases or different platforms. 62 | Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds. 63 | 64 | In order to make computations deterministic on CPU, on your specific problem on one specific platform, 65 | you need to pass a ``seed`` argument at the creation of a model and set `n_cpu_tf_sess=1` (number of cpu for Tensorflow session). 66 | If you pass an environment to the model using `set_env()`, then you also need to seed the environment first. 67 | 68 | .. note:: 69 | 70 | Because of the current limits of Tensorflow 1.x, we cannot ensure reproducible results on the GPU yet. This issue is solved in `Stable-Baselines3 "PyTorch edition" `_ 71 | 72 | 73 | .. note:: 74 | 75 | TD3 sometimes fail to have reproducible results for obscure reasons, even when following the previous steps (cf `PR #492 `_). If you find the reason then please open an issue ;) 76 | 77 | 78 | Credit: part of the *Reproducibility* section comes from `PyTorch Documentation `_ 79 | -------------------------------------------------------------------------------- /docs/guide/custom_env.rst: -------------------------------------------------------------------------------- 1 | .. _custom_env: 2 | 3 | Using Custom Environments 4 | ========================== 5 | 6 | To use the rl baselines with custom environments, they just need to follow the *gym* interface. 7 | That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class): 8 | 9 | 10 | .. note:: 11 | If you are using images as input, the input values must be in [0, 255] as the observation 12 | is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. 13 | 14 | 15 | 16 | .. code-block:: python 17 | 18 | import gym 19 | from gym import spaces 20 | 21 | class CustomEnv(gym.Env): 22 | """Custom Environment that follows gym interface""" 23 | metadata = {'render.modes': ['human']} 24 | 25 | def __init__(self, arg1, arg2, ...): 26 | super(CustomEnv, self).__init__() 27 | # Define action and observation space 28 | # They must be gym.spaces objects 29 | # Example when using discrete actions: 30 | self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS) 31 | # Example for using image as input: 32 | self.observation_space = spaces.Box(low=0, high=255, 33 | shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8) 34 | 35 | def step(self, action): 36 | ... 37 | return observation, reward, done, info 38 | def reset(self): 39 | ... 40 | return observation # reward, done, info can't be included 41 | def render(self, mode='human'): 42 | ... 43 | def close (self): 44 | ... 45 | 46 | 47 | Then you can define and train a RL agent with: 48 | 49 | .. code-block:: python 50 | 51 | # Instantiate the env 52 | env = CustomEnv(arg1, ...) 53 | # Define and Train the agent 54 | model = A2C('CnnPolicy', env).learn(total_timesteps=1000) 55 | 56 | 57 | To check that your environment follows the gym interface, please use: 58 | 59 | .. code-block:: python 60 | 61 | from stable_baselines.common.env_checker import check_env 62 | 63 | env = CustomEnv(arg1, ...) 64 | # It will check your custom environment and output additional warnings if needed 65 | check_env(env) 66 | 67 | 68 | 69 | We have created a `colab notebook `_ for 70 | a concrete example of creating a custom environment. 71 | 72 | You can also find a `complete guide online `_ 73 | on creating a custom Gym environment. 74 | 75 | 76 | Optionally, you can also register the environment with gym, 77 | that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env). 78 | 79 | 80 | In the project, for testing purposes, we use a custom environment named ``IdentityEnv`` 81 | defined `in this file `_. 82 | An example of how to use it can be found `here `_. 83 | -------------------------------------------------------------------------------- /docs/guide/quickstart.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | =============== 4 | Getting Started 5 | =============== 6 | 7 | Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. 8 | 9 | Here is a quick example of how to train and run PPO2 on a cartpole environment: 10 | 11 | .. code-block:: python 12 | 13 | import gym 14 | 15 | from stable_baselines.common.policies import MlpPolicy 16 | from stable_baselines.common.vec_env import DummyVecEnv 17 | from stable_baselines import PPO2 18 | 19 | env = gym.make('CartPole-v1') 20 | # Optional: PPO2 requires a vectorized environment to run 21 | # the env is now wrapped automatically when passing it to the constructor 22 | # env = DummyVecEnv([lambda: env]) 23 | 24 | model = PPO2(MlpPolicy, env, verbose=1) 25 | model.learn(total_timesteps=10000) 26 | 27 | obs = env.reset() 28 | for i in range(1000): 29 | action, _states = model.predict(obs) 30 | obs, rewards, dones, info = env.step(action) 31 | env.render() 32 | 33 | 34 | Or just train a model with a one liner if 35 | `the environment is registered in Gym `_ and if 36 | `the policy is registered `_: 37 | 38 | .. code-block:: python 39 | 40 | from stable_baselines import PPO2 41 | 42 | model = PPO2('MlpPolicy', 'CartPole-v1').learn(10000) 43 | 44 | 45 | .. figure:: https://cdn-images-1.medium.com/max/960/1*R_VMmdgKAY0EDhEjHVelzw.gif 46 | 47 | Define and train a RL agent in one line of code! 48 | -------------------------------------------------------------------------------- /docs/guide/rl.rst: -------------------------------------------------------------------------------- 1 | .. _rl: 2 | 3 | ================================ 4 | Reinforcement Learning Resources 5 | ================================ 6 | 7 | 8 | Stable-Baselines assumes that you already understand the basic concepts of Reinforcement Learning (RL). 9 | 10 | However, if you want to learn about RL, there are several good resources to get started: 11 | 12 | - `OpenAI Spinning Up `_ 13 | - `David Silver's course `_ 14 | - `Lilian Weng's blog `_ 15 | - `Berkeley's Deep RL Bootcamp `_ 16 | - `Berkeley's Deep Reinforcement Learning course `_ 17 | - `More resources `_ 18 | -------------------------------------------------------------------------------- /docs/guide/rl_zoo.rst: -------------------------------------------------------------------------------- 1 | .. _rl_zoo: 2 | 3 | ================= 4 | RL Baselines Zoo 5 | ================= 6 | 7 | `RL Baselines Zoo `_. is a collection of pre-trained Reinforcement Learning agents using 8 | Stable-Baselines. 9 | It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos. 10 | 11 | Goals of this repository: 12 | 13 | 1. Provide a simple interface to train and enjoy RL agents 14 | 2. Benchmark the different Reinforcement Learning algorithms 15 | 3. Provide tuned hyperparameters for each environment and RL algorithm 16 | 4. Have fun with the trained agents! 17 | 18 | Installation 19 | ------------ 20 | 21 | 1. Install dependencies 22 | :: 23 | 24 | apt-get install swig cmake libopenmpi-dev zlib1g-dev ffmpeg 25 | pip install stable-baselines box2d box2d-kengz pyyaml pybullet optuna pytablewriter 26 | 27 | 2. Clone the repository: 28 | 29 | :: 30 | 31 | git clone https://github.com/araffin/rl-baselines-zoo 32 | 33 | 34 | Train an Agent 35 | -------------- 36 | 37 | The hyperparameters for each environment are defined in 38 | ``hyperparameters/algo_name.yml``. 39 | 40 | If the environment exists in this file, then you can train an agent 41 | using: 42 | 43 | :: 44 | 45 | python train.py --algo algo_name --env env_id 46 | 47 | For example (with tensorboard support): 48 | 49 | :: 50 | 51 | python train.py --algo ppo2 --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/ 52 | 53 | Train for multiple environments (with one call) and with tensorboard 54 | logging: 55 | 56 | :: 57 | 58 | python train.py --algo a2c --env MountainCar-v0 CartPole-v1 --tensorboard-log /tmp/stable-baselines/ 59 | 60 | Continue training (here, load pretrained agent for Breakout and continue 61 | training for 5000 steps): 62 | 63 | :: 64 | 65 | python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4.pkl -n 5000 66 | 67 | 68 | Enjoy a Trained Agent 69 | --------------------- 70 | 71 | If the trained agent exists, then you can see it in action using: 72 | 73 | :: 74 | 75 | python enjoy.py --algo algo_name --env env_id 76 | 77 | For example, enjoy A2C on Breakout during 5000 timesteps: 78 | 79 | :: 80 | 81 | python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder trained_agents/ -n 5000 82 | 83 | 84 | Hyperparameter Optimization 85 | --------------------------- 86 | 87 | We use `Optuna `_ for optimizing the hyperparameters. 88 | 89 | 90 | Tune the hyperparameters for PPO2, using a random sampler and median pruner, 2 parallels jobs, 91 | with a budget of 1000 trials and a maximum of 50000 steps: 92 | 93 | :: 94 | 95 | python train.py --algo ppo2 --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \ 96 | --sampler random --pruner median 97 | 98 | 99 | Colab Notebook: Try it Online! 100 | ------------------------------ 101 | 102 | You can train agents online using Google `colab notebook `_. 103 | 104 | 105 | .. note:: 106 | 107 | You can find more information about the rl baselines zoo in the repo `README `_. For instance, how to record a video of a trained agent. 108 | -------------------------------------------------------------------------------- /docs/guide/save_format.rst: -------------------------------------------------------------------------------- 1 | .. _save_format: 2 | 3 | 4 | On saving and loading 5 | ===================== 6 | 7 | Stable baselines stores both neural network parameters and algorithm-related parameters such as 8 | exploration schedule, number of environments and observation/action space. This allows continual learning and easy 9 | use of trained agents without training, but it is not without its issues. Following describes two formats 10 | used to save agents in stable baselines, their pros and shortcomings. 11 | 12 | Terminology used in this page: 13 | 14 | - *parameters* refer to neural network parameters (also called "weights"). This is a dictionary 15 | mapping Tensorflow variable name to a NumPy array. 16 | - *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space. 17 | These depend on the algorithm used. This is a dictionary mapping classes variable names their values. 18 | 19 | 20 | Cloudpickle (stable-baselines<=2.7.0) 21 | ------------------------------------- 22 | 23 | Original stable baselines save format. Data and parameters are bundled up into a tuple ``(data, parameters)`` 24 | and then serialized with ``cloudpickle`` library (essentially the same as ``pickle``). 25 | 26 | This save format is still available via an argument in model save function in stable-baselines versions above 27 | v2.7.0 for backwards compatibility reasons, but its usage is discouraged. 28 | 29 | Pros: 30 | 31 | - Easy to implement and use. 32 | - Works with almost any type of Python object, including functions. 33 | 34 | 35 | Cons: 36 | 37 | - Pickle/Cloudpickle is not designed for long-term storage or sharing between Python version. 38 | - If one object in file is not readable (e.g. wrong library version), then reading the rest of the 39 | file is difficult. 40 | - Python-specific format, hard to read stored files from other languages. 41 | 42 | 43 | If part of a saved model becomes unreadable for any reason (e.g. different Tensorflow versions), then 44 | it may be tricky to restore any of the model. For this reason another save format was designed. 45 | 46 | 47 | Zip-archive (stable-baselines>2.7.0) 48 | ------------------------------------- 49 | 50 | A zip-archived JSON dump and NumPy zip archive of the arrays. The data dictionary (class parameters) 51 | is stored as a JSON file, model parameters are serialized with ``numpy.savez`` function and these two files 52 | are stored under a single .zip archive. 53 | 54 | Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded 55 | string in the JSON file, along with some information that was stored in the serialization. This allows 56 | inspecting stored objects without deserializing the object itself. 57 | 58 | This format allows skipping elements in the file, i.e. we can skip deserializing objects that are 59 | broken/non-serializable. This can be done via ``custom_objects`` argument to load functions. 60 | 61 | This is the default save format in stable baselines versions after v2.7.0. 62 | 63 | File structure: 64 | 65 | :: 66 | 67 | saved_model.zip/ 68 | ├── data JSON file of class-parameters (dictionary) 69 | ├── parameter_list JSON file of model parameters and their ordering (list) 70 | ├── parameters Bytes from numpy.savez (a zip file of the numpy arrays). ... 71 | ├── ... Being a zip-archive itself, this object can also be opened ... 72 | ├── ... as a zip-archive and browsed. 73 | 74 | 75 | Pros: 76 | 77 | 78 | - More robust to unserializable objects (one bad object does not break everything). 79 | - Saved file can be inspected/extracted with zip-archive explorers and by other 80 | languages. 81 | 82 | 83 | Cons: 84 | 85 | - More complex implementation. 86 | - Still relies partly on cloudpickle for complex objects (e.g. custom functions). 87 | -------------------------------------------------------------------------------- /docs/guide/vec_envs.rst: -------------------------------------------------------------------------------- 1 | .. _vec_env: 2 | 3 | .. automodule:: stable_baselines.common.vec_env 4 | 5 | Vectorized Environments 6 | ======================= 7 | 8 | Vectorized Environments are a method for stacking multiple independent environments into a single environment. 9 | Instead of training an RL agent on 1 environment per step, it allows us to train it on ``n`` environments per step. 10 | Because of this, ``actions`` passed to the environment are now a vector (of dimension ``n``). 11 | It is the same for ``observations``, ``rewards`` and end of episode signals (``dones``). 12 | In the case of non-array observation spaces such as ``Dict`` or ``Tuple``, where different sub-spaces 13 | may have different shapes, the sub-observations are vectors (of dimension ``n``). 14 | 15 | ============= ======= ============ ======== ========= ================ 16 | Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing 17 | ============= ======= ============ ======== ========= ================ 18 | DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️ 19 | SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ 20 | ============= ======= ============ ======== ========= ================ 21 | 22 | .. note:: 23 | 24 | Vectorized environments are required when using wrappers for frame-stacking or normalization. 25 | 26 | .. note:: 27 | 28 | When using vectorized environments, the environments are automatically reset at the end of each episode. 29 | Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated. 30 | You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv. 31 | 32 | .. warning:: 33 | 34 | When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows). 35 | On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks. 36 | 37 | For more information, see Python's `multiprocessing guidelines `_. 38 | 39 | VecEnv 40 | ------ 41 | 42 | .. autoclass:: VecEnv 43 | :members: 44 | 45 | DummyVecEnv 46 | ----------- 47 | 48 | .. autoclass:: DummyVecEnv 49 | :members: 50 | 51 | SubprocVecEnv 52 | ------------- 53 | 54 | .. autoclass:: SubprocVecEnv 55 | :members: 56 | 57 | Wrappers 58 | -------- 59 | 60 | VecFrameStack 61 | ~~~~~~~~~~~~~ 62 | 63 | .. autoclass:: VecFrameStack 64 | :members: 65 | 66 | 67 | VecNormalize 68 | ~~~~~~~~~~~~ 69 | 70 | .. autoclass:: VecNormalize 71 | :members: 72 | 73 | 74 | VecVideoRecorder 75 | ~~~~~~~~~~~~~~~~ 76 | 77 | .. autoclass:: VecVideoRecorder 78 | :members: 79 | 80 | 81 | VecCheckNan 82 | ~~~~~~~~~~~~~~~~ 83 | 84 | .. autoclass:: VecCheckNan 85 | :members: 86 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Stable Baselines documentation master file, created by 2 | sphinx-quickstart on Sat Aug 25 10:33:54 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Stable Baselines docs! - RL Baselines Made Easy 7 | =========================================================== 8 | 9 | `Stable Baselines `_ is a set of improved implementations 10 | of Reinforcement Learning (RL) algorithms based on OpenAI `Baselines `_. 11 | 12 | 13 | Github repository: https://github.com/hill-a/stable-baselines 14 | 15 | RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo 16 | 17 | RL Baselines zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning. 18 | 19 | You can read a detailed presentation of Stable Baselines in the 20 | Medium article: `link `_ 21 | 22 | 23 | Main differences with OpenAI Baselines 24 | -------------------------------------- 25 | 26 | This toolset is a fork of OpenAI Baselines, with a major structural refactoring, and code cleanups: 27 | 28 | - Unified structure for all algorithms 29 | - PEP8 compliant (unified code style) 30 | - Documented functions and classes 31 | - More tests & more code coverage 32 | - Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3) 33 | 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | :caption: User Guide 38 | 39 | guide/install 40 | guide/quickstart 41 | guide/rl_tips 42 | guide/rl 43 | guide/algos 44 | guide/examples 45 | guide/vec_envs 46 | guide/custom_env 47 | guide/custom_policy 48 | guide/callbacks 49 | guide/tensorboard 50 | guide/rl_zoo 51 | guide/pretrain 52 | guide/checking_nan 53 | guide/save_format 54 | guide/export 55 | 56 | 57 | .. toctree:: 58 | :maxdepth: 1 59 | :caption: RL Algorithms 60 | 61 | modules/base 62 | modules/policies 63 | modules/a2c 64 | modules/acer 65 | modules/acktr 66 | modules/ddpg 67 | modules/dqn 68 | modules/gail 69 | modules/her 70 | modules/ppo1 71 | modules/ppo2 72 | modules/sac 73 | modules/td3 74 | modules/trpo 75 | 76 | .. toctree:: 77 | :maxdepth: 1 78 | :caption: Common 79 | 80 | common/distributions 81 | common/tf_utils 82 | common/cmd_utils 83 | common/schedules 84 | common/evaluation 85 | common/env_checker 86 | common/monitor 87 | 88 | .. toctree:: 89 | :maxdepth: 1 90 | :caption: Misc 91 | 92 | misc/changelog 93 | misc/projects 94 | misc/results_plotter 95 | 96 | 97 | Citing Stable Baselines 98 | ----------------------- 99 | To cite this project in publications: 100 | 101 | .. code-block:: bibtex 102 | 103 | @misc{stable-baselines, 104 | author = {Hill, Ashley and Raffin, Antonin and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Traore, Rene and Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai}, 105 | title = {Stable Baselines}, 106 | year = {2018}, 107 | publisher = {GitHub}, 108 | journal = {GitHub repository}, 109 | howpublished = {\url{https://github.com/hill-a/stable-baselines}}, 110 | } 111 | 112 | Contributing 113 | ------------ 114 | 115 | To any interested in making the rl baselines better, there are still some improvements 116 | that need to be done. 117 | A full TODO list is available in the `roadmap `_. 118 | 119 | If you want to contribute, please read `CONTRIBUTING.md `_ first. 120 | 121 | Indices and tables 122 | ------------------- 123 | 124 | * :ref:`genindex` 125 | * :ref:`search` 126 | * :ref:`modindex` 127 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=StableBaselines 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/misc/results_plotter.rst: -------------------------------------------------------------------------------- 1 | .. _results_plotter: 2 | 3 | 4 | Plotting Results 5 | ================ 6 | 7 | .. automodule:: stable_baselines.results_plotter 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/modules/base.rst: -------------------------------------------------------------------------------- 1 | .. _base_algo: 2 | 3 | .. automodule:: stable_baselines.common.base_class 4 | 5 | 6 | Base RL Class 7 | ============= 8 | 9 | Common interface for all the RL algorithms 10 | 11 | .. autoclass:: BaseRLModel 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/modules/gail.rst: -------------------------------------------------------------------------------- 1 | .. _gail: 2 | 3 | .. automodule:: stable_baselines.gail 4 | 5 | 6 | GAIL 7 | ==== 8 | 9 | The `Generative Adversarial Imitation Learning (GAIL) `_ uses expert trajectories 10 | to recover a cost function and then learn a policy. 11 | 12 | Learning a cost function from expert demonstrations is called Inverse Reinforcement Learning (IRL). 13 | The connection between GAIL and Generative Adversarial Networks (GANs) is that it uses a discriminator that tries 14 | to separate expert trajectory from trajectories of the learned policy, which has the role of the generator here. 15 | 16 | .. note:: 17 | 18 | GAIL requires :ref:`OpenMPI `. If OpenMPI isn't enabled, then GAIL isn't 19 | imported into the ``stable_baselines`` module. 20 | 21 | 22 | Notes 23 | ----- 24 | 25 | - Original paper: https://arxiv.org/abs/1606.03476 26 | 27 | .. warning:: 28 | 29 | Images are not yet handled properly by the current implementation 30 | 31 | 32 | 33 | If you want to train an imitation learning agent 34 | ------------------------------------------------ 35 | 36 | 37 | Step 1: Generate expert data 38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 39 | 40 | You can either train a RL algorithm in a classic setting, use another controller (e.g. a PID controller) 41 | or human demonstrations. 42 | 43 | We recommend you to take a look at :ref:`pre-training ` section 44 | or directly look at ``stable_baselines/gail/dataset/`` folder to learn more about the expected format for the dataset. 45 | 46 | Here is an example of training a Soft Actor-Critic model to generate expert trajectories for GAIL: 47 | 48 | 49 | .. code-block:: python 50 | 51 | from stable_baselines import SAC 52 | from stable_baselines.gail import generate_expert_traj 53 | 54 | # Generate expert trajectories (train expert) 55 | model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1) 56 | # Train for 60000 timesteps and record 10 trajectories 57 | # all the data will be saved in 'expert_pendulum.npz' file 58 | generate_expert_traj(model, 'expert_pendulum', n_timesteps=60000, n_episodes=10) 59 | 60 | 61 | 62 | Step 2: Run GAIL 63 | ~~~~~~~~~~~~~~~~ 64 | 65 | 66 | **In case you want to run Behavior Cloning (BC)** 67 | 68 | Use the ``.pretrain()`` method (cf guide). 69 | 70 | 71 | **Others** 72 | 73 | Thanks to the open source: 74 | 75 | - @openai/imitation 76 | - @carpedm20/deep-rl-tensorflow 77 | 78 | 79 | Can I use? 80 | ---------- 81 | 82 | - Recurrent policies: ❌ 83 | - Multi processing: ✔️ (using MPI) 84 | - Gym spaces: 85 | 86 | 87 | ============= ====== =========== 88 | Space Action Observation 89 | ============= ====== =========== 90 | Discrete ✔️ ✔️ 91 | Box ✔️ ✔️ 92 | MultiDiscrete ❌ ✔️ 93 | MultiBinary ❌ ✔️ 94 | ============= ====== =========== 95 | 96 | 97 | Example 98 | ------- 99 | 100 | .. code-block:: python 101 | 102 | import gym 103 | 104 | from stable_baselines import GAIL, SAC 105 | from stable_baselines.gail import ExpertDataset, generate_expert_traj 106 | 107 | # Generate expert trajectories (train expert) 108 | model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1) 109 | generate_expert_traj(model, 'expert_pendulum', n_timesteps=100, n_episodes=10) 110 | 111 | # Load the expert dataset 112 | dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1) 113 | 114 | model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1) 115 | # Note: in practice, you need to train for 1M steps to have a working policy 116 | model.learn(total_timesteps=1000) 117 | model.save("gail_pendulum") 118 | 119 | del model # remove to demonstrate saving and loading 120 | 121 | model = GAIL.load("gail_pendulum") 122 | 123 | env = gym.make('Pendulum-v0') 124 | obs = env.reset() 125 | while True: 126 | action, _states = model.predict(obs) 127 | obs, rewards, dones, info = env.step(action) 128 | env.render() 129 | 130 | 131 | Parameters 132 | ---------- 133 | 134 | .. autoclass:: GAIL 135 | :members: 136 | :inherited-members: 137 | -------------------------------------------------------------------------------- /docs/modules/her.rst: -------------------------------------------------------------------------------- 1 | .. _her: 2 | 3 | .. automodule:: stable_baselines.her 4 | 5 | 6 | HER 7 | ==== 8 | 9 | `Hindsight Experience Replay (HER) `_ 10 | 11 | HER is a method wrapper that works with Off policy methods (DQN, SAC, TD3 and DDPG for example). 12 | 13 | .. note:: 14 | 15 | HER was re-implemented from scratch in Stable-Baselines compared to the original OpenAI baselines. 16 | If you want to reproduce results from the paper, please use the rl baselines zoo 17 | in order to have the correct hyperparameters and at least 8 MPI workers with DDPG. 18 | 19 | .. warning:: 20 | 21 | HER requires the environment to inherits from `gym.GoalEnv `_ 22 | 23 | 24 | .. warning:: 25 | 26 | you must pass an environment or wrap it with ``HERGoalEnvWrapper`` in order to use the predict method 27 | 28 | 29 | Notes 30 | ----- 31 | 32 | - Original paper: https://arxiv.org/abs/1707.01495 33 | - OpenAI paper: `Plappert et al. (2018)`_ 34 | - OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/ 35 | 36 | 37 | .. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464 38 | 39 | Can I use? 40 | ---------- 41 | 42 | Please refer to the wrapped model (DQN, SAC, TD3 or DDPG) for that section. 43 | 44 | Example 45 | ------- 46 | 47 | .. code-block:: python 48 | 49 | from stable_baselines import HER, DQN, SAC, DDPG, TD3 50 | from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper 51 | from stable_baselines.common.bit_flipping_env import BitFlippingEnv 52 | 53 | model_class = DQN # works also with SAC, DDPG and TD3 54 | 55 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) 56 | 57 | # Available strategies (cf paper): future, final, episode, random 58 | goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE 59 | 60 | # Wrap the model 61 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, 62 | verbose=1) 63 | # Train the model 64 | model.learn(1000) 65 | 66 | model.save("./her_bit_env") 67 | 68 | # WARNING: you must pass an env 69 | # or wrap your environment with HERGoalEnvWrapper to use the predict method 70 | model = HER.load('./her_bit_env', env=env) 71 | 72 | obs = env.reset() 73 | for _ in range(100): 74 | action, _ = model.predict(obs) 75 | obs, reward, done, _ = env.step(action) 76 | 77 | if done: 78 | obs = env.reset() 79 | 80 | 81 | Parameters 82 | ---------- 83 | 84 | .. autoclass:: HER 85 | :members: 86 | 87 | Goal Selection Strategies 88 | ------------------------- 89 | 90 | .. autoclass:: GoalSelectionStrategy 91 | :members: 92 | :inherited-members: 93 | :undoc-members: 94 | 95 | 96 | Goal Env Wrapper 97 | ---------------- 98 | 99 | .. autoclass:: HERGoalEnvWrapper 100 | :members: 101 | :inherited-members: 102 | :undoc-members: 103 | 104 | 105 | Replay Wrapper 106 | -------------- 107 | 108 | .. autoclass:: HindsightExperienceReplayWrapper 109 | :members: 110 | :inherited-members: 111 | -------------------------------------------------------------------------------- /docs/modules/policies.rst: -------------------------------------------------------------------------------- 1 | .. _policies: 2 | 3 | .. automodule:: stable_baselines.common.policies 4 | 5 | Policy Networks 6 | =============== 7 | 8 | Stable-baselines provides a set of default policies, that can be used with most action spaces. 9 | To customize the default policies, you can specify the ``policy_kwargs`` parameter to the model class you use. 10 | Those kwargs are then passed to the policy on instantiation (see :ref:`custom_policy` for an example). 11 | If you need more control on the policy architecture, you can also create a custom policy (see :ref:`custom_policy`). 12 | 13 | .. note:: 14 | 15 | CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints) 16 | 17 | .. warning:: 18 | For all algorithms (except DDPG, TD3 and SAC), continuous actions are clipped during training and testing 19 | (to avoid out of bound error). 20 | 21 | 22 | .. rubric:: Available Policies 23 | 24 | .. autosummary:: 25 | :nosignatures: 26 | 27 | MlpPolicy 28 | MlpLstmPolicy 29 | MlpLnLstmPolicy 30 | CnnPolicy 31 | CnnLstmPolicy 32 | CnnLnLstmPolicy 33 | 34 | 35 | Base Classes 36 | ------------ 37 | 38 | .. autoclass:: BasePolicy 39 | :members: 40 | 41 | .. autoclass:: ActorCriticPolicy 42 | :members: 43 | 44 | .. autoclass:: FeedForwardPolicy 45 | :members: 46 | 47 | .. autoclass:: LstmPolicy 48 | :members: 49 | 50 | MLP Policies 51 | ------------ 52 | 53 | .. autoclass:: MlpPolicy 54 | :members: 55 | 56 | .. autoclass:: MlpLstmPolicy 57 | :members: 58 | 59 | .. autoclass:: MlpLnLstmPolicy 60 | :members: 61 | 62 | 63 | CNN Policies 64 | ------------ 65 | 66 | .. autoclass:: CnnPolicy 67 | :members: 68 | 69 | .. autoclass:: CnnLstmPolicy 70 | :members: 71 | 72 | .. autoclass:: CnnLnLstmPolicy 73 | :members: 74 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | pandas 3 | -------------------------------------------------------------------------------- /docs/spelling_wordlist.txt: -------------------------------------------------------------------------------- 1 | py 2 | env 3 | atari 4 | argparse 5 | Argparse 6 | TensorFlow 7 | feedforward 8 | envs 9 | VecEnv 10 | pretrain 11 | petrained 12 | tf 13 | np 14 | mujoco 15 | cpu 16 | ndarray 17 | ndarrays 18 | timestep 19 | timesteps 20 | stepsize 21 | dataset 22 | adam 23 | fn 24 | normalisation 25 | Kullback 26 | Leibler 27 | boolean 28 | deserialized 29 | pretrained 30 | minibatch 31 | subprocesses 32 | ArgumentParser 33 | Tensorflow 34 | Gaussian 35 | approximator 36 | minibatches 37 | hyperparameters 38 | hyperparameter 39 | vectorized 40 | rl 41 | colab 42 | dataloader 43 | npz 44 | datasets 45 | vf 46 | logits 47 | num 48 | Utils 49 | backpropagate 50 | prepend 51 | NaN 52 | preprocessing 53 | Cloudpickle 54 | async 55 | multiprocess 56 | tensorflow 57 | mlp 58 | cnn 59 | neglogp 60 | tanh 61 | coef 62 | repo 63 | Huber 64 | params 65 | ppo 66 | arxiv 67 | Arxiv 68 | func 69 | DQN 70 | Uhlenbeck 71 | Ornstein 72 | multithread 73 | cancelled 74 | Tensorboard 75 | parallelize 76 | customising 77 | serializable 78 | Multiprocessed 79 | cartpole 80 | toolset 81 | lstm 82 | rescale 83 | ffmpeg 84 | avconv 85 | unnormalized 86 | Github 87 | pre 88 | preprocess 89 | backend 90 | attr 91 | preprocess 92 | Antonin 93 | Raffin 94 | araffin 95 | Homebrew 96 | Numpy 97 | Theano 98 | rollout 99 | kfac 100 | Piecewise 101 | csv 102 | nvidia 103 | visdom 104 | tensorboard 105 | preprocessed 106 | namespace 107 | sklearn 108 | GoalEnv 109 | BaseCallback 110 | Keras 111 | -------------------------------------------------------------------------------- /scripts/build_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CPU_PARENT=ubuntu:16.04 4 | GPU_PARENT=nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04 5 | 6 | TAG=stablebaselines/stable-baselines 7 | VERSION=$(cat ./stable_baselines/version.txt) 8 | 9 | if [[ ${USE_GPU} == "True" ]]; then 10 | PARENT=${GPU_PARENT} 11 | else 12 | PARENT=${CPU_PARENT} 13 | TAG="${TAG}-cpu" 14 | fi 15 | 16 | docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg USE_GPU=${USE_GPU} -t ${TAG}:${VERSION} . 17 | docker tag ${TAG}:${VERSION} ${TAG}:latest 18 | 19 | if [[ ${RELEASE} == "True" ]]; then 20 | docker push ${TAG}:${VERSION} 21 | docker push ${TAG}:latest 22 | fi 23 | -------------------------------------------------------------------------------- /scripts/run_docker_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Launch an experiment using the docker cpu image 3 | 4 | cmd_line="$@" 5 | 6 | echo "Executing in the docker (cpu image):" 7 | echo $cmd_line 8 | 9 | docker run -it --rm --network host --ipc=host \ 10 | --mount src=$(pwd),target=/root/code/stable-baselines,type=bind stablebaselines/stable-baselines-cpu:v2.10.0 \ 11 | bash -c "cd /root/code/stable-baselines/ && $cmd_line" 12 | -------------------------------------------------------------------------------- /scripts/run_docker_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Launch an experiment using the docker gpu image 3 | 4 | cmd_line="$@" 5 | 6 | echo "Executing in the docker (gpu image):" 7 | echo $cmd_line 8 | 9 | # TODO: always use new-style once sufficiently widely used (probably 2021 onwards) 10 | if [ -x "$(which nvidia-docker)" ]; then 11 | # old-style nvidia-docker2 12 | NVIDIA_ARG="--runtime=nvidia" 13 | else 14 | NVIDIA_ARG="--gpus all" 15 | fi 16 | 17 | docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \ 18 | --mount src=$(pwd),target=/root/code/stable-baselines,type=bind stablebaselines/stable-baselines:v2.10.0 \ 19 | bash -c "cd /root/code/stable-baselines/ && $cmd_line" 20 | -------------------------------------------------------------------------------- /scripts/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v 3 | -------------------------------------------------------------------------------- /scripts/run_tests_travis.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DOCKER_CMD="docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines,type=bind" 4 | BASH_CMD="cd /root/code/stable-baselines/" 5 | 6 | if [[ $# -ne 1 ]]; then 7 | echo "usage: $0 " 8 | exit 1 9 | fi 10 | 11 | if [[ ${DOCKER_IMAGE} = "" ]]; then 12 | echo "Need DOCKER_IMAGE environment variable to be set." 13 | exit 1 14 | fi 15 | 16 | TEST_GLOB=$1 17 | 18 | set -e # exit immediately on any error 19 | 20 | # For pull requests from fork, Codacy token is not available, leading to build failure 21 | if [[ ${CODACY_PROJECT_TOKEN} = "" ]]; then 22 | echo "WARNING: CODACY_PROJECT_TOKEN not set. Skipping Codacy upload." 23 | echo "(This is normal when building in a fork and can be ignored.)" 24 | ${DOCKER_CMD} ${DOCKER_IMAGE} \ 25 | bash -c "${BASH_CMD} && \ 26 | pytest --cov-config .coveragerc --cov-report term --cov=. -v tests/test_${TEST_GLOB}" 27 | else 28 | ${DOCKER_CMD} --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} \ 29 | bash -c "${BASH_CMD} && \ 30 | pytest --cov-config .coveragerc --cov-report term --cov-report xml --cov=. -v tests/test_${TEST_GLOB} && \ 31 | java -jar /root/code/codacy-coverage-reporter.jar report -l python -r coverage.xml --partial" 32 | fi 33 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file in the wheel. 3 | license_file = LICENSE 4 | 5 | [tool:pytest] 6 | # Deterministic ordering for tests; useful for pytest-xdist. 7 | env = 8 | PYTHONHASHSEED=0 9 | filterwarnings = 10 | ignore:inspect.getargspec:DeprecationWarning:tensorflow 11 | ignore::pytest.PytestUnknownMarkWarning 12 | # Tensorflow internal warnings 13 | ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning 14 | ignore:The binary mode of fromstring is deprecated:DeprecationWarning 15 | ignore::FutureWarning:tensorflow 16 | # Gym warnings 17 | ignore:Parameters to load are deprecated.:DeprecationWarning 18 | ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning 19 | 20 | [pytype] 21 | inputs = stable_baselines 22 | ; python_version = 3.5 23 | -------------------------------------------------------------------------------- /stable_baselines/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | from stable_baselines.a2c import A2C 5 | from stable_baselines.acer import ACER 6 | from stable_baselines.acktr import ACKTR 7 | from stable_baselines.deepq import DQN 8 | from stable_baselines.her import HER 9 | from stable_baselines.ppo2 import PPO2 10 | from stable_baselines.td3 import TD3 11 | from stable_baselines.sac import SAC 12 | 13 | # Load mpi4py-dependent algorithms only if mpi is installed. 14 | try: 15 | import mpi4py 16 | except ImportError: 17 | mpi4py = None 18 | 19 | if mpi4py is not None: 20 | from stable_baselines.ddpg import DDPG 21 | from stable_baselines.gail import GAIL 22 | from stable_baselines.ppo1 import PPO1 23 | from stable_baselines.trpo_mpi import TRPO 24 | del mpi4py 25 | 26 | # Read version from file 27 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 28 | with open(version_file, "r") as file_handler: 29 | __version__ = file_handler.read().strip() 30 | 31 | 32 | warnings.warn( 33 | "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation." 34 | ) 35 | -------------------------------------------------------------------------------- /stable_baselines/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.a2c.a2c import A2C 2 | -------------------------------------------------------------------------------- /stable_baselines/a2c/run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from stable_baselines import logger, A2C 4 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser 5 | from stable_baselines.common.vec_env import VecFrameStack 6 | from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy 7 | 8 | 9 | def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env): 10 | """ 11 | Train A2C model for atari environment, for testing purposes 12 | 13 | :param env_id: (str) Environment ID 14 | :param num_timesteps: (int) The total number of samples 15 | :param seed: (int) The initial seed for training 16 | :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...) 17 | :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant', 18 | 'double_linear_con', 'middle_drop' or 'double_middle_drop') 19 | :param num_env: (int) The number of environments 20 | """ 21 | policy_fn = None 22 | if policy == 'cnn': 23 | policy_fn = CnnPolicy 24 | elif policy == 'lstm': 25 | policy_fn = CnnLstmPolicy 26 | elif policy == 'lnlstm': 27 | policy_fn = CnnLnLstmPolicy 28 | if policy_fn is None: 29 | raise ValueError("Error: policy {} not implemented".format(policy)) 30 | 31 | env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4) 32 | 33 | model = A2C(policy_fn, env, lr_schedule=lr_schedule, seed=seed) 34 | model.learn(total_timesteps=int(num_timesteps * 1.1)) 35 | env.close() 36 | 37 | 38 | def main(): 39 | """ 40 | Runs the test 41 | """ 42 | parser = atari_arg_parser() 43 | parser.add_argument('--policy', choices=['cnn', 'lstm', 'lnlstm'], default='cnn', help='Policy architecture') 44 | parser.add_argument('--lr_schedule', choices=['constant', 'linear'], default='constant', 45 | help='Learning rate schedule') 46 | args = parser.parse_args() 47 | logger.configure() 48 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, policy=args.policy, lr_schedule=args.lr_schedule, 49 | num_env=16) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /stable_baselines/acer/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.acer.acer_simple import ACER 2 | -------------------------------------------------------------------------------- /stable_baselines/acer/run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import warnings 3 | 4 | from stable_baselines import logger, ACER 5 | from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy 6 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser 7 | from stable_baselines.common.vec_env import VecFrameStack 8 | 9 | 10 | def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu): 11 | """ 12 | train an ACER model on atari 13 | 14 | :param env_id: (str) Environment ID 15 | :param num_timesteps: (int) The total number of samples 16 | :param seed: (int) The initial seed for training 17 | :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...) 18 | :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant', 19 | 'double_linear_con', 'middle_drop' or 'double_middle_drop') 20 | :param num_cpu: (int) The number of cpu to train on 21 | """ 22 | env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) 23 | if policy == 'cnn': 24 | policy_fn = CnnPolicy 25 | elif policy == 'lstm': 26 | policy_fn = CnnLstmPolicy 27 | else: 28 | warnings.warn("Policy {} not implemented".format(policy)) 29 | return 30 | 31 | model = ACER(policy_fn, env, lr_schedule=lr_schedule, buffer_size=5000, seed=seed) 32 | model.learn(total_timesteps=int(num_timesteps * 1.1)) 33 | env.close() 34 | # Free memory 35 | del model 36 | 37 | 38 | def main(): 39 | """ 40 | Runs the test 41 | """ 42 | parser = atari_arg_parser() 43 | parser.add_argument('--policy', choices=['cnn', 'lstm', 'lnlstm'], default='cnn', help='Policy architecture') 44 | parser.add_argument('--lr_schedule', choices=['constant', 'linear'], default='constant', 45 | help='Learning rate schedule') 46 | parser.add_argument('--logdir', help='Directory for logging') 47 | args = parser.parse_args() 48 | logger.configure(args.logdir) 49 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, 50 | policy=args.policy, lr_schedule=args.lr_schedule, num_cpu=16) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /stable_baselines/acktr/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.acktr.acktr import ACKTR 2 | -------------------------------------------------------------------------------- /stable_baselines/acktr/run_atari.py: -------------------------------------------------------------------------------- 1 | from stable_baselines import logger, ACKTR 2 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser 3 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack 4 | from stable_baselines.common.policies import CnnPolicy 5 | 6 | 7 | def train(env_id, num_timesteps, seed, num_cpu): 8 | """ 9 | train an ACKTR model on atari 10 | 11 | :param env_id: (str) Environment ID 12 | :param num_timesteps: (int) The total number of samples 13 | :param seed: (int) The initial seed for training 14 | :param num_cpu: (int) The number of cpu to train on 15 | """ 16 | env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) 17 | model = ACKTR(CnnPolicy, env, nprocs=num_cpu, seed=seed) 18 | model.learn(total_timesteps=int(num_timesteps * 1.1)) 19 | env.close() 20 | 21 | 22 | def main(): 23 | """ 24 | Runs the test 25 | """ 26 | args = atari_arg_parser().parse_args() 27 | logger.configure() 28 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /stable_baselines/bench/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.bench.monitor import Monitor, load_results 2 | -------------------------------------------------------------------------------- /stable_baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F403 2 | from stable_baselines.common.console_util import fmt_row, fmt_item, colorize 3 | from stable_baselines.common.dataset import Dataset 4 | from stable_baselines.common.math_util import discount, discount_with_boundaries, explained_variance, \ 5 | explained_variance_2d, flatten_arrays, unflatten_vector 6 | from stable_baselines.common.misc_util import zipsame, set_global_seeds, boolean_flag 7 | from stable_baselines.common.base_class import BaseRLModel, ActorCriticRLModel, OffPolicyRLModel, SetVerbosity, \ 8 | TensorboardWriter 9 | from stable_baselines.common.cmd_util import make_vec_env 10 | -------------------------------------------------------------------------------- /stable_baselines/common/cg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def conjugate_gradient(f_ax, b_vec, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): 5 | """ 6 | conjugate gradient calculation (Ax = b), bases on 7 | https://epubs.siam.org/doi/book/10.1137/1.9781611971446 Demmel p 312 8 | 9 | :param f_ax: (function) The function describing the Matrix A dot the vector x 10 | (x being the input parameter of the function) 11 | :param b_vec: (numpy float) vector b, where Ax = b 12 | :param cg_iters: (int) the maximum number of iterations for converging 13 | :param callback: (function) callback the values of x while converging 14 | :param verbose: (bool) print extra information 15 | :param residual_tol: (float) the break point if the residual is below this value 16 | :return: (numpy float) vector x, where Ax = b 17 | """ 18 | first_basis_vect = b_vec.copy() # the first basis vector 19 | residual = b_vec.copy() # the residual 20 | x_var = np.zeros_like(b_vec) # vector x, where Ax = b 21 | residual_dot_residual = residual.dot(residual) # L2 norm of the residual 22 | 23 | fmt_str = "%10i %10.3g %10.3g" 24 | title_str = "%10s %10s %10s" 25 | if verbose: 26 | print(title_str % ("iter", "residual norm", "soln norm")) 27 | 28 | for i in range(cg_iters): 29 | if callback is not None: 30 | callback(x_var) 31 | if verbose: 32 | print(fmt_str % (i, residual_dot_residual, np.linalg.norm(x_var))) 33 | z_var = f_ax(first_basis_vect) 34 | v_var = residual_dot_residual / first_basis_vect.dot(z_var) 35 | x_var += v_var * first_basis_vect 36 | residual -= v_var * z_var 37 | new_residual_dot_residual = residual.dot(residual) 38 | mu_val = new_residual_dot_residual / residual_dot_residual 39 | first_basis_vect = residual + mu_val * first_basis_vect 40 | 41 | residual_dot_residual = new_residual_dot_residual 42 | if residual_dot_residual < residual_tol: 43 | break 44 | 45 | if callback is not None: 46 | callback(x_var) 47 | if verbose: 48 | print(fmt_str % (i + 1, residual_dot_residual, np.linalg.norm(x_var))) 49 | return x_var 50 | -------------------------------------------------------------------------------- /stable_baselines/common/console_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | 5 | 6 | # ================================================================ 7 | # Misc 8 | # ================================================================ 9 | 10 | 11 | def fmt_row(width, row, header=False): 12 | """ 13 | fits a list of items to at least a certain length 14 | 15 | :param width: (int) the minimum width of the string 16 | :param row: ([Any]) a list of object you wish to get the string representation 17 | :param header: (bool) whether or not to return the string as a header 18 | :return: (str) the string representation of all the elements in 'row', of length >= 'width' 19 | """ 20 | out = " | ".join(fmt_item(x, width) for x in row) 21 | if header: 22 | out = out + "\n" + "-" * len(out) 23 | return out 24 | 25 | 26 | def fmt_item(item, min_width): 27 | """ 28 | fits items to a given string length 29 | 30 | :param item: (Any) the item you wish to get the string representation 31 | :param min_width: (int) the minimum width of the string 32 | :return: (str) the string representation of 'x' of length >= 'l' 33 | """ 34 | if isinstance(item, np.ndarray): 35 | assert item.ndim == 0 36 | item = item.item() 37 | if isinstance(item, (float, np.float32, np.float64)): 38 | value = abs(item) 39 | if (value < 1e-4 or value > 1e+4) and value > 0: 40 | rep = "%7.2e" % item 41 | else: 42 | rep = "%7.5f" % item 43 | else: 44 | rep = str(item) 45 | return " " * (min_width - len(rep)) + rep 46 | 47 | 48 | COLOR_TO_NUM = dict( 49 | gray=30, 50 | red=31, 51 | green=32, 52 | yellow=33, 53 | blue=34, 54 | magenta=35, 55 | cyan=36, 56 | white=37, 57 | crimson=38 58 | ) 59 | 60 | 61 | def colorize(string, color, bold=False, highlight=False): 62 | """ 63 | Colorize, bold and/or highlight a string for terminal print 64 | 65 | :param string: (str) input string 66 | :param color: (str) the color, the lookup table is the dict at console_util.color2num 67 | :param bold: (bool) if the string should be bold or not 68 | :param highlight: (bool) if the string should be highlighted or not 69 | :return: (str) the stylized output string 70 | """ 71 | attr = [] 72 | num = COLOR_TO_NUM[color] 73 | if highlight: 74 | num += 10 75 | attr.append(str(num)) 76 | if bold: 77 | attr.append('1') 78 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 79 | -------------------------------------------------------------------------------- /stable_baselines/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Dataset(object): 5 | def __init__(self, data_map, shuffle=True): 6 | """ 7 | Data loader that handles batches and shuffling. 8 | WARNING: this will alter the given data_map ordering, as dicts are mutable 9 | 10 | :param data_map: (dict) the input data, where every column is a key 11 | :param shuffle: (bool) Whether to shuffle or not the dataset 12 | Important: this should be disabled for recurrent policies 13 | """ 14 | self.data_map = data_map 15 | self.shuffle = shuffle 16 | self.n_samples = next(iter(data_map.values())).shape[0] 17 | self._next_id = 0 18 | if self.shuffle: 19 | self.shuffle_dataset() 20 | 21 | def shuffle_dataset(self): 22 | """ 23 | Shuffles the data_map 24 | """ 25 | perm = np.arange(self.n_samples) 26 | np.random.shuffle(perm) 27 | 28 | for key in self.data_map: 29 | self.data_map[key] = self.data_map[key][perm] 30 | 31 | def next_batch(self, batch_size): 32 | """ 33 | returns a batch of data of a given size 34 | 35 | :param batch_size: (int) the size of the batch 36 | :return: (dict) a batch of the input data of size 'batch_size' 37 | """ 38 | if self._next_id >= self.n_samples: 39 | self._next_id = 0 40 | if self.shuffle: 41 | self.shuffle_dataset() 42 | 43 | cur_id = self._next_id 44 | cur_batch_size = min(batch_size, self.n_samples - self._next_id) 45 | self._next_id += cur_batch_size 46 | 47 | data_map = dict() 48 | for key in self.data_map: 49 | data_map[key] = self.data_map[key][cur_id:cur_id + cur_batch_size] 50 | return data_map 51 | 52 | def iterate_once(self, batch_size): 53 | """ 54 | generator that iterates over the dataset 55 | 56 | :param batch_size: (int) the size of the batch 57 | :return: (dict) a batch of the input data of size 'batch_size' 58 | """ 59 | if self.shuffle: 60 | self.shuffle_dataset() 61 | 62 | while self._next_id <= self.n_samples - batch_size: 63 | yield self.next_batch(batch_size) 64 | self._next_id = 0 65 | 66 | def subset(self, num_elements, shuffle=True): 67 | """ 68 | Return a subset of the current dataset 69 | 70 | :param num_elements: (int) the number of element you wish to have in the subset 71 | :param shuffle: (bool) Whether to shuffle or not the dataset 72 | :return: (Dataset) a new subset of the current Dataset object 73 | """ 74 | data_map = dict() 75 | for key in self.data_map: 76 | data_map[key] = self.data_map[key][:num_elements] 77 | return Dataset(data_map, shuffle) 78 | 79 | 80 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True): 81 | """ 82 | Iterates over arrays in batches, must provide either num_batches or batch_size, the other must be None. 83 | 84 | :param arrays: (tuple) a tuple of arrays 85 | :param num_batches: (int) the number of batches, must be None is batch_size is defined 86 | :param batch_size: (int) the size of the batch, must be None is num_batches is defined 87 | :param shuffle: (bool) enable auto shuffle 88 | :param include_final_partial_batch: (bool) add the last batch if not the same size as the batch_size 89 | :return: (tuples) a tuple of a batch of the arrays 90 | """ 91 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both' 92 | arrays = tuple(map(np.asarray, arrays)) 93 | n_samples = arrays[0].shape[0] 94 | assert all(a.shape[0] == n_samples for a in arrays[1:]) 95 | inds = np.arange(n_samples) 96 | if shuffle: 97 | np.random.shuffle(inds) 98 | sections = np.arange(0, n_samples, batch_size)[1:] if num_batches is None else num_batches 99 | for batch_inds in np.array_split(inds, sections): 100 | if include_final_partial_batch or len(batch_inds) == batch_size: 101 | yield tuple(a[batch_inds] for a in arrays) 102 | -------------------------------------------------------------------------------- /stable_baselines/common/evaluation.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Callable, List, Optional, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from stable_baselines.common.vec_env import VecEnv 8 | 9 | if typing.TYPE_CHECKING: 10 | from stable_baselines.common.base_class import BaseRLModel 11 | 12 | 13 | def evaluate_policy( 14 | model: "BaseRLModel", 15 | env: Union[gym.Env, VecEnv], 16 | n_eval_episodes: int = 10, 17 | deterministic: bool = True, 18 | render: bool = False, 19 | callback: Optional[Callable] = None, 20 | reward_threshold: Optional[float] = None, 21 | return_episode_rewards: bool = False, 22 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: 23 | """ 24 | Runs policy for ``n_eval_episodes`` episodes and returns average reward. 25 | This is made to work only with one env. 26 | 27 | :param model: (BaseRLModel) The RL agent you want to evaluate. 28 | :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv`` 29 | this must contain only one environment. 30 | :param n_eval_episodes: (int) Number of episode to evaluate the agent 31 | :param deterministic: (bool) Whether to use deterministic or stochastic actions 32 | :param render: (bool) Whether to render the environment or not 33 | :param callback: (callable) callback function to do additional checks, 34 | called after each step. 35 | :param reward_threshold: (float) Minimum expected reward per episode, 36 | this will raise an error if the performance is not met 37 | :param return_episode_rewards: (Optional[float]) If True, a list of reward per episode 38 | will be returned instead of the mean. 39 | :return: (float, float) Mean reward per episode, std of reward per episode 40 | returns ([float], [int]) when ``return_episode_rewards`` is True 41 | """ 42 | if isinstance(env, VecEnv): 43 | assert env.num_envs == 1, "You must pass only one environment when using this function" 44 | 45 | is_recurrent = model.policy.recurrent 46 | 47 | episode_rewards, episode_lengths = [], [] 48 | for i in range(n_eval_episodes): 49 | # Avoid double reset, as VecEnv are reset automatically 50 | if not isinstance(env, VecEnv) or i == 0: 51 | obs = env.reset() 52 | # Because recurrent policies need the same observation space during training and evaluation, we need to pad 53 | # observation to match training shape. See https://github.com/hill-a/stable-baselines/issues/1015 54 | if is_recurrent: 55 | zero_completed_obs = np.zeros((model.n_envs,) + model.observation_space.shape) 56 | zero_completed_obs[0, :] = obs 57 | obs = zero_completed_obs 58 | done, state = False, None 59 | episode_reward = 0.0 60 | episode_length = 0 61 | while not done: 62 | action, state = model.predict(obs, state=state, deterministic=deterministic) 63 | new_obs, reward, done, _info = env.step(action) 64 | if is_recurrent: 65 | obs[0, :] = new_obs 66 | else: 67 | obs = new_obs 68 | episode_reward += reward 69 | if callback is not None: 70 | callback(locals(), globals()) 71 | episode_length += 1 72 | if render: 73 | env.render() 74 | episode_rewards.append(episode_reward) 75 | episode_lengths.append(episode_length) 76 | mean_reward = np.mean(episode_rewards) 77 | std_reward = np.std(episode_rewards) 78 | if reward_threshold is not None: 79 | assert mean_reward > reward_threshold, "Mean reward below threshold: {:.2f} < {:.2f}".format(mean_reward, reward_threshold) 80 | if return_episode_rewards: 81 | return episode_rewards, episode_lengths 82 | return mean_reward, std_reward 83 | -------------------------------------------------------------------------------- /stable_baselines/common/identity_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Optional 3 | 4 | from gym import Env, Space 5 | from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box 6 | 7 | 8 | class IdentityEnv(Env): 9 | def __init__(self, 10 | dim: Optional[int] = None, 11 | space: Optional[Space] = None, 12 | ep_length: int = 100): 13 | """ 14 | Identity environment for testing purposes 15 | 16 | :param dim: the size of the action and observation dimension you want 17 | to learn. Provide at most one of `dim` and `space`. If both are 18 | None, then initialization proceeds with `dim=1` and `space=None`. 19 | :param space: the action and observation space. Provide at most one of 20 | `dim` and `space`. 21 | :param ep_length: the length of each episode in timesteps 22 | """ 23 | if space is None: 24 | if dim is None: 25 | dim = 1 26 | space = Discrete(dim) 27 | else: 28 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed" 29 | 30 | self.action_space = self.observation_space = space 31 | self.ep_length = ep_length 32 | self.current_step = 0 33 | self.num_resets = -1 # Becomes 0 after __init__ exits. 34 | self.reset() 35 | 36 | def reset(self): 37 | self.current_step = 0 38 | self.num_resets += 1 39 | self._choose_next_state() 40 | return self.state 41 | 42 | def step(self, action): 43 | reward = self._get_reward(action) 44 | self._choose_next_state() 45 | self.current_step += 1 46 | done = self.current_step >= self.ep_length 47 | return self.state, reward, done, {} 48 | 49 | def _choose_next_state(self): 50 | self.state = self.action_space.sample() 51 | 52 | def _get_reward(self, action): 53 | return 1 if np.all(self.state == action) else 0 54 | 55 | def render(self, mode='human'): 56 | pass 57 | 58 | 59 | class IdentityEnvBox(IdentityEnv): 60 | def __init__(self, low=-1, high=1, eps=0.05, ep_length=100): 61 | """ 62 | Identity environment for testing purposes 63 | 64 | :param low: (float) the lower bound of the box dim 65 | :param high: (float) the upper bound of the box dim 66 | :param eps: (float) the epsilon bound for correct value 67 | :param ep_length: (int) the length of each episode in timesteps 68 | """ 69 | space = Box(low=low, high=high, shape=(1,), dtype=np.float32) 70 | super().__init__(ep_length=ep_length, space=space) 71 | self.eps = eps 72 | 73 | def step(self, action): 74 | reward = self._get_reward(action) 75 | self._choose_next_state() 76 | self.current_step += 1 77 | done = self.current_step >= self.ep_length 78 | return self.state, reward, done, {} 79 | 80 | def _get_reward(self, action): 81 | return 1 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0 82 | 83 | 84 | class IdentityEnvMultiDiscrete(IdentityEnv): 85 | def __init__(self, dim=1, ep_length=100): 86 | """ 87 | Identity environment for testing purposes 88 | 89 | :param dim: (int) the size of the dimensions you want to learn 90 | :param ep_length: (int) the length of each episode in timesteps 91 | """ 92 | space = MultiDiscrete([dim, dim]) 93 | super().__init__(ep_length=ep_length, space=space) 94 | 95 | 96 | class IdentityEnvMultiBinary(IdentityEnv): 97 | def __init__(self, dim=1, ep_length=100): 98 | """ 99 | Identity environment for testing purposes 100 | 101 | :param dim: (int) the size of the dimensions you want to learn 102 | :param ep_length: (int) the length of each episode in timesteps 103 | """ 104 | space = MultiBinary(dim) 105 | super().__init__(ep_length=ep_length, space=space) 106 | -------------------------------------------------------------------------------- /stable_baselines/common/input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete 4 | 5 | 6 | def observation_input(ob_space, batch_size=None, name='Ob', scale=False): 7 | """ 8 | Build observation input with encoding depending on the observation space type 9 | 10 | When using Box ob_space, the input will be normalized between [1, 0] on the bounds ob_space.low and ob_space.high. 11 | 12 | :param ob_space: (Gym Space) The observation space 13 | :param batch_size: (int) batch size for input 14 | (default is None, so that resulting input placeholder can take tensors with any batch size) 15 | :param name: (str) tensorflow variable name for input placeholder 16 | :param scale: (bool) whether or not to scale the input 17 | :return: (TensorFlow Tensor, TensorFlow Tensor) input_placeholder, processed_input_tensor 18 | """ 19 | if isinstance(ob_space, Discrete): 20 | observation_ph = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name) 21 | processed_observations = tf.cast(tf.one_hot(observation_ph, ob_space.n), tf.float32) 22 | return observation_ph, processed_observations 23 | 24 | elif isinstance(ob_space, Box): 25 | observation_ph = tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name) 26 | processed_observations = tf.cast(observation_ph, tf.float32) 27 | # rescale to [1, 0] if the bounds are defined 28 | if (scale and 29 | not np.any(np.isinf(ob_space.low)) and not np.any(np.isinf(ob_space.high)) and 30 | np.any((ob_space.high - ob_space.low) != 0)): 31 | 32 | # equivalent to processed_observations / 255.0 when bounds are set to [255, 0] 33 | processed_observations = ((processed_observations - ob_space.low) / (ob_space.high - ob_space.low)) 34 | return observation_ph, processed_observations 35 | 36 | elif isinstance(ob_space, MultiBinary): 37 | observation_ph = tf.placeholder(shape=(batch_size, ob_space.n), dtype=tf.int32, name=name) 38 | processed_observations = tf.cast(observation_ph, tf.float32) 39 | return observation_ph, processed_observations 40 | 41 | elif isinstance(ob_space, MultiDiscrete): 42 | observation_ph = tf.placeholder(shape=(batch_size, len(ob_space.nvec)), dtype=tf.int32, name=name) 43 | processed_observations = tf.concat([ 44 | tf.cast(tf.one_hot(input_split, ob_space.nvec[i]), tf.float32) for i, input_split 45 | in enumerate(tf.split(observation_ph, len(ob_space.nvec), axis=-1)) 46 | ], axis=-1) 47 | return observation_ph, processed_observations 48 | 49 | else: 50 | raise NotImplementedError("Error: the model does not support input space of type {}".format( 51 | type(ob_space).__name__)) 52 | -------------------------------------------------------------------------------- /stable_baselines/common/misc_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import gym 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | def zipsame(*seqs): 9 | """ 10 | Performes a zip function, but asserts that all zipped elements are of the same size 11 | 12 | :param seqs: a list of arrays that are zipped together 13 | :return: the zipped arguments 14 | """ 15 | length = len(seqs[0]) 16 | assert all(len(seq) == length for seq in seqs[1:]) 17 | return zip(*seqs) 18 | 19 | 20 | def set_global_seeds(seed): 21 | """ 22 | set the seed for python random, tensorflow, numpy and gym spaces 23 | 24 | :param seed: (int) the seed 25 | """ 26 | tf.set_random_seed(seed) 27 | np.random.seed(seed) 28 | random.seed(seed) 29 | # prng was removed in latest gym version 30 | if hasattr(gym.spaces, 'prng'): 31 | gym.spaces.prng.seed(seed) 32 | 33 | 34 | def boolean_flag(parser, name, default=False, help_msg=None): 35 | """ 36 | Add a boolean flag to argparse parser. 37 | 38 | :param parser: (argparse.Parser) parser to add the flag to 39 | :param name: (str) -- will enable the flag, while --no- will disable it 40 | :param default: (bool) default value of the flag 41 | :param help_msg: (str) help string for the flag 42 | """ 43 | dest = name.replace('-', '_') 44 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help_msg) 45 | parser.add_argument("--no-" + name, action="store_false", dest=dest) 46 | 47 | 48 | def mpi_rank_or_zero(): 49 | """ 50 | Return the MPI rank if mpi is installed. Otherwise, return 0. 51 | :return: (int) 52 | """ 53 | try: 54 | import mpi4py 55 | return mpi4py.MPI.COMM_WORLD.Get_rank() 56 | except ImportError: 57 | return 0 58 | 59 | 60 | def flatten_lists(listoflists): 61 | """ 62 | Flatten a python list of list 63 | 64 | :param listoflists: (list(list)) 65 | :return: (list) 66 | """ 67 | return [el for list_ in listoflists for el in list_] 68 | -------------------------------------------------------------------------------- /stable_baselines/common/mpi_moments.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | 4 | from stable_baselines.common.misc_util import zipsame 5 | 6 | 7 | def mpi_mean(arr, axis=0, comm=None, keepdims=False): 8 | """ 9 | calculates the mean of an array, using MPI 10 | 11 | :param arr: (np.ndarray) 12 | :param axis: (int or tuple or list) the axis to run the means over 13 | :param comm: (MPI Communicators) if None, MPI.COMM_WORLD 14 | :param keepdims: (bool) keep the other dimensions intact 15 | :return: (np.ndarray or Number) the result of the sum 16 | """ 17 | arr = np.asarray(arr) 18 | assert arr.ndim > 0 19 | if comm is None: 20 | comm = MPI.COMM_WORLD 21 | xsum = arr.sum(axis=axis, keepdims=keepdims) 22 | size = xsum.size 23 | localsum = np.zeros(size + 1, arr.dtype) 24 | localsum[:size] = xsum.ravel() 25 | localsum[size] = arr.shape[axis] 26 | globalsum = np.zeros_like(localsum) 27 | comm.Allreduce(localsum, globalsum, op=MPI.SUM) 28 | return globalsum[:size].reshape(xsum.shape) / globalsum[size], globalsum[size] 29 | 30 | 31 | def mpi_moments(arr, axis=0, comm=None, keepdims=False): 32 | """ 33 | calculates the mean and std of an array, using MPI 34 | 35 | :param arr: (np.ndarray) 36 | :param axis: (int or tuple or list) the axis to run the moments over 37 | :param comm: (MPI Communicators) if None, MPI.COMM_WORLD 38 | :param keepdims: (bool) keep the other dimensions intact 39 | :return: (np.ndarray or Number) the result of the moments 40 | """ 41 | arr = np.asarray(arr) 42 | assert arr.ndim > 0 43 | mean, count = mpi_mean(arr, axis=axis, comm=comm, keepdims=True) 44 | sqdiffs = np.square(arr - mean) 45 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) 46 | assert count1 == count 47 | std = np.sqrt(meansqdiff) 48 | if not keepdims: 49 | newshape = mean.shape[:axis] + mean.shape[axis + 1:] 50 | mean = mean.reshape(newshape) 51 | std = std.reshape(newshape) 52 | return mean, std, count 53 | 54 | 55 | def _helper_runningmeanstd(): 56 | comm = MPI.COMM_WORLD 57 | np.random.seed(0) 58 | for (triple, axis) in [ 59 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)), 0), 60 | ((np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), 0), 61 | ((np.random.randn(2, 3), np.random.randn(2, 4), np.random.randn(2, 4)), 1)]: 62 | 63 | arr = np.concatenate(triple, axis=axis) 64 | ms1 = [arr.mean(axis=axis), arr.std(axis=axis), arr.shape[axis]] 65 | 66 | ms2 = mpi_moments(triple[comm.Get_rank()], axis=axis) 67 | 68 | for (res_1, res_2) in zipsame(ms1, ms2): 69 | print(res_1, res_2) 70 | assert np.allclose(res_1, res_2) 71 | print("ok!") 72 | -------------------------------------------------------------------------------- /stable_baselines/common/mpi_running_mean_std.py: -------------------------------------------------------------------------------- 1 | import mpi4py 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | import stable_baselines.common.tf_util as tf_util 6 | 7 | 8 | class RunningMeanStd(object): 9 | def __init__(self, epsilon=1e-2, shape=()): 10 | """ 11 | calulates the running mean and std of a data stream 12 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 13 | 14 | :param epsilon: (float) helps with arithmetic issues 15 | :param shape: (tuple) the shape of the data stream's output 16 | """ 17 | self._sum = tf.get_variable( 18 | dtype=tf.float64, 19 | shape=shape, 20 | initializer=tf.constant_initializer(0.0), 21 | name="runningsum", trainable=False) 22 | self._sumsq = tf.get_variable( 23 | dtype=tf.float64, 24 | shape=shape, 25 | initializer=tf.constant_initializer(epsilon), 26 | name="runningsumsq", trainable=False) 27 | self._count = tf.get_variable( 28 | dtype=tf.float64, 29 | shape=(), 30 | initializer=tf.constant_initializer(epsilon), 31 | name="count", trainable=False) 32 | self.shape = shape 33 | 34 | self.mean = tf.cast(self._sum / self._count, tf.float32) 35 | self.std = tf.sqrt(tf.maximum(tf.cast(self._sumsq / self._count, tf.float32) - tf.square(self.mean), 36 | 1e-2)) 37 | 38 | newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum') 39 | newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var') 40 | newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') 41 | self.incfiltparams = tf_util.function([newsum, newsumsq, newcount], [], 42 | updates=[tf.assign_add(self._sum, newsum), 43 | tf.assign_add(self._sumsq, newsumsq), 44 | tf.assign_add(self._count, newcount)]) 45 | 46 | def update(self, data): 47 | """ 48 | update the running mean and std 49 | 50 | :param data: (np.ndarray) the data 51 | """ 52 | data = data.astype('float64') 53 | data_size = int(np.prod(self.shape)) 54 | totalvec = np.zeros(data_size * 2 + 1, 'float64') 55 | addvec = np.concatenate([data.sum(axis=0).ravel(), np.square(data).sum(axis=0).ravel(), 56 | np.array([len(data)], dtype='float64')]) 57 | mpi4py.MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=mpi4py.MPI.SUM) 58 | self.incfiltparams(totalvec[0: data_size].reshape(self.shape), 59 | totalvec[data_size: 2 * data_size].reshape(self.shape), totalvec[2 * data_size]) 60 | 61 | 62 | @tf_util.in_session 63 | def test_dist(): 64 | """ 65 | test the running mean std 66 | """ 67 | np.random.seed(0) 68 | p_1, p_2, p_3 = (np.random.randn(3, 1), np.random.randn(4, 1), np.random.randn(5, 1)) 69 | q_1, q_2, q_3 = (np.random.randn(6, 1), np.random.randn(7, 1), np.random.randn(8, 1)) 70 | 71 | comm = mpi4py.MPI.COMM_WORLD 72 | assert comm.Get_size() == 2 73 | if comm.Get_rank() == 0: 74 | x_1, x_2, x_3 = p_1, p_2, p_3 75 | elif comm.Get_rank() == 1: 76 | x_1, x_2, x_3 = q_1, q_2, q_3 77 | else: 78 | assert False 79 | 80 | rms = RunningMeanStd(epsilon=0.0, shape=(1,)) 81 | tf_util.initialize() 82 | 83 | rms.update(x_1) 84 | rms.update(x_2) 85 | rms.update(x_3) 86 | 87 | bigvec = np.concatenate([p_1, p_2, p_3, q_1, q_2, q_3]) 88 | 89 | def checkallclose(var_1, var_2): 90 | print(var_1, var_2) 91 | return np.allclose(var_1, var_2) 92 | 93 | assert checkallclose( 94 | bigvec.mean(axis=0), 95 | rms.mean.eval(), 96 | ) 97 | assert checkallclose( 98 | bigvec.std(axis=0), 99 | rms.std.eval(), 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | # Run with mpirun -np 2 python 105 | test_dist() 106 | -------------------------------------------------------------------------------- /stable_baselines/common/noise.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | 5 | 6 | class AdaptiveParamNoiseSpec(object): 7 | """ 8 | Implements adaptive parameter noise 9 | 10 | :param initial_stddev: (float) the initial value for the standard deviation of the noise 11 | :param desired_action_stddev: (float) the desired value for the standard deviation of the noise 12 | :param adoption_coefficient: (float) the update coefficient for the standard deviation of the noise 13 | """ 14 | 15 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01): 16 | self.initial_stddev = initial_stddev 17 | self.desired_action_stddev = desired_action_stddev 18 | self.adoption_coefficient = adoption_coefficient 19 | 20 | self.current_stddev = initial_stddev 21 | 22 | def adapt(self, distance): 23 | """ 24 | update the standard deviation for the parameter noise 25 | 26 | :param distance: (float) the noise distance applied to the parameters 27 | """ 28 | if distance > self.desired_action_stddev: 29 | # Decrease stddev. 30 | self.current_stddev /= self.adoption_coefficient 31 | else: 32 | # Increase stddev. 33 | self.current_stddev *= self.adoption_coefficient 34 | 35 | def get_stats(self): 36 | """ 37 | return the standard deviation for the parameter noise 38 | 39 | :return: (dict) the stats of the noise 40 | """ 41 | return {'param_noise_stddev': self.current_stddev} 42 | 43 | def __repr__(self): 44 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})' 45 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient) 46 | 47 | 48 | class ActionNoise(ABC): 49 | """ 50 | The action noise base class 51 | """ 52 | 53 | def __init__(self): 54 | super(ActionNoise, self).__init__() 55 | 56 | def reset(self) -> None: 57 | """ 58 | call end of episode reset for the noise 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def __call__(self) -> np.ndarray: 64 | raise NotImplementedError() 65 | 66 | 67 | class NormalActionNoise(ActionNoise): 68 | """ 69 | A Gaussian action noise 70 | 71 | :param mean: (float) the mean value of the noise 72 | :param sigma: (float) the scale of the noise (std here) 73 | """ 74 | 75 | def __init__(self, mean, sigma): 76 | super().__init__() 77 | self._mu = mean 78 | self._sigma = sigma 79 | 80 | def __call__(self) -> np.ndarray: 81 | return np.random.normal(self._mu, self._sigma) 82 | 83 | def __repr__(self) -> str: 84 | return 'NormalActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma) 85 | 86 | 87 | class OrnsteinUhlenbeckActionNoise(ActionNoise): 88 | """ 89 | A Ornstein Uhlenbeck action noise, this is designed to approximate brownian motion with friction. 90 | 91 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 92 | 93 | :param mean: (float) the mean of the noise 94 | :param sigma: (float) the scale of the noise 95 | :param theta: (float) the rate of mean reversion 96 | :param dt: (float) the timestep for the noise 97 | :param initial_noise: ([float]) the initial value for the noise output, (if None: 0) 98 | """ 99 | 100 | def __init__(self, mean, sigma, theta=.15, dt=1e-2, initial_noise=None): 101 | super().__init__() 102 | self._theta = theta 103 | self._mu = mean 104 | self._sigma = sigma 105 | self._dt = dt 106 | self.initial_noise = initial_noise 107 | self.noise_prev = None 108 | self.reset() 109 | 110 | def __call__(self) -> np.ndarray: 111 | noise = self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + \ 112 | self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) 113 | self.noise_prev = noise 114 | return noise 115 | 116 | def reset(self) -> None: 117 | """ 118 | reset the Ornstein Uhlenbeck noise, to the initial position 119 | """ 120 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) 121 | 122 | def __repr__(self) -> str: 123 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma) 124 | -------------------------------------------------------------------------------- /stable_baselines/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RunningMeanStd(object): 5 | def __init__(self, epsilon=1e-4, shape=()): 6 | """ 7 | calulates the running mean and std of a data stream 8 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 9 | 10 | :param epsilon: (float) helps with arithmetic issues 11 | :param shape: (tuple) the shape of the data stream's output 12 | """ 13 | self.mean = np.zeros(shape, 'float64') 14 | self.var = np.ones(shape, 'float64') 15 | self.count = epsilon 16 | 17 | def update(self, arr): 18 | batch_mean = np.mean(arr, axis=0) 19 | batch_var = np.var(arr, axis=0) 20 | batch_count = arr.shape[0] 21 | self.update_from_moments(batch_mean, batch_var, batch_count) 22 | 23 | def update_from_moments(self, batch_mean, batch_var, batch_count): 24 | delta = batch_mean - self.mean 25 | tot_count = self.count + batch_count 26 | 27 | new_mean = self.mean + delta * batch_count / tot_count 28 | m_a = self.var * self.count 29 | m_b = batch_var * batch_count 30 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 31 | new_var = m_2 / (self.count + batch_count) 32 | 33 | new_count = batch_count + self.count 34 | 35 | self.mean = new_mean 36 | self.var = new_var 37 | self.count = new_count 38 | -------------------------------------------------------------------------------- /stable_baselines/common/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def tile_images(img_nhwc): 5 | """ 6 | Tile N images into one big PxQ image 7 | (P,Q) are chosen to be as close as possible, and if N 8 | is square, then P=Q. 9 | 10 | :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc 11 | n = batch index, h = height, w = width, c = channel 12 | :return: (numpy float) img_HWc, ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | n_images, height, width, n_channels = img_nhwc.shape 16 | # new_height was named H before 17 | new_height = int(np.ceil(np.sqrt(n_images))) 18 | # new_width was named W before 19 | new_width = int(np.ceil(float(n_images) / new_height)) 20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) 21 | # img_HWhwc 22 | out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels) 23 | # img_HhWwc 24 | out_image = out_image.transpose(0, 2, 1, 3, 4) 25 | # img_Hh_Ww_c 26 | out_image = out_image.reshape(new_height * height, new_width * width, n_channels) 27 | return out_image 28 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from copy import deepcopy 3 | 4 | import gym 5 | 6 | # flake8: noqa F401 7 | from stable_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \ 8 | CloudpickleWrapper 9 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv 10 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 11 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack 12 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize 13 | from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder 14 | from stable_baselines.common.vec_env.vec_check_nan import VecCheckNan 15 | 16 | 17 | def unwrap_vec_normalize(env: Union[gym.Env, VecEnv]) -> Union[VecNormalize, None]: 18 | """ 19 | :param env: (Union[gym.Env, VecEnv]) 20 | :return: (VecNormalize) 21 | """ 22 | env_tmp = env 23 | while isinstance(env_tmp, VecEnvWrapper): 24 | if isinstance(env_tmp, VecNormalize): 25 | return env_tmp 26 | env_tmp = env_tmp.venv 27 | return None 28 | 29 | 30 | # Define here to avoid circular import 31 | def sync_envs_normalization(env: Union[gym.Env, VecEnv], eval_env: Union[gym.Env, VecEnv]) -> None: 32 | """ 33 | Sync eval and train environments when using VecNormalize 34 | 35 | :param env: (Union[gym.Env, VecEnv])) 36 | :param eval_env: (Union[gym.Env, VecEnv])) 37 | """ 38 | env_tmp, eval_env_tmp = env, eval_env 39 | # Special case for the _UnvecWrapper 40 | # Avoid circular import 41 | from stable_baselines.common.base_class import _UnvecWrapper 42 | if isinstance(env_tmp, _UnvecWrapper): 43 | return 44 | while isinstance(env_tmp, VecEnvWrapper): 45 | if isinstance(env_tmp, VecNormalize): 46 | # sync reward and observation scaling 47 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 48 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) 49 | env_tmp = env_tmp.venv 50 | # Make pytype happy, in theory env and eval_env have the same type 51 | assert isinstance(eval_env_tmp, VecEnvWrapper), "the second env differs from the first env" 52 | eval_env_tmp = eval_env_tmp.venv 53 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy a dict of numpy arrays. 14 | 15 | :param obs: (OrderedDict): a dict of numpy arrays. 16 | :return (OrderedDict) a dict of copied numpy arrays. 17 | """ 18 | assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs)) 19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 20 | 21 | 22 | def dict_to_obs(space, obs_dict): 23 | """ 24 | Convert an internal representation raw_obs into the appropriate type 25 | specified by space. 26 | 27 | :param space: (gym.spaces.Space) an observation space. 28 | :param obs_dict: (OrderedDict) a dict of numpy arrays. 29 | :return (ndarray, tuple or dict): returns an observation 30 | of the same type as space. If space is Dict, function is identity; 31 | if space is Tuple, converts dict to Tuple; otherwise, space is 32 | unstructured and returns the value raw_obs[None]. 33 | """ 34 | if isinstance(space, gym.spaces.Dict): 35 | return obs_dict 36 | elif isinstance(space, gym.spaces.Tuple): 37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space" 38 | return tuple((obs_dict[i] for i in range(len(space.spaces)))) 39 | else: 40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 41 | return obs_dict[None] 42 | 43 | 44 | def obs_space_info(obs_space): 45 | """ 46 | Get dict-structured information about a gym.Space. 47 | 48 | Dict spaces are represented directly by their dict of subspaces. 49 | Tuple spaces are converted into a dict with keys indexing into the tuple. 50 | Unstructured spaces are represented by {None: obs_space}. 51 | 52 | :param obs_space: (gym.spaces.Space) an observation space 53 | :return (tuple) A tuple (keys, shapes, dtypes): 54 | keys: a list of dict keys. 55 | shapes: a dict mapping keys to shapes. 56 | dtypes: a dict mapping keys to dtypes. 57 | """ 58 | if isinstance(obs_space, gym.spaces.Dict): 59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 60 | subspaces = obs_space.spaces 61 | elif isinstance(obs_space, gym.spaces.Tuple): 62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 63 | else: 64 | assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space)) 65 | subspaces = {None: obs_space} 66 | keys = [] 67 | shapes = {} 68 | dtypes = {} 69 | for key, box in subspaces.items(): 70 | keys.append(key) 71 | shapes[key] = box.shape 72 | dtypes[key] = box.dtype 73 | return keys, shapes, dtypes 74 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: (bool) Whether or not to only warn once. 16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions): 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self): 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self): 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step, **kwargs): 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += "found {} in {}".format(type_val, name) 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions) 80 | else: 81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations) 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 7 | 8 | 9 | class VecFrameStack(VecEnvWrapper): 10 | """ 11 | Frame stacking wrapper for vectorized environment 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param n_stack: (int) Number of frames to stack 15 | """ 16 | 17 | def __init__(self, venv, n_stack): 18 | self.venv = venv 19 | self.n_stack = n_stack 20 | wrapped_obs_space = venv.observation_space 21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) 22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) 23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 26 | 27 | def step_wait(self): 28 | observations, rewards, dones, infos = self.venv.step_wait() 29 | last_ax_size = observations.shape[-1] 30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) 31 | for i, done in enumerate(dones): 32 | if done: 33 | if 'terminal_observation' in infos[i]: 34 | old_terminal = infos[i]['terminal_observation'] 35 | new_terminal = np.concatenate( 36 | (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) 37 | infos[i]['terminal_observation'] = new_terminal 38 | else: 39 | warnings.warn( 40 | "VecFrameStack wrapping a VecEnv without terminal_observation info") 41 | self.stackedobs[i] = 0 42 | self.stackedobs[..., -observations.shape[-1]:] = observations 43 | return self.stackedobs, rewards, dones, infos 44 | 45 | def reset(self): 46 | """ 47 | Reset all environments 48 | """ 49 | obs = self.venv.reset() 50 | self.stackedobs[...] = 0 51 | self.stackedobs[..., -obs.shape[-1]:] = obs 52 | return self.stackedobs 53 | 54 | def close(self): 55 | self.venv.close() 56 | -------------------------------------------------------------------------------- /stable_baselines/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.wrappers.monitoring import video_recorder 4 | 5 | from stable_baselines import logger 6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper 7 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack 10 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize 11 | 12 | 13 | class VecVideoRecorder(VecEnvWrapper): 14 | """ 15 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 16 | It requires ffmpeg or avconv to be installed on the machine. 17 | 18 | :param venv: (VecEnv or VecEnvWrapper) 19 | :param video_folder: (str) Where to save videos 20 | :param record_video_trigger: (func) Function that defines when to start recording. 21 | The function takes the current number of step, 22 | and returns whether we should start recording or not. 23 | :param video_length: (int) Length of recorded videos 24 | :param name_prefix: (str) Prefix to the video name 25 | """ 26 | 27 | def __init__(self, venv, video_folder, record_video_trigger, 28 | video_length=200, name_prefix='rl-video'): 29 | 30 | VecEnvWrapper.__init__(self, venv) 31 | 32 | self.env = venv 33 | # Temp variable to retrieve metadata 34 | temp_env = venv 35 | 36 | # Unwrap to retrieve metadata dict 37 | # that will be used by gym recorder 38 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack): 39 | temp_env = temp_env.venv 40 | 41 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 42 | metadata = temp_env.get_attr('metadata')[0] 43 | else: 44 | metadata = temp_env.metadata 45 | 46 | self.env.metadata = metadata 47 | 48 | self.record_video_trigger = record_video_trigger 49 | self.video_recorder = None 50 | 51 | self.video_folder = os.path.abspath(video_folder) 52 | # Create output folder if needed 53 | os.makedirs(self.video_folder, exist_ok=True) 54 | 55 | self.name_prefix = name_prefix 56 | self.step_id = 0 57 | self.video_length = video_length 58 | 59 | self.recording = False 60 | self.recorded_frames = 0 61 | 62 | def reset(self): 63 | obs = self.venv.reset() 64 | self.start_video_recorder() 65 | return obs 66 | 67 | def start_video_recorder(self): 68 | self.close_video_recorder() 69 | 70 | video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id, 71 | self.step_id + self.video_length) 72 | base_path = os.path.join(self.video_folder, video_name) 73 | self.video_recorder = video_recorder.VideoRecorder( 74 | env=self.env, 75 | base_path=base_path, 76 | metadata={'step_id': self.step_id} 77 | ) 78 | 79 | self.video_recorder.capture_frame() 80 | self.recorded_frames = 1 81 | self.recording = True 82 | 83 | def _video_enabled(self): 84 | return self.record_video_trigger(self.step_id) 85 | 86 | def step_wait(self): 87 | obs, rews, dones, infos = self.venv.step_wait() 88 | 89 | self.step_id += 1 90 | if self.recording: 91 | self.video_recorder.capture_frame() 92 | self.recorded_frames += 1 93 | if self.recorded_frames > self.video_length: 94 | logger.info("Saving video to ", self.video_recorder.path) 95 | self.close_video_recorder() 96 | elif self._video_enabled(): 97 | self.start_video_recorder() 98 | 99 | return obs, rews, dones, infos 100 | 101 | def close_video_recorder(self): 102 | if self.recording: 103 | self.video_recorder.close() 104 | self.recording = False 105 | self.recorded_frames = 1 106 | 107 | def close(self): 108 | VecEnvWrapper.close(self) 109 | self.close_video_recorder() 110 | 111 | def __del__(self): 112 | self.close() 113 | -------------------------------------------------------------------------------- /stable_baselines/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.common.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise 2 | from stable_baselines.ddpg.ddpg import DDPG 3 | from stable_baselines.ddpg.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy 4 | -------------------------------------------------------------------------------- /stable_baselines/ddpg/noise.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.common.noise import NormalActionNoise, AdaptiveParamNoiseSpec, OrnsteinUhlenbeckActionNoise # pylint: disable=unused-import 2 | -------------------------------------------------------------------------------- /stable_baselines/deepq/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.deepq.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy 2 | from stable_baselines.deepq.build_graph import build_act, build_train # noqa 3 | from stable_baselines.deepq.dqn import DQN 4 | from stable_baselines.common.buffers import ReplayBuffer, PrioritizedReplayBuffer # noqa 5 | 6 | 7 | def wrap_atari_dqn(env): 8 | """ 9 | wrap the environment in atari wrappers for DQN 10 | 11 | :param env: (Gym Environment) the environment 12 | :return: (Gym Environment) the wrapped environment 13 | """ 14 | from stable_baselines.common.atari_wrappers import wrap_deepmind 15 | return wrap_deepmind(env, frame_stack=True, scale=False) 16 | -------------------------------------------------------------------------------- /stable_baselines/deepq/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/deepq/experiments/__init__.py -------------------------------------------------------------------------------- /stable_baselines/deepq/experiments/enjoy_cartpole.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | from stable_baselines.deepq import DQN 6 | 7 | 8 | def main(args): 9 | """ 10 | Run a trained model for the cartpole problem 11 | 12 | :param args: (ArgumentParser) the input arguments 13 | """ 14 | env = gym.make("CartPole-v0") 15 | model = DQN.load("cartpole_model.zip", env) 16 | 17 | while True: 18 | obs, done = env.reset(), False 19 | episode_rew = 0 20 | while not done: 21 | if not args.no_render: 22 | env.render() 23 | action, _ = model.predict(obs) 24 | obs, rew, done, _ = env.step(action) 25 | episode_rew += rew 26 | print("Episode reward", episode_rew) 27 | # No render is only used for automatic testing 28 | if args.no_render: 29 | break 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser(description="Enjoy trained DQN on cartpole") 34 | parser.add_argument('--no-render', default=False, action="store_true", help="Disable rendering") 35 | args = parser.parse_args() 36 | main(args) 37 | -------------------------------------------------------------------------------- /stable_baselines/deepq/experiments/enjoy_mountaincar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from stable_baselines.deepq import DQN 7 | 8 | 9 | def main(args): 10 | """ 11 | Run a trained model for the mountain car problem 12 | 13 | :param args: (ArgumentParser) the input arguments 14 | """ 15 | env = gym.make("MountainCar-v0") 16 | model = DQN.load("mountaincar_model.zip", env) 17 | 18 | while True: 19 | obs, done = env.reset(), False 20 | episode_rew = 0 21 | while not done: 22 | if not args.no_render: 23 | env.render() 24 | # Epsilon-greedy 25 | if np.random.random() < 0.02: 26 | action = env.action_space.sample() 27 | else: 28 | action, _ = model.predict(obs, deterministic=True) 29 | obs, rew, done, _ = env.step(action) 30 | episode_rew += rew 31 | print("Episode reward", episode_rew) 32 | # No render is only used for automatic testing 33 | if args.no_render: 34 | break 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser(description="Enjoy trained DQN on MountainCar") 39 | parser.add_argument('--no-render', default=False, action="store_true", help="Disable rendering") 40 | args = parser.parse_args() 41 | main(args) 42 | -------------------------------------------------------------------------------- /stable_baselines/deepq/experiments/run_atari.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | 4 | from stable_baselines import bench, logger 5 | from stable_baselines.common import set_global_seeds 6 | from stable_baselines.common.atari_wrappers import make_atari 7 | from stable_baselines.deepq import DQN, wrap_atari_dqn, CnnPolicy 8 | 9 | 10 | def main(): 11 | """ 12 | Run the atari test 13 | """ 14 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') 16 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 17 | parser.add_argument('--prioritized', type=int, default=1) 18 | parser.add_argument('--dueling', type=int, default=1) 19 | parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) 20 | parser.add_argument('--num-timesteps', type=int, default=int(10e6)) 21 | 22 | args = parser.parse_args() 23 | logger.configure() 24 | set_global_seeds(args.seed) 25 | env = make_atari(args.env) 26 | env = bench.Monitor(env, logger.get_dir()) 27 | env = wrap_atari_dqn(env) 28 | policy = partial(CnnPolicy, dueling=args.dueling == 1) 29 | 30 | model = DQN( 31 | env=env, 32 | policy=policy, 33 | learning_rate=1e-4, 34 | buffer_size=10000, 35 | exploration_fraction=0.1, 36 | exploration_final_eps=0.01, 37 | train_freq=4, 38 | learning_starts=10000, 39 | target_network_update_freq=1000, 40 | gamma=0.99, 41 | prioritized_replay=bool(args.prioritized), 42 | prioritized_replay_alpha=args.prioritized_replay_alpha, 43 | ) 44 | model.learn(total_timesteps=args.num_timesteps) 45 | 46 | env.close() 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /stable_baselines/deepq/experiments/train_cartpole.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from stable_baselines.deepq import DQN, MlpPolicy 7 | 8 | 9 | def callback(lcl, _glb): 10 | """ 11 | The callback function for logging and saving 12 | 13 | :param lcl: (dict) the local variables 14 | :param _glb: (dict) the global variables 15 | :return: (bool) is solved 16 | """ 17 | # stop training if reward exceeds 199 18 | if len(lcl['episode_rewards'][-101:-1]) == 0: 19 | mean_100ep_reward = -np.inf 20 | else: 21 | mean_100ep_reward = round(float(np.mean(lcl['episode_rewards'][-101:-1])), 1) 22 | is_solved = lcl['self'].num_timesteps > 100 and mean_100ep_reward >= 199 23 | return not is_solved 24 | 25 | 26 | def main(args): 27 | """ 28 | Train and save the DQN model, for the cartpole problem 29 | 30 | :param args: (ArgumentParser) the input arguments 31 | """ 32 | env = gym.make("CartPole-v0") 33 | model = DQN( 34 | env=env, 35 | policy=MlpPolicy, 36 | learning_rate=1e-3, 37 | buffer_size=50000, 38 | exploration_fraction=0.1, 39 | exploration_final_eps=0.02, 40 | ) 41 | model.learn(total_timesteps=args.max_timesteps, callback=callback) 42 | 43 | print("Saving model to cartpole_model.zip") 44 | model.save("cartpole_model.zip") 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser(description="Train DQN on cartpole") 49 | parser.add_argument('--max-timesteps', default=100000, type=int, help="Maximum number of timesteps") 50 | args = parser.parse_args() 51 | main(args) 52 | -------------------------------------------------------------------------------- /stable_baselines/deepq/experiments/train_mountaincar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | from stable_baselines.deepq import DQN 6 | 7 | 8 | def main(args): 9 | """ 10 | Train and save the DQN model, for the mountain car problem 11 | 12 | :param args: (ArgumentParser) the input arguments 13 | """ 14 | env = gym.make("MountainCar-v0") 15 | 16 | # using layer norm policy here is important for parameter space noise! 17 | model = DQN( 18 | policy="LnMlpPolicy", 19 | env=env, 20 | learning_rate=1e-3, 21 | buffer_size=50000, 22 | exploration_fraction=0.1, 23 | exploration_final_eps=0.1, 24 | param_noise=True, 25 | policy_kwargs=dict(layers=[64]) 26 | ) 27 | model.learn(total_timesteps=args.max_timesteps) 28 | 29 | print("Saving model to mountaincar_model.zip") 30 | model.save("mountaincar_model") 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser(description="Train DQN on MountainCar") 35 | parser.add_argument('--max-timesteps', default=100000, type=int, help="Maximum number of timesteps") 36 | args = parser.parse_args() 37 | main(args) 38 | -------------------------------------------------------------------------------- /stable_baselines/gail/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.gail.model import GAIL 2 | from stable_baselines.gail.dataset.dataset import ExpertDataset, DataLoader 3 | from stable_baselines.gail.dataset.record_expert import generate_expert_traj 4 | -------------------------------------------------------------------------------- /stable_baselines/gail/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/gail/dataset/__init__.py -------------------------------------------------------------------------------- /stable_baselines/gail/dataset/expert_cartpole.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/gail/dataset/expert_cartpole.npz -------------------------------------------------------------------------------- /stable_baselines/gail/dataset/expert_pendulum.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/gail/dataset/expert_pendulum.npz -------------------------------------------------------------------------------- /stable_baselines/gail/model.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.trpo_mpi import TRPO 2 | 3 | 4 | class GAIL(TRPO): 5 | """ 6 | Generative Adversarial Imitation Learning (GAIL) 7 | 8 | .. warning:: 9 | 10 | Images are not yet handled properly by the current implementation 11 | 12 | 13 | :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...) 14 | :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) 15 | :param expert_dataset: (ExpertDataset) the dataset manager 16 | :param gamma: (float) the discount value 17 | :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon) 18 | :param max_kl: (float) the Kullback-Leibler loss threshold 19 | :param cg_iters: (int) the number of iterations for the conjugate gradient calculation 20 | :param lam: (float) GAE factor 21 | :param entcoeff: (float) the weight for the entropy loss 22 | :param cg_damping: (float) the compute gradient dampening factor 23 | :param vf_stepsize: (float) the value function stepsize 24 | :param vf_iters: (int) the value function's number iterations for learning 25 | :param hidden_size: ([int]) the hidden dimension for the MLP 26 | :param g_step: (int) number of steps to train policy in each epoch 27 | :param d_step: (int) number of steps to train discriminator in each epoch 28 | :param d_stepsize: (float) the reward giver stepsize 29 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug 30 | :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance 31 | :param full_tensorboard_log: (bool) enable additional logging when using tensorboard 32 | WARNING: this logging can take a lot of space quickly 33 | """ 34 | 35 | def __init__(self, policy, env, expert_dataset=None, 36 | hidden_size_adversary=100, adversary_entcoeff=1e-3, 37 | g_step=3, d_step=1, d_stepsize=3e-4, verbose=0, 38 | _init_setup_model=True, **kwargs): 39 | super().__init__(policy, env, verbose=verbose, _init_setup_model=False, **kwargs) 40 | self.using_gail = True 41 | self.expert_dataset = expert_dataset 42 | self.g_step = g_step 43 | self.d_step = d_step 44 | self.d_stepsize = d_stepsize 45 | self.hidden_size_adversary = hidden_size_adversary 46 | self.adversary_entcoeff = adversary_entcoeff 47 | 48 | if _init_setup_model: 49 | self.setup_model() 50 | 51 | def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="GAIL", 52 | reset_num_timesteps=True): 53 | assert self.expert_dataset is not None, "You must pass an expert dataset to GAIL for training" 54 | return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps) 55 | -------------------------------------------------------------------------------- /stable_baselines/her/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.her.her import HER 2 | from stable_baselines.her.replay_buffer import GoalSelectionStrategy, HindsightExperienceReplayWrapper 3 | from stable_baselines.her.utils import HERGoalEnvWrapper 4 | -------------------------------------------------------------------------------- /stable_baselines/her/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | # Important: gym mixes up ordered and unordered keys 7 | # and the Dict space may return a different order of keys that the actual one 8 | KEY_ORDER = ['observation', 'achieved_goal', 'desired_goal'] 9 | 10 | 11 | class HERGoalEnvWrapper(object): 12 | """ 13 | A wrapper that allow to use dict observation space (coming from GoalEnv) with 14 | the RL algorithms. 15 | It assumes that all the spaces of the dict space are of the same type. 16 | 17 | :param env: (gym.GoalEnv) 18 | """ 19 | 20 | def __init__(self, env): 21 | super(HERGoalEnvWrapper, self).__init__() 22 | self.env = env 23 | self.metadata = self.env.metadata 24 | self.action_space = env.action_space 25 | self.spaces = list(env.observation_space.spaces.values()) 26 | # Check that all spaces are of the same type 27 | # (current limitation of the wrapper) 28 | space_types = [type(env.observation_space.spaces[key]) for key in KEY_ORDER] 29 | assert len(set(space_types)) == 1, "The spaces for goal and observation"\ 30 | " must be of the same type" 31 | 32 | if isinstance(self.spaces[0], spaces.Discrete): 33 | self.obs_dim = 1 34 | self.goal_dim = 1 35 | else: 36 | goal_space_shape = env.observation_space.spaces['achieved_goal'].shape 37 | self.obs_dim = env.observation_space.spaces['observation'].shape[0] 38 | self.goal_dim = goal_space_shape[0] 39 | 40 | if len(goal_space_shape) == 2: 41 | assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet" 42 | else: 43 | assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet" 44 | 45 | if isinstance(self.spaces[0], spaces.MultiBinary): 46 | total_dim = self.obs_dim + 2 * self.goal_dim 47 | self.observation_space = spaces.MultiBinary(total_dim) 48 | 49 | elif isinstance(self.spaces[0], spaces.Box): 50 | lows = np.concatenate([space.low for space in self.spaces]) 51 | highs = np.concatenate([space.high for space in self.spaces]) 52 | self.observation_space = spaces.Box(lows, highs, dtype=np.float32) 53 | 54 | elif isinstance(self.spaces[0], spaces.Discrete): 55 | dimensions = [env.observation_space.spaces[key].n for key in KEY_ORDER] 56 | self.observation_space = spaces.MultiDiscrete(dimensions) 57 | 58 | else: 59 | raise NotImplementedError("{} space is not supported".format(type(self.spaces[0]))) 60 | 61 | def convert_dict_to_obs(self, obs_dict): 62 | """ 63 | :param obs_dict: (dict) 64 | :return: (np.ndarray) 65 | """ 66 | # Note: achieved goal is not removed from the observation 67 | # this is helpful to have a revertible transformation 68 | if isinstance(self.observation_space, spaces.MultiDiscrete): 69 | # Special case for multidiscrete 70 | return np.concatenate([[int(obs_dict[key])] for key in KEY_ORDER]) 71 | return np.concatenate([obs_dict[key] for key in KEY_ORDER]) 72 | 73 | def convert_obs_to_dict(self, observations): 74 | """ 75 | Inverse operation of convert_dict_to_obs 76 | 77 | :param observations: (np.ndarray) 78 | :return: (OrderedDict) 79 | """ 80 | return OrderedDict([ 81 | ('observation', observations[:self.obs_dim]), 82 | ('achieved_goal', observations[self.obs_dim:self.obs_dim + self.goal_dim]), 83 | ('desired_goal', observations[self.obs_dim + self.goal_dim:]), 84 | ]) 85 | 86 | def step(self, action): 87 | obs, reward, done, info = self.env.step(action) 88 | return self.convert_dict_to_obs(obs), reward, done, info 89 | 90 | def seed(self, seed=None): 91 | return self.env.seed(seed) 92 | 93 | def reset(self): 94 | return self.convert_dict_to_obs(self.env.reset()) 95 | 96 | def compute_reward(self, achieved_goal, desired_goal, info): 97 | return self.env.compute_reward(achieved_goal, desired_goal, info) 98 | 99 | def render(self, mode='human'): 100 | return self.env.render(mode) 101 | 102 | def close(self): 103 | return self.env.close() 104 | -------------------------------------------------------------------------------- /stable_baselines/ppo1/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.ppo1.pposgd_simple import PPO1 2 | -------------------------------------------------------------------------------- /stable_baselines/ppo1/experiments/train_cartpole.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple test to check that PPO1 is running with no errors (see issue #50) 3 | """ 4 | from stable_baselines import PPO1 5 | 6 | 7 | if __name__ == '__main__': 8 | model = PPO1('MlpPolicy', 'CartPole-v1', schedule='linear', verbose=0) 9 | model.learn(total_timesteps=1000) 10 | -------------------------------------------------------------------------------- /stable_baselines/ppo1/run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | 4 | from mpi4py import MPI 5 | 6 | from stable_baselines.common import set_global_seeds 7 | from stable_baselines import bench, logger, PPO1 8 | from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind 9 | from stable_baselines.common.cmd_util import atari_arg_parser 10 | from stable_baselines.common.policies import CnnPolicy 11 | 12 | 13 | def train(env_id, num_timesteps, seed): 14 | """ 15 | Train PPO1 model for Atari environments, for testing purposes 16 | 17 | :param env_id: (str) Environment ID 18 | :param num_timesteps: (int) The total number of samples 19 | :param seed: (int) The initial seed for training 20 | """ 21 | rank = MPI.COMM_WORLD.Get_rank() 22 | 23 | if rank == 0: 24 | logger.configure() 25 | else: 26 | logger.configure(format_strs=[]) 27 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() 28 | set_global_seeds(workerseed) 29 | env = make_atari(env_id) 30 | 31 | env = bench.Monitor(env, logger.get_dir() and 32 | os.path.join(logger.get_dir(), str(rank))) 33 | env.seed(workerseed) 34 | 35 | env = wrap_deepmind(env) 36 | env.seed(workerseed) 37 | 38 | model = PPO1(CnnPolicy, env, timesteps_per_actorbatch=256, clip_param=0.2, entcoeff=0.01, optim_epochs=4, 39 | optim_stepsize=1e-3, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear', verbose=2) 40 | model.learn(total_timesteps=num_timesteps) 41 | env.close() 42 | del env 43 | 44 | 45 | def main(): 46 | """ 47 | Runs the test 48 | """ 49 | args = atari_arg_parser().parse_args() 50 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /stable_baselines/ppo1/run_mujoco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from stable_baselines.ppo1 import PPO1 4 | from stable_baselines.common.policies import MlpPolicy 5 | from stable_baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser 6 | from stable_baselines import logger 7 | 8 | 9 | def train(env_id, num_timesteps, seed): 10 | """ 11 | Train PPO1 model for the Mujoco environment, for testing purposes 12 | 13 | :param env_id: (str) Environment ID 14 | :param num_timesteps: (int) The total number of samples 15 | :param seed: (int) The initial seed for training 16 | """ 17 | env = make_mujoco_env(env_id, seed) 18 | model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=10, 19 | optim_stepsize=3e-4, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear') 20 | model.learn(total_timesteps=num_timesteps) 21 | env.close() 22 | 23 | 24 | def main(): 25 | """ 26 | Runs the test 27 | """ 28 | args = mujoco_arg_parser().parse_args() 29 | logger.configure() 30 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /stable_baselines/ppo1/run_robotics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from mpi4py import MPI 4 | import mujoco_py # pytype:disable=import-error 5 | 6 | from stable_baselines.common import set_global_seeds 7 | from stable_baselines.common.policies import MlpPolicy 8 | from stable_baselines.common.cmd_util import make_robotics_env, robotics_arg_parser 9 | from stable_baselines.ppo1 import PPO1 10 | 11 | 12 | def train(env_id, num_timesteps, seed): 13 | """ 14 | Train PPO1 model for Robotics environment, for testing purposes 15 | 16 | :param env_id: (str) Environment ID 17 | :param num_timesteps: (int) The total number of samples 18 | :param seed: (int) The initial seed for training 19 | """ 20 | 21 | rank = MPI.COMM_WORLD.Get_rank() 22 | with mujoco_py.ignore_mujoco_warnings(): 23 | workerseed = seed + 10000 * rank 24 | set_global_seeds(workerseed) 25 | env = make_robotics_env(env_id, workerseed, rank=rank) 26 | 27 | model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=5, 28 | optim_stepsize=3e-4, optim_batchsize=256, gamma=0.99, lam=0.95, schedule='linear') 29 | model.learn(total_timesteps=num_timesteps) 30 | env.close() 31 | 32 | 33 | def main(): 34 | """ 35 | Runs the test 36 | """ 37 | args = robotics_arg_parser().parse_args() 38 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /stable_baselines/ppo2/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.ppo2.ppo2 import PPO2 2 | -------------------------------------------------------------------------------- /stable_baselines/ppo2/run_atari.py: -------------------------------------------------------------------------------- 1 | from stable_baselines import PPO2, logger 2 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser 3 | from stable_baselines.common.vec_env import VecFrameStack 4 | from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy, MlpPolicy 5 | 6 | 7 | def train(env_id, num_timesteps, seed, policy, 8 | n_envs=8, nminibatches=4, n_steps=128): 9 | """ 10 | Train PPO2 model for atari environment, for testing purposes 11 | 12 | :param env_id: (str) the environment id string 13 | :param num_timesteps: (int) the number of timesteps to run 14 | :param seed: (int) Used to seed the random generator. 15 | :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...) 16 | :param n_envs: (int) Number of parallel environments 17 | :param nminibatches: (int) Number of training minibatches per update. For recurrent policies, 18 | the number of environments run in parallel should be a multiple of nminibatches. 19 | :param n_steps: (int) The number of steps to run for each environment per update 20 | (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) 21 | """ 22 | 23 | env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4) 24 | policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy] 25 | model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches, 26 | lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01, 27 | learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1) 28 | model.learn(total_timesteps=num_timesteps) 29 | 30 | env.close() 31 | # Free memory 32 | del model 33 | 34 | 35 | def main(): 36 | """ 37 | Runs the test 38 | """ 39 | parser = atari_arg_parser() 40 | parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn') 41 | args = parser.parse_args() 42 | logger.configure() 43 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, 44 | policy=args.policy) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /stable_baselines/ppo2/run_mujoco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | import gym 4 | 5 | from stable_baselines.common.cmd_util import mujoco_arg_parser 6 | from stable_baselines import bench, logger 7 | from stable_baselines.common import set_global_seeds 8 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize 9 | from stable_baselines.ppo2 import PPO2 10 | from stable_baselines.common.policies import MlpPolicy 11 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv 12 | 13 | 14 | def train(env_id, num_timesteps, seed): 15 | """ 16 | Train PPO2 model for Mujoco environment, for testing purposes 17 | 18 | :param env_id: (str) the environment id string 19 | :param num_timesteps: (int) the number of timesteps to run 20 | :param seed: (int) Used to seed the random generator. 21 | """ 22 | def make_env(): 23 | env_out = gym.make(env_id) 24 | env_out = bench.Monitor(env_out, logger.get_dir(), allow_early_resets=True) 25 | return env_out 26 | 27 | env = DummyVecEnv([make_env]) 28 | env = VecNormalize(env) 29 | 30 | set_global_seeds(seed) 31 | policy = MlpPolicy 32 | model = PPO2(policy=policy, env=env, n_steps=2048, nminibatches=32, lam=0.95, gamma=0.99, noptepochs=10, 33 | ent_coef=0.0, learning_rate=3e-4, cliprange=0.2) 34 | model.learn(total_timesteps=num_timesteps) 35 | 36 | return model, env 37 | 38 | 39 | def main(): 40 | """ 41 | Runs the test 42 | """ 43 | args = mujoco_arg_parser().parse_args() 44 | logger.configure() 45 | model, env = train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 46 | 47 | if args.play: 48 | logger.log("Running trained model") 49 | obs = np.zeros((env.num_envs,) + env.observation_space.shape) 50 | obs[:] = env.reset() 51 | while True: 52 | actions = model.step(obs)[0] 53 | obs[:] = env.step(actions)[0] 54 | env.render() 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /stable_baselines/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/py.typed -------------------------------------------------------------------------------- /stable_baselines/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.sac.sac import SAC 2 | from stable_baselines.sac.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy 3 | -------------------------------------------------------------------------------- /stable_baselines/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 2 | from stable_baselines.td3.td3 import TD3 3 | from stable_baselines.td3.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy 4 | -------------------------------------------------------------------------------- /stable_baselines/trpo_mpi/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.trpo_mpi.trpo_mpi import TRPO 2 | -------------------------------------------------------------------------------- /stable_baselines/trpo_mpi/run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | 4 | from mpi4py import MPI 5 | 6 | from stable_baselines.common import set_global_seeds 7 | from stable_baselines import bench, logger, TRPO 8 | from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind 9 | from stable_baselines.common.cmd_util import atari_arg_parser 10 | from stable_baselines.common.policies import CnnPolicy 11 | 12 | 13 | def train(env_id, num_timesteps, seed): 14 | """ 15 | Train TRPO model for the atari environment, for testing purposes 16 | 17 | :param env_id: (str) Environment ID 18 | :param num_timesteps: (int) The total number of samples 19 | :param seed: (int) The initial seed for training 20 | """ 21 | rank = MPI.COMM_WORLD.Get_rank() 22 | 23 | if rank == 0: 24 | logger.configure() 25 | else: 26 | logger.configure(format_strs=[]) 27 | 28 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() 29 | set_global_seeds(workerseed) 30 | env = make_atari(env_id) 31 | 32 | env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) 33 | env.seed(workerseed) 34 | 35 | env = wrap_deepmind(env) 36 | env.seed(workerseed) 37 | 38 | model = TRPO(CnnPolicy, env, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3, entcoeff=0.0, 39 | gamma=0.98, lam=1, vf_iters=3, vf_stepsize=1e-4) 40 | model.learn(total_timesteps=int(num_timesteps * 1.1)) 41 | env.close() 42 | # Free memory 43 | del env 44 | 45 | 46 | def main(): 47 | """ 48 | Runs the test 49 | """ 50 | args = atari_arg_parser().parse_args() 51 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /stable_baselines/trpo_mpi/run_mujoco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # noinspection PyUnresolvedReferences 3 | from mpi4py import MPI 4 | 5 | from stable_baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser 6 | from stable_baselines.common.policies import MlpPolicy 7 | from stable_baselines import logger 8 | from stable_baselines.trpo_mpi import TRPO 9 | import stable_baselines.common.tf_util as tf_util 10 | 11 | 12 | def train(env_id, num_timesteps, seed): 13 | """ 14 | Train TRPO model for the mujoco environment, for testing purposes 15 | 16 | :param env_id: (str) Environment ID 17 | :param num_timesteps: (int) The total number of samples 18 | :param seed: (int) The initial seed for training 19 | """ 20 | with tf_util.single_threaded_session(): 21 | rank = MPI.COMM_WORLD.Get_rank() 22 | if rank == 0: 23 | logger.configure() 24 | else: 25 | logger.configure(format_strs=[]) 26 | logger.set_level(logger.DISABLED) 27 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() 28 | 29 | env = make_mujoco_env(env_id, workerseed) 30 | model = TRPO(MlpPolicy, env, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1, entcoeff=0.0, 31 | gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3) 32 | model.learn(total_timesteps=num_timesteps) 33 | env.close() 34 | 35 | 36 | def main(): 37 | """ 38 | Runs the test 39 | """ 40 | args = mujoco_arg_parser().parse_args() 41 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /stable_baselines/trpo_mpi/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def add_vtarg_and_adv(seg, gamma, lam): 5 | """ 6 | Compute target value using TD(lambda) estimator, and advantage with GAE(lambda) 7 | 8 | :param seg: (dict) the current segment of the trajectory (see traj_segment_generator return for more information) 9 | :param gamma: (float) Discount factor 10 | :param lam: (float) GAE factor 11 | """ 12 | # last element is only used for last vtarg, but we already zeroed it if last new = 1 13 | episode_starts = np.append(seg["episode_starts"], False) 14 | vpred = np.append(seg["vpred"], seg["nextvpred"]) 15 | rew_len = len(seg["rewards"]) 16 | seg["adv"] = np.empty(rew_len, 'float32') 17 | rewards = seg["rewards"] 18 | lastgaelam = 0 19 | for step in reversed(range(rew_len)): 20 | nonterminal = 1 - float(episode_starts[step + 1]) 21 | delta = rewards[step] + gamma * vpred[step + 1] * nonterminal - vpred[step] 22 | seg["adv"][step] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam 23 | seg["tdlamret"] = seg["adv"] + seg["vpred"] 24 | -------------------------------------------------------------------------------- /stable_baselines/version.txt: -------------------------------------------------------------------------------- 1 | 2.10.3a0 (WIP) 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_0deterministic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, SAC, TRPO, TD3 5 | from stable_baselines.common.noise import NormalActionNoise 6 | 7 | N_STEPS_TRAINING = 300 8 | SEED = 0 9 | 10 | 11 | # Weird stuff: TD3 would fail if another algorithm is tested before 12 | # with n_cpu_tf_sess > 1 13 | @pytest.mark.xfail(reason="TD3 deterministic randomly fail when run with others...", strict=False) 14 | def test_deterministic_td3(): 15 | results = [[], []] 16 | rewards = [[], []] 17 | kwargs = {'n_cpu_tf_sess': 1} 18 | env_id = 'Pendulum-v0' 19 | kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)}) 20 | 21 | for i in range(2): 22 | model = TD3('MlpPolicy', env_id, seed=SEED, **kwargs) 23 | model.learn(N_STEPS_TRAINING) 24 | env = model.get_env() 25 | obs = env.reset() 26 | for _ in range(20): 27 | action, _ = model.predict(obs, deterministic=True) 28 | obs, reward, _, _ = env.step(action) 29 | results[i].append(action) 30 | rewards[i].append(reward) 31 | # without the extended tolerance, test fails for unknown reasons on Github... 32 | assert np.allclose(results[0], results[1], rtol=1e-2), results 33 | assert np.allclose(rewards[0], rewards[1], rtol=1e-2), rewards 34 | 35 | 36 | @pytest.mark.parametrize("algo", [A2C, ACKTR, ACER, DDPG, DQN, PPO1, PPO2, SAC, TRPO]) 37 | def test_deterministic_training_common(algo): 38 | results = [[], []] 39 | rewards = [[], []] 40 | kwargs = {'n_cpu_tf_sess': 1} 41 | if algo in [DDPG, TD3, SAC]: 42 | env_id = 'Pendulum-v0' 43 | kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)}) 44 | else: 45 | env_id = 'CartPole-v1' 46 | if algo == DQN: 47 | kwargs.update({'learning_starts': 100}) 48 | 49 | for i in range(2): 50 | model = algo('MlpPolicy', env_id, seed=SEED, **kwargs) 51 | model.learn(N_STEPS_TRAINING) 52 | env = model.get_env() 53 | obs = env.reset() 54 | for _ in range(20): 55 | action, _ = model.predict(obs, deterministic=False) 56 | obs, reward, _, _ = env.step(action) 57 | results[i].append(action) 58 | rewards[i].append(reward) 59 | assert sum(results[0]) == sum(results[1]), results 60 | assert sum(rewards[0]) == sum(rewards[1]), rewards 61 | -------------------------------------------------------------------------------- /tests/test_a2c.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import gym 5 | 6 | from stable_baselines import A2C 7 | from stable_baselines.common import make_vec_env 8 | from stable_baselines.common.vec_env import DummyVecEnv 9 | 10 | 11 | def test_a2c_update_n_batch_on_load(tmp_path): 12 | env = make_vec_env("CartPole-v1", n_envs=2) 13 | model = A2C("MlpPolicy", env, n_steps=10) 14 | 15 | model.learn(total_timesteps=100) 16 | model.save(os.path.join(str(tmp_path), "a2c_cartpole.zip")) 17 | 18 | del model 19 | 20 | model = A2C.load(os.path.join(str(tmp_path), "a2c_cartpole.zip")) 21 | test_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) 22 | 23 | model.set_env(test_env) 24 | assert model.n_batch == 10 25 | model.learn(100) 26 | os.remove(os.path.join(str(tmp_path), "a2c_cartpole.zip")) 27 | -------------------------------------------------------------------------------- /tests/test_a2c_conv.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from stable_baselines.common.tf_layers import conv 6 | from stable_baselines.common.input import observation_input 7 | 8 | 9 | ENV_ID = 'BreakoutNoFrameskip-v4' 10 | SEED = 3 11 | 12 | 13 | def test_conv_kernel(): 14 | """Test convolution kernel with various input formats.""" 15 | filter_size_1 = 4 # The size of squared filter for the first layer 16 | filter_size_2 = (3, 5) # The size of non-squared filter for the second layer 17 | target_shape_1 = [2, 52, 40, 32] # The desired shape of the first layer 18 | target_shape_2 = [2, 13, 9, 32] # The desired shape of the second layer 19 | kwargs = {} 20 | n_envs = 1 21 | n_steps = 2 22 | n_batch = n_envs * n_steps 23 | scale = False 24 | env = gym.make(ENV_ID) 25 | ob_space = env.observation_space 26 | 27 | with tf.Graph().as_default(): 28 | _, scaled_images = observation_input(ob_space, n_batch, scale=scale) 29 | activ = tf.nn.relu 30 | layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=filter_size_1, 31 | stride=4, init_scale=np.sqrt(2), **kwargs)) 32 | layer_2 = activ(conv(layer_1, 'c2', n_filters=32, filter_size=filter_size_2, 33 | stride=4, init_scale=np.sqrt(2), **kwargs)) 34 | assert layer_1.shape == target_shape_1, \ 35 | "The shape of layer based on the squared kernel matrix is not correct. " \ 36 | "The current shape is {} and the desired shape is {}".format(layer_1.shape, target_shape_1) 37 | assert layer_2.shape == target_shape_2, \ 38 | "The shape of layer based on the non-squared kernel matrix is not correct. " \ 39 | "The current shape is {} and the desired shape is {}".format(layer_2.shape, target_shape_2) 40 | env.close() 41 | -------------------------------------------------------------------------------- /tests/test_action_scaling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import DDPG, TD3, SAC 5 | from stable_baselines.common.identity_env import IdentityEnvBox 6 | 7 | ROLLOUT_STEPS = 100 8 | 9 | MODEL_LIST = [ 10 | (DDPG, dict(nb_train_steps=0, nb_rollout_steps=ROLLOUT_STEPS)), 11 | (TD3, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=0)), 12 | (SAC, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=0)), 13 | (TD3, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=ROLLOUT_STEPS)), 14 | (SAC, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=ROLLOUT_STEPS)) 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize("model_class, model_kwargs", MODEL_LIST) 19 | def test_buffer_actions_scaling(model_class, model_kwargs): 20 | """ 21 | Test if actions are scaled to tanh co-domain before being put in a buffer 22 | for algorithms that use tanh-squashing, i.e., DDPG, TD3, SAC 23 | 24 | :param model_class: (BaseRLModel) A RL Model 25 | :param model_kwargs: (dict) Dictionary containing named arguments to the given algorithm 26 | """ 27 | 28 | # check random and inferred actions as they possibly have different flows 29 | for random_coeff in [0.0, 1.0]: 30 | 31 | env = IdentityEnvBox(-2000, 1000) 32 | 33 | model = model_class("MlpPolicy", env, seed=1, random_exploration=random_coeff, **model_kwargs) 34 | model.learn(total_timesteps=ROLLOUT_STEPS) 35 | 36 | assert hasattr(model, 'replay_buffer') 37 | 38 | buffer = model.replay_buffer 39 | 40 | assert buffer.can_sample(ROLLOUT_STEPS) 41 | 42 | _, actions, _, _, _ = buffer.sample(ROLLOUT_STEPS) 43 | 44 | assert not np.any(actions > np.ones_like(actions)) 45 | assert not np.any(actions < -np.ones_like(actions)) 46 | -------------------------------------------------------------------------------- /tests/test_action_space.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import A2C, PPO1, PPO2, TRPO 5 | from stable_baselines.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete 6 | from stable_baselines.common.vec_env import DummyVecEnv 7 | from stable_baselines.common.evaluation import evaluate_policy 8 | 9 | MODEL_LIST = [ 10 | A2C, 11 | PPO1, 12 | PPO2, 13 | TRPO 14 | ] 15 | 16 | 17 | @pytest.mark.slow 18 | @pytest.mark.parametrize("model_class", MODEL_LIST) 19 | def test_identity_multidiscrete(model_class): 20 | """ 21 | Test if the algorithm (with a given policy) 22 | can learn an identity transformation (i.e. return observation as an action) 23 | with a multidiscrete action space 24 | 25 | :param model_class: (BaseRLModel) A RL Model 26 | """ 27 | env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(10)]) 28 | 29 | model = model_class("MlpPolicy", env) 30 | model.learn(total_timesteps=1000) 31 | evaluate_policy(model, env, n_eval_episodes=5) 32 | obs = env.reset() 33 | 34 | assert np.array(model.action_probability(obs)).shape == (2, 1, 10), \ 35 | "Error: action_probability not returning correct shape" 36 | assert np.prod(model.action_probability(obs, actions=env.action_space.sample()).shape) == 1, \ 37 | "Error: not scalar probability" 38 | 39 | 40 | @pytest.mark.slow 41 | @pytest.mark.parametrize("model_class", MODEL_LIST) 42 | def test_identity_multibinary(model_class): 43 | """ 44 | Test if the algorithm (with a given policy) 45 | can learn an identity transformation (i.e. return observation as an action) 46 | with a multibinary action space 47 | 48 | :param model_class: (BaseRLModel) A RL Model 49 | """ 50 | env = DummyVecEnv([lambda: IdentityEnvMultiBinary(10)]) 51 | 52 | model = model_class("MlpPolicy", env) 53 | model.learn(total_timesteps=1000) 54 | evaluate_policy(model, env, n_eval_episodes=5) 55 | obs = env.reset() 56 | 57 | assert model.action_probability(obs).shape == (1, 10), \ 58 | "Error: action_probability not returning correct shape" 59 | assert np.prod(model.action_probability(obs, actions=env.action_space.sample()).shape) == 1, \ 60 | "Error: not scalar probability" 61 | -------------------------------------------------------------------------------- /tests/test_atari.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from stable_baselines import bench, logger 4 | from stable_baselines.deepq import DQN, wrap_atari_dqn, CnnPolicy 5 | from stable_baselines.common import set_global_seeds 6 | from stable_baselines.common.atari_wrappers import make_atari 7 | import stable_baselines.a2c.run_atari as a2c_atari 8 | import stable_baselines.acer.run_atari as acer_atari 9 | import stable_baselines.acktr.run_atari as acktr_atari 10 | import stable_baselines.ppo1.run_atari as ppo1_atari 11 | import stable_baselines.ppo2.run_atari as ppo2_atari 12 | import stable_baselines.trpo_mpi.run_atari as trpo_atari 13 | 14 | 15 | ENV_ID = 'BreakoutNoFrameskip-v4' 16 | SEED = 3 17 | NUM_TIMESTEPS = 300 18 | NUM_CPU = 2 19 | 20 | 21 | @pytest.mark.slow 22 | @pytest.mark.parametrize("policy", ['cnn', 'lstm', 'lnlstm']) 23 | def test_a2c(policy): 24 | """ 25 | test A2C on atari 26 | 27 | :param policy: (str) the policy to test for A2C 28 | """ 29 | a2c_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, 30 | policy=policy, lr_schedule='constant', num_env=NUM_CPU) 31 | 32 | 33 | @pytest.mark.slow 34 | @pytest.mark.parametrize("policy", ['cnn', 'lstm']) 35 | def test_acer(policy): 36 | """ 37 | test ACER on atari 38 | 39 | :param policy: (str) the policy to test for ACER 40 | """ 41 | acer_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, 42 | policy=policy, lr_schedule='constant', num_cpu=NUM_CPU) 43 | 44 | 45 | @pytest.mark.slow 46 | def test_acktr(): 47 | """ 48 | test ACKTR on atari 49 | """ 50 | acktr_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, num_cpu=NUM_CPU) 51 | 52 | 53 | @pytest.mark.slow 54 | def test_deepq(): 55 | """ 56 | test DeepQ on atari 57 | """ 58 | logger.configure() 59 | set_global_seeds(SEED) 60 | env = make_atari(ENV_ID) 61 | env = bench.Monitor(env, logger.get_dir()) 62 | env = wrap_atari_dqn(env) 63 | 64 | model = DQN(env=env, policy=CnnPolicy, learning_rate=1e-4, buffer_size=10000, exploration_fraction=0.1, 65 | exploration_final_eps=0.01, train_freq=4, learning_starts=100, target_network_update_freq=100, 66 | gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6) 67 | model.learn(total_timesteps=NUM_TIMESTEPS) 68 | 69 | env.close() 70 | del model, env 71 | 72 | 73 | @pytest.mark.slow 74 | def test_ppo1(): 75 | """ 76 | test PPO1 on atari 77 | """ 78 | ppo1_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED) 79 | 80 | 81 | @pytest.mark.slow 82 | @pytest.mark.parametrize("policy", ['cnn', 'lstm', 'lnlstm', 'mlp']) 83 | def test_ppo2(policy): 84 | """ 85 | test PPO2 on atari 86 | 87 | :param policy: (str) the policy to test for PPO2 88 | """ 89 | ppo2_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, 90 | seed=SEED, policy=policy, n_envs=NUM_CPU, 91 | nminibatches=NUM_CPU, n_steps=16) 92 | 93 | 94 | @pytest.mark.slow 95 | def test_trpo(): 96 | """ 97 | test TRPO on atari 98 | """ 99 | trpo_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED) 100 | -------------------------------------------------------------------------------- /tests/test_auto_vec_detection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import A2C, ACER, ACKTR, DDPG, DQN, PPO1, PPO2, SAC, TRPO, TD3 5 | from stable_baselines.common.vec_env import DummyVecEnv 6 | from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, \ 7 | IdentityEnvMultiDiscrete 8 | from stable_baselines.common.evaluation import evaluate_policy 9 | 10 | 11 | def check_shape(make_env, model_class, shape_1, shape_2): 12 | model = model_class(policy="MlpPolicy", env=DummyVecEnv([make_env])) 13 | 14 | env0 = make_env() 15 | env1 = DummyVecEnv([make_env]) 16 | 17 | for env, expected_shape in [(env0, shape_1), (env1, shape_2)]: 18 | def callback(locals_, _globals): 19 | assert np.array(locals_['action']).shape == expected_shape 20 | evaluate_policy(model, env, n_eval_episodes=5, callback=callback) 21 | 22 | 23 | @pytest.mark.slow 24 | @pytest.mark.parametrize("model_class", [A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO]) 25 | def test_identity(model_class): 26 | """ 27 | test the Disrete environment vectorisation detection 28 | 29 | :param model_class: (BaseRLModel) the RL model 30 | """ 31 | check_shape(lambda: IdentityEnv(dim=10), model_class, (), (1,)) 32 | 33 | 34 | @pytest.mark.slow 35 | @pytest.mark.parametrize("model_class", [A2C, DDPG, PPO1, PPO2, SAC, TRPO, TD3]) 36 | def test_identity_box(model_class): 37 | """ 38 | test the Box environment vectorisation detection 39 | 40 | :param model_class: (BaseRLModel) the RL model 41 | """ 42 | check_shape(lambda: IdentityEnvBox(eps=0.5), model_class, (1,), (1, 1)) 43 | 44 | 45 | @pytest.mark.slow 46 | @pytest.mark.parametrize("model_class", [A2C, PPO1, PPO2, TRPO]) 47 | def test_identity_multi_binary(model_class): 48 | """ 49 | test the MultiBinary environment vectorisation detection 50 | 51 | :param model_class: (BaseRLModel) the RL model 52 | """ 53 | check_shape(lambda: IdentityEnvMultiBinary(dim=10), model_class, (10,), (1, 10)) 54 | 55 | 56 | @pytest.mark.slow 57 | @pytest.mark.parametrize("model_class", [A2C, PPO1, PPO2, TRPO]) 58 | def test_identity_multi_discrete(model_class): 59 | """ 60 | test the MultiDiscrete environment vectorisation detection 61 | 62 | :param model_class: (BaseRLModel) the RL model 63 | """ 64 | check_shape(lambda: IdentityEnvMultiDiscrete(dim=10), model_class, (2,), (1, 2)) 65 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import sys 3 | 4 | 5 | def _assert_eq(left, right): 6 | assert left == right, '{} != {}'.format(left, right) 7 | 8 | 9 | def _assert_neq(left, right): 10 | assert left != right, '{} == {}'.format(left, right) 11 | 12 | 13 | @contextmanager 14 | def _maybe_disable_mpi(mpi_disabled): 15 | """A context that can temporarily remove the mpi4py import. 16 | 17 | Useful for testing whether non-MPI algorithms work as intended when 18 | mpi4py isn't installed. 19 | 20 | Args: 21 | disable_mpi (bool): If True, then this context temporarily removes 22 | the mpi4py import from `sys.modules` 23 | """ 24 | if mpi_disabled and "mpi4py" in sys.modules: 25 | temp = sys.modules["mpi4py"] 26 | try: 27 | sys.modules["mpi4py"] = None 28 | yield 29 | finally: 30 | sys.modules["mpi4py"] = temp 31 | else: 32 | yield 33 | -------------------------------------------------------------------------------- /tests/test_deepq.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.deepq.experiments.train_cartpole import main as train_cartpole 2 | from stable_baselines.deepq.experiments.enjoy_cartpole import main as enjoy_cartpole 3 | from stable_baselines.deepq.experiments.train_mountaincar import main as train_mountaincar 4 | from stable_baselines.deepq.experiments.enjoy_mountaincar import main as enjoy_mountaincar 5 | 6 | 7 | class DummyObject(object): 8 | """ 9 | Dummy object to create fake Parsed Arguments object 10 | """ 11 | pass 12 | 13 | 14 | args = DummyObject() 15 | args.no_render = True 16 | args.max_timesteps = 200 17 | 18 | 19 | def test_cartpole(): 20 | train_cartpole(args) 21 | enjoy_cartpole(args) 22 | 23 | 24 | def test_mountaincar(): 25 | train_mountaincar(args) 26 | enjoy_mountaincar(args) 27 | -------------------------------------------------------------------------------- /tests/test_distri.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import stable_baselines.common.tf_util as tf_util 5 | from stable_baselines.common.distributions import DiagGaussianProbabilityDistributionType,\ 6 | CategoricalProbabilityDistributionType, \ 7 | MultiCategoricalProbabilityDistributionType, BernoulliProbabilityDistributionType 8 | 9 | 10 | @tf_util.in_session 11 | def test_probtypes(): 12 | """ 13 | test probability distribution types 14 | """ 15 | np.random.seed(0) 16 | 17 | pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8]) 18 | diag_gauss = DiagGaussianProbabilityDistributionType(pdparam_diag_gauss.size // 2) 19 | validate_probtype(diag_gauss, pdparam_diag_gauss) 20 | 21 | pdparam_categorical = np.array([-.2, .3, .5]) 22 | categorical = CategoricalProbabilityDistributionType(pdparam_categorical.size) 23 | validate_probtype(categorical, pdparam_categorical) 24 | 25 | nvec = np.array([1, 2, 3]) 26 | pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1]) 27 | multicategorical = MultiCategoricalProbabilityDistributionType(nvec) 28 | validate_probtype(multicategorical, pdparam_multicategorical) 29 | 30 | pdparam_bernoulli = np.array([-.2, .3, .5]) 31 | bernoulli = BernoulliProbabilityDistributionType(pdparam_bernoulli.size) 32 | validate_probtype(bernoulli, pdparam_bernoulli) 33 | 34 | 35 | def validate_probtype(probtype, pdparam): 36 | """ 37 | validate probability distribution types 38 | 39 | :param probtype: (ProbabilityDistributionType) the type to validate 40 | :param pdparam: ([float]) the flat probabilities to test 41 | """ 42 | number_samples = 100000 43 | # Check to see if mean negative log likelihood == differential entropy 44 | mval = np.repeat(pdparam[None, :], number_samples, axis=0) 45 | mval_ph = probtype.param_placeholder([number_samples]) 46 | xval_ph = probtype.sample_placeholder([number_samples]) 47 | proba_distribution = probtype.proba_distribution_from_flat(mval_ph) 48 | calcloglik = tf_util.function([xval_ph, mval_ph], proba_distribution.logp(xval_ph)) 49 | calcent = tf_util.function([mval_ph], proba_distribution.entropy()) 50 | xval = tf.get_default_session().run(proba_distribution.sample(), feed_dict={mval_ph: mval}) 51 | logliks = calcloglik(xval, mval) 52 | entval_ll = - logliks.mean() 53 | entval_ll_stderr = logliks.std() / np.sqrt(number_samples) 54 | entval = calcent(mval).mean() 55 | assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas 56 | 57 | # Check to see if kldiv[p,q] = - ent[p] - E_p[log q] 58 | mval2_ph = probtype.param_placeholder([number_samples]) 59 | pd2 = probtype.proba_distribution_from_flat(mval2_ph) 60 | tmp = pdparam + np.random.randn(pdparam.size) * 0.1 61 | mval2 = np.repeat(tmp[None, :], number_samples, axis=0) 62 | calckl = tf_util.function([mval_ph, mval2_ph], proba_distribution.kl(pd2)) 63 | klval = calckl(mval, mval2).mean() 64 | logliks = calcloglik(xval, mval2) 65 | klval_ll = - entval - logliks.mean() 66 | klval_ll_stderr = logliks.std() / np.sqrt(number_samples) 67 | assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas 68 | print('ok on', probtype, pdparam) 69 | -------------------------------------------------------------------------------- /tests/test_her.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from stable_baselines import HER, DQN, SAC, DDPG, TD3 6 | from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper 7 | from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY 8 | from stable_baselines.common.bit_flipping_env import BitFlippingEnv 9 | from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize 10 | 11 | N_BITS = 10 12 | 13 | 14 | def model_predict(model, env, n_steps, additional_check=None): 15 | """ 16 | Test helper 17 | :param model: (rl model) 18 | :param env: (gym.Env) 19 | :param n_steps: (int) 20 | :param additional_check: (callable) 21 | """ 22 | obs = env.reset() 23 | for _ in range(n_steps): 24 | action, _ = model.predict(obs) 25 | obs, reward, done, _ = env.step(action) 26 | 27 | if additional_check is not None: 28 | additional_check(obs, action, reward, done) 29 | 30 | if done: 31 | obs = env.reset() 32 | 33 | 34 | @pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) 35 | @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG, TD3]) 36 | @pytest.mark.parametrize('discrete_obs_space', [False, True]) 37 | def test_her(model_class, goal_selection_strategy, discrete_obs_space): 38 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], 39 | max_steps=N_BITS, discrete_obs_space=discrete_obs_space) 40 | 41 | # Take random actions 10% of the time 42 | kwargs = {'random_exploration': 0.1} if model_class in [DDPG, SAC, TD3] else {} 43 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, 44 | verbose=0, **kwargs) 45 | model.learn(150) 46 | 47 | 48 | @pytest.mark.parametrize('model_class', [DDPG, SAC, DQN, TD3]) 49 | def test_long_episode(model_class): 50 | """ 51 | Check that the model does not break when the replay buffer is still empty 52 | after the first rollout (because the episode is not over). 53 | """ 54 | # n_bits > nb_rollout_steps 55 | n_bits = 10 56 | env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC, TD3], 57 | max_steps=n_bits) 58 | kwargs = {} 59 | if model_class == DDPG: 60 | kwargs['nb_rollout_steps'] = 9 # < n_bits 61 | elif model_class in [DQN, SAC, TD3]: 62 | kwargs['batch_size'] = 8 # < n_bits 63 | kwargs['learning_starts'] = 0 64 | 65 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future', 66 | verbose=0, **kwargs) 67 | model.learn(100) 68 | 69 | 70 | @pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]]) 71 | @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG, TD3]) 72 | def test_model_manipulation(model_class, goal_selection_strategy): 73 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) 74 | env = DummyVecEnv([lambda: env]) 75 | 76 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy, 77 | verbose=0) 78 | model.learn(150) 79 | 80 | model_predict(model, env, n_steps=20, additional_check=None) 81 | 82 | model.save('./test_her.zip') 83 | del model 84 | 85 | # NOTE: HER does not support VecEnvWrapper yet 86 | with pytest.raises(AssertionError): 87 | model = HER.load('./test_her.zip', env=VecNormalize(env)) 88 | 89 | model = HER.load('./test_her.zip') 90 | 91 | # Check that the model raises an error when the env 92 | # is not wrapped (or no env passed to the model) 93 | with pytest.raises(ValueError): 94 | model.predict(env.reset()) 95 | 96 | env_ = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) 97 | env_ = HERGoalEnvWrapper(env_) 98 | 99 | model_predict(model, env_, n_steps=20, additional_check=None) 100 | 101 | model.set_env(env) 102 | model.learn(150) 103 | 104 | model_predict(model, env_, n_steps=20, additional_check=None) 105 | 106 | assert model.n_sampled_goal == 3 107 | 108 | del model 109 | 110 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) 111 | model = HER.load('./test_her', env=env) 112 | model.learn(150) 113 | 114 | model_predict(model, env_, n_steps=20, additional_check=None) 115 | 116 | assert model.n_sampled_goal == 3 117 | 118 | if os.path.isfile('./test_her.zip'): 119 | os.remove('./test_her.zip') 120 | -------------------------------------------------------------------------------- /tests/test_identity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, SAC, PPO1, PPO2, TD3, TRPO 5 | from stable_baselines.ddpg import NormalActionNoise 6 | from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox 7 | from stable_baselines.common.vec_env import DummyVecEnv 8 | from stable_baselines.common.evaluation import evaluate_policy 9 | 10 | 11 | # Hyperparameters for learning identity for each RL model 12 | LEARN_FUNC_DICT = { 13 | "a2c": lambda e: A2C( 14 | policy="MlpPolicy", 15 | learning_rate=1e-3, 16 | n_steps=4, 17 | gamma=0.4, 18 | ent_coef=0.0, 19 | env=e, 20 | seed=0, 21 | ).learn(total_timesteps=4000), 22 | "acer": lambda e: ACER( 23 | policy="MlpPolicy", 24 | env=e, 25 | seed=0, 26 | n_steps=4, 27 | replay_ratio=1, 28 | ent_coef=0.0, 29 | ).learn(total_timesteps=4000), 30 | "acktr": lambda e: ACKTR( 31 | policy="MlpPolicy", env=e, seed=0, learning_rate=5e-4, ent_coef=0.0, n_steps=4 32 | ).learn(total_timesteps=4000), 33 | "dqn": lambda e: DQN( 34 | policy="MlpPolicy", 35 | batch_size=32, 36 | gamma=0.1, 37 | learning_starts=0, 38 | exploration_final_eps=0.05, 39 | exploration_fraction=0.1, 40 | env=e, 41 | seed=0, 42 | ).learn(total_timesteps=4000), 43 | "ppo1": lambda e: PPO1( 44 | policy="MlpPolicy", 45 | env=e, 46 | seed=0, 47 | lam=0.5, 48 | entcoeff=0.0, 49 | optim_batchsize=16, 50 | gamma=0.4, 51 | optim_stepsize=1e-3, 52 | ).learn(total_timesteps=3000), 53 | "ppo2": lambda e: PPO2( 54 | policy="MlpPolicy", 55 | env=e, 56 | seed=0, 57 | learning_rate=1.5e-3, 58 | lam=0.8, 59 | ent_coef=0.0, 60 | gamma=0.4, 61 | ).learn(total_timesteps=3000), 62 | "trpo": lambda e: TRPO( 63 | policy="MlpPolicy", 64 | env=e, 65 | gamma=0.4, 66 | seed=0, 67 | max_kl=0.05, 68 | lam=0.7, 69 | timesteps_per_batch=256, 70 | ).learn(total_timesteps=4000), 71 | } 72 | 73 | 74 | @pytest.mark.slow 75 | @pytest.mark.parametrize( 76 | "model_name", ["a2c", "acer", "acktr", "dqn", "ppo1", "ppo2", "trpo"] 77 | ) 78 | def test_identity_discrete(model_name): 79 | """ 80 | Test if the algorithm (with a given policy) 81 | can learn an identity transformation (i.e. return observation as an action) 82 | 83 | :param model_name: (str) Name of the RL model 84 | """ 85 | env = DummyVecEnv([lambda: IdentityEnv(10)]) 86 | 87 | model = LEARN_FUNC_DICT[model_name](env) 88 | evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) 89 | 90 | obs = env.reset() 91 | assert model.action_probability(obs).shape == ( 92 | 1, 93 | 10, 94 | ), "Error: action_probability not returning correct shape" 95 | action = env.action_space.sample() 96 | action_prob = model.action_probability(obs, actions=action) 97 | assert np.prod(action_prob.shape) == 1, "Error: not scalar probability" 98 | action_logprob = model.action_probability(obs, actions=action, logp=True) 99 | assert np.allclose(action_prob, np.exp(action_logprob)), ( 100 | action_prob, 101 | action_logprob, 102 | ) 103 | 104 | # Free memory 105 | del model, env 106 | 107 | 108 | @pytest.mark.slow 109 | @pytest.mark.parametrize("model_class", [DDPG, TD3, SAC]) 110 | def test_identity_continuous(model_class): 111 | """ 112 | Test if the algorithm (with a given policy) 113 | can learn an identity transformation (i.e. return observation as an action) 114 | """ 115 | env = DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)]) 116 | 117 | n_steps = {SAC: 700, TD3: 500, DDPG: 2000}[model_class] 118 | 119 | kwargs = dict(seed=0, gamma=0.95, buffer_size=1e5) 120 | if model_class in [DDPG, TD3]: 121 | n_actions = 1 122 | action_noise = NormalActionNoise( 123 | mean=np.zeros(n_actions), sigma=0.05 * np.ones(n_actions) 124 | ) 125 | kwargs["action_noise"] = action_noise 126 | 127 | if model_class == DDPG: 128 | kwargs["actor_lr"] = 1e-3 129 | kwargs["batch_size"] = 100 130 | 131 | model = model_class("MlpPolicy", env, **kwargs) 132 | model.learn(total_timesteps=n_steps) 133 | 134 | evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) 135 | # Free memory 136 | del model, env 137 | -------------------------------------------------------------------------------- /tests/test_log_prob.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines import A2C, ACKTR, PPO1, PPO2, TRPO 5 | from stable_baselines.common.identity_env import IdentityEnvBox 6 | 7 | 8 | class Helper: 9 | @staticmethod 10 | def proba_vals(obs, state, mask): 11 | # Return fixed mean, std 12 | return np.array([-0.4]), np.array([[0.1]]) 13 | 14 | 15 | @pytest.mark.parametrize("model_class", [A2C, ACKTR, PPO1, PPO2, TRPO]) 16 | def test_log_prob_calcuation(model_class): 17 | model = model_class("MlpPolicy", IdentityEnvBox()) 18 | # Fixed mean/std 19 | model.proba_step = Helper.proba_vals 20 | # Check that the log probability is the one expected for the given mean/std 21 | logprob = model.action_probability(observation=np.array([[0.5], [0.5]]), actions=0.2, logp=True) 22 | assert np.allclose(logprob, np.array([-16.616353440210627])), "Calculation failed for {}".format(model_class) 23 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from stable_baselines.logger import make_output_format, read_tb, read_csv, read_json, _demo 5 | from .test_common import _maybe_disable_mpi 6 | 7 | 8 | KEY_VALUES = { 9 | "test": 1, 10 | "b": -3.14, 11 | "8": 9.9, 12 | "l": [1, 2], 13 | "a": np.array([1, 2, 3]), 14 | "f": np.array(1), 15 | "g": np.array([[[1]]]), 16 | } 17 | LOG_DIR = '/tmp/openai_baselines/' 18 | 19 | 20 | def test_main(): 21 | """ 22 | Dry-run python -m stable_baselines.logger 23 | """ 24 | _demo() 25 | 26 | 27 | @pytest.mark.parametrize('_format', ['tensorboard', 'stdout', 'log', 'json', 'csv']) 28 | @pytest.mark.parametrize('mpi_disabled', [False, True]) 29 | def test_make_output(_format, mpi_disabled): 30 | """ 31 | test make output 32 | 33 | :param _format: (str) output format 34 | """ 35 | with _maybe_disable_mpi(mpi_disabled): 36 | writer = make_output_format(_format, LOG_DIR) 37 | writer.writekvs(KEY_VALUES) 38 | if _format == 'tensorboard': 39 | read_tb(LOG_DIR) 40 | elif _format == "csv": 41 | read_csv(LOG_DIR + 'progress.csv') 42 | elif _format == 'json': 43 | read_json(LOG_DIR + 'progress.json') 44 | writer.close() 45 | 46 | 47 | def test_make_output_fail(): 48 | """ 49 | test value error on logger 50 | """ 51 | with pytest.raises(ValueError): 52 | make_output_format('dummy_format', LOG_DIR) 53 | -------------------------------------------------------------------------------- /tests/test_math_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from gym.spaces.box import Box 4 | 5 | from stable_baselines.common.math_util import discount_with_boundaries, scale_action, unscale_action 6 | 7 | 8 | def test_discount_with_boundaries(): 9 | """ 10 | test the discount_with_boundaries function 11 | """ 12 | gamma = 0.9 13 | rewards = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 14 | episode_starts = [1.0, 0.0, 0.0, 1.0] 15 | discounted_rewards = discount_with_boundaries(rewards, episode_starts, gamma) 16 | assert np.allclose(discounted_rewards, [1 + gamma * 2 + gamma ** 2 * 3, 2 + gamma * 3, 3, 4]) 17 | return 18 | 19 | 20 | def test_scaling_action(): 21 | """ 22 | test scaling of scalar, 1d and 2d vectors of finite non-NaN real numbers to and from tanh co-domain (per component) 23 | """ 24 | test_ranges = [(-1, 1), (-10, 10), (-10, 5), (-10, 0), (-10, -5), (0, 10), (5, 10)] 25 | 26 | # scalars 27 | for (range_low, range_high) in test_ranges: 28 | check_scaled_actions_from_range(range_low, range_high, scalar=True) 29 | 30 | # 1d vectors: wrapped scalars 31 | for test_range in test_ranges: 32 | check_scaled_actions_from_range(*test_range) 33 | 34 | # 2d vectors: all combinations of ranges above 35 | for (r1_low, r1_high) in test_ranges: 36 | for (r2_low, r2_high) in test_ranges: 37 | check_scaled_actions_from_range(np.array([r1_low, r2_low], dtype=np.float), 38 | np.array([r1_high, r2_high], dtype=np.float)) 39 | 40 | 41 | def check_scaled_actions_from_range(low, high, scalar=False): 42 | """ 43 | helper method which creates dummy action space spanning between respective components of low and high 44 | and then checks scaling to and from tanh co-domain for low, middle and high value from that action space 45 | :param low: (np.ndarray), (int) or (float) 46 | :param high: (np.ndarray), (int) or (float) 47 | :param scalar: (bool) Whether consider scalar range or wrap it into 1d vector 48 | """ 49 | 50 | if scalar and (isinstance(low, float) or isinstance(low, int)): 51 | ones = 1. 52 | action_space = Box(low, high, shape=(1,)) 53 | else: 54 | low = np.atleast_1d(low) 55 | high = np.atleast_1d(high) 56 | ones = np.ones_like(low) 57 | action_space = Box(low, high) 58 | 59 | mid = 0.5 * (low + high) 60 | 61 | expected_mapping = [(low, -ones), (mid, 0. * ones), (high, ones)] 62 | 63 | for (not_scaled, scaled) in expected_mapping: 64 | assert np.allclose(scale_action(action_space, not_scaled), scaled) 65 | assert np.allclose(unscale_action(action_space, scaled), not_scaled) 66 | 67 | 68 | def test_batch_shape_invariant_to_scaling(): 69 | """ 70 | test that scaling deals well with batches as tensors and numpy matrices in terms of shape 71 | """ 72 | action_space = Box(np.array([-10., -5., -1.]), np.array([10., 3., 2.])) 73 | 74 | tensor = tf.constant(1., shape=[2, 3]) 75 | matrix = np.ones((2, 3)) 76 | 77 | assert scale_action(action_space, tensor).shape == (2, 3) 78 | assert scale_action(action_space, matrix).shape == (2, 3) 79 | 80 | assert unscale_action(action_space, tensor).shape == (2, 3) 81 | assert unscale_action(action_space, matrix).shape == (2, 3) 82 | -------------------------------------------------------------------------------- /tests/test_monitor.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import json 3 | import os 4 | 5 | import pandas 6 | import gym 7 | 8 | from stable_baselines.bench import Monitor 9 | from stable_baselines.bench.monitor import get_monitor_files, load_results 10 | 11 | 12 | def test_monitor(): 13 | """ 14 | test the monitor wrapper 15 | """ 16 | env = gym.make("CartPole-v1") 17 | env.seed(0) 18 | mon_file = "/tmp/stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()) 19 | menv = Monitor(env, mon_file) 20 | menv.reset() 21 | for _ in range(1000): 22 | _, _, done, _ = menv.step(0) 23 | if done: 24 | menv.reset() 25 | 26 | file_handler = open(mon_file, 'rt') 27 | 28 | firstline = file_handler.readline() 29 | assert firstline.startswith('#') 30 | metadata = json.loads(firstline[1:]) 31 | assert metadata['env_id'] == "CartPole-v1" 32 | assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata" 33 | 34 | last_logline = pandas.read_csv(file_handler, index_col=None) 35 | assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" 36 | file_handler.close() 37 | os.remove(mon_file) 38 | 39 | 40 | def test_monitor_load_results(tmp_path): 41 | """ 42 | test load_results on log files produced by the monitor wrapper 43 | """ 44 | tmp_path = str(tmp_path) 45 | env1 = gym.make("CartPole-v1") 46 | env1.seed(0) 47 | monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) 48 | monitor_env1 = Monitor(env1, monitor_file1) 49 | 50 | monitor_files = get_monitor_files(tmp_path) 51 | assert len(monitor_files) == 1 52 | assert monitor_file1 in monitor_files 53 | 54 | monitor_env1.reset() 55 | episode_count1 = 0 56 | for _ in range(1000): 57 | _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample()) 58 | if done: 59 | episode_count1 += 1 60 | monitor_env1.reset() 61 | 62 | results_size1 = len(load_results(os.path.join(tmp_path)).index) 63 | assert results_size1 == episode_count1 64 | 65 | env2 = gym.make("CartPole-v1") 66 | env2.seed(0) 67 | monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) 68 | monitor_env2 = Monitor(env2, monitor_file2) 69 | monitor_files = get_monitor_files(tmp_path) 70 | assert len(monitor_files) == 2 71 | assert monitor_file1 in monitor_files 72 | assert monitor_file2 in monitor_files 73 | 74 | monitor_env2.reset() 75 | episode_count2 = 0 76 | for _ in range(1000): 77 | _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample()) 78 | if done: 79 | episode_count2 += 1 80 | monitor_env2.reset() 81 | 82 | results_size2 = len(load_results(os.path.join(tmp_path)).index) 83 | 84 | assert results_size2 == (results_size1 + episode_count2) 85 | 86 | os.remove(monitor_file1) 87 | os.remove(monitor_file2) 88 | -------------------------------------------------------------------------------- /tests/test_mpi_adam.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pytest 4 | 5 | from .test_common import _assert_eq 6 | 7 | 8 | def test_mpi_adam(): 9 | """Test RunningMeanStd object for MPI""" 10 | # Test will be run in CI before pytest is run 11 | pytest.skip() 12 | return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', 13 | 'python', '-m', 'stable_baselines.common.mpi_adam']) 14 | _assert_eq(return_code, 0) 15 | 16 | 17 | def test_mpi_adam_ppo1(): 18 | """Running test for ppo1""" 19 | # Test will be run in CI before pytest is run 20 | pytest.skip() 21 | return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', 22 | 'python', '-m', 23 | 'stable_baselines.ppo1.experiments.train_cartpole']) 24 | _assert_eq(return_code, 0) 25 | -------------------------------------------------------------------------------- /tests/test_multiple_learn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from stable_baselines import A2C, ACER, ACKTR, PPO2 4 | from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox 5 | from stable_baselines.common.vec_env import DummyVecEnv 6 | 7 | # TODO: Fix multiple-learn on commented-out models (Issue #619). 8 | MODEL_LIST = [ 9 | A2C, 10 | ACER, 11 | ACKTR, 12 | PPO2, 13 | 14 | # MPI-based models, which use traj_segment_generator instead of Runner. 15 | # 16 | # PPO1, 17 | # TRPO, 18 | 19 | # Off-policy models, which don't use Runner but reset every .learn() anyways. 20 | # 21 | # DDPG, 22 | # SAC, 23 | # TD3, 24 | ] 25 | 26 | 27 | @pytest.mark.parametrize("model_class", MODEL_LIST) 28 | def test_model_multiple_learn_no_reset(model_class): 29 | """Check that when we call learn multiple times, we don't unnecessarily 30 | reset the environment. 31 | """ 32 | if model_class is ACER: 33 | def make_env(): 34 | return IdentityEnv(ep_length=1e10, dim=2) 35 | else: 36 | def make_env(): 37 | return IdentityEnvBox(ep_length=1e10) 38 | env = make_env() 39 | venv = DummyVecEnv([lambda: env]) 40 | model = model_class(policy="MlpPolicy", env=venv) 41 | _check_reset_count(model, env) 42 | 43 | # Try again following a `set_env`. 44 | env = make_env() 45 | venv = DummyVecEnv([lambda: env]) 46 | assert env.num_resets == 0 47 | 48 | model.set_env(venv) 49 | _check_reset_count(model, env) 50 | 51 | 52 | def _check_reset_count(model, env: IdentityEnv): 53 | assert env.num_resets == 0 54 | _prev_runner = None 55 | for _ in range(2): 56 | model.learn(total_timesteps=300) 57 | # Lazy constructor for Runner fires upon the first call to learn. 58 | assert env.num_resets == 1 59 | if _prev_runner is not None: 60 | assert _prev_runner is model.runner, "Runner shouldn't change" 61 | _prev_runner = model.runner 62 | -------------------------------------------------------------------------------- /tests/test_no_mpi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from .test_common import _maybe_disable_mpi 4 | 5 | 6 | def test_no_mpi_no_crash(): 7 | with _maybe_disable_mpi(True): 8 | # Temporarily delete previously imported stable baselines 9 | old_modules = {} 10 | sb_modules = [name for name in sys.modules.keys() 11 | if name.startswith('stable_baselines')] 12 | for name in sb_modules: 13 | old_modules[name] = sys.modules.pop(name) 14 | 15 | # Re-import (with mpi disabled) 16 | import stable_baselines 17 | del stable_baselines # appease Codacy 18 | 19 | # Restore old version of stable baselines (with MPI imported) 20 | for name, mod in old_modules.items(): 21 | sys.modules[name] = mod 22 | -------------------------------------------------------------------------------- /tests/test_ppo2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import gym 5 | 6 | from stable_baselines import PPO2 7 | from stable_baselines.common import make_vec_env 8 | from stable_baselines.common.vec_env import DummyVecEnv 9 | 10 | 11 | @pytest.mark.parametrize("cliprange", [0.2, lambda x: 0.1 * x]) 12 | @pytest.mark.parametrize("cliprange_vf", [None, 0.2, lambda x: 0.3 * x, -1.0]) 13 | def test_clipping(tmp_path, cliprange, cliprange_vf): 14 | """Test the different clipping (policy and vf)""" 15 | model = PPO2( 16 | "MlpPolicy", 17 | "CartPole-v1", 18 | cliprange=cliprange, 19 | cliprange_vf=cliprange_vf, 20 | noptepochs=2, 21 | n_steps=64, 22 | ).learn(100) 23 | save_path = os.path.join(str(tmp_path), "ppo2_clip.zip") 24 | model.save(save_path) 25 | env = model.get_env() 26 | model = PPO2.load(save_path, env=env) 27 | model.learn(100) 28 | 29 | if os.path.exists(save_path): 30 | os.remove(save_path) 31 | 32 | 33 | def test_ppo2_update_n_batch_on_load(tmp_path): 34 | env = make_vec_env("CartPole-v1", n_envs=2) 35 | model = PPO2("MlpPolicy", env, n_steps=10, nminibatches=1) 36 | save_path = os.path.join(str(tmp_path), "ppo2_cartpole.zip") 37 | 38 | model.learn(total_timesteps=100) 39 | model.save(save_path) 40 | 41 | del model 42 | 43 | model = PPO2.load(save_path) 44 | test_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) 45 | 46 | model.set_env(test_env) 47 | model.learn(total_timesteps=100) 48 | -------------------------------------------------------------------------------- /tests/test_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from stable_baselines.common.buffers import ReplayBuffer, PrioritizedReplayBuffer 4 | 5 | 6 | def test_extend_uniform(): 7 | nvals = 16 8 | states = [np.random.rand(2, 2) for _ in range(nvals)] 9 | actions = [np.random.rand(2) for _ in range(nvals)] 10 | rewards = [np.random.rand() for _ in range(nvals)] 11 | newstate = [np.random.rand(2, 2) for _ in range(nvals)] 12 | done = [np.random.randint(0, 2) for _ in range(nvals)] 13 | 14 | size = 32 15 | baseline = ReplayBuffer(size) 16 | ext = ReplayBuffer(size) 17 | for data in zip(states, actions, rewards, newstate, done): 18 | baseline.add(*data) 19 | 20 | states, actions, rewards, newstates, done = map( 21 | np.array, [states, actions, rewards, newstate, done]) 22 | 23 | ext.extend(states, actions, rewards, newstates, done) 24 | assert len(baseline) == len(ext) 25 | 26 | # Check buffers have same values 27 | for i in range(nvals): 28 | for j in range(5): 29 | condition = (baseline.storage[i][j] == ext.storage[i][j]) 30 | if isinstance(condition, np.ndarray): 31 | # for obs, obs_t1 32 | assert np.all(condition) 33 | else: 34 | # for done, reward action 35 | assert condition 36 | 37 | 38 | def test_extend_prioritized(): 39 | nvals = 16 40 | states = [np.random.rand(2, 2) for _ in range(nvals)] 41 | actions = [np.random.rand(2) for _ in range(nvals)] 42 | rewards = [np.random.rand() for _ in range(nvals)] 43 | newstate = [np.random.rand(2, 2) for _ in range(nvals)] 44 | done = [np.random.randint(0, 2) for _ in range(nvals)] 45 | 46 | size = 32 47 | alpha = 0.99 48 | baseline = PrioritizedReplayBuffer(size, alpha) 49 | ext = PrioritizedReplayBuffer(size, alpha) 50 | for data in zip(states, actions, rewards, newstate, done): 51 | baseline.add(*data) 52 | 53 | states, actions, rewards, newstates, done = map( 54 | np.array, [states, actions, rewards, newstate, done]) 55 | 56 | ext.extend(states, actions, rewards, newstates, done) 57 | assert len(baseline) == len(ext) 58 | 59 | # Check buffers have same values 60 | for i in range(nvals): 61 | for j in range(5): 62 | condition = (baseline.storage[i][j] == ext.storage[i][j]) 63 | if isinstance(condition, np.ndarray): 64 | # for obs, obs_t1 65 | assert np.all(condition) 66 | else: 67 | # for done, reward action 68 | assert condition 69 | 70 | # assert priorities 71 | assert (baseline._it_min._value == ext._it_min._value).all() 72 | assert (baseline._it_sum._value == ext._it_sum._value).all() 73 | -------------------------------------------------------------------------------- /tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from stable_baselines.common.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule 4 | 5 | 6 | def test_piecewise_schedule(): 7 | """ 8 | test PiecewiseSchedule 9 | """ 10 | piecewise_sched = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], 11 | outside_value=500) 12 | 13 | assert np.isclose(piecewise_sched.value(-10), 500) 14 | assert np.isclose(piecewise_sched.value(0), 150) 15 | assert np.isclose(piecewise_sched.value(5), 200) 16 | assert np.isclose(piecewise_sched.value(9), 80) 17 | assert np.isclose(piecewise_sched.value(50), 50) 18 | assert np.isclose(piecewise_sched.value(80), 50) 19 | assert np.isclose(piecewise_sched.value(150), 0) 20 | assert np.isclose(piecewise_sched.value(175), -25) 21 | assert np.isclose(piecewise_sched.value(201), 500) 22 | assert np.isclose(piecewise_sched.value(500), 500) 23 | 24 | assert np.isclose(piecewise_sched.value(200 - 1e-10), -50) 25 | 26 | 27 | def test_constant_schedule(): 28 | """ 29 | test ConstantSchedule 30 | """ 31 | constant_sched = ConstantSchedule(5) 32 | for i in range(-100, 100): 33 | assert np.isclose(constant_sched.value(i), 5) 34 | 35 | 36 | def test_linear_schedule(): 37 | """ 38 | test LinearSchedule 39 | """ 40 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.2, final_p=0.8) 41 | assert np.isclose(linear_sched.value(50), 0.5) 42 | assert np.isclose(linear_sched.value(0), 0.2) 43 | assert np.isclose(linear_sched.value(100), 0.8) 44 | 45 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.8, final_p=0.2) 46 | assert np.isclose(linear_sched.value(50), 0.5) 47 | assert np.isclose(linear_sched.value(0), 0.8) 48 | assert np.isclose(linear_sched.value(100), 0.2) 49 | 50 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=-0.6, final_p=0.2) 51 | assert np.isclose(linear_sched.value(50), -0.2) 52 | assert np.isclose(linear_sched.value(0), -0.6) 53 | assert np.isclose(linear_sched.value(100), 0.2) 54 | 55 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.2, final_p=-0.6) 56 | assert np.isclose(linear_sched.value(50), -0.2) 57 | assert np.isclose(linear_sched.value(0), 0.2) 58 | assert np.isclose(linear_sched.value(100), -0.6) 59 | -------------------------------------------------------------------------------- /tests/test_tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import pytest 5 | 6 | from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, SAC, TD3, TRPO 7 | 8 | TENSORBOARD_DIR = '/tmp/tb_dir/' 9 | 10 | if os.path.isdir(TENSORBOARD_DIR): 11 | shutil.rmtree(TENSORBOARD_DIR) 12 | 13 | MODEL_DICT = { 14 | 'a2c': (A2C, 'CartPole-v1'), 15 | 'acer': (ACER, 'CartPole-v1'), 16 | 'acktr': (ACKTR, 'CartPole-v1'), 17 | 'dqn': (DQN, 'CartPole-v1'), 18 | 'ddpg': (DDPG, 'Pendulum-v0'), 19 | 'ppo1': (PPO1, 'CartPole-v1'), 20 | 'ppo2': (PPO2, 'CartPole-v1'), 21 | 'sac': (SAC, 'Pendulum-v0'), 22 | 'td3': (TD3, 'Pendulum-v0'), 23 | 'trpo': (TRPO, 'CartPole-v1'), 24 | } 25 | 26 | N_STEPS = 300 27 | 28 | 29 | @pytest.mark.parametrize("model_name", MODEL_DICT.keys()) 30 | def test_tensorboard(model_name): 31 | logname = model_name.upper() 32 | algo, env_id = MODEL_DICT[model_name] 33 | model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR) 34 | model.learn(N_STEPS) 35 | model.learn(N_STEPS, reset_num_timesteps=False) 36 | 37 | assert os.path.isdir(TENSORBOARD_DIR + logname + "_1") 38 | assert not os.path.isdir(TENSORBOARD_DIR + logname + "_2") 39 | 40 | 41 | @pytest.mark.parametrize("model_name", MODEL_DICT.keys()) 42 | def test_multiple_runs(model_name): 43 | logname = "tb_multiple_runs_" + model_name 44 | algo, env_id = MODEL_DICT[model_name] 45 | model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR) 46 | model.learn(N_STEPS, tb_log_name=logname) 47 | model.learn(N_STEPS, tb_log_name=logname) 48 | 49 | assert os.path.isdir(TENSORBOARD_DIR + logname + "_1") 50 | # Check that the log dir name increments correctly 51 | assert os.path.isdir(TENSORBOARD_DIR + logname + "_2") 52 | -------------------------------------------------------------------------------- /tests/test_tf_util.py: -------------------------------------------------------------------------------- 1 | # tests for tf_util 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from stable_baselines.common.tf_util import function, initialize, single_threaded_session, is_image 6 | 7 | 8 | def test_function(): 9 | """ 10 | test the function function in tf_util 11 | """ 12 | with tf.Graph().as_default(): 13 | x_ph = tf.placeholder(tf.int32, (), name="x") 14 | y_ph = tf.placeholder(tf.int32, (), name="y") 15 | z_ph = 3 * x_ph + 2 * y_ph 16 | linear_fn = function([x_ph, y_ph], z_ph, givens={y_ph: 0}) 17 | 18 | with single_threaded_session(): 19 | initialize() 20 | 21 | assert linear_fn(2) == 6 22 | assert linear_fn(2, 2) == 10 23 | 24 | 25 | def test_multikwargs(): 26 | """ 27 | test the function function in tf_util 28 | """ 29 | with tf.Graph().as_default(): 30 | x_ph = tf.placeholder(tf.int32, (), name="x") 31 | with tf.variable_scope("other"): 32 | x2_ph = tf.placeholder(tf.int32, (), name="x") 33 | z_ph = 3 * x_ph + 2 * x2_ph 34 | 35 | linear_fn = function([x_ph, x2_ph], z_ph, givens={x2_ph: 0}) 36 | with single_threaded_session(): 37 | initialize() 38 | assert linear_fn(2) == 6 39 | assert linear_fn(2, 2) == 10 40 | 41 | 42 | def test_image_detection(): 43 | rgb = (32, 64, 3) 44 | gray = (43, 23, 1) 45 | rgbd = (12, 32, 4) 46 | invalid_1 = (32, 12) 47 | invalid_2 = (12, 32, 6) 48 | 49 | # TF checks 50 | for shape in (rgb, gray, rgbd): 51 | assert is_image(tf.placeholder(tf.uint8, shape=shape)) 52 | 53 | for shape in (invalid_1, invalid_2): 54 | assert not is_image(tf.placeholder(tf.uint8, shape=shape)) 55 | 56 | # Numpy checks 57 | for shape in (rgb, gray, rgbd): 58 | assert is_image(np.ones(shape)) 59 | 60 | for shape in (invalid_1, invalid_2): 61 | assert not is_image(np.ones(shape)) 62 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import pytest 5 | import gym 6 | 7 | from stable_baselines import A2C 8 | from stable_baselines.bench.monitor import Monitor 9 | from stable_baselines.common.evaluation import evaluate_policy 10 | from stable_baselines.common.cmd_util import make_vec_env 11 | from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv 12 | 13 | 14 | @pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')]) 15 | @pytest.mark.parametrize("n_envs", [1, 2]) 16 | @pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv]) 17 | @pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit]) 18 | def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class): 19 | env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, 20 | wrapper_class=wrapper_class, monitor_dir=None, seed=0) 21 | 22 | assert env.num_envs == n_envs 23 | 24 | if vec_env_cls is None: 25 | assert isinstance(env, DummyVecEnv) 26 | if wrapper_class is not None: 27 | assert isinstance(env.envs[0], wrapper_class) 28 | else: 29 | assert isinstance(env.envs[0], Monitor) 30 | else: 31 | assert isinstance(env, SubprocVecEnv) 32 | # Kill subprocesses 33 | env.close() 34 | 35 | 36 | def test_custom_vec_env(): 37 | """ 38 | Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests. 39 | """ 40 | monitor_dir = 'logs/test_make_vec_env/' 41 | env = make_vec_env('CartPole-v1', n_envs=1, 42 | monitor_dir=monitor_dir, seed=0, 43 | vec_env_cls=SubprocVecEnv, vec_env_kwargs={'start_method': None}) 44 | 45 | assert env.num_envs == 1 46 | assert isinstance(env, SubprocVecEnv) 47 | assert os.path.isdir('logs/test_make_vec_env/') 48 | # Kill subprocess 49 | env.close() 50 | # Cleanup folder 51 | shutil.rmtree(monitor_dir) 52 | 53 | # This should fail because DummyVecEnv does not have any keyword argument 54 | with pytest.raises(TypeError): 55 | make_vec_env('CartPole-v1', n_envs=1, vec_env_kwargs={'dummy': False}) 56 | 57 | 58 | def test_evaluate_policy(): 59 | model = A2C('MlpPolicy', 'Pendulum-v0', seed=0) 60 | n_steps_per_episode, n_eval_episodes = 200, 2 61 | model.n_callback_calls = 0 62 | 63 | def dummy_callback(locals_, _globals): 64 | locals_['model'].n_callback_calls += 1 65 | 66 | _, episode_lengths = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True, 67 | render=False, callback=dummy_callback, reward_threshold=None, 68 | return_episode_rewards=True) 69 | 70 | n_steps = sum(episode_lengths) 71 | assert n_steps == n_steps_per_episode * n_eval_episodes 72 | assert n_steps == model.n_callback_calls 73 | 74 | # Reaching a mean reward of zero is impossible with the Pendulum env 75 | with pytest.raises(AssertionError): 76 | evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0) 77 | 78 | episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True) 79 | assert len(episode_rewards) == n_eval_episodes 80 | -------------------------------------------------------------------------------- /tests/test_vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | 5 | from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan 6 | 7 | 8 | class NanAndInfEnv(gym.Env): 9 | """Custom Environment that raised NaNs and Infs""" 10 | metadata = {'render.modes': ['human']} 11 | 12 | def __init__(self): 13 | super(NanAndInfEnv, self).__init__() 14 | self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) 15 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) 16 | 17 | @staticmethod 18 | def step(action): 19 | if np.all(np.array(action) > 0): 20 | obs = float('NaN') 21 | elif np.all(np.array(action) < 0): 22 | obs = float('inf') 23 | else: 24 | obs = 0 25 | return [obs], 0.0, False, {} 26 | 27 | @staticmethod 28 | def reset(): 29 | return [0.0] 30 | 31 | def render(self, mode='human', close=False): 32 | pass 33 | 34 | 35 | def test_check_nan(): 36 | """Test VecCheckNan Object""" 37 | 38 | env = DummyVecEnv([NanAndInfEnv]) 39 | env = VecCheckNan(env, raise_exception=True) 40 | 41 | env.step([[0]]) 42 | 43 | try: 44 | env.step([[float('NaN')]]) 45 | except ValueError: 46 | pass 47 | else: 48 | assert False 49 | 50 | try: 51 | env.step([[float('inf')]]) 52 | except ValueError: 53 | pass 54 | else: 55 | assert False 56 | 57 | try: 58 | env.step([[-1]]) 59 | except ValueError: 60 | pass 61 | else: 62 | assert False 63 | 64 | try: 65 | env.step([[1]]) 66 | except ValueError: 67 | pass 68 | else: 69 | assert False 70 | 71 | env.step(np.array([[0, 1], [0, 1]])) 72 | --------------------------------------------------------------------------------