├── .coveragerc
├── .dockerignore
├── .github
├── ISSUE_TEMPLATE
│ └── issue-template.md
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ └── ci.yml
├── .gitignore
├── .readthedocs.yml
├── .travis.yml
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── conftest.py
├── data
└── logo.jpg
├── docs
├── Makefile
├── README.md
├── _static
│ ├── css
│ │ └── baselines_theme.css
│ └── img
│ │ ├── Tensorboard_example_1.png
│ │ ├── Tensorboard_example_2.png
│ │ ├── Tensorboard_example_3.png
│ │ ├── breakout.gif
│ │ ├── colab.svg
│ │ ├── learning_curve.png
│ │ ├── logo.png
│ │ ├── mistake.png
│ │ └── try_it.png
├── common
│ ├── cmd_utils.rst
│ ├── distributions.rst
│ ├── env_checker.rst
│ ├── evaluation.rst
│ ├── monitor.rst
│ ├── schedules.rst
│ └── tf_utils.rst
├── conf.py
├── guide
│ ├── algos.rst
│ ├── callbacks.rst
│ ├── checking_nan.rst
│ ├── custom_env.rst
│ ├── custom_policy.rst
│ ├── examples.rst
│ ├── export.rst
│ ├── install.rst
│ ├── pretrain.rst
│ ├── quickstart.rst
│ ├── rl.rst
│ ├── rl_tips.rst
│ ├── rl_zoo.rst
│ ├── save_format.rst
│ ├── tensorboard.rst
│ └── vec_envs.rst
├── index.rst
├── make.bat
├── misc
│ ├── changelog.rst
│ ├── projects.rst
│ └── results_plotter.rst
├── modules
│ ├── a2c.rst
│ ├── acer.rst
│ ├── acktr.rst
│ ├── base.rst
│ ├── ddpg.rst
│ ├── dqn.rst
│ ├── gail.rst
│ ├── her.rst
│ ├── policies.rst
│ ├── ppo1.rst
│ ├── ppo2.rst
│ ├── sac.rst
│ ├── td3.rst
│ └── trpo.rst
├── requirements.txt
└── spelling_wordlist.txt
├── scripts
├── build_docker.sh
├── run_docker_cpu.sh
├── run_docker_gpu.sh
├── run_tests.sh
└── run_tests_travis.sh
├── setup.cfg
├── setup.py
├── stable_baselines
├── __init__.py
├── a2c
│ ├── __init__.py
│ ├── a2c.py
│ └── run_atari.py
├── acer
│ ├── __init__.py
│ ├── acer_simple.py
│ ├── buffer.py
│ └── run_atari.py
├── acktr
│ ├── __init__.py
│ ├── acktr.py
│ ├── kfac.py
│ ├── kfac_utils.py
│ └── run_atari.py
├── bench
│ ├── __init__.py
│ └── monitor.py
├── common
│ ├── __init__.py
│ ├── atari_wrappers.py
│ ├── base_class.py
│ ├── bit_flipping_env.py
│ ├── buffers.py
│ ├── callbacks.py
│ ├── cg.py
│ ├── cmd_util.py
│ ├── console_util.py
│ ├── dataset.py
│ ├── distributions.py
│ ├── env_checker.py
│ ├── evaluation.py
│ ├── identity_env.py
│ ├── input.py
│ ├── math_util.py
│ ├── misc_util.py
│ ├── mpi_adam.py
│ ├── mpi_moments.py
│ ├── mpi_running_mean_std.py
│ ├── noise.py
│ ├── policies.py
│ ├── runners.py
│ ├── running_mean_std.py
│ ├── save_util.py
│ ├── schedules.py
│ ├── segment_tree.py
│ ├── tf_layers.py
│ ├── tf_util.py
│ ├── tile_images.py
│ └── vec_env
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── subproc_vec_env.py
│ │ ├── util.py
│ │ ├── vec_check_nan.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_normalize.py
│ │ └── vec_video_recorder.py
├── ddpg
│ ├── __init__.py
│ ├── ddpg.py
│ ├── main.py
│ ├── noise.py
│ └── policies.py
├── deepq
│ ├── __init__.py
│ ├── build_graph.py
│ ├── dqn.py
│ ├── experiments
│ │ ├── __init__.py
│ │ ├── enjoy_cartpole.py
│ │ ├── enjoy_mountaincar.py
│ │ ├── run_atari.py
│ │ ├── train_cartpole.py
│ │ └── train_mountaincar.py
│ └── policies.py
├── gail
│ ├── __init__.py
│ ├── adversary.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── expert_cartpole.npz
│ │ ├── expert_pendulum.npz
│ │ └── record_expert.py
│ └── model.py
├── her
│ ├── __init__.py
│ ├── her.py
│ ├── replay_buffer.py
│ └── utils.py
├── logger.py
├── ppo1
│ ├── __init__.py
│ ├── experiments
│ │ └── train_cartpole.py
│ ├── pposgd_simple.py
│ ├── run_atari.py
│ ├── run_mujoco.py
│ └── run_robotics.py
├── ppo2
│ ├── __init__.py
│ ├── ppo2.py
│ ├── run_atari.py
│ └── run_mujoco.py
├── py.typed
├── results_plotter.py
├── sac
│ ├── __init__.py
│ ├── policies.py
│ └── sac.py
├── td3
│ ├── __init__.py
│ ├── policies.py
│ └── td3.py
├── trpo_mpi
│ ├── __init__.py
│ ├── run_atari.py
│ ├── run_mujoco.py
│ ├── trpo_mpi.py
│ └── utils.py
└── version.txt
└── tests
├── __init__.py
├── test_0deterministic.py
├── test_a2c.py
├── test_a2c_conv.py
├── test_action_scaling.py
├── test_action_space.py
├── test_atari.py
├── test_auto_vec_detection.py
├── test_callbacks.py
├── test_common.py
├── test_continuous.py
├── test_custom_policy.py
├── test_deepq.py
├── test_distri.py
├── test_envs.py
├── test_gail.py
├── test_her.py
├── test_identity.py
├── test_load_parameters.py
├── test_log_prob.py
├── test_logger.py
├── test_lstm_policy.py
├── test_math_util.py
├── test_monitor.py
├── test_mpi_adam.py
├── test_multiple_learn.py
├── test_no_mpi.py
├── test_ppo2.py
├── test_replay_buffer.py
├── test_save.py
├── test_schedules.py
├── test_segment_tree.py
├── test_tensorboard.py
├── test_tf_util.py
├── test_utils.py
├── test_vec_check_nan.py
├── test_vec_envs.py
└── test_vec_normalize.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = False
3 | omit =
4 | # Mujoco requires a licence
5 | stable_baselines/*/run_mujoco.py
6 | stable_baselines/ppo1/run_humanoid.py
7 | stable_baselines/ppo1/run_robotics.py
8 | # HER requires mpi and Mujoco
9 | stable_baselines/her/experiment/*
10 | tests/*
11 | setup.py
12 |
13 | [report]
14 | exclude_lines =
15 | pragma: no cover
16 | raise NotImplementedError()
17 | if KFAC_DEBUG:
18 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .gitignore
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/issue-template.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Issue Template
3 | about: How to create an issue for this repository
4 |
5 | ---
6 |
7 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
8 |
9 | If you have any questions, feel free to create an issue with the tag [question].
10 | If you wish to suggest an enhancement or feature request, add the tag [feature request].
11 | If you are submitting a bug report, please fill in the following details.
12 |
13 | If your issue is related to a custom gym environment, please check it first using:
14 |
15 | ```python
16 | from stable_baselines.common.env_checker import check_env
17 |
18 | env = CustomEnv(arg1, ...)
19 | # It will check your custom environment and output additional warnings if needed
20 | check_env(env)
21 | ```
22 |
23 | **Describe the bug**
24 | A clear and concise description of what the bug is.
25 |
26 | **Code example**
27 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
28 |
29 | Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
30 | for both code and stack traces.
31 |
32 | ```python
33 | from stable_baselines import ...
34 |
35 | ```
36 |
37 | ```bash
38 | Traceback (most recent call last): File ...
39 |
40 | ```
41 |
42 | **System Info**
43 | Describe the characteristic of your environment:
44 | * Describe how the library was installed (pip, docker, source, ...)
45 | * GPU models and configuration
46 | * Python version
47 | * Tensorflow version
48 | * Versions of any other relevant libraries
49 |
50 | **Additional context**
51 | Add any other context about the problem here.
52 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Description
4 |
5 |
6 | ## Motivation and Context
7 |
8 |
9 |
10 | - [ ] I have raised an issue to propose this change ([required](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) for new features and bug fixes)
11 |
12 | ## Types of changes
13 |
14 | - [ ] Bug fix (non-breaking change which fixes an issue)
15 | - [ ] New feature (non-breaking change which adds functionality)
16 | - [ ] Breaking change (fix or feature that would cause existing functionality to change)
17 | - [ ] Documentation (update in the documentation)
18 |
19 | ## Checklist:
20 |
21 |
22 | - [ ] I've read the [CONTRIBUTION](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) guide (**required**)
23 | - [ ] I have updated the [changelog](https://github.com/hill-a/stable-baselines/blob/master/docs/misc/changelog.rst) accordingly (**required**).
24 | - [ ] My change requires a change to the documentation.
25 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
26 | - [ ] I have updated the documentation accordingly.
27 | - [ ] I have ensured `pytest` and `pytype` both pass (by running `make pytest` and `make type`).
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 | # Skip CI if [ci skip] in the commit message
15 | if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
16 | runs-on: ubuntu-latest
17 | strategy:
18 | matrix:
19 | python-version: [3.6] # Deactivate 3.5 build as it is not longer maintained
20 |
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | sudo apt-get install libopenmpi-dev
30 | python -m pip install --upgrade pip
31 | pip install wheel
32 | pip install .[mpi,tests,docs]
33 | # Use headless version
34 | pip install opencv-python-headless
35 | # Tmp fix: ROM missing in the newest atari-py version
36 | pip install atari-py==0.2.5
37 | - name: MPI
38 | run: |
39 | # check MPI
40 | mpirun -h
41 | python -c "import mpi4py; print(mpi4py.__version__)"
42 | mpirun --allow-run-as-root -np 2 python -m stable_baselines.common.mpi_adam
43 | mpirun --allow-run-as-root -np 2 python -m stable_baselines.ppo1.experiments.train_cartpole
44 | mpirun --allow-run-as-root -np 2 python -m stable_baselines.common.mpi_running_mean_std
45 | # MPI requires 3 processes to run the following code
46 | # but will throw an error on GitHub CI as there is only two threads
47 | # mpirun --allow-run-as-root -np 3 python -c "from stable_baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()"
48 |
49 | - name: Build the doc
50 | run: |
51 | make doc
52 | - name: Type check
53 | run: |
54 | make type
55 | - name: Test with pytest
56 | run: |
57 | # Prevent issues with multiprocessing
58 | DEFAULT_START_METHOD=fork make pytest
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | *.pkl
4 | *.py~
5 | *.bak
6 | .pytest_cache
7 | .pytype
8 | .DS_Store
9 | .idea
10 | .coverage
11 | .coverage.*
12 | __pycache__/
13 | _build/
14 | *.npz
15 | *.zip
16 |
17 | # Setuptools distribution and build folders.
18 | /dist/
19 | /build
20 | keys/
21 |
22 | # Virtualenv
23 | /env
24 | /venv
25 |
26 | *.sublime-project
27 | *.sublime-workspace
28 |
29 | logs/
30 |
31 | .ipynb_checkpoints
32 | ghostdriver.log
33 |
34 | htmlcov
35 |
36 | junk
37 | src
38 |
39 | *.egg-info
40 | .cache
41 |
42 | MUJOCO_LOG.TXT
43 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Build documentation in the docs/ directory with Sphinx
8 | sphinx:
9 | configuration: docs/conf.py
10 |
11 | # Optionally build your docs in additional formats such as PDF and ePub
12 | formats: all
13 |
14 | # Optionally set the version of Python and requirements required to build your docs
15 | python:
16 | version: 3.7
17 | install:
18 | - requirements: docs/requirements.txt
19 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.5"
4 |
5 | env:
6 | global:
7 | - DOCKER_IMAGE=stablebaselines/stable-baselines-cpu:v2.10.0
8 |
9 | notifications:
10 | email: false
11 |
12 | services:
13 | - docker
14 |
15 | install:
16 | - docker pull ${DOCKER_IMAGE}
17 |
18 | script:
19 | - ./scripts/run_tests_travis.sh "${TEST_GLOB}"
20 |
21 | jobs:
22 | include:
23 | # Big test suite. Run in parallel to decrease wall-clock time, and to avoid OOM error from leaks
24 | - stage: Test
25 | name: "Unit Tests a-h"
26 | env: TEST_GLOB="[a-h]*"
27 |
28 | - name: "Unit Tests i-l"
29 | env: TEST_GLOB="[i-l]*"
30 |
31 | - name: "Unit Tests m-sa"
32 | env: TEST_GLOB="{[m-r]*,sa*}"
33 |
34 | - name: "Unit Tests sb-z"
35 | env: TEST_GLOB="{s[b-z]*,[t-z]*}"
36 |
37 | - name: "Unit Tests determinism"
38 | env: TEST_GLOB="0deterministic.py"
39 |
40 | - name: "Sphinx Documentation"
41 | script:
42 | - 'docker run -it --rm --mount src=$(pwd),target=/root/code/stable-baselines,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines/ && pushd docs/ && make clean && make html"'
43 |
44 | - name: "Type Checking"
45 | script:
46 | - 'docker run --rm --mount src=$(pwd),target=/root/code/stable-baselines,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines/ && pytype --version && pytype"'
47 |
48 | - stage: Codacy Trigger
49 | if: type != pull_request
50 | script:
51 | # When all test coverage reports have been uploaded, instruct Codacy to start analysis.
52 | - 'docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines,type=bind --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines/ && java -jar /root/code/codacy-coverage-reporter.jar final"'
53 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing to Stable-Baselines
2 |
3 | If you are interested in contributing to Stable-Baselines, your contributions will fall
4 | into two categories:
5 | 1. You want to propose a new Feature and implement it
6 | - Create an issue about your intended feature, and we shall discuss the design and
7 | implementation. Once we agree that the plan looks good, go ahead and implement it.
8 | 2. You want to implement a feature or bug-fix for an outstanding issue
9 | - Look at the outstanding issues here: https://github.com/hill-a/stable-baselines/issues
10 | - Look at the roadmap here: https://github.com/hill-a/stable-baselines/projects/1
11 | - Pick an issue or feature and comment on the task that you want to work on this feature.
12 | - If you need more context on a particular issue, please ask and we shall provide.
13 |
14 | Once you finish implementing a feature or bug-fix, please send a Pull Request to
15 | https://github.com/hill-a/stable-baselines/
16 |
17 |
18 | If you are not familiar with creating a Pull Request, here are some guides:
19 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
20 | - https://help.github.com/articles/creating-a-pull-request/
21 |
22 |
23 | ## Developing Stable-Baselines
24 |
25 | To develop Stable-Baselines on your machine, here are some tips:
26 |
27 | 1. Clone a copy of Stable-Baselines from source:
28 |
29 | ```bash
30 | git clone https://github.com/hill-a/stable-baselines/
31 | cd stable-baselines
32 | ```
33 |
34 | 2. Install Stable-Baselines in develop mode, with support for building the docs and running tests:
35 |
36 | ```bash
37 | pip install -e .[docs,tests]
38 | ```
39 |
40 | ## Codestyle
41 |
42 | We follow the [PEP8 codestyle](https://www.python.org/dev/peps/pep-0008/). Please order the imports as follows:
43 |
44 | 1. built-in
45 | 2. packages
46 | 3. current module
47 |
48 | with one space between each, that gives for instance:
49 | ```python
50 | import os
51 | import warnings
52 |
53 | import numpy as np
54 |
55 | from stable_baselines import PPO2
56 | ```
57 |
58 | In general, we recommend using pycharm to format everything in an efficient way.
59 |
60 | Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template:
61 |
62 | ```python
63 |
64 | def my_function(arg1: type1, arg2: type2) -> returntype:
65 | """
66 | Short description of the function.
67 |
68 | :param arg1: (type1) describe what is arg1
69 | :param arg2: (type2) describe what is arg2
70 | :return: (returntype) describe what is returned
71 | """
72 | ...
73 | return my_variable
74 | ```
75 |
76 | ## Pull Request (PR)
77 |
78 | Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
79 |
80 | Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli).
81 | A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch.
82 |
83 | Note: in rare cases, we can create exception for codacy failure.
84 |
85 | ## Test
86 |
87 | All new features must add tests in the `tests/` folder ensuring that everything works fine.
88 | We use [pytest](https://pytest.org/).
89 | Also, when a bug fix is proposed, tests should be added to avoid regression.
90 |
91 | To run tests with `pytest`:
92 |
93 | ```
94 | make pytest
95 | ```
96 |
97 | Type checking with `pytype`:
98 |
99 | ```
100 | make type
101 | ```
102 |
103 | Build the documentation:
104 |
105 | ```
106 | make doc
107 | ```
108 |
109 | Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that):
110 |
111 | ```
112 | make spelling
113 | ```
114 |
115 |
116 | ## Changelog and Documentation
117 |
118 | Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed.
119 | A README is present in the `docs/` folder for instructions on how to build the documentation.
120 |
121 |
122 | Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
123 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PARENT_IMAGE
2 | FROM $PARENT_IMAGE
3 | ARG USE_GPU
4 |
5 | RUN apt-get -y update \
6 | && apt-get -y install \
7 | curl \
8 | cmake \
9 | default-jre \
10 | git \
11 | jq \
12 | python-dev \
13 | python-pip \
14 | python3-dev \
15 | libfontconfig1 \
16 | libglib2.0-0 \
17 | libsm6 \
18 | libxext6 \
19 | libxrender1 \
20 | libopenmpi-dev \
21 | zlib1g-dev \
22 | && apt-get clean \
23 | && rm -rf /var/lib/apt/lists/*
24 |
25 | ENV CODE_DIR /root/code
26 | ENV VENV /root/venv
27 |
28 | COPY ./stable_baselines/version.txt ${CODE_DIR}/stable-baselines/stable_baselines/version.txt
29 | COPY ./setup.py ${CODE_DIR}/stable-baselines/setup.py
30 |
31 | RUN \
32 | pip install pip --upgrade && \
33 | pip install virtualenv && \
34 | virtualenv $VENV --python=python3 && \
35 | . $VENV/bin/activate && \
36 | pip install --upgrade pip && \
37 | cd ${CODE_DIR}/stable-baselines && \
38 | pip install -e .[mpi,tests,docs] && \
39 | rm -rf $HOME/.cache/pip
40 |
41 | ENV PATH=$VENV/bin:$PATH
42 |
43 | # Codacy code coverage report: used for partial code coverage reporting
44 | RUN cd $CODE_DIR && \
45 | curl -Ls -o codacy-coverage-reporter.jar "$(curl -Ls https://api.github.com/repos/codacy/codacy-coverage-reporter/releases/latest | jq -r '.assets | map({name, browser_download_url} | select(.name | (startswith("codacy-coverage-reporter") and contains("assembly") and endswith(".jar")))) | .[0].browser_download_url')"
46 |
47 | CMD /bin/bash
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2017 OpenAI (http://openai.com)
4 | Copyright (c) 2018-2019 Stable-Baselines Team
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Run pytest and coverage report
2 | pytest:
3 | ./scripts/run_tests.sh
4 |
5 | # Type check
6 | type:
7 | pytype -j auto
8 |
9 | # Build the doc
10 | doc:
11 | cd docs && make html
12 |
13 | # Check the spelling in the doc
14 | spelling:
15 | cd docs && make spelling
16 |
17 | # Clean the doc build folder
18 | clean:
19 | cd docs && make clean
20 |
21 | # Build docker images
22 | # If you do export RELEASE=True, it will also push them
23 | docker: docker-cpu docker-gpu
24 |
25 | docker-cpu:
26 | ./scripts/build_docker.sh
27 |
28 | docker-gpu:
29 | USE_GPU=True ./scripts/build_docker.sh
30 |
31 | # PyPi package release
32 | release:
33 | python setup.py sdist
34 | python setup.py bdist_wheel
35 | twine upload dist/*
36 |
37 | # Test PyPi package release
38 | test-release:
39 | python setup.py sdist
40 | python setup.py bdist_wheel
41 | twine upload --repository-url https://test.pypi.org/legacy/ dist/*
42 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | """Configures pytest to ignore certain unit tests unless the appropriate flag is passed.
2 |
3 | --rungpu: tests that require GPU.
4 | --expensive: tests that take a long time to run (e.g. training an RL algorithm for many timestesps)."""
5 |
6 | import pytest
7 |
8 |
9 | def pytest_addoption(parser):
10 | parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests")
11 | parser.addoption("--expensive", action="store_true",
12 | help="run expensive tests (which are otherwise skipped).")
13 |
14 |
15 | def pytest_collection_modifyitems(config, items):
16 | flags = {'gpu': '--rungpu', 'expensive': '--expensive'}
17 | skips = {keyword: pytest.mark.skip(reason="need {} option to run".format(flag))
18 | for keyword, flag in flags.items() if not config.getoption(flag)}
19 | for item in items:
20 | for keyword, skip in skips.items():
21 | if keyword in item.keywords:
22 | item.add_marker(skip)
23 |
--------------------------------------------------------------------------------
/data/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/data/logo.jpg
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W # make warnings fatal
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = StableBaselines
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Stable Baselines Documentation
2 |
3 | This folder contains documentation for the RL baselines.
4 |
5 |
6 | ### Build the Documentation
7 |
8 | #### Install Sphinx and Theme
9 |
10 | ```
11 | pip install sphinx sphinx-autobuild sphinx-rtd-theme
12 | ```
13 |
14 | #### Building the Docs
15 |
16 | In the `docs/` folder:
17 | ```
18 | make html
19 | ```
20 |
21 | if you want to building each time a file is changed:
22 |
23 | ```
24 | sphinx-autobuild . _build/html
25 | ```
26 |
--------------------------------------------------------------------------------
/docs/_static/css/baselines_theme.css:
--------------------------------------------------------------------------------
1 | /* Main colors from https://color.adobe.com/fr/Copy-of-NOUEBO-Original-color-theme-11116609 */
2 | :root{
3 | --main-bg-color: #324D5C;
4 | --link-color: #14B278;
5 | }
6 |
7 | /* Header fonts y */
8 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
9 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
10 | }
11 |
12 |
13 | /* Docs background */
14 | .wy-side-nav-search{
15 | background-color: var(--main-bg-color);
16 | }
17 |
18 | /* Mobile version */
19 | .wy-nav-top{
20 | background-color: var(--main-bg-color);
21 | }
22 |
23 | /* Change link colors (except for the menu) */
24 | a {
25 | color: var(--link-color);
26 | }
27 |
28 | a:hover {
29 | color: #4F778F;
30 | }
31 |
32 | .wy-menu a {
33 | color: #b3b3b3;
34 | }
35 |
36 | .wy-menu a:hover {
37 | color: #b3b3b3;
38 | }
39 |
40 | a.icon.icon-home {
41 | color: #b3b3b3;
42 | }
43 |
44 | .version{
45 | color: var(--link-color) !important;
46 | }
47 |
48 |
49 | /* Make code blocks have a background */
50 | .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
51 | background: #f8f8f8;;
52 | }
53 |
--------------------------------------------------------------------------------
/docs/_static/img/Tensorboard_example_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/Tensorboard_example_1.png
--------------------------------------------------------------------------------
/docs/_static/img/Tensorboard_example_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/Tensorboard_example_2.png
--------------------------------------------------------------------------------
/docs/_static/img/Tensorboard_example_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/Tensorboard_example_3.png
--------------------------------------------------------------------------------
/docs/_static/img/breakout.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/breakout.gif
--------------------------------------------------------------------------------
/docs/_static/img/colab.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/_static/img/learning_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/learning_curve.png
--------------------------------------------------------------------------------
/docs/_static/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/logo.png
--------------------------------------------------------------------------------
/docs/_static/img/mistake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/mistake.png
--------------------------------------------------------------------------------
/docs/_static/img/try_it.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/docs/_static/img/try_it.png
--------------------------------------------------------------------------------
/docs/common/cmd_utils.rst:
--------------------------------------------------------------------------------
1 | .. _cmd_utils:
2 |
3 | Command Utils
4 | =========================
5 |
6 | .. automodule:: stable_baselines.common.cmd_util
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/common/distributions.rst:
--------------------------------------------------------------------------------
1 | .. _distributions:
2 |
3 | Probability Distributions
4 | =========================
5 |
6 | Probability distributions used for the different action spaces:
7 |
8 | - ``CategoricalProbabilityDistribution`` -> Discrete
9 | - ``DiagGaussianProbabilityDistribution`` -> Box (continuous actions)
10 | - ``MultiCategoricalProbabilityDistribution`` -> MultiDiscrete
11 | - ``BernoulliProbabilityDistribution`` -> MultiBinary
12 |
13 | The policy networks output parameters for the distributions (named ``flat`` in the methods).
14 | Actions are then sampled from those distributions.
15 |
16 | For instance, in the case of discrete actions. The policy network outputs probability
17 | of taking each action. The ``CategoricalProbabilityDistribution`` allows to sample from it,
18 | computes the entropy, the negative log probability (``neglogp``) and backpropagate the gradient.
19 |
20 | In the case of continuous actions, a Gaussian distribution is used. The policy network outputs
21 | mean and (log) std of the distribution (assumed to be a ``DiagGaussianProbabilityDistribution``).
22 |
23 | .. automodule:: stable_baselines.common.distributions
24 | :members:
25 |
--------------------------------------------------------------------------------
/docs/common/env_checker.rst:
--------------------------------------------------------------------------------
1 | .. _env_checker:
2 |
3 | Gym Environment Checker
4 | ========================
5 |
6 | .. automodule:: stable_baselines.common.env_checker
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/common/evaluation.rst:
--------------------------------------------------------------------------------
1 | .. _eval:
2 |
3 | Evaluation Helper
4 | =================
5 |
6 | .. automodule:: stable_baselines.common.evaluation
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/common/monitor.rst:
--------------------------------------------------------------------------------
1 | .. _monitor:
2 |
3 | Monitor Wrapper
4 | ===============
5 |
6 | .. automodule:: stable_baselines.bench.monitor
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/common/schedules.rst:
--------------------------------------------------------------------------------
1 | .. _schedules:
2 |
3 | Schedules
4 | =========
5 |
6 | Schedules are used as hyperparameter for most of the algorithms,
7 | in order to change value of a parameter over time (usually the learning rate).
8 |
9 |
10 | .. automodule:: stable_baselines.common.schedules
11 | :members:
12 |
--------------------------------------------------------------------------------
/docs/common/tf_utils.rst:
--------------------------------------------------------------------------------
1 | .. _tf_utils:
2 |
3 | Tensorflow Utils
4 | =========================
5 |
6 | .. automodule:: stable_baselines.common.tf_util
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/guide/algos.rst:
--------------------------------------------------------------------------------
1 | RL Algorithms
2 | =============
3 |
4 | This table displays the rl algorithms that are implemented in the stable baselines project,
5 | along with some useful characteristics: support for recurrent policies, discrete/continuous actions, multiprocessing.
6 |
7 | .. Table too large
8 | .. ===== ======================== ========= ======= ============ ================= =============== ================
9 | .. Name Refactored \ :sup:`(1)`\ Recurrent ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
10 | .. ===== ======================== ========= ======= ============ ================= =============== ================
11 | .. A2C ✔️
12 | .. ===== ======================== ========= ======= ============ ================= =============== ================
13 |
14 |
15 | ============ ======================== ========= =========== ============ ================
16 | Name Refactored [#f1]_ Recurrent ``Box`` ``Discrete`` Multi Processing
17 | ============ ======================== ========= =========== ============ ================
18 | A2C ✔️ ✔️ ✔️ ✔️ ✔️
19 | ACER ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️
20 | ACKTR ✔️ ✔️ ✔️ ✔️ ✔️
21 | DDPG ✔️ ❌ ✔️ ❌ ✔️ [#f3]_
22 | DQN ✔️ ❌ ❌ ✔️ ❌
23 | HER ✔️ ❌ ✔️ ✔️ ❌
24 | GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️ ✔️ [#f3]_
25 | PPO1 ✔️ ❌ ✔️ ✔️ ✔️ [#f3]_
26 | PPO2 ✔️ ✔️ ✔️ ✔️ ✔️
27 | SAC ✔️ ❌ ✔️ ❌ ❌
28 | TD3 ✔️ ❌ ✔️ ❌ ❌
29 | TRPO ✔️ ❌ ✔️ ✔ ✔️ [#f3]_
30 | ============ ======================== ========= =========== ============ ================
31 |
32 | .. [#f1] Whether or not the algorithm has be refactored to fit the ``BaseRLModel`` class.
33 | .. [#f2] Only implemented for TRPO.
34 | .. [#f3] Multi Processing with `MPI`_.
35 | .. [#f4] TODO, in project scope.
36 |
37 | .. note::
38 | Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm,
39 | except HER for dict when working with ``gym.GoalEnv``
40 |
41 | Actions ``gym.spaces``:
42 |
43 | - ``Box``: A N-dimensional box that contains every point in the action
44 | space.
45 | - ``Discrete``: A list of possible actions, where each timestep only
46 | one of the actions can be used.
47 | - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
48 | - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
49 |
50 | .. _MPI: https://mpi4py.readthedocs.io/en/stable/
51 |
52 | .. note::
53 |
54 | Some logging values (like ``ep_rewmean``, ``eplenmean``) are only available when using a Monitor wrapper
55 | See `Issue #339 `_ for more info.
56 |
57 |
58 | Reproducibility
59 | ---------------
60 |
61 | Completely reproducible results are not guaranteed across Tensorflow releases or different platforms.
62 | Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds.
63 |
64 | In order to make computations deterministic on CPU, on your specific problem on one specific platform,
65 | you need to pass a ``seed`` argument at the creation of a model and set `n_cpu_tf_sess=1` (number of cpu for Tensorflow session).
66 | If you pass an environment to the model using `set_env()`, then you also need to seed the environment first.
67 |
68 | .. note::
69 |
70 | Because of the current limits of Tensorflow 1.x, we cannot ensure reproducible results on the GPU yet. This issue is solved in `Stable-Baselines3 "PyTorch edition" `_
71 |
72 |
73 | .. note::
74 |
75 | TD3 sometimes fail to have reproducible results for obscure reasons, even when following the previous steps (cf `PR #492 `_). If you find the reason then please open an issue ;)
76 |
77 |
78 | Credit: part of the *Reproducibility* section comes from `PyTorch Documentation `_
79 |
--------------------------------------------------------------------------------
/docs/guide/custom_env.rst:
--------------------------------------------------------------------------------
1 | .. _custom_env:
2 |
3 | Using Custom Environments
4 | ==========================
5 |
6 | To use the rl baselines with custom environments, they just need to follow the *gym* interface.
7 | That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class):
8 |
9 |
10 | .. note::
11 | If you are using images as input, the input values must be in [0, 255] as the observation
12 | is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies.
13 |
14 |
15 |
16 | .. code-block:: python
17 |
18 | import gym
19 | from gym import spaces
20 |
21 | class CustomEnv(gym.Env):
22 | """Custom Environment that follows gym interface"""
23 | metadata = {'render.modes': ['human']}
24 |
25 | def __init__(self, arg1, arg2, ...):
26 | super(CustomEnv, self).__init__()
27 | # Define action and observation space
28 | # They must be gym.spaces objects
29 | # Example when using discrete actions:
30 | self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
31 | # Example for using image as input:
32 | self.observation_space = spaces.Box(low=0, high=255,
33 | shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
34 |
35 | def step(self, action):
36 | ...
37 | return observation, reward, done, info
38 | def reset(self):
39 | ...
40 | return observation # reward, done, info can't be included
41 | def render(self, mode='human'):
42 | ...
43 | def close (self):
44 | ...
45 |
46 |
47 | Then you can define and train a RL agent with:
48 |
49 | .. code-block:: python
50 |
51 | # Instantiate the env
52 | env = CustomEnv(arg1, ...)
53 | # Define and Train the agent
54 | model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
55 |
56 |
57 | To check that your environment follows the gym interface, please use:
58 |
59 | .. code-block:: python
60 |
61 | from stable_baselines.common.env_checker import check_env
62 |
63 | env = CustomEnv(arg1, ...)
64 | # It will check your custom environment and output additional warnings if needed
65 | check_env(env)
66 |
67 |
68 |
69 | We have created a `colab notebook `_ for
70 | a concrete example of creating a custom environment.
71 |
72 | You can also find a `complete guide online `_
73 | on creating a custom Gym environment.
74 |
75 |
76 | Optionally, you can also register the environment with gym,
77 | that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env).
78 |
79 |
80 | In the project, for testing purposes, we use a custom environment named ``IdentityEnv``
81 | defined `in this file `_.
82 | An example of how to use it can be found `here `_.
83 |
--------------------------------------------------------------------------------
/docs/guide/quickstart.rst:
--------------------------------------------------------------------------------
1 | .. _quickstart:
2 |
3 | ===============
4 | Getting Started
5 | ===============
6 |
7 | Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
8 |
9 | Here is a quick example of how to train and run PPO2 on a cartpole environment:
10 |
11 | .. code-block:: python
12 |
13 | import gym
14 |
15 | from stable_baselines.common.policies import MlpPolicy
16 | from stable_baselines.common.vec_env import DummyVecEnv
17 | from stable_baselines import PPO2
18 |
19 | env = gym.make('CartPole-v1')
20 | # Optional: PPO2 requires a vectorized environment to run
21 | # the env is now wrapped automatically when passing it to the constructor
22 | # env = DummyVecEnv([lambda: env])
23 |
24 | model = PPO2(MlpPolicy, env, verbose=1)
25 | model.learn(total_timesteps=10000)
26 |
27 | obs = env.reset()
28 | for i in range(1000):
29 | action, _states = model.predict(obs)
30 | obs, rewards, dones, info = env.step(action)
31 | env.render()
32 |
33 |
34 | Or just train a model with a one liner if
35 | `the environment is registered in Gym `_ and if
36 | `the policy is registered `_:
37 |
38 | .. code-block:: python
39 |
40 | from stable_baselines import PPO2
41 |
42 | model = PPO2('MlpPolicy', 'CartPole-v1').learn(10000)
43 |
44 |
45 | .. figure:: https://cdn-images-1.medium.com/max/960/1*R_VMmdgKAY0EDhEjHVelzw.gif
46 |
47 | Define and train a RL agent in one line of code!
48 |
--------------------------------------------------------------------------------
/docs/guide/rl.rst:
--------------------------------------------------------------------------------
1 | .. _rl:
2 |
3 | ================================
4 | Reinforcement Learning Resources
5 | ================================
6 |
7 |
8 | Stable-Baselines assumes that you already understand the basic concepts of Reinforcement Learning (RL).
9 |
10 | However, if you want to learn about RL, there are several good resources to get started:
11 |
12 | - `OpenAI Spinning Up `_
13 | - `David Silver's course `_
14 | - `Lilian Weng's blog `_
15 | - `Berkeley's Deep RL Bootcamp `_
16 | - `Berkeley's Deep Reinforcement Learning course `_
17 | - `More resources `_
18 |
--------------------------------------------------------------------------------
/docs/guide/rl_zoo.rst:
--------------------------------------------------------------------------------
1 | .. _rl_zoo:
2 |
3 | =================
4 | RL Baselines Zoo
5 | =================
6 |
7 | `RL Baselines Zoo `_. is a collection of pre-trained Reinforcement Learning agents using
8 | Stable-Baselines.
9 | It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos.
10 |
11 | Goals of this repository:
12 |
13 | 1. Provide a simple interface to train and enjoy RL agents
14 | 2. Benchmark the different Reinforcement Learning algorithms
15 | 3. Provide tuned hyperparameters for each environment and RL algorithm
16 | 4. Have fun with the trained agents!
17 |
18 | Installation
19 | ------------
20 |
21 | 1. Install dependencies
22 | ::
23 |
24 | apt-get install swig cmake libopenmpi-dev zlib1g-dev ffmpeg
25 | pip install stable-baselines box2d box2d-kengz pyyaml pybullet optuna pytablewriter
26 |
27 | 2. Clone the repository:
28 |
29 | ::
30 |
31 | git clone https://github.com/araffin/rl-baselines-zoo
32 |
33 |
34 | Train an Agent
35 | --------------
36 |
37 | The hyperparameters for each environment are defined in
38 | ``hyperparameters/algo_name.yml``.
39 |
40 | If the environment exists in this file, then you can train an agent
41 | using:
42 |
43 | ::
44 |
45 | python train.py --algo algo_name --env env_id
46 |
47 | For example (with tensorboard support):
48 |
49 | ::
50 |
51 | python train.py --algo ppo2 --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/
52 |
53 | Train for multiple environments (with one call) and with tensorboard
54 | logging:
55 |
56 | ::
57 |
58 | python train.py --algo a2c --env MountainCar-v0 CartPole-v1 --tensorboard-log /tmp/stable-baselines/
59 |
60 | Continue training (here, load pretrained agent for Breakout and continue
61 | training for 5000 steps):
62 |
63 | ::
64 |
65 | python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4.pkl -n 5000
66 |
67 |
68 | Enjoy a Trained Agent
69 | ---------------------
70 |
71 | If the trained agent exists, then you can see it in action using:
72 |
73 | ::
74 |
75 | python enjoy.py --algo algo_name --env env_id
76 |
77 | For example, enjoy A2C on Breakout during 5000 timesteps:
78 |
79 | ::
80 |
81 | python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder trained_agents/ -n 5000
82 |
83 |
84 | Hyperparameter Optimization
85 | ---------------------------
86 |
87 | We use `Optuna `_ for optimizing the hyperparameters.
88 |
89 |
90 | Tune the hyperparameters for PPO2, using a random sampler and median pruner, 2 parallels jobs,
91 | with a budget of 1000 trials and a maximum of 50000 steps:
92 |
93 | ::
94 |
95 | python train.py --algo ppo2 --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
96 | --sampler random --pruner median
97 |
98 |
99 | Colab Notebook: Try it Online!
100 | ------------------------------
101 |
102 | You can train agents online using Google `colab notebook `_.
103 |
104 |
105 | .. note::
106 |
107 | You can find more information about the rl baselines zoo in the repo `README `_. For instance, how to record a video of a trained agent.
108 |
--------------------------------------------------------------------------------
/docs/guide/save_format.rst:
--------------------------------------------------------------------------------
1 | .. _save_format:
2 |
3 |
4 | On saving and loading
5 | =====================
6 |
7 | Stable baselines stores both neural network parameters and algorithm-related parameters such as
8 | exploration schedule, number of environments and observation/action space. This allows continual learning and easy
9 | use of trained agents without training, but it is not without its issues. Following describes two formats
10 | used to save agents in stable baselines, their pros and shortcomings.
11 |
12 | Terminology used in this page:
13 |
14 | - *parameters* refer to neural network parameters (also called "weights"). This is a dictionary
15 | mapping Tensorflow variable name to a NumPy array.
16 | - *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space.
17 | These depend on the algorithm used. This is a dictionary mapping classes variable names their values.
18 |
19 |
20 | Cloudpickle (stable-baselines<=2.7.0)
21 | -------------------------------------
22 |
23 | Original stable baselines save format. Data and parameters are bundled up into a tuple ``(data, parameters)``
24 | and then serialized with ``cloudpickle`` library (essentially the same as ``pickle``).
25 |
26 | This save format is still available via an argument in model save function in stable-baselines versions above
27 | v2.7.0 for backwards compatibility reasons, but its usage is discouraged.
28 |
29 | Pros:
30 |
31 | - Easy to implement and use.
32 | - Works with almost any type of Python object, including functions.
33 |
34 |
35 | Cons:
36 |
37 | - Pickle/Cloudpickle is not designed for long-term storage or sharing between Python version.
38 | - If one object in file is not readable (e.g. wrong library version), then reading the rest of the
39 | file is difficult.
40 | - Python-specific format, hard to read stored files from other languages.
41 |
42 |
43 | If part of a saved model becomes unreadable for any reason (e.g. different Tensorflow versions), then
44 | it may be tricky to restore any of the model. For this reason another save format was designed.
45 |
46 |
47 | Zip-archive (stable-baselines>2.7.0)
48 | -------------------------------------
49 |
50 | A zip-archived JSON dump and NumPy zip archive of the arrays. The data dictionary (class parameters)
51 | is stored as a JSON file, model parameters are serialized with ``numpy.savez`` function and these two files
52 | are stored under a single .zip archive.
53 |
54 | Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded
55 | string in the JSON file, along with some information that was stored in the serialization. This allows
56 | inspecting stored objects without deserializing the object itself.
57 |
58 | This format allows skipping elements in the file, i.e. we can skip deserializing objects that are
59 | broken/non-serializable. This can be done via ``custom_objects`` argument to load functions.
60 |
61 | This is the default save format in stable baselines versions after v2.7.0.
62 |
63 | File structure:
64 |
65 | ::
66 |
67 | saved_model.zip/
68 | ├── data JSON file of class-parameters (dictionary)
69 | ├── parameter_list JSON file of model parameters and their ordering (list)
70 | ├── parameters Bytes from numpy.savez (a zip file of the numpy arrays). ...
71 | ├── ... Being a zip-archive itself, this object can also be opened ...
72 | ├── ... as a zip-archive and browsed.
73 |
74 |
75 | Pros:
76 |
77 |
78 | - More robust to unserializable objects (one bad object does not break everything).
79 | - Saved file can be inspected/extracted with zip-archive explorers and by other
80 | languages.
81 |
82 |
83 | Cons:
84 |
85 | - More complex implementation.
86 | - Still relies partly on cloudpickle for complex objects (e.g. custom functions).
87 |
--------------------------------------------------------------------------------
/docs/guide/vec_envs.rst:
--------------------------------------------------------------------------------
1 | .. _vec_env:
2 |
3 | .. automodule:: stable_baselines.common.vec_env
4 |
5 | Vectorized Environments
6 | =======================
7 |
8 | Vectorized Environments are a method for stacking multiple independent environments into a single environment.
9 | Instead of training an RL agent on 1 environment per step, it allows us to train it on ``n`` environments per step.
10 | Because of this, ``actions`` passed to the environment are now a vector (of dimension ``n``).
11 | It is the same for ``observations``, ``rewards`` and end of episode signals (``dones``).
12 | In the case of non-array observation spaces such as ``Dict`` or ``Tuple``, where different sub-spaces
13 | may have different shapes, the sub-observations are vectors (of dimension ``n``).
14 |
15 | ============= ======= ============ ======== ========= ================
16 | Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing
17 | ============= ======= ============ ======== ========= ================
18 | DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️
19 | SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
20 | ============= ======= ============ ======== ========= ================
21 |
22 | .. note::
23 |
24 | Vectorized environments are required when using wrappers for frame-stacking or normalization.
25 |
26 | .. note::
27 |
28 | When using vectorized environments, the environments are automatically reset at the end of each episode.
29 | Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
30 | You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
31 |
32 | .. warning::
33 |
34 | When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
35 | On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
36 |
37 | For more information, see Python's `multiprocessing guidelines `_.
38 |
39 | VecEnv
40 | ------
41 |
42 | .. autoclass:: VecEnv
43 | :members:
44 |
45 | DummyVecEnv
46 | -----------
47 |
48 | .. autoclass:: DummyVecEnv
49 | :members:
50 |
51 | SubprocVecEnv
52 | -------------
53 |
54 | .. autoclass:: SubprocVecEnv
55 | :members:
56 |
57 | Wrappers
58 | --------
59 |
60 | VecFrameStack
61 | ~~~~~~~~~~~~~
62 |
63 | .. autoclass:: VecFrameStack
64 | :members:
65 |
66 |
67 | VecNormalize
68 | ~~~~~~~~~~~~
69 |
70 | .. autoclass:: VecNormalize
71 | :members:
72 |
73 |
74 | VecVideoRecorder
75 | ~~~~~~~~~~~~~~~~
76 |
77 | .. autoclass:: VecVideoRecorder
78 | :members:
79 |
80 |
81 | VecCheckNan
82 | ~~~~~~~~~~~~~~~~
83 |
84 | .. autoclass:: VecCheckNan
85 | :members:
86 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Stable Baselines documentation master file, created by
2 | sphinx-quickstart on Sat Aug 25 10:33:54 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Stable Baselines docs! - RL Baselines Made Easy
7 | ===========================================================
8 |
9 | `Stable Baselines `_ is a set of improved implementations
10 | of Reinforcement Learning (RL) algorithms based on OpenAI `Baselines `_.
11 |
12 |
13 | Github repository: https://github.com/hill-a/stable-baselines
14 |
15 | RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo
16 |
17 | RL Baselines zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
18 |
19 | You can read a detailed presentation of Stable Baselines in the
20 | Medium article: `link `_
21 |
22 |
23 | Main differences with OpenAI Baselines
24 | --------------------------------------
25 |
26 | This toolset is a fork of OpenAI Baselines, with a major structural refactoring, and code cleanups:
27 |
28 | - Unified structure for all algorithms
29 | - PEP8 compliant (unified code style)
30 | - Documented functions and classes
31 | - More tests & more code coverage
32 | - Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)
33 |
34 |
35 | .. toctree::
36 | :maxdepth: 2
37 | :caption: User Guide
38 |
39 | guide/install
40 | guide/quickstart
41 | guide/rl_tips
42 | guide/rl
43 | guide/algos
44 | guide/examples
45 | guide/vec_envs
46 | guide/custom_env
47 | guide/custom_policy
48 | guide/callbacks
49 | guide/tensorboard
50 | guide/rl_zoo
51 | guide/pretrain
52 | guide/checking_nan
53 | guide/save_format
54 | guide/export
55 |
56 |
57 | .. toctree::
58 | :maxdepth: 1
59 | :caption: RL Algorithms
60 |
61 | modules/base
62 | modules/policies
63 | modules/a2c
64 | modules/acer
65 | modules/acktr
66 | modules/ddpg
67 | modules/dqn
68 | modules/gail
69 | modules/her
70 | modules/ppo1
71 | modules/ppo2
72 | modules/sac
73 | modules/td3
74 | modules/trpo
75 |
76 | .. toctree::
77 | :maxdepth: 1
78 | :caption: Common
79 |
80 | common/distributions
81 | common/tf_utils
82 | common/cmd_utils
83 | common/schedules
84 | common/evaluation
85 | common/env_checker
86 | common/monitor
87 |
88 | .. toctree::
89 | :maxdepth: 1
90 | :caption: Misc
91 |
92 | misc/changelog
93 | misc/projects
94 | misc/results_plotter
95 |
96 |
97 | Citing Stable Baselines
98 | -----------------------
99 | To cite this project in publications:
100 |
101 | .. code-block:: bibtex
102 |
103 | @misc{stable-baselines,
104 | author = {Hill, Ashley and Raffin, Antonin and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Traore, Rene and Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai},
105 | title = {Stable Baselines},
106 | year = {2018},
107 | publisher = {GitHub},
108 | journal = {GitHub repository},
109 | howpublished = {\url{https://github.com/hill-a/stable-baselines}},
110 | }
111 |
112 | Contributing
113 | ------------
114 |
115 | To any interested in making the rl baselines better, there are still some improvements
116 | that need to be done.
117 | A full TODO list is available in the `roadmap `_.
118 |
119 | If you want to contribute, please read `CONTRIBUTING.md `_ first.
120 |
121 | Indices and tables
122 | -------------------
123 |
124 | * :ref:`genindex`
125 | * :ref:`search`
126 | * :ref:`modindex`
127 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=StableBaselines
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/misc/results_plotter.rst:
--------------------------------------------------------------------------------
1 | .. _results_plotter:
2 |
3 |
4 | Plotting Results
5 | ================
6 |
7 | .. automodule:: stable_baselines.results_plotter
8 | :members:
9 |
--------------------------------------------------------------------------------
/docs/modules/base.rst:
--------------------------------------------------------------------------------
1 | .. _base_algo:
2 |
3 | .. automodule:: stable_baselines.common.base_class
4 |
5 |
6 | Base RL Class
7 | =============
8 |
9 | Common interface for all the RL algorithms
10 |
11 | .. autoclass:: BaseRLModel
12 | :members:
13 |
--------------------------------------------------------------------------------
/docs/modules/gail.rst:
--------------------------------------------------------------------------------
1 | .. _gail:
2 |
3 | .. automodule:: stable_baselines.gail
4 |
5 |
6 | GAIL
7 | ====
8 |
9 | The `Generative Adversarial Imitation Learning (GAIL) `_ uses expert trajectories
10 | to recover a cost function and then learn a policy.
11 |
12 | Learning a cost function from expert demonstrations is called Inverse Reinforcement Learning (IRL).
13 | The connection between GAIL and Generative Adversarial Networks (GANs) is that it uses a discriminator that tries
14 | to separate expert trajectory from trajectories of the learned policy, which has the role of the generator here.
15 |
16 | .. note::
17 |
18 | GAIL requires :ref:`OpenMPI `. If OpenMPI isn't enabled, then GAIL isn't
19 | imported into the ``stable_baselines`` module.
20 |
21 |
22 | Notes
23 | -----
24 |
25 | - Original paper: https://arxiv.org/abs/1606.03476
26 |
27 | .. warning::
28 |
29 | Images are not yet handled properly by the current implementation
30 |
31 |
32 |
33 | If you want to train an imitation learning agent
34 | ------------------------------------------------
35 |
36 |
37 | Step 1: Generate expert data
38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39 |
40 | You can either train a RL algorithm in a classic setting, use another controller (e.g. a PID controller)
41 | or human demonstrations.
42 |
43 | We recommend you to take a look at :ref:`pre-training ` section
44 | or directly look at ``stable_baselines/gail/dataset/`` folder to learn more about the expected format for the dataset.
45 |
46 | Here is an example of training a Soft Actor-Critic model to generate expert trajectories for GAIL:
47 |
48 |
49 | .. code-block:: python
50 |
51 | from stable_baselines import SAC
52 | from stable_baselines.gail import generate_expert_traj
53 |
54 | # Generate expert trajectories (train expert)
55 | model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
56 | # Train for 60000 timesteps and record 10 trajectories
57 | # all the data will be saved in 'expert_pendulum.npz' file
58 | generate_expert_traj(model, 'expert_pendulum', n_timesteps=60000, n_episodes=10)
59 |
60 |
61 |
62 | Step 2: Run GAIL
63 | ~~~~~~~~~~~~~~~~
64 |
65 |
66 | **In case you want to run Behavior Cloning (BC)**
67 |
68 | Use the ``.pretrain()`` method (cf guide).
69 |
70 |
71 | **Others**
72 |
73 | Thanks to the open source:
74 |
75 | - @openai/imitation
76 | - @carpedm20/deep-rl-tensorflow
77 |
78 |
79 | Can I use?
80 | ----------
81 |
82 | - Recurrent policies: ❌
83 | - Multi processing: ✔️ (using MPI)
84 | - Gym spaces:
85 |
86 |
87 | ============= ====== ===========
88 | Space Action Observation
89 | ============= ====== ===========
90 | Discrete ✔️ ✔️
91 | Box ✔️ ✔️
92 | MultiDiscrete ❌ ✔️
93 | MultiBinary ❌ ✔️
94 | ============= ====== ===========
95 |
96 |
97 | Example
98 | -------
99 |
100 | .. code-block:: python
101 |
102 | import gym
103 |
104 | from stable_baselines import GAIL, SAC
105 | from stable_baselines.gail import ExpertDataset, generate_expert_traj
106 |
107 | # Generate expert trajectories (train expert)
108 | model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
109 | generate_expert_traj(model, 'expert_pendulum', n_timesteps=100, n_episodes=10)
110 |
111 | # Load the expert dataset
112 | dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)
113 |
114 | model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
115 | # Note: in practice, you need to train for 1M steps to have a working policy
116 | model.learn(total_timesteps=1000)
117 | model.save("gail_pendulum")
118 |
119 | del model # remove to demonstrate saving and loading
120 |
121 | model = GAIL.load("gail_pendulum")
122 |
123 | env = gym.make('Pendulum-v0')
124 | obs = env.reset()
125 | while True:
126 | action, _states = model.predict(obs)
127 | obs, rewards, dones, info = env.step(action)
128 | env.render()
129 |
130 |
131 | Parameters
132 | ----------
133 |
134 | .. autoclass:: GAIL
135 | :members:
136 | :inherited-members:
137 |
--------------------------------------------------------------------------------
/docs/modules/her.rst:
--------------------------------------------------------------------------------
1 | .. _her:
2 |
3 | .. automodule:: stable_baselines.her
4 |
5 |
6 | HER
7 | ====
8 |
9 | `Hindsight Experience Replay (HER) `_
10 |
11 | HER is a method wrapper that works with Off policy methods (DQN, SAC, TD3 and DDPG for example).
12 |
13 | .. note::
14 |
15 | HER was re-implemented from scratch in Stable-Baselines compared to the original OpenAI baselines.
16 | If you want to reproduce results from the paper, please use the rl baselines zoo
17 | in order to have the correct hyperparameters and at least 8 MPI workers with DDPG.
18 |
19 | .. warning::
20 |
21 | HER requires the environment to inherits from `gym.GoalEnv `_
22 |
23 |
24 | .. warning::
25 |
26 | you must pass an environment or wrap it with ``HERGoalEnvWrapper`` in order to use the predict method
27 |
28 |
29 | Notes
30 | -----
31 |
32 | - Original paper: https://arxiv.org/abs/1707.01495
33 | - OpenAI paper: `Plappert et al. (2018)`_
34 | - OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/
35 |
36 |
37 | .. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464
38 |
39 | Can I use?
40 | ----------
41 |
42 | Please refer to the wrapped model (DQN, SAC, TD3 or DDPG) for that section.
43 |
44 | Example
45 | -------
46 |
47 | .. code-block:: python
48 |
49 | from stable_baselines import HER, DQN, SAC, DDPG, TD3
50 | from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper
51 | from stable_baselines.common.bit_flipping_env import BitFlippingEnv
52 |
53 | model_class = DQN # works also with SAC, DDPG and TD3
54 |
55 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
56 |
57 | # Available strategies (cf paper): future, final, episode, random
58 | goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
59 |
60 | # Wrap the model
61 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy,
62 | verbose=1)
63 | # Train the model
64 | model.learn(1000)
65 |
66 | model.save("./her_bit_env")
67 |
68 | # WARNING: you must pass an env
69 | # or wrap your environment with HERGoalEnvWrapper to use the predict method
70 | model = HER.load('./her_bit_env', env=env)
71 |
72 | obs = env.reset()
73 | for _ in range(100):
74 | action, _ = model.predict(obs)
75 | obs, reward, done, _ = env.step(action)
76 |
77 | if done:
78 | obs = env.reset()
79 |
80 |
81 | Parameters
82 | ----------
83 |
84 | .. autoclass:: HER
85 | :members:
86 |
87 | Goal Selection Strategies
88 | -------------------------
89 |
90 | .. autoclass:: GoalSelectionStrategy
91 | :members:
92 | :inherited-members:
93 | :undoc-members:
94 |
95 |
96 | Goal Env Wrapper
97 | ----------------
98 |
99 | .. autoclass:: HERGoalEnvWrapper
100 | :members:
101 | :inherited-members:
102 | :undoc-members:
103 |
104 |
105 | Replay Wrapper
106 | --------------
107 |
108 | .. autoclass:: HindsightExperienceReplayWrapper
109 | :members:
110 | :inherited-members:
111 |
--------------------------------------------------------------------------------
/docs/modules/policies.rst:
--------------------------------------------------------------------------------
1 | .. _policies:
2 |
3 | .. automodule:: stable_baselines.common.policies
4 |
5 | Policy Networks
6 | ===============
7 |
8 | Stable-baselines provides a set of default policies, that can be used with most action spaces.
9 | To customize the default policies, you can specify the ``policy_kwargs`` parameter to the model class you use.
10 | Those kwargs are then passed to the policy on instantiation (see :ref:`custom_policy` for an example).
11 | If you need more control on the policy architecture, you can also create a custom policy (see :ref:`custom_policy`).
12 |
13 | .. note::
14 |
15 | CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints)
16 |
17 | .. warning::
18 | For all algorithms (except DDPG, TD3 and SAC), continuous actions are clipped during training and testing
19 | (to avoid out of bound error).
20 |
21 |
22 | .. rubric:: Available Policies
23 |
24 | .. autosummary::
25 | :nosignatures:
26 |
27 | MlpPolicy
28 | MlpLstmPolicy
29 | MlpLnLstmPolicy
30 | CnnPolicy
31 | CnnLstmPolicy
32 | CnnLnLstmPolicy
33 |
34 |
35 | Base Classes
36 | ------------
37 |
38 | .. autoclass:: BasePolicy
39 | :members:
40 |
41 | .. autoclass:: ActorCriticPolicy
42 | :members:
43 |
44 | .. autoclass:: FeedForwardPolicy
45 | :members:
46 |
47 | .. autoclass:: LstmPolicy
48 | :members:
49 |
50 | MLP Policies
51 | ------------
52 |
53 | .. autoclass:: MlpPolicy
54 | :members:
55 |
56 | .. autoclass:: MlpLstmPolicy
57 | :members:
58 |
59 | .. autoclass:: MlpLnLstmPolicy
60 | :members:
61 |
62 |
63 | CNN Policies
64 | ------------
65 |
66 | .. autoclass:: CnnPolicy
67 | :members:
68 |
69 | .. autoclass:: CnnLstmPolicy
70 | :members:
71 |
72 | .. autoclass:: CnnLnLstmPolicy
73 | :members:
74 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | gym
2 | pandas
3 |
--------------------------------------------------------------------------------
/docs/spelling_wordlist.txt:
--------------------------------------------------------------------------------
1 | py
2 | env
3 | atari
4 | argparse
5 | Argparse
6 | TensorFlow
7 | feedforward
8 | envs
9 | VecEnv
10 | pretrain
11 | petrained
12 | tf
13 | np
14 | mujoco
15 | cpu
16 | ndarray
17 | ndarrays
18 | timestep
19 | timesteps
20 | stepsize
21 | dataset
22 | adam
23 | fn
24 | normalisation
25 | Kullback
26 | Leibler
27 | boolean
28 | deserialized
29 | pretrained
30 | minibatch
31 | subprocesses
32 | ArgumentParser
33 | Tensorflow
34 | Gaussian
35 | approximator
36 | minibatches
37 | hyperparameters
38 | hyperparameter
39 | vectorized
40 | rl
41 | colab
42 | dataloader
43 | npz
44 | datasets
45 | vf
46 | logits
47 | num
48 | Utils
49 | backpropagate
50 | prepend
51 | NaN
52 | preprocessing
53 | Cloudpickle
54 | async
55 | multiprocess
56 | tensorflow
57 | mlp
58 | cnn
59 | neglogp
60 | tanh
61 | coef
62 | repo
63 | Huber
64 | params
65 | ppo
66 | arxiv
67 | Arxiv
68 | func
69 | DQN
70 | Uhlenbeck
71 | Ornstein
72 | multithread
73 | cancelled
74 | Tensorboard
75 | parallelize
76 | customising
77 | serializable
78 | Multiprocessed
79 | cartpole
80 | toolset
81 | lstm
82 | rescale
83 | ffmpeg
84 | avconv
85 | unnormalized
86 | Github
87 | pre
88 | preprocess
89 | backend
90 | attr
91 | preprocess
92 | Antonin
93 | Raffin
94 | araffin
95 | Homebrew
96 | Numpy
97 | Theano
98 | rollout
99 | kfac
100 | Piecewise
101 | csv
102 | nvidia
103 | visdom
104 | tensorboard
105 | preprocessed
106 | namespace
107 | sklearn
108 | GoalEnv
109 | BaseCallback
110 | Keras
111 |
--------------------------------------------------------------------------------
/scripts/build_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CPU_PARENT=ubuntu:16.04
4 | GPU_PARENT=nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04
5 |
6 | TAG=stablebaselines/stable-baselines
7 | VERSION=$(cat ./stable_baselines/version.txt)
8 |
9 | if [[ ${USE_GPU} == "True" ]]; then
10 | PARENT=${GPU_PARENT}
11 | else
12 | PARENT=${CPU_PARENT}
13 | TAG="${TAG}-cpu"
14 | fi
15 |
16 | docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg USE_GPU=${USE_GPU} -t ${TAG}:${VERSION} .
17 | docker tag ${TAG}:${VERSION} ${TAG}:latest
18 |
19 | if [[ ${RELEASE} == "True" ]]; then
20 | docker push ${TAG}:${VERSION}
21 | docker push ${TAG}:latest
22 | fi
23 |
--------------------------------------------------------------------------------
/scripts/run_docker_cpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Launch an experiment using the docker cpu image
3 |
4 | cmd_line="$@"
5 |
6 | echo "Executing in the docker (cpu image):"
7 | echo $cmd_line
8 |
9 | docker run -it --rm --network host --ipc=host \
10 | --mount src=$(pwd),target=/root/code/stable-baselines,type=bind stablebaselines/stable-baselines-cpu:v2.10.0 \
11 | bash -c "cd /root/code/stable-baselines/ && $cmd_line"
12 |
--------------------------------------------------------------------------------
/scripts/run_docker_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Launch an experiment using the docker gpu image
3 |
4 | cmd_line="$@"
5 |
6 | echo "Executing in the docker (gpu image):"
7 | echo $cmd_line
8 |
9 | # TODO: always use new-style once sufficiently widely used (probably 2021 onwards)
10 | if [ -x "$(which nvidia-docker)" ]; then
11 | # old-style nvidia-docker2
12 | NVIDIA_ARG="--runtime=nvidia"
13 | else
14 | NVIDIA_ARG="--gpus all"
15 | fi
16 |
17 | docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \
18 | --mount src=$(pwd),target=/root/code/stable-baselines,type=bind stablebaselines/stable-baselines:v2.10.0 \
19 | bash -c "cd /root/code/stable-baselines/ && $cmd_line"
20 |
--------------------------------------------------------------------------------
/scripts/run_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v
3 |
--------------------------------------------------------------------------------
/scripts/run_tests_travis.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DOCKER_CMD="docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines,type=bind"
4 | BASH_CMD="cd /root/code/stable-baselines/"
5 |
6 | if [[ $# -ne 1 ]]; then
7 | echo "usage: $0 "
8 | exit 1
9 | fi
10 |
11 | if [[ ${DOCKER_IMAGE} = "" ]]; then
12 | echo "Need DOCKER_IMAGE environment variable to be set."
13 | exit 1
14 | fi
15 |
16 | TEST_GLOB=$1
17 |
18 | set -e # exit immediately on any error
19 |
20 | # For pull requests from fork, Codacy token is not available, leading to build failure
21 | if [[ ${CODACY_PROJECT_TOKEN} = "" ]]; then
22 | echo "WARNING: CODACY_PROJECT_TOKEN not set. Skipping Codacy upload."
23 | echo "(This is normal when building in a fork and can be ignored.)"
24 | ${DOCKER_CMD} ${DOCKER_IMAGE} \
25 | bash -c "${BASH_CMD} && \
26 | pytest --cov-config .coveragerc --cov-report term --cov=. -v tests/test_${TEST_GLOB}"
27 | else
28 | ${DOCKER_CMD} --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} \
29 | bash -c "${BASH_CMD} && \
30 | pytest --cov-config .coveragerc --cov-report term --cov-report xml --cov=. -v tests/test_${TEST_GLOB} && \
31 | java -jar /root/code/codacy-coverage-reporter.jar report -l python -r coverage.xml --partial"
32 | fi
33 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | # This includes the license file in the wheel.
3 | license_file = LICENSE
4 |
5 | [tool:pytest]
6 | # Deterministic ordering for tests; useful for pytest-xdist.
7 | env =
8 | PYTHONHASHSEED=0
9 | filterwarnings =
10 | ignore:inspect.getargspec:DeprecationWarning:tensorflow
11 | ignore::pytest.PytestUnknownMarkWarning
12 | # Tensorflow internal warnings
13 | ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
14 | ignore:The binary mode of fromstring is deprecated:DeprecationWarning
15 | ignore::FutureWarning:tensorflow
16 | # Gym warnings
17 | ignore:Parameters to load are deprecated.:DeprecationWarning
18 | ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
19 |
20 | [pytype]
21 | inputs = stable_baselines
22 | ; python_version = 3.5
23 |
--------------------------------------------------------------------------------
/stable_baselines/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | from stable_baselines.a2c import A2C
5 | from stable_baselines.acer import ACER
6 | from stable_baselines.acktr import ACKTR
7 | from stable_baselines.deepq import DQN
8 | from stable_baselines.her import HER
9 | from stable_baselines.ppo2 import PPO2
10 | from stable_baselines.td3 import TD3
11 | from stable_baselines.sac import SAC
12 |
13 | # Load mpi4py-dependent algorithms only if mpi is installed.
14 | try:
15 | import mpi4py
16 | except ImportError:
17 | mpi4py = None
18 |
19 | if mpi4py is not None:
20 | from stable_baselines.ddpg import DDPG
21 | from stable_baselines.gail import GAIL
22 | from stable_baselines.ppo1 import PPO1
23 | from stable_baselines.trpo_mpi import TRPO
24 | del mpi4py
25 |
26 | # Read version from file
27 | version_file = os.path.join(os.path.dirname(__file__), "version.txt")
28 | with open(version_file, "r") as file_handler:
29 | __version__ = file_handler.read().strip()
30 |
31 |
32 | warnings.warn(
33 | "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."
34 | )
35 |
--------------------------------------------------------------------------------
/stable_baselines/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.a2c.a2c import A2C
2 |
--------------------------------------------------------------------------------
/stable_baselines/a2c/run_atari.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from stable_baselines import logger, A2C
4 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser
5 | from stable_baselines.common.vec_env import VecFrameStack
6 | from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
7 |
8 |
9 | def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env):
10 | """
11 | Train A2C model for atari environment, for testing purposes
12 |
13 | :param env_id: (str) Environment ID
14 | :param num_timesteps: (int) The total number of samples
15 | :param seed: (int) The initial seed for training
16 | :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
17 | :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
18 | 'double_linear_con', 'middle_drop' or 'double_middle_drop')
19 | :param num_env: (int) The number of environments
20 | """
21 | policy_fn = None
22 | if policy == 'cnn':
23 | policy_fn = CnnPolicy
24 | elif policy == 'lstm':
25 | policy_fn = CnnLstmPolicy
26 | elif policy == 'lnlstm':
27 | policy_fn = CnnLnLstmPolicy
28 | if policy_fn is None:
29 | raise ValueError("Error: policy {} not implemented".format(policy))
30 |
31 | env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
32 |
33 | model = A2C(policy_fn, env, lr_schedule=lr_schedule, seed=seed)
34 | model.learn(total_timesteps=int(num_timesteps * 1.1))
35 | env.close()
36 |
37 |
38 | def main():
39 | """
40 | Runs the test
41 | """
42 | parser = atari_arg_parser()
43 | parser.add_argument('--policy', choices=['cnn', 'lstm', 'lnlstm'], default='cnn', help='Policy architecture')
44 | parser.add_argument('--lr_schedule', choices=['constant', 'linear'], default='constant',
45 | help='Learning rate schedule')
46 | args = parser.parse_args()
47 | logger.configure()
48 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, policy=args.policy, lr_schedule=args.lr_schedule,
49 | num_env=16)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/stable_baselines/acer/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.acer.acer_simple import ACER
2 |
--------------------------------------------------------------------------------
/stable_baselines/acer/run_atari.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import warnings
3 |
4 | from stable_baselines import logger, ACER
5 | from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy
6 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser
7 | from stable_baselines.common.vec_env import VecFrameStack
8 |
9 |
10 | def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu):
11 | """
12 | train an ACER model on atari
13 |
14 | :param env_id: (str) Environment ID
15 | :param num_timesteps: (int) The total number of samples
16 | :param seed: (int) The initial seed for training
17 | :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
18 | :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
19 | 'double_linear_con', 'middle_drop' or 'double_middle_drop')
20 | :param num_cpu: (int) The number of cpu to train on
21 | """
22 | env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4)
23 | if policy == 'cnn':
24 | policy_fn = CnnPolicy
25 | elif policy == 'lstm':
26 | policy_fn = CnnLstmPolicy
27 | else:
28 | warnings.warn("Policy {} not implemented".format(policy))
29 | return
30 |
31 | model = ACER(policy_fn, env, lr_schedule=lr_schedule, buffer_size=5000, seed=seed)
32 | model.learn(total_timesteps=int(num_timesteps * 1.1))
33 | env.close()
34 | # Free memory
35 | del model
36 |
37 |
38 | def main():
39 | """
40 | Runs the test
41 | """
42 | parser = atari_arg_parser()
43 | parser.add_argument('--policy', choices=['cnn', 'lstm', 'lnlstm'], default='cnn', help='Policy architecture')
44 | parser.add_argument('--lr_schedule', choices=['constant', 'linear'], default='constant',
45 | help='Learning rate schedule')
46 | parser.add_argument('--logdir', help='Directory for logging')
47 | args = parser.parse_args()
48 | logger.configure(args.logdir)
49 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
50 | policy=args.policy, lr_schedule=args.lr_schedule, num_cpu=16)
51 |
52 |
53 | if __name__ == '__main__':
54 | main()
55 |
--------------------------------------------------------------------------------
/stable_baselines/acktr/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.acktr.acktr import ACKTR
2 |
--------------------------------------------------------------------------------
/stable_baselines/acktr/run_atari.py:
--------------------------------------------------------------------------------
1 | from stable_baselines import logger, ACKTR
2 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser
3 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
4 | from stable_baselines.common.policies import CnnPolicy
5 |
6 |
7 | def train(env_id, num_timesteps, seed, num_cpu):
8 | """
9 | train an ACKTR model on atari
10 |
11 | :param env_id: (str) Environment ID
12 | :param num_timesteps: (int) The total number of samples
13 | :param seed: (int) The initial seed for training
14 | :param num_cpu: (int) The number of cpu to train on
15 | """
16 | env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4)
17 | model = ACKTR(CnnPolicy, env, nprocs=num_cpu, seed=seed)
18 | model.learn(total_timesteps=int(num_timesteps * 1.1))
19 | env.close()
20 |
21 |
22 | def main():
23 | """
24 | Runs the test
25 | """
26 | args = atari_arg_parser().parse_args()
27 | logger.configure()
28 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32)
29 |
30 |
31 | if __name__ == '__main__':
32 | main()
33 |
--------------------------------------------------------------------------------
/stable_baselines/bench/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.bench.monitor import Monitor, load_results
2 |
--------------------------------------------------------------------------------
/stable_baselines/common/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa F403
2 | from stable_baselines.common.console_util import fmt_row, fmt_item, colorize
3 | from stable_baselines.common.dataset import Dataset
4 | from stable_baselines.common.math_util import discount, discount_with_boundaries, explained_variance, \
5 | explained_variance_2d, flatten_arrays, unflatten_vector
6 | from stable_baselines.common.misc_util import zipsame, set_global_seeds, boolean_flag
7 | from stable_baselines.common.base_class import BaseRLModel, ActorCriticRLModel, OffPolicyRLModel, SetVerbosity, \
8 | TensorboardWriter
9 | from stable_baselines.common.cmd_util import make_vec_env
10 |
--------------------------------------------------------------------------------
/stable_baselines/common/cg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def conjugate_gradient(f_ax, b_vec, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
5 | """
6 | conjugate gradient calculation (Ax = b), bases on
7 | https://epubs.siam.org/doi/book/10.1137/1.9781611971446 Demmel p 312
8 |
9 | :param f_ax: (function) The function describing the Matrix A dot the vector x
10 | (x being the input parameter of the function)
11 | :param b_vec: (numpy float) vector b, where Ax = b
12 | :param cg_iters: (int) the maximum number of iterations for converging
13 | :param callback: (function) callback the values of x while converging
14 | :param verbose: (bool) print extra information
15 | :param residual_tol: (float) the break point if the residual is below this value
16 | :return: (numpy float) vector x, where Ax = b
17 | """
18 | first_basis_vect = b_vec.copy() # the first basis vector
19 | residual = b_vec.copy() # the residual
20 | x_var = np.zeros_like(b_vec) # vector x, where Ax = b
21 | residual_dot_residual = residual.dot(residual) # L2 norm of the residual
22 |
23 | fmt_str = "%10i %10.3g %10.3g"
24 | title_str = "%10s %10s %10s"
25 | if verbose:
26 | print(title_str % ("iter", "residual norm", "soln norm"))
27 |
28 | for i in range(cg_iters):
29 | if callback is not None:
30 | callback(x_var)
31 | if verbose:
32 | print(fmt_str % (i, residual_dot_residual, np.linalg.norm(x_var)))
33 | z_var = f_ax(first_basis_vect)
34 | v_var = residual_dot_residual / first_basis_vect.dot(z_var)
35 | x_var += v_var * first_basis_vect
36 | residual -= v_var * z_var
37 | new_residual_dot_residual = residual.dot(residual)
38 | mu_val = new_residual_dot_residual / residual_dot_residual
39 | first_basis_vect = residual + mu_val * first_basis_vect
40 |
41 | residual_dot_residual = new_residual_dot_residual
42 | if residual_dot_residual < residual_tol:
43 | break
44 |
45 | if callback is not None:
46 | callback(x_var)
47 | if verbose:
48 | print(fmt_str % (i + 1, residual_dot_residual, np.linalg.norm(x_var)))
49 | return x_var
50 |
--------------------------------------------------------------------------------
/stable_baselines/common/console_util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 |
5 |
6 | # ================================================================
7 | # Misc
8 | # ================================================================
9 |
10 |
11 | def fmt_row(width, row, header=False):
12 | """
13 | fits a list of items to at least a certain length
14 |
15 | :param width: (int) the minimum width of the string
16 | :param row: ([Any]) a list of object you wish to get the string representation
17 | :param header: (bool) whether or not to return the string as a header
18 | :return: (str) the string representation of all the elements in 'row', of length >= 'width'
19 | """
20 | out = " | ".join(fmt_item(x, width) for x in row)
21 | if header:
22 | out = out + "\n" + "-" * len(out)
23 | return out
24 |
25 |
26 | def fmt_item(item, min_width):
27 | """
28 | fits items to a given string length
29 |
30 | :param item: (Any) the item you wish to get the string representation
31 | :param min_width: (int) the minimum width of the string
32 | :return: (str) the string representation of 'x' of length >= 'l'
33 | """
34 | if isinstance(item, np.ndarray):
35 | assert item.ndim == 0
36 | item = item.item()
37 | if isinstance(item, (float, np.float32, np.float64)):
38 | value = abs(item)
39 | if (value < 1e-4 or value > 1e+4) and value > 0:
40 | rep = "%7.2e" % item
41 | else:
42 | rep = "%7.5f" % item
43 | else:
44 | rep = str(item)
45 | return " " * (min_width - len(rep)) + rep
46 |
47 |
48 | COLOR_TO_NUM = dict(
49 | gray=30,
50 | red=31,
51 | green=32,
52 | yellow=33,
53 | blue=34,
54 | magenta=35,
55 | cyan=36,
56 | white=37,
57 | crimson=38
58 | )
59 |
60 |
61 | def colorize(string, color, bold=False, highlight=False):
62 | """
63 | Colorize, bold and/or highlight a string for terminal print
64 |
65 | :param string: (str) input string
66 | :param color: (str) the color, the lookup table is the dict at console_util.color2num
67 | :param bold: (bool) if the string should be bold or not
68 | :param highlight: (bool) if the string should be highlighted or not
69 | :return: (str) the stylized output string
70 | """
71 | attr = []
72 | num = COLOR_TO_NUM[color]
73 | if highlight:
74 | num += 10
75 | attr.append(str(num))
76 | if bold:
77 | attr.append('1')
78 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
79 |
--------------------------------------------------------------------------------
/stable_baselines/common/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Dataset(object):
5 | def __init__(self, data_map, shuffle=True):
6 | """
7 | Data loader that handles batches and shuffling.
8 | WARNING: this will alter the given data_map ordering, as dicts are mutable
9 |
10 | :param data_map: (dict) the input data, where every column is a key
11 | :param shuffle: (bool) Whether to shuffle or not the dataset
12 | Important: this should be disabled for recurrent policies
13 | """
14 | self.data_map = data_map
15 | self.shuffle = shuffle
16 | self.n_samples = next(iter(data_map.values())).shape[0]
17 | self._next_id = 0
18 | if self.shuffle:
19 | self.shuffle_dataset()
20 |
21 | def shuffle_dataset(self):
22 | """
23 | Shuffles the data_map
24 | """
25 | perm = np.arange(self.n_samples)
26 | np.random.shuffle(perm)
27 |
28 | for key in self.data_map:
29 | self.data_map[key] = self.data_map[key][perm]
30 |
31 | def next_batch(self, batch_size):
32 | """
33 | returns a batch of data of a given size
34 |
35 | :param batch_size: (int) the size of the batch
36 | :return: (dict) a batch of the input data of size 'batch_size'
37 | """
38 | if self._next_id >= self.n_samples:
39 | self._next_id = 0
40 | if self.shuffle:
41 | self.shuffle_dataset()
42 |
43 | cur_id = self._next_id
44 | cur_batch_size = min(batch_size, self.n_samples - self._next_id)
45 | self._next_id += cur_batch_size
46 |
47 | data_map = dict()
48 | for key in self.data_map:
49 | data_map[key] = self.data_map[key][cur_id:cur_id + cur_batch_size]
50 | return data_map
51 |
52 | def iterate_once(self, batch_size):
53 | """
54 | generator that iterates over the dataset
55 |
56 | :param batch_size: (int) the size of the batch
57 | :return: (dict) a batch of the input data of size 'batch_size'
58 | """
59 | if self.shuffle:
60 | self.shuffle_dataset()
61 |
62 | while self._next_id <= self.n_samples - batch_size:
63 | yield self.next_batch(batch_size)
64 | self._next_id = 0
65 |
66 | def subset(self, num_elements, shuffle=True):
67 | """
68 | Return a subset of the current dataset
69 |
70 | :param num_elements: (int) the number of element you wish to have in the subset
71 | :param shuffle: (bool) Whether to shuffle or not the dataset
72 | :return: (Dataset) a new subset of the current Dataset object
73 | """
74 | data_map = dict()
75 | for key in self.data_map:
76 | data_map[key] = self.data_map[key][:num_elements]
77 | return Dataset(data_map, shuffle)
78 |
79 |
80 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
81 | """
82 | Iterates over arrays in batches, must provide either num_batches or batch_size, the other must be None.
83 |
84 | :param arrays: (tuple) a tuple of arrays
85 | :param num_batches: (int) the number of batches, must be None is batch_size is defined
86 | :param batch_size: (int) the size of the batch, must be None is num_batches is defined
87 | :param shuffle: (bool) enable auto shuffle
88 | :param include_final_partial_batch: (bool) add the last batch if not the same size as the batch_size
89 | :return: (tuples) a tuple of a batch of the arrays
90 | """
91 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
92 | arrays = tuple(map(np.asarray, arrays))
93 | n_samples = arrays[0].shape[0]
94 | assert all(a.shape[0] == n_samples for a in arrays[1:])
95 | inds = np.arange(n_samples)
96 | if shuffle:
97 | np.random.shuffle(inds)
98 | sections = np.arange(0, n_samples, batch_size)[1:] if num_batches is None else num_batches
99 | for batch_inds in np.array_split(inds, sections):
100 | if include_final_partial_batch or len(batch_inds) == batch_size:
101 | yield tuple(a[batch_inds] for a in arrays)
102 |
--------------------------------------------------------------------------------
/stable_baselines/common/evaluation.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import Callable, List, Optional, Tuple, Union
3 |
4 | import gym
5 | import numpy as np
6 |
7 | from stable_baselines.common.vec_env import VecEnv
8 |
9 | if typing.TYPE_CHECKING:
10 | from stable_baselines.common.base_class import BaseRLModel
11 |
12 |
13 | def evaluate_policy(
14 | model: "BaseRLModel",
15 | env: Union[gym.Env, VecEnv],
16 | n_eval_episodes: int = 10,
17 | deterministic: bool = True,
18 | render: bool = False,
19 | callback: Optional[Callable] = None,
20 | reward_threshold: Optional[float] = None,
21 | return_episode_rewards: bool = False,
22 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
23 | """
24 | Runs policy for ``n_eval_episodes`` episodes and returns average reward.
25 | This is made to work only with one env.
26 |
27 | :param model: (BaseRLModel) The RL agent you want to evaluate.
28 | :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv``
29 | this must contain only one environment.
30 | :param n_eval_episodes: (int) Number of episode to evaluate the agent
31 | :param deterministic: (bool) Whether to use deterministic or stochastic actions
32 | :param render: (bool) Whether to render the environment or not
33 | :param callback: (callable) callback function to do additional checks,
34 | called after each step.
35 | :param reward_threshold: (float) Minimum expected reward per episode,
36 | this will raise an error if the performance is not met
37 | :param return_episode_rewards: (Optional[float]) If True, a list of reward per episode
38 | will be returned instead of the mean.
39 | :return: (float, float) Mean reward per episode, std of reward per episode
40 | returns ([float], [int]) when ``return_episode_rewards`` is True
41 | """
42 | if isinstance(env, VecEnv):
43 | assert env.num_envs == 1, "You must pass only one environment when using this function"
44 |
45 | is_recurrent = model.policy.recurrent
46 |
47 | episode_rewards, episode_lengths = [], []
48 | for i in range(n_eval_episodes):
49 | # Avoid double reset, as VecEnv are reset automatically
50 | if not isinstance(env, VecEnv) or i == 0:
51 | obs = env.reset()
52 | # Because recurrent policies need the same observation space during training and evaluation, we need to pad
53 | # observation to match training shape. See https://github.com/hill-a/stable-baselines/issues/1015
54 | if is_recurrent:
55 | zero_completed_obs = np.zeros((model.n_envs,) + model.observation_space.shape)
56 | zero_completed_obs[0, :] = obs
57 | obs = zero_completed_obs
58 | done, state = False, None
59 | episode_reward = 0.0
60 | episode_length = 0
61 | while not done:
62 | action, state = model.predict(obs, state=state, deterministic=deterministic)
63 | new_obs, reward, done, _info = env.step(action)
64 | if is_recurrent:
65 | obs[0, :] = new_obs
66 | else:
67 | obs = new_obs
68 | episode_reward += reward
69 | if callback is not None:
70 | callback(locals(), globals())
71 | episode_length += 1
72 | if render:
73 | env.render()
74 | episode_rewards.append(episode_reward)
75 | episode_lengths.append(episode_length)
76 | mean_reward = np.mean(episode_rewards)
77 | std_reward = np.std(episode_rewards)
78 | if reward_threshold is not None:
79 | assert mean_reward > reward_threshold, "Mean reward below threshold: {:.2f} < {:.2f}".format(mean_reward, reward_threshold)
80 | if return_episode_rewards:
81 | return episode_rewards, episode_lengths
82 | return mean_reward, std_reward
83 |
--------------------------------------------------------------------------------
/stable_baselines/common/identity_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Optional
3 |
4 | from gym import Env, Space
5 | from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box
6 |
7 |
8 | class IdentityEnv(Env):
9 | def __init__(self,
10 | dim: Optional[int] = None,
11 | space: Optional[Space] = None,
12 | ep_length: int = 100):
13 | """
14 | Identity environment for testing purposes
15 |
16 | :param dim: the size of the action and observation dimension you want
17 | to learn. Provide at most one of `dim` and `space`. If both are
18 | None, then initialization proceeds with `dim=1` and `space=None`.
19 | :param space: the action and observation space. Provide at most one of
20 | `dim` and `space`.
21 | :param ep_length: the length of each episode in timesteps
22 | """
23 | if space is None:
24 | if dim is None:
25 | dim = 1
26 | space = Discrete(dim)
27 | else:
28 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
29 |
30 | self.action_space = self.observation_space = space
31 | self.ep_length = ep_length
32 | self.current_step = 0
33 | self.num_resets = -1 # Becomes 0 after __init__ exits.
34 | self.reset()
35 |
36 | def reset(self):
37 | self.current_step = 0
38 | self.num_resets += 1
39 | self._choose_next_state()
40 | return self.state
41 |
42 | def step(self, action):
43 | reward = self._get_reward(action)
44 | self._choose_next_state()
45 | self.current_step += 1
46 | done = self.current_step >= self.ep_length
47 | return self.state, reward, done, {}
48 |
49 | def _choose_next_state(self):
50 | self.state = self.action_space.sample()
51 |
52 | def _get_reward(self, action):
53 | return 1 if np.all(self.state == action) else 0
54 |
55 | def render(self, mode='human'):
56 | pass
57 |
58 |
59 | class IdentityEnvBox(IdentityEnv):
60 | def __init__(self, low=-1, high=1, eps=0.05, ep_length=100):
61 | """
62 | Identity environment for testing purposes
63 |
64 | :param low: (float) the lower bound of the box dim
65 | :param high: (float) the upper bound of the box dim
66 | :param eps: (float) the epsilon bound for correct value
67 | :param ep_length: (int) the length of each episode in timesteps
68 | """
69 | space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
70 | super().__init__(ep_length=ep_length, space=space)
71 | self.eps = eps
72 |
73 | def step(self, action):
74 | reward = self._get_reward(action)
75 | self._choose_next_state()
76 | self.current_step += 1
77 | done = self.current_step >= self.ep_length
78 | return self.state, reward, done, {}
79 |
80 | def _get_reward(self, action):
81 | return 1 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0
82 |
83 |
84 | class IdentityEnvMultiDiscrete(IdentityEnv):
85 | def __init__(self, dim=1, ep_length=100):
86 | """
87 | Identity environment for testing purposes
88 |
89 | :param dim: (int) the size of the dimensions you want to learn
90 | :param ep_length: (int) the length of each episode in timesteps
91 | """
92 | space = MultiDiscrete([dim, dim])
93 | super().__init__(ep_length=ep_length, space=space)
94 |
95 |
96 | class IdentityEnvMultiBinary(IdentityEnv):
97 | def __init__(self, dim=1, ep_length=100):
98 | """
99 | Identity environment for testing purposes
100 |
101 | :param dim: (int) the size of the dimensions you want to learn
102 | :param ep_length: (int) the length of each episode in timesteps
103 | """
104 | space = MultiBinary(dim)
105 | super().__init__(ep_length=ep_length, space=space)
106 |
--------------------------------------------------------------------------------
/stable_baselines/common/input.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete
4 |
5 |
6 | def observation_input(ob_space, batch_size=None, name='Ob', scale=False):
7 | """
8 | Build observation input with encoding depending on the observation space type
9 |
10 | When using Box ob_space, the input will be normalized between [1, 0] on the bounds ob_space.low and ob_space.high.
11 |
12 | :param ob_space: (Gym Space) The observation space
13 | :param batch_size: (int) batch size for input
14 | (default is None, so that resulting input placeholder can take tensors with any batch size)
15 | :param name: (str) tensorflow variable name for input placeholder
16 | :param scale: (bool) whether or not to scale the input
17 | :return: (TensorFlow Tensor, TensorFlow Tensor) input_placeholder, processed_input_tensor
18 | """
19 | if isinstance(ob_space, Discrete):
20 | observation_ph = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name)
21 | processed_observations = tf.cast(tf.one_hot(observation_ph, ob_space.n), tf.float32)
22 | return observation_ph, processed_observations
23 |
24 | elif isinstance(ob_space, Box):
25 | observation_ph = tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name)
26 | processed_observations = tf.cast(observation_ph, tf.float32)
27 | # rescale to [1, 0] if the bounds are defined
28 | if (scale and
29 | not np.any(np.isinf(ob_space.low)) and not np.any(np.isinf(ob_space.high)) and
30 | np.any((ob_space.high - ob_space.low) != 0)):
31 |
32 | # equivalent to processed_observations / 255.0 when bounds are set to [255, 0]
33 | processed_observations = ((processed_observations - ob_space.low) / (ob_space.high - ob_space.low))
34 | return observation_ph, processed_observations
35 |
36 | elif isinstance(ob_space, MultiBinary):
37 | observation_ph = tf.placeholder(shape=(batch_size, ob_space.n), dtype=tf.int32, name=name)
38 | processed_observations = tf.cast(observation_ph, tf.float32)
39 | return observation_ph, processed_observations
40 |
41 | elif isinstance(ob_space, MultiDiscrete):
42 | observation_ph = tf.placeholder(shape=(batch_size, len(ob_space.nvec)), dtype=tf.int32, name=name)
43 | processed_observations = tf.concat([
44 | tf.cast(tf.one_hot(input_split, ob_space.nvec[i]), tf.float32) for i, input_split
45 | in enumerate(tf.split(observation_ph, len(ob_space.nvec), axis=-1))
46 | ], axis=-1)
47 | return observation_ph, processed_observations
48 |
49 | else:
50 | raise NotImplementedError("Error: the model does not support input space of type {}".format(
51 | type(ob_space).__name__))
52 |
--------------------------------------------------------------------------------
/stable_baselines/common/misc_util.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import gym
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 |
8 | def zipsame(*seqs):
9 | """
10 | Performes a zip function, but asserts that all zipped elements are of the same size
11 |
12 | :param seqs: a list of arrays that are zipped together
13 | :return: the zipped arguments
14 | """
15 | length = len(seqs[0])
16 | assert all(len(seq) == length for seq in seqs[1:])
17 | return zip(*seqs)
18 |
19 |
20 | def set_global_seeds(seed):
21 | """
22 | set the seed for python random, tensorflow, numpy and gym spaces
23 |
24 | :param seed: (int) the seed
25 | """
26 | tf.set_random_seed(seed)
27 | np.random.seed(seed)
28 | random.seed(seed)
29 | # prng was removed in latest gym version
30 | if hasattr(gym.spaces, 'prng'):
31 | gym.spaces.prng.seed(seed)
32 |
33 |
34 | def boolean_flag(parser, name, default=False, help_msg=None):
35 | """
36 | Add a boolean flag to argparse parser.
37 |
38 | :param parser: (argparse.Parser) parser to add the flag to
39 | :param name: (str) -- will enable the flag, while --no- will disable it
40 | :param default: (bool) default value of the flag
41 | :param help_msg: (str) help string for the flag
42 | """
43 | dest = name.replace('-', '_')
44 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help_msg)
45 | parser.add_argument("--no-" + name, action="store_false", dest=dest)
46 |
47 |
48 | def mpi_rank_or_zero():
49 | """
50 | Return the MPI rank if mpi is installed. Otherwise, return 0.
51 | :return: (int)
52 | """
53 | try:
54 | import mpi4py
55 | return mpi4py.MPI.COMM_WORLD.Get_rank()
56 | except ImportError:
57 | return 0
58 |
59 |
60 | def flatten_lists(listoflists):
61 | """
62 | Flatten a python list of list
63 |
64 | :param listoflists: (list(list))
65 | :return: (list)
66 | """
67 | return [el for list_ in listoflists for el in list_]
68 |
--------------------------------------------------------------------------------
/stable_baselines/common/mpi_moments.py:
--------------------------------------------------------------------------------
1 | from mpi4py import MPI
2 | import numpy as np
3 |
4 | from stable_baselines.common.misc_util import zipsame
5 |
6 |
7 | def mpi_mean(arr, axis=0, comm=None, keepdims=False):
8 | """
9 | calculates the mean of an array, using MPI
10 |
11 | :param arr: (np.ndarray)
12 | :param axis: (int or tuple or list) the axis to run the means over
13 | :param comm: (MPI Communicators) if None, MPI.COMM_WORLD
14 | :param keepdims: (bool) keep the other dimensions intact
15 | :return: (np.ndarray or Number) the result of the sum
16 | """
17 | arr = np.asarray(arr)
18 | assert arr.ndim > 0
19 | if comm is None:
20 | comm = MPI.COMM_WORLD
21 | xsum = arr.sum(axis=axis, keepdims=keepdims)
22 | size = xsum.size
23 | localsum = np.zeros(size + 1, arr.dtype)
24 | localsum[:size] = xsum.ravel()
25 | localsum[size] = arr.shape[axis]
26 | globalsum = np.zeros_like(localsum)
27 | comm.Allreduce(localsum, globalsum, op=MPI.SUM)
28 | return globalsum[:size].reshape(xsum.shape) / globalsum[size], globalsum[size]
29 |
30 |
31 | def mpi_moments(arr, axis=0, comm=None, keepdims=False):
32 | """
33 | calculates the mean and std of an array, using MPI
34 |
35 | :param arr: (np.ndarray)
36 | :param axis: (int or tuple or list) the axis to run the moments over
37 | :param comm: (MPI Communicators) if None, MPI.COMM_WORLD
38 | :param keepdims: (bool) keep the other dimensions intact
39 | :return: (np.ndarray or Number) the result of the moments
40 | """
41 | arr = np.asarray(arr)
42 | assert arr.ndim > 0
43 | mean, count = mpi_mean(arr, axis=axis, comm=comm, keepdims=True)
44 | sqdiffs = np.square(arr - mean)
45 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True)
46 | assert count1 == count
47 | std = np.sqrt(meansqdiff)
48 | if not keepdims:
49 | newshape = mean.shape[:axis] + mean.shape[axis + 1:]
50 | mean = mean.reshape(newshape)
51 | std = std.reshape(newshape)
52 | return mean, std, count
53 |
54 |
55 | def _helper_runningmeanstd():
56 | comm = MPI.COMM_WORLD
57 | np.random.seed(0)
58 | for (triple, axis) in [
59 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)), 0),
60 | ((np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), 0),
61 | ((np.random.randn(2, 3), np.random.randn(2, 4), np.random.randn(2, 4)), 1)]:
62 |
63 | arr = np.concatenate(triple, axis=axis)
64 | ms1 = [arr.mean(axis=axis), arr.std(axis=axis), arr.shape[axis]]
65 |
66 | ms2 = mpi_moments(triple[comm.Get_rank()], axis=axis)
67 |
68 | for (res_1, res_2) in zipsame(ms1, ms2):
69 | print(res_1, res_2)
70 | assert np.allclose(res_1, res_2)
71 | print("ok!")
72 |
--------------------------------------------------------------------------------
/stable_baselines/common/mpi_running_mean_std.py:
--------------------------------------------------------------------------------
1 | import mpi4py
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | import stable_baselines.common.tf_util as tf_util
6 |
7 |
8 | class RunningMeanStd(object):
9 | def __init__(self, epsilon=1e-2, shape=()):
10 | """
11 | calulates the running mean and std of a data stream
12 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
13 |
14 | :param epsilon: (float) helps with arithmetic issues
15 | :param shape: (tuple) the shape of the data stream's output
16 | """
17 | self._sum = tf.get_variable(
18 | dtype=tf.float64,
19 | shape=shape,
20 | initializer=tf.constant_initializer(0.0),
21 | name="runningsum", trainable=False)
22 | self._sumsq = tf.get_variable(
23 | dtype=tf.float64,
24 | shape=shape,
25 | initializer=tf.constant_initializer(epsilon),
26 | name="runningsumsq", trainable=False)
27 | self._count = tf.get_variable(
28 | dtype=tf.float64,
29 | shape=(),
30 | initializer=tf.constant_initializer(epsilon),
31 | name="count", trainable=False)
32 | self.shape = shape
33 |
34 | self.mean = tf.cast(self._sum / self._count, tf.float32)
35 | self.std = tf.sqrt(tf.maximum(tf.cast(self._sumsq / self._count, tf.float32) - tf.square(self.mean),
36 | 1e-2))
37 |
38 | newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
39 | newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
40 | newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
41 | self.incfiltparams = tf_util.function([newsum, newsumsq, newcount], [],
42 | updates=[tf.assign_add(self._sum, newsum),
43 | tf.assign_add(self._sumsq, newsumsq),
44 | tf.assign_add(self._count, newcount)])
45 |
46 | def update(self, data):
47 | """
48 | update the running mean and std
49 |
50 | :param data: (np.ndarray) the data
51 | """
52 | data = data.astype('float64')
53 | data_size = int(np.prod(self.shape))
54 | totalvec = np.zeros(data_size * 2 + 1, 'float64')
55 | addvec = np.concatenate([data.sum(axis=0).ravel(), np.square(data).sum(axis=0).ravel(),
56 | np.array([len(data)], dtype='float64')])
57 | mpi4py.MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=mpi4py.MPI.SUM)
58 | self.incfiltparams(totalvec[0: data_size].reshape(self.shape),
59 | totalvec[data_size: 2 * data_size].reshape(self.shape), totalvec[2 * data_size])
60 |
61 |
62 | @tf_util.in_session
63 | def test_dist():
64 | """
65 | test the running mean std
66 | """
67 | np.random.seed(0)
68 | p_1, p_2, p_3 = (np.random.randn(3, 1), np.random.randn(4, 1), np.random.randn(5, 1))
69 | q_1, q_2, q_3 = (np.random.randn(6, 1), np.random.randn(7, 1), np.random.randn(8, 1))
70 |
71 | comm = mpi4py.MPI.COMM_WORLD
72 | assert comm.Get_size() == 2
73 | if comm.Get_rank() == 0:
74 | x_1, x_2, x_3 = p_1, p_2, p_3
75 | elif comm.Get_rank() == 1:
76 | x_1, x_2, x_3 = q_1, q_2, q_3
77 | else:
78 | assert False
79 |
80 | rms = RunningMeanStd(epsilon=0.0, shape=(1,))
81 | tf_util.initialize()
82 |
83 | rms.update(x_1)
84 | rms.update(x_2)
85 | rms.update(x_3)
86 |
87 | bigvec = np.concatenate([p_1, p_2, p_3, q_1, q_2, q_3])
88 |
89 | def checkallclose(var_1, var_2):
90 | print(var_1, var_2)
91 | return np.allclose(var_1, var_2)
92 |
93 | assert checkallclose(
94 | bigvec.mean(axis=0),
95 | rms.mean.eval(),
96 | )
97 | assert checkallclose(
98 | bigvec.std(axis=0),
99 | rms.std.eval(),
100 | )
101 |
102 |
103 | if __name__ == "__main__":
104 | # Run with mpirun -np 2 python
105 | test_dist()
106 |
--------------------------------------------------------------------------------
/stable_baselines/common/noise.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 |
5 |
6 | class AdaptiveParamNoiseSpec(object):
7 | """
8 | Implements adaptive parameter noise
9 |
10 | :param initial_stddev: (float) the initial value for the standard deviation of the noise
11 | :param desired_action_stddev: (float) the desired value for the standard deviation of the noise
12 | :param adoption_coefficient: (float) the update coefficient for the standard deviation of the noise
13 | """
14 |
15 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01):
16 | self.initial_stddev = initial_stddev
17 | self.desired_action_stddev = desired_action_stddev
18 | self.adoption_coefficient = adoption_coefficient
19 |
20 | self.current_stddev = initial_stddev
21 |
22 | def adapt(self, distance):
23 | """
24 | update the standard deviation for the parameter noise
25 |
26 | :param distance: (float) the noise distance applied to the parameters
27 | """
28 | if distance > self.desired_action_stddev:
29 | # Decrease stddev.
30 | self.current_stddev /= self.adoption_coefficient
31 | else:
32 | # Increase stddev.
33 | self.current_stddev *= self.adoption_coefficient
34 |
35 | def get_stats(self):
36 | """
37 | return the standard deviation for the parameter noise
38 |
39 | :return: (dict) the stats of the noise
40 | """
41 | return {'param_noise_stddev': self.current_stddev}
42 |
43 | def __repr__(self):
44 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})'
45 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient)
46 |
47 |
48 | class ActionNoise(ABC):
49 | """
50 | The action noise base class
51 | """
52 |
53 | def __init__(self):
54 | super(ActionNoise, self).__init__()
55 |
56 | def reset(self) -> None:
57 | """
58 | call end of episode reset for the noise
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def __call__(self) -> np.ndarray:
64 | raise NotImplementedError()
65 |
66 |
67 | class NormalActionNoise(ActionNoise):
68 | """
69 | A Gaussian action noise
70 |
71 | :param mean: (float) the mean value of the noise
72 | :param sigma: (float) the scale of the noise (std here)
73 | """
74 |
75 | def __init__(self, mean, sigma):
76 | super().__init__()
77 | self._mu = mean
78 | self._sigma = sigma
79 |
80 | def __call__(self) -> np.ndarray:
81 | return np.random.normal(self._mu, self._sigma)
82 |
83 | def __repr__(self) -> str:
84 | return 'NormalActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma)
85 |
86 |
87 | class OrnsteinUhlenbeckActionNoise(ActionNoise):
88 | """
89 | A Ornstein Uhlenbeck action noise, this is designed to approximate brownian motion with friction.
90 |
91 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
92 |
93 | :param mean: (float) the mean of the noise
94 | :param sigma: (float) the scale of the noise
95 | :param theta: (float) the rate of mean reversion
96 | :param dt: (float) the timestep for the noise
97 | :param initial_noise: ([float]) the initial value for the noise output, (if None: 0)
98 | """
99 |
100 | def __init__(self, mean, sigma, theta=.15, dt=1e-2, initial_noise=None):
101 | super().__init__()
102 | self._theta = theta
103 | self._mu = mean
104 | self._sigma = sigma
105 | self._dt = dt
106 | self.initial_noise = initial_noise
107 | self.noise_prev = None
108 | self.reset()
109 |
110 | def __call__(self) -> np.ndarray:
111 | noise = self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + \
112 | self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
113 | self.noise_prev = noise
114 | return noise
115 |
116 | def reset(self) -> None:
117 | """
118 | reset the Ornstein Uhlenbeck noise, to the initial position
119 | """
120 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
121 |
122 | def __repr__(self) -> str:
123 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self._mu, self._sigma)
124 |
--------------------------------------------------------------------------------
/stable_baselines/common/running_mean_std.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class RunningMeanStd(object):
5 | def __init__(self, epsilon=1e-4, shape=()):
6 | """
7 | calulates the running mean and std of a data stream
8 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
9 |
10 | :param epsilon: (float) helps with arithmetic issues
11 | :param shape: (tuple) the shape of the data stream's output
12 | """
13 | self.mean = np.zeros(shape, 'float64')
14 | self.var = np.ones(shape, 'float64')
15 | self.count = epsilon
16 |
17 | def update(self, arr):
18 | batch_mean = np.mean(arr, axis=0)
19 | batch_var = np.var(arr, axis=0)
20 | batch_count = arr.shape[0]
21 | self.update_from_moments(batch_mean, batch_var, batch_count)
22 |
23 | def update_from_moments(self, batch_mean, batch_var, batch_count):
24 | delta = batch_mean - self.mean
25 | tot_count = self.count + batch_count
26 |
27 | new_mean = self.mean + delta * batch_count / tot_count
28 | m_a = self.var * self.count
29 | m_b = batch_var * batch_count
30 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
31 | new_var = m_2 / (self.count + batch_count)
32 |
33 | new_count = batch_count + self.count
34 |
35 | self.mean = new_mean
36 | self.var = new_var
37 | self.count = new_count
38 |
--------------------------------------------------------------------------------
/stable_baselines/common/tile_images.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def tile_images(img_nhwc):
5 | """
6 | Tile N images into one big PxQ image
7 | (P,Q) are chosen to be as close as possible, and if N
8 | is square, then P=Q.
9 |
10 | :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc
11 | n = batch index, h = height, w = width, c = channel
12 | :return: (numpy float) img_HWc, ndim=3
13 | """
14 | img_nhwc = np.asarray(img_nhwc)
15 | n_images, height, width, n_channels = img_nhwc.shape
16 | # new_height was named H before
17 | new_height = int(np.ceil(np.sqrt(n_images)))
18 | # new_width was named W before
19 | new_width = int(np.ceil(float(n_images) / new_height))
20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
21 | # img_HWhwc
22 | out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels)
23 | # img_HhWwc
24 | out_image = out_image.transpose(0, 2, 1, 3, 4)
25 | # img_Hh_Ww_c
26 | out_image = out_image.reshape(new_height * height, new_width * width, n_channels)
27 | return out_image
28 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from copy import deepcopy
3 |
4 | import gym
5 |
6 | # flake8: noqa F401
7 | from stable_baselines.common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \
8 | CloudpickleWrapper
9 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
10 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
11 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
12 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize
13 | from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
14 | from stable_baselines.common.vec_env.vec_check_nan import VecCheckNan
15 |
16 |
17 | def unwrap_vec_normalize(env: Union[gym.Env, VecEnv]) -> Union[VecNormalize, None]:
18 | """
19 | :param env: (Union[gym.Env, VecEnv])
20 | :return: (VecNormalize)
21 | """
22 | env_tmp = env
23 | while isinstance(env_tmp, VecEnvWrapper):
24 | if isinstance(env_tmp, VecNormalize):
25 | return env_tmp
26 | env_tmp = env_tmp.venv
27 | return None
28 |
29 |
30 | # Define here to avoid circular import
31 | def sync_envs_normalization(env: Union[gym.Env, VecEnv], eval_env: Union[gym.Env, VecEnv]) -> None:
32 | """
33 | Sync eval and train environments when using VecNormalize
34 |
35 | :param env: (Union[gym.Env, VecEnv]))
36 | :param eval_env: (Union[gym.Env, VecEnv]))
37 | """
38 | env_tmp, eval_env_tmp = env, eval_env
39 | # Special case for the _UnvecWrapper
40 | # Avoid circular import
41 | from stable_baselines.common.base_class import _UnvecWrapper
42 | if isinstance(env_tmp, _UnvecWrapper):
43 | return
44 | while isinstance(env_tmp, VecEnvWrapper):
45 | if isinstance(env_tmp, VecNormalize):
46 | # sync reward and observation scaling
47 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
48 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
49 | env_tmp = env_tmp.venv
50 | # Make pytype happy, in theory env and eval_env have the same type
51 | assert isinstance(eval_env_tmp, VecEnvWrapper), "the second env differs from the first env"
52 | eval_env_tmp = eval_env_tmp.venv
53 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for dealing with vectorized environments.
3 | """
4 |
5 | from collections import OrderedDict
6 |
7 | import gym
8 | import numpy as np
9 |
10 |
11 | def copy_obs_dict(obs):
12 | """
13 | Deep-copy a dict of numpy arrays.
14 |
15 | :param obs: (OrderedDict): a dict of numpy arrays.
16 | :return (OrderedDict) a dict of copied numpy arrays.
17 | """
18 | assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs))
19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
20 |
21 |
22 | def dict_to_obs(space, obs_dict):
23 | """
24 | Convert an internal representation raw_obs into the appropriate type
25 | specified by space.
26 |
27 | :param space: (gym.spaces.Space) an observation space.
28 | :param obs_dict: (OrderedDict) a dict of numpy arrays.
29 | :return (ndarray, tuple or dict): returns an observation
30 | of the same type as space. If space is Dict, function is identity;
31 | if space is Tuple, converts dict to Tuple; otherwise, space is
32 | unstructured and returns the value raw_obs[None].
33 | """
34 | if isinstance(space, gym.spaces.Dict):
35 | return obs_dict
36 | elif isinstance(space, gym.spaces.Tuple):
37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space"
38 | return tuple((obs_dict[i] for i in range(len(space.spaces))))
39 | else:
40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
41 | return obs_dict[None]
42 |
43 |
44 | def obs_space_info(obs_space):
45 | """
46 | Get dict-structured information about a gym.Space.
47 |
48 | Dict spaces are represented directly by their dict of subspaces.
49 | Tuple spaces are converted into a dict with keys indexing into the tuple.
50 | Unstructured spaces are represented by {None: obs_space}.
51 |
52 | :param obs_space: (gym.spaces.Space) an observation space
53 | :return (tuple) A tuple (keys, shapes, dtypes):
54 | keys: a list of dict keys.
55 | shapes: a dict mapping keys to shapes.
56 | dtypes: a dict mapping keys to dtypes.
57 | """
58 | if isinstance(obs_space, gym.spaces.Dict):
59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
60 | subspaces = obs_space.spaces
61 | elif isinstance(obs_space, gym.spaces.Tuple):
62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
63 | else:
64 | assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space))
65 | subspaces = {None: obs_space}
66 | keys = []
67 | shapes = {}
68 | dtypes = {}
69 | for key, box in subspaces.items():
70 | keys.append(key)
71 | shapes[key] = box.shape
72 | dtypes[key] = box.dtype
73 | return keys, shapes, dtypes
74 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 |
5 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
6 |
7 |
8 | class VecCheckNan(VecEnvWrapper):
9 | """
10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
11 | allowing you to know from what the NaN of inf originated from.
12 |
13 | :param venv: (VecEnv) the vectorized environment to wrap
14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning
15 | :param warn_once: (bool) Whether or not to only warn once.
16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well
17 | """
18 |
19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True):
20 | VecEnvWrapper.__init__(self, venv)
21 | self.raise_exception = raise_exception
22 | self.warn_once = warn_once
23 | self.check_inf = check_inf
24 | self._actions = None
25 | self._observations = None
26 | self._user_warned = False
27 |
28 | def step_async(self, actions):
29 | self._check_val(async_step=True, actions=actions)
30 |
31 | self._actions = actions
32 | self.venv.step_async(actions)
33 |
34 | def step_wait(self):
35 | observations, rewards, news, infos = self.venv.step_wait()
36 |
37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
38 |
39 | self._observations = observations
40 | return observations, rewards, news, infos
41 |
42 | def reset(self):
43 | observations = self.venv.reset()
44 | self._actions = None
45 |
46 | self._check_val(async_step=False, observations=observations)
47 |
48 | self._observations = observations
49 | return observations
50 |
51 | def _check_val(self, *, async_step, **kwargs):
52 | # if warn and warn once and have warned once: then stop checking
53 | if not self.raise_exception and self.warn_once and self._user_warned:
54 | return
55 |
56 | found = []
57 | for name, val in kwargs.items():
58 | has_nan = np.any(np.isnan(val))
59 | has_inf = self.check_inf and np.any(np.isinf(val))
60 | if has_inf:
61 | found.append((name, "inf"))
62 | if has_nan:
63 | found.append((name, "nan"))
64 |
65 | if found:
66 | self._user_warned = True
67 | msg = ""
68 | for i, (name, type_val) in enumerate(found):
69 | msg += "found {} in {}".format(type_val, name)
70 | if i != len(found) - 1:
71 | msg += ", "
72 |
73 | msg += ".\r\nOriginated from the "
74 |
75 | if not async_step:
76 | if self._actions is None:
77 | msg += "environment observation (at reset)"
78 | else:
79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions)
80 | else:
81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations)
82 |
83 | if self.raise_exception:
84 | raise ValueError(msg)
85 | else:
86 | warnings.warn(msg, UserWarning)
87 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
7 |
8 |
9 | class VecFrameStack(VecEnvWrapper):
10 | """
11 | Frame stacking wrapper for vectorized environment
12 |
13 | :param venv: (VecEnv) the vectorized environment to wrap
14 | :param n_stack: (int) Number of frames to stack
15 | """
16 |
17 | def __init__(self, venv, n_stack):
18 | self.venv = venv
19 | self.n_stack = n_stack
20 | wrapped_obs_space = venv.observation_space
21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
26 |
27 | def step_wait(self):
28 | observations, rewards, dones, infos = self.venv.step_wait()
29 | last_ax_size = observations.shape[-1]
30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
31 | for i, done in enumerate(dones):
32 | if done:
33 | if 'terminal_observation' in infos[i]:
34 | old_terminal = infos[i]['terminal_observation']
35 | new_terminal = np.concatenate(
36 | (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
37 | infos[i]['terminal_observation'] = new_terminal
38 | else:
39 | warnings.warn(
40 | "VecFrameStack wrapping a VecEnv without terminal_observation info")
41 | self.stackedobs[i] = 0
42 | self.stackedobs[..., -observations.shape[-1]:] = observations
43 | return self.stackedobs, rewards, dones, infos
44 |
45 | def reset(self):
46 | """
47 | Reset all environments
48 | """
49 | obs = self.venv.reset()
50 | self.stackedobs[...] = 0
51 | self.stackedobs[..., -obs.shape[-1]:] = obs
52 | return self.stackedobs
53 |
54 | def close(self):
55 | self.venv.close()
56 |
--------------------------------------------------------------------------------
/stable_baselines/common/vec_env/vec_video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym.wrappers.monitoring import video_recorder
4 |
5 | from stable_baselines import logger
6 | from stable_baselines.common.vec_env.base_vec_env import VecEnvWrapper
7 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
9 | from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
10 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize
11 |
12 |
13 | class VecVideoRecorder(VecEnvWrapper):
14 | """
15 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
16 | It requires ffmpeg or avconv to be installed on the machine.
17 |
18 | :param venv: (VecEnv or VecEnvWrapper)
19 | :param video_folder: (str) Where to save videos
20 | :param record_video_trigger: (func) Function that defines when to start recording.
21 | The function takes the current number of step,
22 | and returns whether we should start recording or not.
23 | :param video_length: (int) Length of recorded videos
24 | :param name_prefix: (str) Prefix to the video name
25 | """
26 |
27 | def __init__(self, venv, video_folder, record_video_trigger,
28 | video_length=200, name_prefix='rl-video'):
29 |
30 | VecEnvWrapper.__init__(self, venv)
31 |
32 | self.env = venv
33 | # Temp variable to retrieve metadata
34 | temp_env = venv
35 |
36 | # Unwrap to retrieve metadata dict
37 | # that will be used by gym recorder
38 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack):
39 | temp_env = temp_env.venv
40 |
41 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
42 | metadata = temp_env.get_attr('metadata')[0]
43 | else:
44 | metadata = temp_env.metadata
45 |
46 | self.env.metadata = metadata
47 |
48 | self.record_video_trigger = record_video_trigger
49 | self.video_recorder = None
50 |
51 | self.video_folder = os.path.abspath(video_folder)
52 | # Create output folder if needed
53 | os.makedirs(self.video_folder, exist_ok=True)
54 |
55 | self.name_prefix = name_prefix
56 | self.step_id = 0
57 | self.video_length = video_length
58 |
59 | self.recording = False
60 | self.recorded_frames = 0
61 |
62 | def reset(self):
63 | obs = self.venv.reset()
64 | self.start_video_recorder()
65 | return obs
66 |
67 | def start_video_recorder(self):
68 | self.close_video_recorder()
69 |
70 | video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id,
71 | self.step_id + self.video_length)
72 | base_path = os.path.join(self.video_folder, video_name)
73 | self.video_recorder = video_recorder.VideoRecorder(
74 | env=self.env,
75 | base_path=base_path,
76 | metadata={'step_id': self.step_id}
77 | )
78 |
79 | self.video_recorder.capture_frame()
80 | self.recorded_frames = 1
81 | self.recording = True
82 |
83 | def _video_enabled(self):
84 | return self.record_video_trigger(self.step_id)
85 |
86 | def step_wait(self):
87 | obs, rews, dones, infos = self.venv.step_wait()
88 |
89 | self.step_id += 1
90 | if self.recording:
91 | self.video_recorder.capture_frame()
92 | self.recorded_frames += 1
93 | if self.recorded_frames > self.video_length:
94 | logger.info("Saving video to ", self.video_recorder.path)
95 | self.close_video_recorder()
96 | elif self._video_enabled():
97 | self.start_video_recorder()
98 |
99 | return obs, rews, dones, infos
100 |
101 | def close_video_recorder(self):
102 | if self.recording:
103 | self.video_recorder.close()
104 | self.recording = False
105 | self.recorded_frames = 1
106 |
107 | def close(self):
108 | VecEnvWrapper.close(self)
109 | self.close_video_recorder()
110 |
111 | def __del__(self):
112 | self.close()
113 |
--------------------------------------------------------------------------------
/stable_baselines/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.common.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
2 | from stable_baselines.ddpg.ddpg import DDPG
3 | from stable_baselines.ddpg.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy
4 |
--------------------------------------------------------------------------------
/stable_baselines/ddpg/noise.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.common.noise import NormalActionNoise, AdaptiveParamNoiseSpec, OrnsteinUhlenbeckActionNoise # pylint: disable=unused-import
2 |
--------------------------------------------------------------------------------
/stable_baselines/deepq/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.deepq.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy
2 | from stable_baselines.deepq.build_graph import build_act, build_train # noqa
3 | from stable_baselines.deepq.dqn import DQN
4 | from stable_baselines.common.buffers import ReplayBuffer, PrioritizedReplayBuffer # noqa
5 |
6 |
7 | def wrap_atari_dqn(env):
8 | """
9 | wrap the environment in atari wrappers for DQN
10 |
11 | :param env: (Gym Environment) the environment
12 | :return: (Gym Environment) the wrapped environment
13 | """
14 | from stable_baselines.common.atari_wrappers import wrap_deepmind
15 | return wrap_deepmind(env, frame_stack=True, scale=False)
16 |
--------------------------------------------------------------------------------
/stable_baselines/deepq/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/deepq/experiments/__init__.py
--------------------------------------------------------------------------------
/stable_baselines/deepq/experiments/enjoy_cartpole.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | from stable_baselines.deepq import DQN
6 |
7 |
8 | def main(args):
9 | """
10 | Run a trained model for the cartpole problem
11 |
12 | :param args: (ArgumentParser) the input arguments
13 | """
14 | env = gym.make("CartPole-v0")
15 | model = DQN.load("cartpole_model.zip", env)
16 |
17 | while True:
18 | obs, done = env.reset(), False
19 | episode_rew = 0
20 | while not done:
21 | if not args.no_render:
22 | env.render()
23 | action, _ = model.predict(obs)
24 | obs, rew, done, _ = env.step(action)
25 | episode_rew += rew
26 | print("Episode reward", episode_rew)
27 | # No render is only used for automatic testing
28 | if args.no_render:
29 | break
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser(description="Enjoy trained DQN on cartpole")
34 | parser.add_argument('--no-render', default=False, action="store_true", help="Disable rendering")
35 | args = parser.parse_args()
36 | main(args)
37 |
--------------------------------------------------------------------------------
/stable_baselines/deepq/experiments/enjoy_mountaincar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 | import numpy as np
5 |
6 | from stable_baselines.deepq import DQN
7 |
8 |
9 | def main(args):
10 | """
11 | Run a trained model for the mountain car problem
12 |
13 | :param args: (ArgumentParser) the input arguments
14 | """
15 | env = gym.make("MountainCar-v0")
16 | model = DQN.load("mountaincar_model.zip", env)
17 |
18 | while True:
19 | obs, done = env.reset(), False
20 | episode_rew = 0
21 | while not done:
22 | if not args.no_render:
23 | env.render()
24 | # Epsilon-greedy
25 | if np.random.random() < 0.02:
26 | action = env.action_space.sample()
27 | else:
28 | action, _ = model.predict(obs, deterministic=True)
29 | obs, rew, done, _ = env.step(action)
30 | episode_rew += rew
31 | print("Episode reward", episode_rew)
32 | # No render is only used for automatic testing
33 | if args.no_render:
34 | break
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser(description="Enjoy trained DQN on MountainCar")
39 | parser.add_argument('--no-render', default=False, action="store_true", help="Disable rendering")
40 | args = parser.parse_args()
41 | main(args)
42 |
--------------------------------------------------------------------------------
/stable_baselines/deepq/experiments/run_atari.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import partial
3 |
4 | from stable_baselines import bench, logger
5 | from stable_baselines.common import set_global_seeds
6 | from stable_baselines.common.atari_wrappers import make_atari
7 | from stable_baselines.deepq import DQN, wrap_atari_dqn, CnnPolicy
8 |
9 |
10 | def main():
11 | """
12 | Run the atari test
13 | """
14 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
15 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
16 | parser.add_argument('--seed', help='RNG seed', type=int, default=0)
17 | parser.add_argument('--prioritized', type=int, default=1)
18 | parser.add_argument('--dueling', type=int, default=1)
19 | parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
20 | parser.add_argument('--num-timesteps', type=int, default=int(10e6))
21 |
22 | args = parser.parse_args()
23 | logger.configure()
24 | set_global_seeds(args.seed)
25 | env = make_atari(args.env)
26 | env = bench.Monitor(env, logger.get_dir())
27 | env = wrap_atari_dqn(env)
28 | policy = partial(CnnPolicy, dueling=args.dueling == 1)
29 |
30 | model = DQN(
31 | env=env,
32 | policy=policy,
33 | learning_rate=1e-4,
34 | buffer_size=10000,
35 | exploration_fraction=0.1,
36 | exploration_final_eps=0.01,
37 | train_freq=4,
38 | learning_starts=10000,
39 | target_network_update_freq=1000,
40 | gamma=0.99,
41 | prioritized_replay=bool(args.prioritized),
42 | prioritized_replay_alpha=args.prioritized_replay_alpha,
43 | )
44 | model.learn(total_timesteps=args.num_timesteps)
45 |
46 | env.close()
47 |
48 |
49 | if __name__ == '__main__':
50 | main()
51 |
--------------------------------------------------------------------------------
/stable_baselines/deepq/experiments/train_cartpole.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 | import numpy as np
5 |
6 | from stable_baselines.deepq import DQN, MlpPolicy
7 |
8 |
9 | def callback(lcl, _glb):
10 | """
11 | The callback function for logging and saving
12 |
13 | :param lcl: (dict) the local variables
14 | :param _glb: (dict) the global variables
15 | :return: (bool) is solved
16 | """
17 | # stop training if reward exceeds 199
18 | if len(lcl['episode_rewards'][-101:-1]) == 0:
19 | mean_100ep_reward = -np.inf
20 | else:
21 | mean_100ep_reward = round(float(np.mean(lcl['episode_rewards'][-101:-1])), 1)
22 | is_solved = lcl['self'].num_timesteps > 100 and mean_100ep_reward >= 199
23 | return not is_solved
24 |
25 |
26 | def main(args):
27 | """
28 | Train and save the DQN model, for the cartpole problem
29 |
30 | :param args: (ArgumentParser) the input arguments
31 | """
32 | env = gym.make("CartPole-v0")
33 | model = DQN(
34 | env=env,
35 | policy=MlpPolicy,
36 | learning_rate=1e-3,
37 | buffer_size=50000,
38 | exploration_fraction=0.1,
39 | exploration_final_eps=0.02,
40 | )
41 | model.learn(total_timesteps=args.max_timesteps, callback=callback)
42 |
43 | print("Saving model to cartpole_model.zip")
44 | model.save("cartpole_model.zip")
45 |
46 |
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser(description="Train DQN on cartpole")
49 | parser.add_argument('--max-timesteps', default=100000, type=int, help="Maximum number of timesteps")
50 | args = parser.parse_args()
51 | main(args)
52 |
--------------------------------------------------------------------------------
/stable_baselines/deepq/experiments/train_mountaincar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | from stable_baselines.deepq import DQN
6 |
7 |
8 | def main(args):
9 | """
10 | Train and save the DQN model, for the mountain car problem
11 |
12 | :param args: (ArgumentParser) the input arguments
13 | """
14 | env = gym.make("MountainCar-v0")
15 |
16 | # using layer norm policy here is important for parameter space noise!
17 | model = DQN(
18 | policy="LnMlpPolicy",
19 | env=env,
20 | learning_rate=1e-3,
21 | buffer_size=50000,
22 | exploration_fraction=0.1,
23 | exploration_final_eps=0.1,
24 | param_noise=True,
25 | policy_kwargs=dict(layers=[64])
26 | )
27 | model.learn(total_timesteps=args.max_timesteps)
28 |
29 | print("Saving model to mountaincar_model.zip")
30 | model.save("mountaincar_model")
31 |
32 |
33 | if __name__ == '__main__':
34 | parser = argparse.ArgumentParser(description="Train DQN on MountainCar")
35 | parser.add_argument('--max-timesteps', default=100000, type=int, help="Maximum number of timesteps")
36 | args = parser.parse_args()
37 | main(args)
38 |
--------------------------------------------------------------------------------
/stable_baselines/gail/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.gail.model import GAIL
2 | from stable_baselines.gail.dataset.dataset import ExpertDataset, DataLoader
3 | from stable_baselines.gail.dataset.record_expert import generate_expert_traj
4 |
--------------------------------------------------------------------------------
/stable_baselines/gail/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/gail/dataset/__init__.py
--------------------------------------------------------------------------------
/stable_baselines/gail/dataset/expert_cartpole.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/gail/dataset/expert_cartpole.npz
--------------------------------------------------------------------------------
/stable_baselines/gail/dataset/expert_pendulum.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/gail/dataset/expert_pendulum.npz
--------------------------------------------------------------------------------
/stable_baselines/gail/model.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.trpo_mpi import TRPO
2 |
3 |
4 | class GAIL(TRPO):
5 | """
6 | Generative Adversarial Imitation Learning (GAIL)
7 |
8 | .. warning::
9 |
10 | Images are not yet handled properly by the current implementation
11 |
12 |
13 | :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
14 | :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
15 | :param expert_dataset: (ExpertDataset) the dataset manager
16 | :param gamma: (float) the discount value
17 | :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
18 | :param max_kl: (float) the Kullback-Leibler loss threshold
19 | :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
20 | :param lam: (float) GAE factor
21 | :param entcoeff: (float) the weight for the entropy loss
22 | :param cg_damping: (float) the compute gradient dampening factor
23 | :param vf_stepsize: (float) the value function stepsize
24 | :param vf_iters: (int) the value function's number iterations for learning
25 | :param hidden_size: ([int]) the hidden dimension for the MLP
26 | :param g_step: (int) number of steps to train policy in each epoch
27 | :param d_step: (int) number of steps to train discriminator in each epoch
28 | :param d_stepsize: (float) the reward giver stepsize
29 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
30 | :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
31 | :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
32 | WARNING: this logging can take a lot of space quickly
33 | """
34 |
35 | def __init__(self, policy, env, expert_dataset=None,
36 | hidden_size_adversary=100, adversary_entcoeff=1e-3,
37 | g_step=3, d_step=1, d_stepsize=3e-4, verbose=0,
38 | _init_setup_model=True, **kwargs):
39 | super().__init__(policy, env, verbose=verbose, _init_setup_model=False, **kwargs)
40 | self.using_gail = True
41 | self.expert_dataset = expert_dataset
42 | self.g_step = g_step
43 | self.d_step = d_step
44 | self.d_stepsize = d_stepsize
45 | self.hidden_size_adversary = hidden_size_adversary
46 | self.adversary_entcoeff = adversary_entcoeff
47 |
48 | if _init_setup_model:
49 | self.setup_model()
50 |
51 | def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="GAIL",
52 | reset_num_timesteps=True):
53 | assert self.expert_dataset is not None, "You must pass an expert dataset to GAIL for training"
54 | return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
55 |
--------------------------------------------------------------------------------
/stable_baselines/her/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.her.her import HER
2 | from stable_baselines.her.replay_buffer import GoalSelectionStrategy, HindsightExperienceReplayWrapper
3 | from stable_baselines.her.utils import HERGoalEnvWrapper
4 |
--------------------------------------------------------------------------------
/stable_baselines/her/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | # Important: gym mixes up ordered and unordered keys
7 | # and the Dict space may return a different order of keys that the actual one
8 | KEY_ORDER = ['observation', 'achieved_goal', 'desired_goal']
9 |
10 |
11 | class HERGoalEnvWrapper(object):
12 | """
13 | A wrapper that allow to use dict observation space (coming from GoalEnv) with
14 | the RL algorithms.
15 | It assumes that all the spaces of the dict space are of the same type.
16 |
17 | :param env: (gym.GoalEnv)
18 | """
19 |
20 | def __init__(self, env):
21 | super(HERGoalEnvWrapper, self).__init__()
22 | self.env = env
23 | self.metadata = self.env.metadata
24 | self.action_space = env.action_space
25 | self.spaces = list(env.observation_space.spaces.values())
26 | # Check that all spaces are of the same type
27 | # (current limitation of the wrapper)
28 | space_types = [type(env.observation_space.spaces[key]) for key in KEY_ORDER]
29 | assert len(set(space_types)) == 1, "The spaces for goal and observation"\
30 | " must be of the same type"
31 |
32 | if isinstance(self.spaces[0], spaces.Discrete):
33 | self.obs_dim = 1
34 | self.goal_dim = 1
35 | else:
36 | goal_space_shape = env.observation_space.spaces['achieved_goal'].shape
37 | self.obs_dim = env.observation_space.spaces['observation'].shape[0]
38 | self.goal_dim = goal_space_shape[0]
39 |
40 | if len(goal_space_shape) == 2:
41 | assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet"
42 | else:
43 | assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet"
44 |
45 | if isinstance(self.spaces[0], spaces.MultiBinary):
46 | total_dim = self.obs_dim + 2 * self.goal_dim
47 | self.observation_space = spaces.MultiBinary(total_dim)
48 |
49 | elif isinstance(self.spaces[0], spaces.Box):
50 | lows = np.concatenate([space.low for space in self.spaces])
51 | highs = np.concatenate([space.high for space in self.spaces])
52 | self.observation_space = spaces.Box(lows, highs, dtype=np.float32)
53 |
54 | elif isinstance(self.spaces[0], spaces.Discrete):
55 | dimensions = [env.observation_space.spaces[key].n for key in KEY_ORDER]
56 | self.observation_space = spaces.MultiDiscrete(dimensions)
57 |
58 | else:
59 | raise NotImplementedError("{} space is not supported".format(type(self.spaces[0])))
60 |
61 | def convert_dict_to_obs(self, obs_dict):
62 | """
63 | :param obs_dict: (dict)
64 | :return: (np.ndarray)
65 | """
66 | # Note: achieved goal is not removed from the observation
67 | # this is helpful to have a revertible transformation
68 | if isinstance(self.observation_space, spaces.MultiDiscrete):
69 | # Special case for multidiscrete
70 | return np.concatenate([[int(obs_dict[key])] for key in KEY_ORDER])
71 | return np.concatenate([obs_dict[key] for key in KEY_ORDER])
72 |
73 | def convert_obs_to_dict(self, observations):
74 | """
75 | Inverse operation of convert_dict_to_obs
76 |
77 | :param observations: (np.ndarray)
78 | :return: (OrderedDict)
79 | """
80 | return OrderedDict([
81 | ('observation', observations[:self.obs_dim]),
82 | ('achieved_goal', observations[self.obs_dim:self.obs_dim + self.goal_dim]),
83 | ('desired_goal', observations[self.obs_dim + self.goal_dim:]),
84 | ])
85 |
86 | def step(self, action):
87 | obs, reward, done, info = self.env.step(action)
88 | return self.convert_dict_to_obs(obs), reward, done, info
89 |
90 | def seed(self, seed=None):
91 | return self.env.seed(seed)
92 |
93 | def reset(self):
94 | return self.convert_dict_to_obs(self.env.reset())
95 |
96 | def compute_reward(self, achieved_goal, desired_goal, info):
97 | return self.env.compute_reward(achieved_goal, desired_goal, info)
98 |
99 | def render(self, mode='human'):
100 | return self.env.render(mode)
101 |
102 | def close(self):
103 | return self.env.close()
104 |
--------------------------------------------------------------------------------
/stable_baselines/ppo1/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.ppo1.pposgd_simple import PPO1
2 |
--------------------------------------------------------------------------------
/stable_baselines/ppo1/experiments/train_cartpole.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple test to check that PPO1 is running with no errors (see issue #50)
3 | """
4 | from stable_baselines import PPO1
5 |
6 |
7 | if __name__ == '__main__':
8 | model = PPO1('MlpPolicy', 'CartPole-v1', schedule='linear', verbose=0)
9 | model.learn(total_timesteps=1000)
10 |
--------------------------------------------------------------------------------
/stable_baselines/ppo1/run_atari.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 |
4 | from mpi4py import MPI
5 |
6 | from stable_baselines.common import set_global_seeds
7 | from stable_baselines import bench, logger, PPO1
8 | from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind
9 | from stable_baselines.common.cmd_util import atari_arg_parser
10 | from stable_baselines.common.policies import CnnPolicy
11 |
12 |
13 | def train(env_id, num_timesteps, seed):
14 | """
15 | Train PPO1 model for Atari environments, for testing purposes
16 |
17 | :param env_id: (str) Environment ID
18 | :param num_timesteps: (int) The total number of samples
19 | :param seed: (int) The initial seed for training
20 | """
21 | rank = MPI.COMM_WORLD.Get_rank()
22 |
23 | if rank == 0:
24 | logger.configure()
25 | else:
26 | logger.configure(format_strs=[])
27 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
28 | set_global_seeds(workerseed)
29 | env = make_atari(env_id)
30 |
31 | env = bench.Monitor(env, logger.get_dir() and
32 | os.path.join(logger.get_dir(), str(rank)))
33 | env.seed(workerseed)
34 |
35 | env = wrap_deepmind(env)
36 | env.seed(workerseed)
37 |
38 | model = PPO1(CnnPolicy, env, timesteps_per_actorbatch=256, clip_param=0.2, entcoeff=0.01, optim_epochs=4,
39 | optim_stepsize=1e-3, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear', verbose=2)
40 | model.learn(total_timesteps=num_timesteps)
41 | env.close()
42 | del env
43 |
44 |
45 | def main():
46 | """
47 | Runs the test
48 | """
49 | args = atari_arg_parser().parse_args()
50 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
51 |
52 |
53 | if __name__ == '__main__':
54 | main()
55 |
--------------------------------------------------------------------------------
/stable_baselines/ppo1/run_mujoco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from stable_baselines.ppo1 import PPO1
4 | from stable_baselines.common.policies import MlpPolicy
5 | from stable_baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser
6 | from stable_baselines import logger
7 |
8 |
9 | def train(env_id, num_timesteps, seed):
10 | """
11 | Train PPO1 model for the Mujoco environment, for testing purposes
12 |
13 | :param env_id: (str) Environment ID
14 | :param num_timesteps: (int) The total number of samples
15 | :param seed: (int) The initial seed for training
16 | """
17 | env = make_mujoco_env(env_id, seed)
18 | model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=10,
19 | optim_stepsize=3e-4, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear')
20 | model.learn(total_timesteps=num_timesteps)
21 | env.close()
22 |
23 |
24 | def main():
25 | """
26 | Runs the test
27 | """
28 | args = mujoco_arg_parser().parse_args()
29 | logger.configure()
30 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
31 |
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
--------------------------------------------------------------------------------
/stable_baselines/ppo1/run_robotics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from mpi4py import MPI
4 | import mujoco_py # pytype:disable=import-error
5 |
6 | from stable_baselines.common import set_global_seeds
7 | from stable_baselines.common.policies import MlpPolicy
8 | from stable_baselines.common.cmd_util import make_robotics_env, robotics_arg_parser
9 | from stable_baselines.ppo1 import PPO1
10 |
11 |
12 | def train(env_id, num_timesteps, seed):
13 | """
14 | Train PPO1 model for Robotics environment, for testing purposes
15 |
16 | :param env_id: (str) Environment ID
17 | :param num_timesteps: (int) The total number of samples
18 | :param seed: (int) The initial seed for training
19 | """
20 |
21 | rank = MPI.COMM_WORLD.Get_rank()
22 | with mujoco_py.ignore_mujoco_warnings():
23 | workerseed = seed + 10000 * rank
24 | set_global_seeds(workerseed)
25 | env = make_robotics_env(env_id, workerseed, rank=rank)
26 |
27 | model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=5,
28 | optim_stepsize=3e-4, optim_batchsize=256, gamma=0.99, lam=0.95, schedule='linear')
29 | model.learn(total_timesteps=num_timesteps)
30 | env.close()
31 |
32 |
33 | def main():
34 | """
35 | Runs the test
36 | """
37 | args = robotics_arg_parser().parse_args()
38 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
39 |
40 |
41 | if __name__ == '__main__':
42 | main()
43 |
--------------------------------------------------------------------------------
/stable_baselines/ppo2/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.ppo2.ppo2 import PPO2
2 |
--------------------------------------------------------------------------------
/stable_baselines/ppo2/run_atari.py:
--------------------------------------------------------------------------------
1 | from stable_baselines import PPO2, logger
2 | from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser
3 | from stable_baselines.common.vec_env import VecFrameStack
4 | from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy, MlpPolicy
5 |
6 |
7 | def train(env_id, num_timesteps, seed, policy,
8 | n_envs=8, nminibatches=4, n_steps=128):
9 | """
10 | Train PPO2 model for atari environment, for testing purposes
11 |
12 | :param env_id: (str) the environment id string
13 | :param num_timesteps: (int) the number of timesteps to run
14 | :param seed: (int) Used to seed the random generator.
15 | :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
16 | :param n_envs: (int) Number of parallel environments
17 | :param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
18 | the number of environments run in parallel should be a multiple of nminibatches.
19 | :param n_steps: (int) The number of steps to run for each environment per update
20 | (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
21 | """
22 |
23 | env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
24 | policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
25 | model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
26 | lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
27 | learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
28 | model.learn(total_timesteps=num_timesteps)
29 |
30 | env.close()
31 | # Free memory
32 | del model
33 |
34 |
35 | def main():
36 | """
37 | Runs the test
38 | """
39 | parser = atari_arg_parser()
40 | parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn')
41 | args = parser.parse_args()
42 | logger.configure()
43 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
44 | policy=args.policy)
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------
/stable_baselines/ppo2/run_mujoco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import numpy as np
3 | import gym
4 |
5 | from stable_baselines.common.cmd_util import mujoco_arg_parser
6 | from stable_baselines import bench, logger
7 | from stable_baselines.common import set_global_seeds
8 | from stable_baselines.common.vec_env.vec_normalize import VecNormalize
9 | from stable_baselines.ppo2 import PPO2
10 | from stable_baselines.common.policies import MlpPolicy
11 | from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
12 |
13 |
14 | def train(env_id, num_timesteps, seed):
15 | """
16 | Train PPO2 model for Mujoco environment, for testing purposes
17 |
18 | :param env_id: (str) the environment id string
19 | :param num_timesteps: (int) the number of timesteps to run
20 | :param seed: (int) Used to seed the random generator.
21 | """
22 | def make_env():
23 | env_out = gym.make(env_id)
24 | env_out = bench.Monitor(env_out, logger.get_dir(), allow_early_resets=True)
25 | return env_out
26 |
27 | env = DummyVecEnv([make_env])
28 | env = VecNormalize(env)
29 |
30 | set_global_seeds(seed)
31 | policy = MlpPolicy
32 | model = PPO2(policy=policy, env=env, n_steps=2048, nminibatches=32, lam=0.95, gamma=0.99, noptepochs=10,
33 | ent_coef=0.0, learning_rate=3e-4, cliprange=0.2)
34 | model.learn(total_timesteps=num_timesteps)
35 |
36 | return model, env
37 |
38 |
39 | def main():
40 | """
41 | Runs the test
42 | """
43 | args = mujoco_arg_parser().parse_args()
44 | logger.configure()
45 | model, env = train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
46 |
47 | if args.play:
48 | logger.log("Running trained model")
49 | obs = np.zeros((env.num_envs,) + env.observation_space.shape)
50 | obs[:] = env.reset()
51 | while True:
52 | actions = model.step(obs)[0]
53 | obs[:] = env.step(actions)[0]
54 | env.render()
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/stable_baselines/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/stable_baselines/py.typed
--------------------------------------------------------------------------------
/stable_baselines/sac/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.sac.sac import SAC
2 | from stable_baselines.sac.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy
3 |
--------------------------------------------------------------------------------
/stable_baselines/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
2 | from stable_baselines.td3.td3 import TD3
3 | from stable_baselines.td3.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy
4 |
--------------------------------------------------------------------------------
/stable_baselines/trpo_mpi/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.trpo_mpi.trpo_mpi import TRPO
2 |
--------------------------------------------------------------------------------
/stable_baselines/trpo_mpi/run_atari.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 |
4 | from mpi4py import MPI
5 |
6 | from stable_baselines.common import set_global_seeds
7 | from stable_baselines import bench, logger, TRPO
8 | from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind
9 | from stable_baselines.common.cmd_util import atari_arg_parser
10 | from stable_baselines.common.policies import CnnPolicy
11 |
12 |
13 | def train(env_id, num_timesteps, seed):
14 | """
15 | Train TRPO model for the atari environment, for testing purposes
16 |
17 | :param env_id: (str) Environment ID
18 | :param num_timesteps: (int) The total number of samples
19 | :param seed: (int) The initial seed for training
20 | """
21 | rank = MPI.COMM_WORLD.Get_rank()
22 |
23 | if rank == 0:
24 | logger.configure()
25 | else:
26 | logger.configure(format_strs=[])
27 |
28 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
29 | set_global_seeds(workerseed)
30 | env = make_atari(env_id)
31 |
32 | env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
33 | env.seed(workerseed)
34 |
35 | env = wrap_deepmind(env)
36 | env.seed(workerseed)
37 |
38 | model = TRPO(CnnPolicy, env, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3, entcoeff=0.0,
39 | gamma=0.98, lam=1, vf_iters=3, vf_stepsize=1e-4)
40 | model.learn(total_timesteps=int(num_timesteps * 1.1))
41 | env.close()
42 | # Free memory
43 | del env
44 |
45 |
46 | def main():
47 | """
48 | Runs the test
49 | """
50 | args = atari_arg_parser().parse_args()
51 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
52 |
53 |
54 | if __name__ == "__main__":
55 | main()
56 |
--------------------------------------------------------------------------------
/stable_baselines/trpo_mpi/run_mujoco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # noinspection PyUnresolvedReferences
3 | from mpi4py import MPI
4 |
5 | from stable_baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser
6 | from stable_baselines.common.policies import MlpPolicy
7 | from stable_baselines import logger
8 | from stable_baselines.trpo_mpi import TRPO
9 | import stable_baselines.common.tf_util as tf_util
10 |
11 |
12 | def train(env_id, num_timesteps, seed):
13 | """
14 | Train TRPO model for the mujoco environment, for testing purposes
15 |
16 | :param env_id: (str) Environment ID
17 | :param num_timesteps: (int) The total number of samples
18 | :param seed: (int) The initial seed for training
19 | """
20 | with tf_util.single_threaded_session():
21 | rank = MPI.COMM_WORLD.Get_rank()
22 | if rank == 0:
23 | logger.configure()
24 | else:
25 | logger.configure(format_strs=[])
26 | logger.set_level(logger.DISABLED)
27 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
28 |
29 | env = make_mujoco_env(env_id, workerseed)
30 | model = TRPO(MlpPolicy, env, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1, entcoeff=0.0,
31 | gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
32 | model.learn(total_timesteps=num_timesteps)
33 | env.close()
34 |
35 |
36 | def main():
37 | """
38 | Runs the test
39 | """
40 | args = mujoco_arg_parser().parse_args()
41 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------
/stable_baselines/trpo_mpi/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def add_vtarg_and_adv(seg, gamma, lam):
5 | """
6 | Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
7 |
8 | :param seg: (dict) the current segment of the trajectory (see traj_segment_generator return for more information)
9 | :param gamma: (float) Discount factor
10 | :param lam: (float) GAE factor
11 | """
12 | # last element is only used for last vtarg, but we already zeroed it if last new = 1
13 | episode_starts = np.append(seg["episode_starts"], False)
14 | vpred = np.append(seg["vpred"], seg["nextvpred"])
15 | rew_len = len(seg["rewards"])
16 | seg["adv"] = np.empty(rew_len, 'float32')
17 | rewards = seg["rewards"]
18 | lastgaelam = 0
19 | for step in reversed(range(rew_len)):
20 | nonterminal = 1 - float(episode_starts[step + 1])
21 | delta = rewards[step] + gamma * vpred[step + 1] * nonterminal - vpred[step]
22 | seg["adv"][step] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
23 | seg["tdlamret"] = seg["adv"] + seg["vpred"]
24 |
--------------------------------------------------------------------------------
/stable_baselines/version.txt:
--------------------------------------------------------------------------------
1 | 2.10.3a0 (WIP)
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hill-a/stable-baselines/45beb246833b6818e0f3fc1f44336b1c52351170/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_0deterministic.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, SAC, TRPO, TD3
5 | from stable_baselines.common.noise import NormalActionNoise
6 |
7 | N_STEPS_TRAINING = 300
8 | SEED = 0
9 |
10 |
11 | # Weird stuff: TD3 would fail if another algorithm is tested before
12 | # with n_cpu_tf_sess > 1
13 | @pytest.mark.xfail(reason="TD3 deterministic randomly fail when run with others...", strict=False)
14 | def test_deterministic_td3():
15 | results = [[], []]
16 | rewards = [[], []]
17 | kwargs = {'n_cpu_tf_sess': 1}
18 | env_id = 'Pendulum-v0'
19 | kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)})
20 |
21 | for i in range(2):
22 | model = TD3('MlpPolicy', env_id, seed=SEED, **kwargs)
23 | model.learn(N_STEPS_TRAINING)
24 | env = model.get_env()
25 | obs = env.reset()
26 | for _ in range(20):
27 | action, _ = model.predict(obs, deterministic=True)
28 | obs, reward, _, _ = env.step(action)
29 | results[i].append(action)
30 | rewards[i].append(reward)
31 | # without the extended tolerance, test fails for unknown reasons on Github...
32 | assert np.allclose(results[0], results[1], rtol=1e-2), results
33 | assert np.allclose(rewards[0], rewards[1], rtol=1e-2), rewards
34 |
35 |
36 | @pytest.mark.parametrize("algo", [A2C, ACKTR, ACER, DDPG, DQN, PPO1, PPO2, SAC, TRPO])
37 | def test_deterministic_training_common(algo):
38 | results = [[], []]
39 | rewards = [[], []]
40 | kwargs = {'n_cpu_tf_sess': 1}
41 | if algo in [DDPG, TD3, SAC]:
42 | env_id = 'Pendulum-v0'
43 | kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)})
44 | else:
45 | env_id = 'CartPole-v1'
46 | if algo == DQN:
47 | kwargs.update({'learning_starts': 100})
48 |
49 | for i in range(2):
50 | model = algo('MlpPolicy', env_id, seed=SEED, **kwargs)
51 | model.learn(N_STEPS_TRAINING)
52 | env = model.get_env()
53 | obs = env.reset()
54 | for _ in range(20):
55 | action, _ = model.predict(obs, deterministic=False)
56 | obs, reward, _, _ = env.step(action)
57 | results[i].append(action)
58 | rewards[i].append(reward)
59 | assert sum(results[0]) == sum(results[1]), results
60 | assert sum(rewards[0]) == sum(rewards[1]), rewards
61 |
--------------------------------------------------------------------------------
/tests/test_a2c.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import gym
5 |
6 | from stable_baselines import A2C
7 | from stable_baselines.common import make_vec_env
8 | from stable_baselines.common.vec_env import DummyVecEnv
9 |
10 |
11 | def test_a2c_update_n_batch_on_load(tmp_path):
12 | env = make_vec_env("CartPole-v1", n_envs=2)
13 | model = A2C("MlpPolicy", env, n_steps=10)
14 |
15 | model.learn(total_timesteps=100)
16 | model.save(os.path.join(str(tmp_path), "a2c_cartpole.zip"))
17 |
18 | del model
19 |
20 | model = A2C.load(os.path.join(str(tmp_path), "a2c_cartpole.zip"))
21 | test_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
22 |
23 | model.set_env(test_env)
24 | assert model.n_batch == 10
25 | model.learn(100)
26 | os.remove(os.path.join(str(tmp_path), "a2c_cartpole.zip"))
27 |
--------------------------------------------------------------------------------
/tests/test_a2c_conv.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from stable_baselines.common.tf_layers import conv
6 | from stable_baselines.common.input import observation_input
7 |
8 |
9 | ENV_ID = 'BreakoutNoFrameskip-v4'
10 | SEED = 3
11 |
12 |
13 | def test_conv_kernel():
14 | """Test convolution kernel with various input formats."""
15 | filter_size_1 = 4 # The size of squared filter for the first layer
16 | filter_size_2 = (3, 5) # The size of non-squared filter for the second layer
17 | target_shape_1 = [2, 52, 40, 32] # The desired shape of the first layer
18 | target_shape_2 = [2, 13, 9, 32] # The desired shape of the second layer
19 | kwargs = {}
20 | n_envs = 1
21 | n_steps = 2
22 | n_batch = n_envs * n_steps
23 | scale = False
24 | env = gym.make(ENV_ID)
25 | ob_space = env.observation_space
26 |
27 | with tf.Graph().as_default():
28 | _, scaled_images = observation_input(ob_space, n_batch, scale=scale)
29 | activ = tf.nn.relu
30 | layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=filter_size_1,
31 | stride=4, init_scale=np.sqrt(2), **kwargs))
32 | layer_2 = activ(conv(layer_1, 'c2', n_filters=32, filter_size=filter_size_2,
33 | stride=4, init_scale=np.sqrt(2), **kwargs))
34 | assert layer_1.shape == target_shape_1, \
35 | "The shape of layer based on the squared kernel matrix is not correct. " \
36 | "The current shape is {} and the desired shape is {}".format(layer_1.shape, target_shape_1)
37 | assert layer_2.shape == target_shape_2, \
38 | "The shape of layer based on the non-squared kernel matrix is not correct. " \
39 | "The current shape is {} and the desired shape is {}".format(layer_2.shape, target_shape_2)
40 | env.close()
41 |
--------------------------------------------------------------------------------
/tests/test_action_scaling.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import DDPG, TD3, SAC
5 | from stable_baselines.common.identity_env import IdentityEnvBox
6 |
7 | ROLLOUT_STEPS = 100
8 |
9 | MODEL_LIST = [
10 | (DDPG, dict(nb_train_steps=0, nb_rollout_steps=ROLLOUT_STEPS)),
11 | (TD3, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=0)),
12 | (SAC, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=0)),
13 | (TD3, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=ROLLOUT_STEPS)),
14 | (SAC, dict(train_freq=ROLLOUT_STEPS + 1, learning_starts=ROLLOUT_STEPS))
15 | ]
16 |
17 |
18 | @pytest.mark.parametrize("model_class, model_kwargs", MODEL_LIST)
19 | def test_buffer_actions_scaling(model_class, model_kwargs):
20 | """
21 | Test if actions are scaled to tanh co-domain before being put in a buffer
22 | for algorithms that use tanh-squashing, i.e., DDPG, TD3, SAC
23 |
24 | :param model_class: (BaseRLModel) A RL Model
25 | :param model_kwargs: (dict) Dictionary containing named arguments to the given algorithm
26 | """
27 |
28 | # check random and inferred actions as they possibly have different flows
29 | for random_coeff in [0.0, 1.0]:
30 |
31 | env = IdentityEnvBox(-2000, 1000)
32 |
33 | model = model_class("MlpPolicy", env, seed=1, random_exploration=random_coeff, **model_kwargs)
34 | model.learn(total_timesteps=ROLLOUT_STEPS)
35 |
36 | assert hasattr(model, 'replay_buffer')
37 |
38 | buffer = model.replay_buffer
39 |
40 | assert buffer.can_sample(ROLLOUT_STEPS)
41 |
42 | _, actions, _, _, _ = buffer.sample(ROLLOUT_STEPS)
43 |
44 | assert not np.any(actions > np.ones_like(actions))
45 | assert not np.any(actions < -np.ones_like(actions))
46 |
--------------------------------------------------------------------------------
/tests/test_action_space.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import A2C, PPO1, PPO2, TRPO
5 | from stable_baselines.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
6 | from stable_baselines.common.vec_env import DummyVecEnv
7 | from stable_baselines.common.evaluation import evaluate_policy
8 |
9 | MODEL_LIST = [
10 | A2C,
11 | PPO1,
12 | PPO2,
13 | TRPO
14 | ]
15 |
16 |
17 | @pytest.mark.slow
18 | @pytest.mark.parametrize("model_class", MODEL_LIST)
19 | def test_identity_multidiscrete(model_class):
20 | """
21 | Test if the algorithm (with a given policy)
22 | can learn an identity transformation (i.e. return observation as an action)
23 | with a multidiscrete action space
24 |
25 | :param model_class: (BaseRLModel) A RL Model
26 | """
27 | env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(10)])
28 |
29 | model = model_class("MlpPolicy", env)
30 | model.learn(total_timesteps=1000)
31 | evaluate_policy(model, env, n_eval_episodes=5)
32 | obs = env.reset()
33 |
34 | assert np.array(model.action_probability(obs)).shape == (2, 1, 10), \
35 | "Error: action_probability not returning correct shape"
36 | assert np.prod(model.action_probability(obs, actions=env.action_space.sample()).shape) == 1, \
37 | "Error: not scalar probability"
38 |
39 |
40 | @pytest.mark.slow
41 | @pytest.mark.parametrize("model_class", MODEL_LIST)
42 | def test_identity_multibinary(model_class):
43 | """
44 | Test if the algorithm (with a given policy)
45 | can learn an identity transformation (i.e. return observation as an action)
46 | with a multibinary action space
47 |
48 | :param model_class: (BaseRLModel) A RL Model
49 | """
50 | env = DummyVecEnv([lambda: IdentityEnvMultiBinary(10)])
51 |
52 | model = model_class("MlpPolicy", env)
53 | model.learn(total_timesteps=1000)
54 | evaluate_policy(model, env, n_eval_episodes=5)
55 | obs = env.reset()
56 |
57 | assert model.action_probability(obs).shape == (1, 10), \
58 | "Error: action_probability not returning correct shape"
59 | assert np.prod(model.action_probability(obs, actions=env.action_space.sample()).shape) == 1, \
60 | "Error: not scalar probability"
61 |
--------------------------------------------------------------------------------
/tests/test_atari.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from stable_baselines import bench, logger
4 | from stable_baselines.deepq import DQN, wrap_atari_dqn, CnnPolicy
5 | from stable_baselines.common import set_global_seeds
6 | from stable_baselines.common.atari_wrappers import make_atari
7 | import stable_baselines.a2c.run_atari as a2c_atari
8 | import stable_baselines.acer.run_atari as acer_atari
9 | import stable_baselines.acktr.run_atari as acktr_atari
10 | import stable_baselines.ppo1.run_atari as ppo1_atari
11 | import stable_baselines.ppo2.run_atari as ppo2_atari
12 | import stable_baselines.trpo_mpi.run_atari as trpo_atari
13 |
14 |
15 | ENV_ID = 'BreakoutNoFrameskip-v4'
16 | SEED = 3
17 | NUM_TIMESTEPS = 300
18 | NUM_CPU = 2
19 |
20 |
21 | @pytest.mark.slow
22 | @pytest.mark.parametrize("policy", ['cnn', 'lstm', 'lnlstm'])
23 | def test_a2c(policy):
24 | """
25 | test A2C on atari
26 |
27 | :param policy: (str) the policy to test for A2C
28 | """
29 | a2c_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED,
30 | policy=policy, lr_schedule='constant', num_env=NUM_CPU)
31 |
32 |
33 | @pytest.mark.slow
34 | @pytest.mark.parametrize("policy", ['cnn', 'lstm'])
35 | def test_acer(policy):
36 | """
37 | test ACER on atari
38 |
39 | :param policy: (str) the policy to test for ACER
40 | """
41 | acer_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED,
42 | policy=policy, lr_schedule='constant', num_cpu=NUM_CPU)
43 |
44 |
45 | @pytest.mark.slow
46 | def test_acktr():
47 | """
48 | test ACKTR on atari
49 | """
50 | acktr_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, num_cpu=NUM_CPU)
51 |
52 |
53 | @pytest.mark.slow
54 | def test_deepq():
55 | """
56 | test DeepQ on atari
57 | """
58 | logger.configure()
59 | set_global_seeds(SEED)
60 | env = make_atari(ENV_ID)
61 | env = bench.Monitor(env, logger.get_dir())
62 | env = wrap_atari_dqn(env)
63 |
64 | model = DQN(env=env, policy=CnnPolicy, learning_rate=1e-4, buffer_size=10000, exploration_fraction=0.1,
65 | exploration_final_eps=0.01, train_freq=4, learning_starts=100, target_network_update_freq=100,
66 | gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6)
67 | model.learn(total_timesteps=NUM_TIMESTEPS)
68 |
69 | env.close()
70 | del model, env
71 |
72 |
73 | @pytest.mark.slow
74 | def test_ppo1():
75 | """
76 | test PPO1 on atari
77 | """
78 | ppo1_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED)
79 |
80 |
81 | @pytest.mark.slow
82 | @pytest.mark.parametrize("policy", ['cnn', 'lstm', 'lnlstm', 'mlp'])
83 | def test_ppo2(policy):
84 | """
85 | test PPO2 on atari
86 |
87 | :param policy: (str) the policy to test for PPO2
88 | """
89 | ppo2_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS,
90 | seed=SEED, policy=policy, n_envs=NUM_CPU,
91 | nminibatches=NUM_CPU, n_steps=16)
92 |
93 |
94 | @pytest.mark.slow
95 | def test_trpo():
96 | """
97 | test TRPO on atari
98 | """
99 | trpo_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED)
100 |
--------------------------------------------------------------------------------
/tests/test_auto_vec_detection.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import A2C, ACER, ACKTR, DDPG, DQN, PPO1, PPO2, SAC, TRPO, TD3
5 | from stable_baselines.common.vec_env import DummyVecEnv
6 | from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, \
7 | IdentityEnvMultiDiscrete
8 | from stable_baselines.common.evaluation import evaluate_policy
9 |
10 |
11 | def check_shape(make_env, model_class, shape_1, shape_2):
12 | model = model_class(policy="MlpPolicy", env=DummyVecEnv([make_env]))
13 |
14 | env0 = make_env()
15 | env1 = DummyVecEnv([make_env])
16 |
17 | for env, expected_shape in [(env0, shape_1), (env1, shape_2)]:
18 | def callback(locals_, _globals):
19 | assert np.array(locals_['action']).shape == expected_shape
20 | evaluate_policy(model, env, n_eval_episodes=5, callback=callback)
21 |
22 |
23 | @pytest.mark.slow
24 | @pytest.mark.parametrize("model_class", [A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO])
25 | def test_identity(model_class):
26 | """
27 | test the Disrete environment vectorisation detection
28 |
29 | :param model_class: (BaseRLModel) the RL model
30 | """
31 | check_shape(lambda: IdentityEnv(dim=10), model_class, (), (1,))
32 |
33 |
34 | @pytest.mark.slow
35 | @pytest.mark.parametrize("model_class", [A2C, DDPG, PPO1, PPO2, SAC, TRPO, TD3])
36 | def test_identity_box(model_class):
37 | """
38 | test the Box environment vectorisation detection
39 |
40 | :param model_class: (BaseRLModel) the RL model
41 | """
42 | check_shape(lambda: IdentityEnvBox(eps=0.5), model_class, (1,), (1, 1))
43 |
44 |
45 | @pytest.mark.slow
46 | @pytest.mark.parametrize("model_class", [A2C, PPO1, PPO2, TRPO])
47 | def test_identity_multi_binary(model_class):
48 | """
49 | test the MultiBinary environment vectorisation detection
50 |
51 | :param model_class: (BaseRLModel) the RL model
52 | """
53 | check_shape(lambda: IdentityEnvMultiBinary(dim=10), model_class, (10,), (1, 10))
54 |
55 |
56 | @pytest.mark.slow
57 | @pytest.mark.parametrize("model_class", [A2C, PPO1, PPO2, TRPO])
58 | def test_identity_multi_discrete(model_class):
59 | """
60 | test the MultiDiscrete environment vectorisation detection
61 |
62 | :param model_class: (BaseRLModel) the RL model
63 | """
64 | check_shape(lambda: IdentityEnvMultiDiscrete(dim=10), model_class, (2,), (1, 2))
65 |
--------------------------------------------------------------------------------
/tests/test_common.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import sys
3 |
4 |
5 | def _assert_eq(left, right):
6 | assert left == right, '{} != {}'.format(left, right)
7 |
8 |
9 | def _assert_neq(left, right):
10 | assert left != right, '{} == {}'.format(left, right)
11 |
12 |
13 | @contextmanager
14 | def _maybe_disable_mpi(mpi_disabled):
15 | """A context that can temporarily remove the mpi4py import.
16 |
17 | Useful for testing whether non-MPI algorithms work as intended when
18 | mpi4py isn't installed.
19 |
20 | Args:
21 | disable_mpi (bool): If True, then this context temporarily removes
22 | the mpi4py import from `sys.modules`
23 | """
24 | if mpi_disabled and "mpi4py" in sys.modules:
25 | temp = sys.modules["mpi4py"]
26 | try:
27 | sys.modules["mpi4py"] = None
28 | yield
29 | finally:
30 | sys.modules["mpi4py"] = temp
31 | else:
32 | yield
33 |
--------------------------------------------------------------------------------
/tests/test_deepq.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.deepq.experiments.train_cartpole import main as train_cartpole
2 | from stable_baselines.deepq.experiments.enjoy_cartpole import main as enjoy_cartpole
3 | from stable_baselines.deepq.experiments.train_mountaincar import main as train_mountaincar
4 | from stable_baselines.deepq.experiments.enjoy_mountaincar import main as enjoy_mountaincar
5 |
6 |
7 | class DummyObject(object):
8 | """
9 | Dummy object to create fake Parsed Arguments object
10 | """
11 | pass
12 |
13 |
14 | args = DummyObject()
15 | args.no_render = True
16 | args.max_timesteps = 200
17 |
18 |
19 | def test_cartpole():
20 | train_cartpole(args)
21 | enjoy_cartpole(args)
22 |
23 |
24 | def test_mountaincar():
25 | train_mountaincar(args)
26 | enjoy_mountaincar(args)
27 |
--------------------------------------------------------------------------------
/tests/test_distri.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | import stable_baselines.common.tf_util as tf_util
5 | from stable_baselines.common.distributions import DiagGaussianProbabilityDistributionType,\
6 | CategoricalProbabilityDistributionType, \
7 | MultiCategoricalProbabilityDistributionType, BernoulliProbabilityDistributionType
8 |
9 |
10 | @tf_util.in_session
11 | def test_probtypes():
12 | """
13 | test probability distribution types
14 | """
15 | np.random.seed(0)
16 |
17 | pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
18 | diag_gauss = DiagGaussianProbabilityDistributionType(pdparam_diag_gauss.size // 2)
19 | validate_probtype(diag_gauss, pdparam_diag_gauss)
20 |
21 | pdparam_categorical = np.array([-.2, .3, .5])
22 | categorical = CategoricalProbabilityDistributionType(pdparam_categorical.size)
23 | validate_probtype(categorical, pdparam_categorical)
24 |
25 | nvec = np.array([1, 2, 3])
26 | pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
27 | multicategorical = MultiCategoricalProbabilityDistributionType(nvec)
28 | validate_probtype(multicategorical, pdparam_multicategorical)
29 |
30 | pdparam_bernoulli = np.array([-.2, .3, .5])
31 | bernoulli = BernoulliProbabilityDistributionType(pdparam_bernoulli.size)
32 | validate_probtype(bernoulli, pdparam_bernoulli)
33 |
34 |
35 | def validate_probtype(probtype, pdparam):
36 | """
37 | validate probability distribution types
38 |
39 | :param probtype: (ProbabilityDistributionType) the type to validate
40 | :param pdparam: ([float]) the flat probabilities to test
41 | """
42 | number_samples = 100000
43 | # Check to see if mean negative log likelihood == differential entropy
44 | mval = np.repeat(pdparam[None, :], number_samples, axis=0)
45 | mval_ph = probtype.param_placeholder([number_samples])
46 | xval_ph = probtype.sample_placeholder([number_samples])
47 | proba_distribution = probtype.proba_distribution_from_flat(mval_ph)
48 | calcloglik = tf_util.function([xval_ph, mval_ph], proba_distribution.logp(xval_ph))
49 | calcent = tf_util.function([mval_ph], proba_distribution.entropy())
50 | xval = tf.get_default_session().run(proba_distribution.sample(), feed_dict={mval_ph: mval})
51 | logliks = calcloglik(xval, mval)
52 | entval_ll = - logliks.mean()
53 | entval_ll_stderr = logliks.std() / np.sqrt(number_samples)
54 | entval = calcent(mval).mean()
55 | assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
56 |
57 | # Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
58 | mval2_ph = probtype.param_placeholder([number_samples])
59 | pd2 = probtype.proba_distribution_from_flat(mval2_ph)
60 | tmp = pdparam + np.random.randn(pdparam.size) * 0.1
61 | mval2 = np.repeat(tmp[None, :], number_samples, axis=0)
62 | calckl = tf_util.function([mval_ph, mval2_ph], proba_distribution.kl(pd2))
63 | klval = calckl(mval, mval2).mean()
64 | logliks = calcloglik(xval, mval2)
65 | klval_ll = - entval - logliks.mean()
66 | klval_ll_stderr = logliks.std() / np.sqrt(number_samples)
67 | assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
68 | print('ok on', probtype, pdparam)
69 |
--------------------------------------------------------------------------------
/tests/test_her.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from stable_baselines import HER, DQN, SAC, DDPG, TD3
6 | from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper
7 | from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY
8 | from stable_baselines.common.bit_flipping_env import BitFlippingEnv
9 | from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize
10 |
11 | N_BITS = 10
12 |
13 |
14 | def model_predict(model, env, n_steps, additional_check=None):
15 | """
16 | Test helper
17 | :param model: (rl model)
18 | :param env: (gym.Env)
19 | :param n_steps: (int)
20 | :param additional_check: (callable)
21 | """
22 | obs = env.reset()
23 | for _ in range(n_steps):
24 | action, _ = model.predict(obs)
25 | obs, reward, done, _ = env.step(action)
26 |
27 | if additional_check is not None:
28 | additional_check(obs, action, reward, done)
29 |
30 | if done:
31 | obs = env.reset()
32 |
33 |
34 | @pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy))
35 | @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG, TD3])
36 | @pytest.mark.parametrize('discrete_obs_space', [False, True])
37 | def test_her(model_class, goal_selection_strategy, discrete_obs_space):
38 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3],
39 | max_steps=N_BITS, discrete_obs_space=discrete_obs_space)
40 |
41 | # Take random actions 10% of the time
42 | kwargs = {'random_exploration': 0.1} if model_class in [DDPG, SAC, TD3] else {}
43 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy,
44 | verbose=0, **kwargs)
45 | model.learn(150)
46 |
47 |
48 | @pytest.mark.parametrize('model_class', [DDPG, SAC, DQN, TD3])
49 | def test_long_episode(model_class):
50 | """
51 | Check that the model does not break when the replay buffer is still empty
52 | after the first rollout (because the episode is not over).
53 | """
54 | # n_bits > nb_rollout_steps
55 | n_bits = 10
56 | env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC, TD3],
57 | max_steps=n_bits)
58 | kwargs = {}
59 | if model_class == DDPG:
60 | kwargs['nb_rollout_steps'] = 9 # < n_bits
61 | elif model_class in [DQN, SAC, TD3]:
62 | kwargs['batch_size'] = 8 # < n_bits
63 | kwargs['learning_starts'] = 0
64 |
65 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future',
66 | verbose=0, **kwargs)
67 | model.learn(100)
68 |
69 |
70 | @pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]])
71 | @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG, TD3])
72 | def test_model_manipulation(model_class, goal_selection_strategy):
73 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
74 | env = DummyVecEnv([lambda: env])
75 |
76 | model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy,
77 | verbose=0)
78 | model.learn(150)
79 |
80 | model_predict(model, env, n_steps=20, additional_check=None)
81 |
82 | model.save('./test_her.zip')
83 | del model
84 |
85 | # NOTE: HER does not support VecEnvWrapper yet
86 | with pytest.raises(AssertionError):
87 | model = HER.load('./test_her.zip', env=VecNormalize(env))
88 |
89 | model = HER.load('./test_her.zip')
90 |
91 | # Check that the model raises an error when the env
92 | # is not wrapped (or no env passed to the model)
93 | with pytest.raises(ValueError):
94 | model.predict(env.reset())
95 |
96 | env_ = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
97 | env_ = HERGoalEnvWrapper(env_)
98 |
99 | model_predict(model, env_, n_steps=20, additional_check=None)
100 |
101 | model.set_env(env)
102 | model.learn(150)
103 |
104 | model_predict(model, env_, n_steps=20, additional_check=None)
105 |
106 | assert model.n_sampled_goal == 3
107 |
108 | del model
109 |
110 | env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
111 | model = HER.load('./test_her', env=env)
112 | model.learn(150)
113 |
114 | model_predict(model, env_, n_steps=20, additional_check=None)
115 |
116 | assert model.n_sampled_goal == 3
117 |
118 | if os.path.isfile('./test_her.zip'):
119 | os.remove('./test_her.zip')
120 |
--------------------------------------------------------------------------------
/tests/test_identity.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, SAC, PPO1, PPO2, TD3, TRPO
5 | from stable_baselines.ddpg import NormalActionNoise
6 | from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox
7 | from stable_baselines.common.vec_env import DummyVecEnv
8 | from stable_baselines.common.evaluation import evaluate_policy
9 |
10 |
11 | # Hyperparameters for learning identity for each RL model
12 | LEARN_FUNC_DICT = {
13 | "a2c": lambda e: A2C(
14 | policy="MlpPolicy",
15 | learning_rate=1e-3,
16 | n_steps=4,
17 | gamma=0.4,
18 | ent_coef=0.0,
19 | env=e,
20 | seed=0,
21 | ).learn(total_timesteps=4000),
22 | "acer": lambda e: ACER(
23 | policy="MlpPolicy",
24 | env=e,
25 | seed=0,
26 | n_steps=4,
27 | replay_ratio=1,
28 | ent_coef=0.0,
29 | ).learn(total_timesteps=4000),
30 | "acktr": lambda e: ACKTR(
31 | policy="MlpPolicy", env=e, seed=0, learning_rate=5e-4, ent_coef=0.0, n_steps=4
32 | ).learn(total_timesteps=4000),
33 | "dqn": lambda e: DQN(
34 | policy="MlpPolicy",
35 | batch_size=32,
36 | gamma=0.1,
37 | learning_starts=0,
38 | exploration_final_eps=0.05,
39 | exploration_fraction=0.1,
40 | env=e,
41 | seed=0,
42 | ).learn(total_timesteps=4000),
43 | "ppo1": lambda e: PPO1(
44 | policy="MlpPolicy",
45 | env=e,
46 | seed=0,
47 | lam=0.5,
48 | entcoeff=0.0,
49 | optim_batchsize=16,
50 | gamma=0.4,
51 | optim_stepsize=1e-3,
52 | ).learn(total_timesteps=3000),
53 | "ppo2": lambda e: PPO2(
54 | policy="MlpPolicy",
55 | env=e,
56 | seed=0,
57 | learning_rate=1.5e-3,
58 | lam=0.8,
59 | ent_coef=0.0,
60 | gamma=0.4,
61 | ).learn(total_timesteps=3000),
62 | "trpo": lambda e: TRPO(
63 | policy="MlpPolicy",
64 | env=e,
65 | gamma=0.4,
66 | seed=0,
67 | max_kl=0.05,
68 | lam=0.7,
69 | timesteps_per_batch=256,
70 | ).learn(total_timesteps=4000),
71 | }
72 |
73 |
74 | @pytest.mark.slow
75 | @pytest.mark.parametrize(
76 | "model_name", ["a2c", "acer", "acktr", "dqn", "ppo1", "ppo2", "trpo"]
77 | )
78 | def test_identity_discrete(model_name):
79 | """
80 | Test if the algorithm (with a given policy)
81 | can learn an identity transformation (i.e. return observation as an action)
82 |
83 | :param model_name: (str) Name of the RL model
84 | """
85 | env = DummyVecEnv([lambda: IdentityEnv(10)])
86 |
87 | model = LEARN_FUNC_DICT[model_name](env)
88 | evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
89 |
90 | obs = env.reset()
91 | assert model.action_probability(obs).shape == (
92 | 1,
93 | 10,
94 | ), "Error: action_probability not returning correct shape"
95 | action = env.action_space.sample()
96 | action_prob = model.action_probability(obs, actions=action)
97 | assert np.prod(action_prob.shape) == 1, "Error: not scalar probability"
98 | action_logprob = model.action_probability(obs, actions=action, logp=True)
99 | assert np.allclose(action_prob, np.exp(action_logprob)), (
100 | action_prob,
101 | action_logprob,
102 | )
103 |
104 | # Free memory
105 | del model, env
106 |
107 |
108 | @pytest.mark.slow
109 | @pytest.mark.parametrize("model_class", [DDPG, TD3, SAC])
110 | def test_identity_continuous(model_class):
111 | """
112 | Test if the algorithm (with a given policy)
113 | can learn an identity transformation (i.e. return observation as an action)
114 | """
115 | env = DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)])
116 |
117 | n_steps = {SAC: 700, TD3: 500, DDPG: 2000}[model_class]
118 |
119 | kwargs = dict(seed=0, gamma=0.95, buffer_size=1e5)
120 | if model_class in [DDPG, TD3]:
121 | n_actions = 1
122 | action_noise = NormalActionNoise(
123 | mean=np.zeros(n_actions), sigma=0.05 * np.ones(n_actions)
124 | )
125 | kwargs["action_noise"] = action_noise
126 |
127 | if model_class == DDPG:
128 | kwargs["actor_lr"] = 1e-3
129 | kwargs["batch_size"] = 100
130 |
131 | model = model_class("MlpPolicy", env, **kwargs)
132 | model.learn(total_timesteps=n_steps)
133 |
134 | evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
135 | # Free memory
136 | del model, env
137 |
--------------------------------------------------------------------------------
/tests/test_log_prob.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines import A2C, ACKTR, PPO1, PPO2, TRPO
5 | from stable_baselines.common.identity_env import IdentityEnvBox
6 |
7 |
8 | class Helper:
9 | @staticmethod
10 | def proba_vals(obs, state, mask):
11 | # Return fixed mean, std
12 | return np.array([-0.4]), np.array([[0.1]])
13 |
14 |
15 | @pytest.mark.parametrize("model_class", [A2C, ACKTR, PPO1, PPO2, TRPO])
16 | def test_log_prob_calcuation(model_class):
17 | model = model_class("MlpPolicy", IdentityEnvBox())
18 | # Fixed mean/std
19 | model.proba_step = Helper.proba_vals
20 | # Check that the log probability is the one expected for the given mean/std
21 | logprob = model.action_probability(observation=np.array([[0.5], [0.5]]), actions=0.2, logp=True)
22 | assert np.allclose(logprob, np.array([-16.616353440210627])), "Calculation failed for {}".format(model_class)
23 |
--------------------------------------------------------------------------------
/tests/test_logger.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from stable_baselines.logger import make_output_format, read_tb, read_csv, read_json, _demo
5 | from .test_common import _maybe_disable_mpi
6 |
7 |
8 | KEY_VALUES = {
9 | "test": 1,
10 | "b": -3.14,
11 | "8": 9.9,
12 | "l": [1, 2],
13 | "a": np.array([1, 2, 3]),
14 | "f": np.array(1),
15 | "g": np.array([[[1]]]),
16 | }
17 | LOG_DIR = '/tmp/openai_baselines/'
18 |
19 |
20 | def test_main():
21 | """
22 | Dry-run python -m stable_baselines.logger
23 | """
24 | _demo()
25 |
26 |
27 | @pytest.mark.parametrize('_format', ['tensorboard', 'stdout', 'log', 'json', 'csv'])
28 | @pytest.mark.parametrize('mpi_disabled', [False, True])
29 | def test_make_output(_format, mpi_disabled):
30 | """
31 | test make output
32 |
33 | :param _format: (str) output format
34 | """
35 | with _maybe_disable_mpi(mpi_disabled):
36 | writer = make_output_format(_format, LOG_DIR)
37 | writer.writekvs(KEY_VALUES)
38 | if _format == 'tensorboard':
39 | read_tb(LOG_DIR)
40 | elif _format == "csv":
41 | read_csv(LOG_DIR + 'progress.csv')
42 | elif _format == 'json':
43 | read_json(LOG_DIR + 'progress.json')
44 | writer.close()
45 |
46 |
47 | def test_make_output_fail():
48 | """
49 | test value error on logger
50 | """
51 | with pytest.raises(ValueError):
52 | make_output_format('dummy_format', LOG_DIR)
53 |
--------------------------------------------------------------------------------
/tests/test_math_util.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from gym.spaces.box import Box
4 |
5 | from stable_baselines.common.math_util import discount_with_boundaries, scale_action, unscale_action
6 |
7 |
8 | def test_discount_with_boundaries():
9 | """
10 | test the discount_with_boundaries function
11 | """
12 | gamma = 0.9
13 | rewards = np.array([1.0, 2.0, 3.0, 4.0], 'float32')
14 | episode_starts = [1.0, 0.0, 0.0, 1.0]
15 | discounted_rewards = discount_with_boundaries(rewards, episode_starts, gamma)
16 | assert np.allclose(discounted_rewards, [1 + gamma * 2 + gamma ** 2 * 3, 2 + gamma * 3, 3, 4])
17 | return
18 |
19 |
20 | def test_scaling_action():
21 | """
22 | test scaling of scalar, 1d and 2d vectors of finite non-NaN real numbers to and from tanh co-domain (per component)
23 | """
24 | test_ranges = [(-1, 1), (-10, 10), (-10, 5), (-10, 0), (-10, -5), (0, 10), (5, 10)]
25 |
26 | # scalars
27 | for (range_low, range_high) in test_ranges:
28 | check_scaled_actions_from_range(range_low, range_high, scalar=True)
29 |
30 | # 1d vectors: wrapped scalars
31 | for test_range in test_ranges:
32 | check_scaled_actions_from_range(*test_range)
33 |
34 | # 2d vectors: all combinations of ranges above
35 | for (r1_low, r1_high) in test_ranges:
36 | for (r2_low, r2_high) in test_ranges:
37 | check_scaled_actions_from_range(np.array([r1_low, r2_low], dtype=np.float),
38 | np.array([r1_high, r2_high], dtype=np.float))
39 |
40 |
41 | def check_scaled_actions_from_range(low, high, scalar=False):
42 | """
43 | helper method which creates dummy action space spanning between respective components of low and high
44 | and then checks scaling to and from tanh co-domain for low, middle and high value from that action space
45 | :param low: (np.ndarray), (int) or (float)
46 | :param high: (np.ndarray), (int) or (float)
47 | :param scalar: (bool) Whether consider scalar range or wrap it into 1d vector
48 | """
49 |
50 | if scalar and (isinstance(low, float) or isinstance(low, int)):
51 | ones = 1.
52 | action_space = Box(low, high, shape=(1,))
53 | else:
54 | low = np.atleast_1d(low)
55 | high = np.atleast_1d(high)
56 | ones = np.ones_like(low)
57 | action_space = Box(low, high)
58 |
59 | mid = 0.5 * (low + high)
60 |
61 | expected_mapping = [(low, -ones), (mid, 0. * ones), (high, ones)]
62 |
63 | for (not_scaled, scaled) in expected_mapping:
64 | assert np.allclose(scale_action(action_space, not_scaled), scaled)
65 | assert np.allclose(unscale_action(action_space, scaled), not_scaled)
66 |
67 |
68 | def test_batch_shape_invariant_to_scaling():
69 | """
70 | test that scaling deals well with batches as tensors and numpy matrices in terms of shape
71 | """
72 | action_space = Box(np.array([-10., -5., -1.]), np.array([10., 3., 2.]))
73 |
74 | tensor = tf.constant(1., shape=[2, 3])
75 | matrix = np.ones((2, 3))
76 |
77 | assert scale_action(action_space, tensor).shape == (2, 3)
78 | assert scale_action(action_space, matrix).shape == (2, 3)
79 |
80 | assert unscale_action(action_space, tensor).shape == (2, 3)
81 | assert unscale_action(action_space, matrix).shape == (2, 3)
82 |
--------------------------------------------------------------------------------
/tests/test_monitor.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | import json
3 | import os
4 |
5 | import pandas
6 | import gym
7 |
8 | from stable_baselines.bench import Monitor
9 | from stable_baselines.bench.monitor import get_monitor_files, load_results
10 |
11 |
12 | def test_monitor():
13 | """
14 | test the monitor wrapper
15 | """
16 | env = gym.make("CartPole-v1")
17 | env.seed(0)
18 | mon_file = "/tmp/stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())
19 | menv = Monitor(env, mon_file)
20 | menv.reset()
21 | for _ in range(1000):
22 | _, _, done, _ = menv.step(0)
23 | if done:
24 | menv.reset()
25 |
26 | file_handler = open(mon_file, 'rt')
27 |
28 | firstline = file_handler.readline()
29 | assert firstline.startswith('#')
30 | metadata = json.loads(firstline[1:])
31 | assert metadata['env_id'] == "CartPole-v1"
32 | assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata"
33 |
34 | last_logline = pandas.read_csv(file_handler, index_col=None)
35 | assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
36 | file_handler.close()
37 | os.remove(mon_file)
38 |
39 |
40 | def test_monitor_load_results(tmp_path):
41 | """
42 | test load_results on log files produced by the monitor wrapper
43 | """
44 | tmp_path = str(tmp_path)
45 | env1 = gym.make("CartPole-v1")
46 | env1.seed(0)
47 | monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
48 | monitor_env1 = Monitor(env1, monitor_file1)
49 |
50 | monitor_files = get_monitor_files(tmp_path)
51 | assert len(monitor_files) == 1
52 | assert monitor_file1 in monitor_files
53 |
54 | monitor_env1.reset()
55 | episode_count1 = 0
56 | for _ in range(1000):
57 | _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample())
58 | if done:
59 | episode_count1 += 1
60 | monitor_env1.reset()
61 |
62 | results_size1 = len(load_results(os.path.join(tmp_path)).index)
63 | assert results_size1 == episode_count1
64 |
65 | env2 = gym.make("CartPole-v1")
66 | env2.seed(0)
67 | monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
68 | monitor_env2 = Monitor(env2, monitor_file2)
69 | monitor_files = get_monitor_files(tmp_path)
70 | assert len(monitor_files) == 2
71 | assert monitor_file1 in monitor_files
72 | assert monitor_file2 in monitor_files
73 |
74 | monitor_env2.reset()
75 | episode_count2 = 0
76 | for _ in range(1000):
77 | _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample())
78 | if done:
79 | episode_count2 += 1
80 | monitor_env2.reset()
81 |
82 | results_size2 = len(load_results(os.path.join(tmp_path)).index)
83 |
84 | assert results_size2 == (results_size1 + episode_count2)
85 |
86 | os.remove(monitor_file1)
87 | os.remove(monitor_file2)
88 |
--------------------------------------------------------------------------------
/tests/test_mpi_adam.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import pytest
4 |
5 | from .test_common import _assert_eq
6 |
7 |
8 | def test_mpi_adam():
9 | """Test RunningMeanStd object for MPI"""
10 | # Test will be run in CI before pytest is run
11 | pytest.skip()
12 | return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2',
13 | 'python', '-m', 'stable_baselines.common.mpi_adam'])
14 | _assert_eq(return_code, 0)
15 |
16 |
17 | def test_mpi_adam_ppo1():
18 | """Running test for ppo1"""
19 | # Test will be run in CI before pytest is run
20 | pytest.skip()
21 | return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2',
22 | 'python', '-m',
23 | 'stable_baselines.ppo1.experiments.train_cartpole'])
24 | _assert_eq(return_code, 0)
25 |
--------------------------------------------------------------------------------
/tests/test_multiple_learn.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from stable_baselines import A2C, ACER, ACKTR, PPO2
4 | from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox
5 | from stable_baselines.common.vec_env import DummyVecEnv
6 |
7 | # TODO: Fix multiple-learn on commented-out models (Issue #619).
8 | MODEL_LIST = [
9 | A2C,
10 | ACER,
11 | ACKTR,
12 | PPO2,
13 |
14 | # MPI-based models, which use traj_segment_generator instead of Runner.
15 | #
16 | # PPO1,
17 | # TRPO,
18 |
19 | # Off-policy models, which don't use Runner but reset every .learn() anyways.
20 | #
21 | # DDPG,
22 | # SAC,
23 | # TD3,
24 | ]
25 |
26 |
27 | @pytest.mark.parametrize("model_class", MODEL_LIST)
28 | def test_model_multiple_learn_no_reset(model_class):
29 | """Check that when we call learn multiple times, we don't unnecessarily
30 | reset the environment.
31 | """
32 | if model_class is ACER:
33 | def make_env():
34 | return IdentityEnv(ep_length=1e10, dim=2)
35 | else:
36 | def make_env():
37 | return IdentityEnvBox(ep_length=1e10)
38 | env = make_env()
39 | venv = DummyVecEnv([lambda: env])
40 | model = model_class(policy="MlpPolicy", env=venv)
41 | _check_reset_count(model, env)
42 |
43 | # Try again following a `set_env`.
44 | env = make_env()
45 | venv = DummyVecEnv([lambda: env])
46 | assert env.num_resets == 0
47 |
48 | model.set_env(venv)
49 | _check_reset_count(model, env)
50 |
51 |
52 | def _check_reset_count(model, env: IdentityEnv):
53 | assert env.num_resets == 0
54 | _prev_runner = None
55 | for _ in range(2):
56 | model.learn(total_timesteps=300)
57 | # Lazy constructor for Runner fires upon the first call to learn.
58 | assert env.num_resets == 1
59 | if _prev_runner is not None:
60 | assert _prev_runner is model.runner, "Runner shouldn't change"
61 | _prev_runner = model.runner
62 |
--------------------------------------------------------------------------------
/tests/test_no_mpi.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from .test_common import _maybe_disable_mpi
4 |
5 |
6 | def test_no_mpi_no_crash():
7 | with _maybe_disable_mpi(True):
8 | # Temporarily delete previously imported stable baselines
9 | old_modules = {}
10 | sb_modules = [name for name in sys.modules.keys()
11 | if name.startswith('stable_baselines')]
12 | for name in sb_modules:
13 | old_modules[name] = sys.modules.pop(name)
14 |
15 | # Re-import (with mpi disabled)
16 | import stable_baselines
17 | del stable_baselines # appease Codacy
18 |
19 | # Restore old version of stable baselines (with MPI imported)
20 | for name, mod in old_modules.items():
21 | sys.modules[name] = mod
22 |
--------------------------------------------------------------------------------
/tests/test_ppo2.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import gym
5 |
6 | from stable_baselines import PPO2
7 | from stable_baselines.common import make_vec_env
8 | from stable_baselines.common.vec_env import DummyVecEnv
9 |
10 |
11 | @pytest.mark.parametrize("cliprange", [0.2, lambda x: 0.1 * x])
12 | @pytest.mark.parametrize("cliprange_vf", [None, 0.2, lambda x: 0.3 * x, -1.0])
13 | def test_clipping(tmp_path, cliprange, cliprange_vf):
14 | """Test the different clipping (policy and vf)"""
15 | model = PPO2(
16 | "MlpPolicy",
17 | "CartPole-v1",
18 | cliprange=cliprange,
19 | cliprange_vf=cliprange_vf,
20 | noptepochs=2,
21 | n_steps=64,
22 | ).learn(100)
23 | save_path = os.path.join(str(tmp_path), "ppo2_clip.zip")
24 | model.save(save_path)
25 | env = model.get_env()
26 | model = PPO2.load(save_path, env=env)
27 | model.learn(100)
28 |
29 | if os.path.exists(save_path):
30 | os.remove(save_path)
31 |
32 |
33 | def test_ppo2_update_n_batch_on_load(tmp_path):
34 | env = make_vec_env("CartPole-v1", n_envs=2)
35 | model = PPO2("MlpPolicy", env, n_steps=10, nminibatches=1)
36 | save_path = os.path.join(str(tmp_path), "ppo2_cartpole.zip")
37 |
38 | model.learn(total_timesteps=100)
39 | model.save(save_path)
40 |
41 | del model
42 |
43 | model = PPO2.load(save_path)
44 | test_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
45 |
46 | model.set_env(test_env)
47 | model.learn(total_timesteps=100)
48 |
--------------------------------------------------------------------------------
/tests/test_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from stable_baselines.common.buffers import ReplayBuffer, PrioritizedReplayBuffer
4 |
5 |
6 | def test_extend_uniform():
7 | nvals = 16
8 | states = [np.random.rand(2, 2) for _ in range(nvals)]
9 | actions = [np.random.rand(2) for _ in range(nvals)]
10 | rewards = [np.random.rand() for _ in range(nvals)]
11 | newstate = [np.random.rand(2, 2) for _ in range(nvals)]
12 | done = [np.random.randint(0, 2) for _ in range(nvals)]
13 |
14 | size = 32
15 | baseline = ReplayBuffer(size)
16 | ext = ReplayBuffer(size)
17 | for data in zip(states, actions, rewards, newstate, done):
18 | baseline.add(*data)
19 |
20 | states, actions, rewards, newstates, done = map(
21 | np.array, [states, actions, rewards, newstate, done])
22 |
23 | ext.extend(states, actions, rewards, newstates, done)
24 | assert len(baseline) == len(ext)
25 |
26 | # Check buffers have same values
27 | for i in range(nvals):
28 | for j in range(5):
29 | condition = (baseline.storage[i][j] == ext.storage[i][j])
30 | if isinstance(condition, np.ndarray):
31 | # for obs, obs_t1
32 | assert np.all(condition)
33 | else:
34 | # for done, reward action
35 | assert condition
36 |
37 |
38 | def test_extend_prioritized():
39 | nvals = 16
40 | states = [np.random.rand(2, 2) for _ in range(nvals)]
41 | actions = [np.random.rand(2) for _ in range(nvals)]
42 | rewards = [np.random.rand() for _ in range(nvals)]
43 | newstate = [np.random.rand(2, 2) for _ in range(nvals)]
44 | done = [np.random.randint(0, 2) for _ in range(nvals)]
45 |
46 | size = 32
47 | alpha = 0.99
48 | baseline = PrioritizedReplayBuffer(size, alpha)
49 | ext = PrioritizedReplayBuffer(size, alpha)
50 | for data in zip(states, actions, rewards, newstate, done):
51 | baseline.add(*data)
52 |
53 | states, actions, rewards, newstates, done = map(
54 | np.array, [states, actions, rewards, newstate, done])
55 |
56 | ext.extend(states, actions, rewards, newstates, done)
57 | assert len(baseline) == len(ext)
58 |
59 | # Check buffers have same values
60 | for i in range(nvals):
61 | for j in range(5):
62 | condition = (baseline.storage[i][j] == ext.storage[i][j])
63 | if isinstance(condition, np.ndarray):
64 | # for obs, obs_t1
65 | assert np.all(condition)
66 | else:
67 | # for done, reward action
68 | assert condition
69 |
70 | # assert priorities
71 | assert (baseline._it_min._value == ext._it_min._value).all()
72 | assert (baseline._it_sum._value == ext._it_sum._value).all()
73 |
--------------------------------------------------------------------------------
/tests/test_schedules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from stable_baselines.common.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule
4 |
5 |
6 | def test_piecewise_schedule():
7 | """
8 | test PiecewiseSchedule
9 | """
10 | piecewise_sched = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)],
11 | outside_value=500)
12 |
13 | assert np.isclose(piecewise_sched.value(-10), 500)
14 | assert np.isclose(piecewise_sched.value(0), 150)
15 | assert np.isclose(piecewise_sched.value(5), 200)
16 | assert np.isclose(piecewise_sched.value(9), 80)
17 | assert np.isclose(piecewise_sched.value(50), 50)
18 | assert np.isclose(piecewise_sched.value(80), 50)
19 | assert np.isclose(piecewise_sched.value(150), 0)
20 | assert np.isclose(piecewise_sched.value(175), -25)
21 | assert np.isclose(piecewise_sched.value(201), 500)
22 | assert np.isclose(piecewise_sched.value(500), 500)
23 |
24 | assert np.isclose(piecewise_sched.value(200 - 1e-10), -50)
25 |
26 |
27 | def test_constant_schedule():
28 | """
29 | test ConstantSchedule
30 | """
31 | constant_sched = ConstantSchedule(5)
32 | for i in range(-100, 100):
33 | assert np.isclose(constant_sched.value(i), 5)
34 |
35 |
36 | def test_linear_schedule():
37 | """
38 | test LinearSchedule
39 | """
40 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.2, final_p=0.8)
41 | assert np.isclose(linear_sched.value(50), 0.5)
42 | assert np.isclose(linear_sched.value(0), 0.2)
43 | assert np.isclose(linear_sched.value(100), 0.8)
44 |
45 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.8, final_p=0.2)
46 | assert np.isclose(linear_sched.value(50), 0.5)
47 | assert np.isclose(linear_sched.value(0), 0.8)
48 | assert np.isclose(linear_sched.value(100), 0.2)
49 |
50 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=-0.6, final_p=0.2)
51 | assert np.isclose(linear_sched.value(50), -0.2)
52 | assert np.isclose(linear_sched.value(0), -0.6)
53 | assert np.isclose(linear_sched.value(100), 0.2)
54 |
55 | linear_sched = LinearSchedule(schedule_timesteps=100, initial_p=0.2, final_p=-0.6)
56 | assert np.isclose(linear_sched.value(50), -0.2)
57 | assert np.isclose(linear_sched.value(0), 0.2)
58 | assert np.isclose(linear_sched.value(100), -0.6)
59 |
--------------------------------------------------------------------------------
/tests/test_tensorboard.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import pytest
5 |
6 | from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, SAC, TD3, TRPO
7 |
8 | TENSORBOARD_DIR = '/tmp/tb_dir/'
9 |
10 | if os.path.isdir(TENSORBOARD_DIR):
11 | shutil.rmtree(TENSORBOARD_DIR)
12 |
13 | MODEL_DICT = {
14 | 'a2c': (A2C, 'CartPole-v1'),
15 | 'acer': (ACER, 'CartPole-v1'),
16 | 'acktr': (ACKTR, 'CartPole-v1'),
17 | 'dqn': (DQN, 'CartPole-v1'),
18 | 'ddpg': (DDPG, 'Pendulum-v0'),
19 | 'ppo1': (PPO1, 'CartPole-v1'),
20 | 'ppo2': (PPO2, 'CartPole-v1'),
21 | 'sac': (SAC, 'Pendulum-v0'),
22 | 'td3': (TD3, 'Pendulum-v0'),
23 | 'trpo': (TRPO, 'CartPole-v1'),
24 | }
25 |
26 | N_STEPS = 300
27 |
28 |
29 | @pytest.mark.parametrize("model_name", MODEL_DICT.keys())
30 | def test_tensorboard(model_name):
31 | logname = model_name.upper()
32 | algo, env_id = MODEL_DICT[model_name]
33 | model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR)
34 | model.learn(N_STEPS)
35 | model.learn(N_STEPS, reset_num_timesteps=False)
36 |
37 | assert os.path.isdir(TENSORBOARD_DIR + logname + "_1")
38 | assert not os.path.isdir(TENSORBOARD_DIR + logname + "_2")
39 |
40 |
41 | @pytest.mark.parametrize("model_name", MODEL_DICT.keys())
42 | def test_multiple_runs(model_name):
43 | logname = "tb_multiple_runs_" + model_name
44 | algo, env_id = MODEL_DICT[model_name]
45 | model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR)
46 | model.learn(N_STEPS, tb_log_name=logname)
47 | model.learn(N_STEPS, tb_log_name=logname)
48 |
49 | assert os.path.isdir(TENSORBOARD_DIR + logname + "_1")
50 | # Check that the log dir name increments correctly
51 | assert os.path.isdir(TENSORBOARD_DIR + logname + "_2")
52 |
--------------------------------------------------------------------------------
/tests/test_tf_util.py:
--------------------------------------------------------------------------------
1 | # tests for tf_util
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from stable_baselines.common.tf_util import function, initialize, single_threaded_session, is_image
6 |
7 |
8 | def test_function():
9 | """
10 | test the function function in tf_util
11 | """
12 | with tf.Graph().as_default():
13 | x_ph = tf.placeholder(tf.int32, (), name="x")
14 | y_ph = tf.placeholder(tf.int32, (), name="y")
15 | z_ph = 3 * x_ph + 2 * y_ph
16 | linear_fn = function([x_ph, y_ph], z_ph, givens={y_ph: 0})
17 |
18 | with single_threaded_session():
19 | initialize()
20 |
21 | assert linear_fn(2) == 6
22 | assert linear_fn(2, 2) == 10
23 |
24 |
25 | def test_multikwargs():
26 | """
27 | test the function function in tf_util
28 | """
29 | with tf.Graph().as_default():
30 | x_ph = tf.placeholder(tf.int32, (), name="x")
31 | with tf.variable_scope("other"):
32 | x2_ph = tf.placeholder(tf.int32, (), name="x")
33 | z_ph = 3 * x_ph + 2 * x2_ph
34 |
35 | linear_fn = function([x_ph, x2_ph], z_ph, givens={x2_ph: 0})
36 | with single_threaded_session():
37 | initialize()
38 | assert linear_fn(2) == 6
39 | assert linear_fn(2, 2) == 10
40 |
41 |
42 | def test_image_detection():
43 | rgb = (32, 64, 3)
44 | gray = (43, 23, 1)
45 | rgbd = (12, 32, 4)
46 | invalid_1 = (32, 12)
47 | invalid_2 = (12, 32, 6)
48 |
49 | # TF checks
50 | for shape in (rgb, gray, rgbd):
51 | assert is_image(tf.placeholder(tf.uint8, shape=shape))
52 |
53 | for shape in (invalid_1, invalid_2):
54 | assert not is_image(tf.placeholder(tf.uint8, shape=shape))
55 |
56 | # Numpy checks
57 | for shape in (rgb, gray, rgbd):
58 | assert is_image(np.ones(shape))
59 |
60 | for shape in (invalid_1, invalid_2):
61 | assert not is_image(np.ones(shape))
62 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import pytest
5 | import gym
6 |
7 | from stable_baselines import A2C
8 | from stable_baselines.bench.monitor import Monitor
9 | from stable_baselines.common.evaluation import evaluate_policy
10 | from stable_baselines.common.cmd_util import make_vec_env
11 | from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
12 |
13 |
14 | @pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')])
15 | @pytest.mark.parametrize("n_envs", [1, 2])
16 | @pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv])
17 | @pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit])
18 | def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
19 | env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls,
20 | wrapper_class=wrapper_class, monitor_dir=None, seed=0)
21 |
22 | assert env.num_envs == n_envs
23 |
24 | if vec_env_cls is None:
25 | assert isinstance(env, DummyVecEnv)
26 | if wrapper_class is not None:
27 | assert isinstance(env.envs[0], wrapper_class)
28 | else:
29 | assert isinstance(env.envs[0], Monitor)
30 | else:
31 | assert isinstance(env, SubprocVecEnv)
32 | # Kill subprocesses
33 | env.close()
34 |
35 |
36 | def test_custom_vec_env():
37 | """
38 | Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
39 | """
40 | monitor_dir = 'logs/test_make_vec_env/'
41 | env = make_vec_env('CartPole-v1', n_envs=1,
42 | monitor_dir=monitor_dir, seed=0,
43 | vec_env_cls=SubprocVecEnv, vec_env_kwargs={'start_method': None})
44 |
45 | assert env.num_envs == 1
46 | assert isinstance(env, SubprocVecEnv)
47 | assert os.path.isdir('logs/test_make_vec_env/')
48 | # Kill subprocess
49 | env.close()
50 | # Cleanup folder
51 | shutil.rmtree(monitor_dir)
52 |
53 | # This should fail because DummyVecEnv does not have any keyword argument
54 | with pytest.raises(TypeError):
55 | make_vec_env('CartPole-v1', n_envs=1, vec_env_kwargs={'dummy': False})
56 |
57 |
58 | def test_evaluate_policy():
59 | model = A2C('MlpPolicy', 'Pendulum-v0', seed=0)
60 | n_steps_per_episode, n_eval_episodes = 200, 2
61 | model.n_callback_calls = 0
62 |
63 | def dummy_callback(locals_, _globals):
64 | locals_['model'].n_callback_calls += 1
65 |
66 | _, episode_lengths = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True,
67 | render=False, callback=dummy_callback, reward_threshold=None,
68 | return_episode_rewards=True)
69 |
70 | n_steps = sum(episode_lengths)
71 | assert n_steps == n_steps_per_episode * n_eval_episodes
72 | assert n_steps == model.n_callback_calls
73 |
74 | # Reaching a mean reward of zero is impossible with the Pendulum env
75 | with pytest.raises(AssertionError):
76 | evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0)
77 |
78 | episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
79 | assert len(episode_rewards) == n_eval_episodes
80 |
--------------------------------------------------------------------------------
/tests/test_vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import spaces
3 | import numpy as np
4 |
5 | from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan
6 |
7 |
8 | class NanAndInfEnv(gym.Env):
9 | """Custom Environment that raised NaNs and Infs"""
10 | metadata = {'render.modes': ['human']}
11 |
12 | def __init__(self):
13 | super(NanAndInfEnv, self).__init__()
14 | self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
15 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
16 |
17 | @staticmethod
18 | def step(action):
19 | if np.all(np.array(action) > 0):
20 | obs = float('NaN')
21 | elif np.all(np.array(action) < 0):
22 | obs = float('inf')
23 | else:
24 | obs = 0
25 | return [obs], 0.0, False, {}
26 |
27 | @staticmethod
28 | def reset():
29 | return [0.0]
30 |
31 | def render(self, mode='human', close=False):
32 | pass
33 |
34 |
35 | def test_check_nan():
36 | """Test VecCheckNan Object"""
37 |
38 | env = DummyVecEnv([NanAndInfEnv])
39 | env = VecCheckNan(env, raise_exception=True)
40 |
41 | env.step([[0]])
42 |
43 | try:
44 | env.step([[float('NaN')]])
45 | except ValueError:
46 | pass
47 | else:
48 | assert False
49 |
50 | try:
51 | env.step([[float('inf')]])
52 | except ValueError:
53 | pass
54 | else:
55 | assert False
56 |
57 | try:
58 | env.step([[-1]])
59 | except ValueError:
60 | pass
61 | else:
62 | assert False
63 |
64 | try:
65 | env.step([[1]])
66 | except ValueError:
67 | pass
68 | else:
69 | assert False
70 |
71 | env.step(np.array([[0, 1], [0, 1]]))
72 |
--------------------------------------------------------------------------------