├── logs
└── .keep
├── output
└── .keep
├── stable-baselines3
├── tests
│ ├── __init__.py
│ ├── test_tensorboard.py
│ ├── test_deterministic.py
│ ├── test_custom_policy.py
│ ├── test_cnn.py
│ ├── test_predict.py
│ ├── test_vec_check_nan.py
│ ├── test_spaces.py
│ ├── test_identity.py
│ ├── test_callbacks.py
│ ├── test_logger.py
│ ├── test_sde.py
│ ├── test_run.py
│ ├── test_monitor.py
│ ├── test_distributions.py
│ └── test_envs.py
├── .dockerignore
├── stable_baselines3
│ ├── py.typed
│ ├── version.txt
│ ├── a2c
│ │ ├── __init__.py
│ │ └── policies.py
│ ├── dqn
│ │ └── __init__.py
│ ├── ppo
│ │ ├── __init__.py
│ │ └── policies.py
│ ├── sac
│ │ └── __init__.py
│ ├── td3
│ │ └── __init__.py
│ ├── ddpg
│ │ ├── __init__.py
│ │ └── policies.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── type_aliases.py
│ │ ├── running_mean_std.py
│ │ ├── vec_env
│ │ │ ├── __init__.py
│ │ │ ├── vec_frame_stack.py
│ │ │ ├── vec_transpose.py
│ │ │ ├── util.py
│ │ │ ├── vec_check_nan.py
│ │ │ ├── vec_video_recorder.py
│ │ │ └── dummy_vec_env.py
│ │ ├── evaluation.py
│ │ ├── results_plotter.py
│ │ ├── bit_flipping_env.py
│ │ ├── preprocessing.py
│ │ └── identity_env.py
│ └── __init__.py
├── scripts
│ ├── run_tests.sh
│ ├── run_docker_cpu.sh
│ ├── run_docker_gpu.sh
│ └── build_docker.sh
├── docs
│ ├── common
│ │ ├── utils.rst
│ │ ├── logger.rst
│ │ ├── noise.rst
│ │ ├── monitor.rst
│ │ ├── cmd_util.rst
│ │ ├── evaluation.rst
│ │ ├── atari_wrappers.rst
│ │ ├── env_checker.rst
│ │ └── distributions.rst
│ ├── _static
│ │ ├── img
│ │ │ ├── logo.png
│ │ │ ├── mistake.png
│ │ │ ├── try_it.png
│ │ │ ├── breakout.gif
│ │ │ ├── Tensorboard_example.png
│ │ │ ├── colab.svg
│ │ │ └── colab-badge.svg
│ │ └── css
│ │ │ └── baselines_theme.css
│ ├── guide
│ │ ├── migration.rst
│ │ ├── rl.rst
│ │ ├── quickstart.rst
│ │ ├── algos.rst
│ │ ├── tensorboard.rst
│ │ ├── custom_env.rst
│ │ ├── vec_envs.rst
│ │ ├── rl_zoo.rst
│ │ ├── custom_policy.rst
│ │ ├── developer.rst
│ │ └── install.rst
│ ├── conda_env.yml
│ ├── README.md
│ ├── Makefile
│ ├── modules
│ │ ├── base.rst
│ │ ├── a2c.rst
│ │ ├── dqn.rst
│ │ ├── ppo.rst
│ │ ├── ddpg.rst
│ │ ├── td3.rst
│ │ └── sac.rst
│ ├── make.bat
│ ├── misc
│ │ └── projects.rst
│ ├── spelling_wordlist.txt
│ └── index.rst
├── .coveragerc
├── .gitlab-ci.yml
├── .readthedocs.yml
├── .gitignore
├── LICENSE
├── NOTICE
├── Dockerfile
├── Makefile
├── .github
│ ├── workflows
│ │ └── ci.yml
│ ├── ISSUE_TEMPLATE
│ │ └── issue-template.md
│ └── PULL_REQUEST_TEMPLATE.md
├── setup.cfg
├── CONTRIBUTING.md
└── setup.py
├── .gitmodules
├── bashfiles
├── blocks_train_selfplay.bash
├── hanabi_train_selfplay.bash
├── arms_train_selfplay.bash
├── hanabi_adapt_to_selfplay.bash
├── arms_human_adapt_to_fixed.bash
├── blocks_adapt_to_selfplay.bash
├── arms_adapt_to_selfplay.bash
├── blocks_adapt_to_fixed.bash
└── arms_adapt_to_fixed.bash
├── setup.py
├── my_gym
├── envs
│ ├── __init__.py
│ ├── hanabi_env.py
│ ├── arms_human_env.py
│ ├── arms_env.py
│ └── generate_grid.py
└── __init__.py
├── tabular.py
├── README.md
├── partner.py
└── run_arms_human.py
/logs/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/stable-baselines3/.dockerignore:
--------------------------------------------------------------------------------
1 | .gitignore
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/version.txt:
--------------------------------------------------------------------------------
1 | 0.8.0a5
2 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "hanabi"]
2 | path = hanabi
3 | url = https://github.com/deepmind/hanabi-learning-environment
4 |
--------------------------------------------------------------------------------
/bashfiles/blocks_train_selfplay.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | python run_blocks.py --run=$runid --selfplay
--------------------------------------------------------------------------------
/bashfiles/hanabi_train_selfplay.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | python run_hanabi.py --run=$runid --selfplay
--------------------------------------------------------------------------------
/stable-baselines3/scripts/run_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v
3 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/utils.rst:
--------------------------------------------------------------------------------
1 | .. _utils:
2 |
3 | Utils
4 | =====
5 |
6 | .. automodule:: stable_baselines3.common.utils
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/logger.rst:
--------------------------------------------------------------------------------
1 | .. _logger:
2 |
3 | Logger
4 | ======
5 |
6 | .. automodule:: stable_baselines3.common.logger
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/logo.png
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/mistake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/mistake.png
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/try_it.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/try_it.png
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/noise.rst:
--------------------------------------------------------------------------------
1 | .. _noise:
2 |
3 | Action Noise
4 | =============
5 |
6 | .. automodule:: stable_baselines3.common.noise
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.a2c.a2c import A2C
2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy
3 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/dqn/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.dqn.dqn import DQN
2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy
3 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy
2 | from stable_baselines3.ppo.ppo import PPO
3 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/sac/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy
2 | from stable_baselines3.sac.sac import SAC
3 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy
2 | from stable_baselines3.td3.td3 import TD3
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='my_gym',
4 | version='0.0.1',
5 | install_requires=['gym']#And any other dependencies required
6 | )
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/breakout.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/breakout.gif
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.ddpg.ddpg import DDPG
2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy
3 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/monitor.rst:
--------------------------------------------------------------------------------
1 | .. _monitor:
2 |
3 | Monitor Wrapper
4 | ===============
5 |
6 | .. automodule:: stable_baselines3.common.monitor
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/ddpg/policies.py:
--------------------------------------------------------------------------------
1 | # DDPG can be view as a special case of TD3
2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy # noqa:F401
3 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/cmd_util.rst:
--------------------------------------------------------------------------------
1 | .. _cmd_util:
2 |
3 | Command Utils
4 | =========================
5 |
6 | .. automodule:: stable_baselines3.common.cmd_util
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/evaluation.rst:
--------------------------------------------------------------------------------
1 | .. _eval:
2 |
3 | Evaluation Helper
4 | =================
5 |
6 | .. automodule:: stable_baselines3.common.evaluation
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/Tensorboard_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/Tensorboard_example.png
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/atari_wrappers.rst:
--------------------------------------------------------------------------------
1 | .. _atari_wrapper:
2 |
3 | Atari Wrappers
4 | ==============
5 |
6 | .. automodule:: stable_baselines3.common.atari_wrappers
7 | :members:
8 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
2 | from stable_baselines3.common.utils import set_random_seed
3 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/env_checker.rst:
--------------------------------------------------------------------------------
1 | .. _env_checker:
2 |
3 | Gym Environment Checker
4 | ========================
5 |
6 | .. automodule:: stable_baselines3.common.env_checker
7 | :members:
8 |
--------------------------------------------------------------------------------
/bashfiles/arms_train_selfplay.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | m=$2
7 |
8 | for i in {0..9}
9 | do
10 | python run_arms.py --run=$(($runid + $i)) --selfplay --m=${m}
11 | done
--------------------------------------------------------------------------------
/my_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from my_gym.envs.blocks_env import BlocksEnv
2 | from my_gym.envs.arms_env import ArmsEnv
3 | from my_gym.envs.hanabi_env import HanabiEnvWrapper
4 | from my_gym.envs.arms_human_env import ArmsHumanEnv
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/migration.rst:
--------------------------------------------------------------------------------
1 | .. _migration:
2 |
3 | ================================
4 | Migrating from Stable-Baselines
5 | ================================
6 |
7 |
8 | This is a guide to migrate from Stable-Baselines to Stable-Baselines3.
9 |
10 | It also references the main changes.
11 |
12 | **TODO**
13 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: root
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - cpuonly=1.0=0
7 | - pip=20.0
8 | - python=3.6
9 | - pytorch=1.5.0=py3.6_cpu_0
10 | - pip:
11 | - gym==0.17.2
12 | - cloudpickle
13 | - opencv-python-headless
14 | - pandas
15 | - numpy
16 | - matplotlib
17 |
--------------------------------------------------------------------------------
/stable-baselines3/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = False
3 | omit =
4 | tests/*
5 | setup.py
6 | # Require graphical interface
7 | stable_baselines3/common/results_plotter.py
8 | # Require ffmpeg
9 | stable_baselines3/common/vec_env/vec_video_recorder.py
10 |
11 | [report]
12 | exclude_lines =
13 | pragma: no cover
14 | raise NotImplementedError()
15 | if typing.TYPE_CHECKING:
16 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/a2c/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for A2C
3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy
4 |
5 | MlpPolicy = ActorCriticPolicy
6 | CnnPolicy = ActorCriticCnnPolicy
7 |
8 | register_policy("MlpPolicy", ActorCriticPolicy)
9 | register_policy("CnnPolicy", ActorCriticCnnPolicy)
10 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/ppo/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for PPO
3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy
4 |
5 | MlpPolicy = ActorCriticPolicy
6 | CnnPolicy = ActorCriticCnnPolicy
7 |
8 | register_policy("MlpPolicy", ActorCriticPolicy)
9 | register_policy("CnnPolicy", ActorCriticCnnPolicy)
10 |
--------------------------------------------------------------------------------
/stable-baselines3/scripts/run_docker_cpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Launch an experiment using the docker cpu image
3 |
4 | cmd_line="$@"
5 |
6 | echo "Executing in the docker (cpu image):"
7 | echo $cmd_line
8 |
9 | docker run -it --rm --network host --ipc=host \
10 | --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \
11 | bash -c "cd /root/code/stable-baselines3/ && $cmd_line"
12 |
--------------------------------------------------------------------------------
/stable-baselines3/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 | image: stablebaselines/stable-baselines3-cpu:0.8.0a4
2 |
3 | type-check:
4 | script:
5 | - make type
6 |
7 | pytest:
8 | script:
9 | # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
10 | - MKL_THREADING_LAYER=GNU make pytest
11 |
12 | doc-build:
13 | script:
14 | - make doc
15 |
16 | lint-check:
17 | script:
18 | - make check-codestyle
19 | - make lint
20 |
--------------------------------------------------------------------------------
/my_gym/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id='blocks-v0',
5 | entry_point='my_gym.envs:BlocksEnv',
6 | )
7 |
8 | register(
9 | id='arms-v0',
10 | entry_point='my_gym.envs:ArmsEnv',
11 | )
12 |
13 | register(
14 | id='hanabi-v0',
15 | entry_point='my_gym.envs:HanabiEnvWrapper',
16 | )
17 |
18 | register(
19 | id='arms-human-v0',
20 | entry_point='my_gym.envs:ArmsHumanEnv',
21 | )
--------------------------------------------------------------------------------
/stable-baselines3/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Build documentation in the docs/ directory with Sphinx
8 | sphinx:
9 | configuration: docs/conf.py
10 |
11 | # Optionally build your docs in additional formats such as PDF and ePub
12 | formats: all
13 |
14 | # Set requirements using conda env
15 | conda:
16 | environment: docs/conda_env.yml
17 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Stable Baselines3 Documentation
2 |
3 | This folder contains documentation for the RL baselines.
4 |
5 |
6 | ### Build the Documentation
7 |
8 | #### Install Sphinx and Theme
9 |
10 | ```
11 | pip install sphinx sphinx-autobuild sphinx-rtd-theme
12 | ```
13 |
14 | #### Building the Docs
15 |
16 | In the `docs/` folder:
17 | ```
18 | make html
19 | ```
20 |
21 | if you want to building each time a file is changed:
22 |
23 | ```
24 | sphinx-autobuild . _build/html
25 | ```
26 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from stable_baselines3.a2c import A2C
4 | from stable_baselines3.ddpg import DDPG
5 | from stable_baselines3.dqn import DQN
6 | from stable_baselines3.ppo import PPO
7 | from stable_baselines3.sac import SAC
8 | from stable_baselines3.td3 import TD3
9 |
10 | # Read version from file
11 | version_file = os.path.join(os.path.dirname(__file__), "version.txt")
12 | with open(version_file, "r") as file_handler:
13 | __version__ = file_handler.read().strip()
14 |
--------------------------------------------------------------------------------
/stable-baselines3/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | *.pkl
4 | *.py~
5 | *.bak
6 | .pytest_cache
7 | .DS_Store
8 | .idea
9 | .coverage
10 | .coverage.*
11 | __pycache__/
12 | _build/
13 | *.npz
14 | *.pth
15 | .pytype/
16 | git_rewrite_commit_history.sh
17 |
18 | # Setuptools distribution and build folders.
19 | /dist/
20 | /build
21 | keys/
22 |
23 | # Virtualenv
24 | /env
25 | /venv
26 |
27 |
28 | *.sublime-project
29 | *.sublime-workspace
30 |
31 | .idea
32 |
33 | logs/
34 |
35 | .ipynb_checkpoints
36 | ghostdriver.log
37 |
38 | htmlcov
39 |
40 | junk
41 | src
42 |
43 | *.egg-info
44 | .cache
45 | *.lprof
46 | *.prof
47 |
48 | MUJOCO_LOG.TXT
49 |
--------------------------------------------------------------------------------
/stable-baselines3/scripts/run_docker_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Launch an experiment using the docker gpu image
3 |
4 | cmd_line="$@"
5 |
6 | echo "Executing in the docker (gpu image):"
7 | echo $cmd_line
8 |
9 | # TODO: always use new-style once sufficiently widely used (probably 2021 onwards)
10 | if [ -x "$(which nvidia-docker)" ]; then
11 | # old-style nvidia-docker2
12 | NVIDIA_ARG="--runtime=nvidia"
13 | else
14 | NVIDIA_ARG="--gpus all"
15 | fi
16 |
17 | docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \
18 | --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \
19 | bash -c "cd /root/code/stable-baselines3/ && $cmd_line"
20 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W # make warnings fatal
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = StableBaselines
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/base.rst:
--------------------------------------------------------------------------------
1 | .. _base_algo:
2 |
3 | .. automodule:: stable_baselines3.common.base_class
4 |
5 |
6 | Base RL Class
7 | =============
8 |
9 | Common interface for all the RL algorithms
10 |
11 | .. autoclass:: BaseAlgorithm
12 | :members:
13 |
14 |
15 | .. automodule:: stable_baselines3.common.off_policy_algorithm
16 |
17 |
18 | Base Off-Policy Class
19 | ^^^^^^^^^^^^^^^^^^^^^
20 |
21 | The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)
22 |
23 | .. autoclass:: OffPolicyAlgorithm
24 | :members:
25 |
26 |
27 | .. automodule:: stable_baselines3.common.on_policy_algorithm
28 |
29 |
30 | Base On-Policy Class
31 | ^^^^^^^^^^^^^^^^^^^^^
32 |
33 | The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)
34 |
35 | .. autoclass:: OnPolicyAlgorithm
36 | :members:
37 |
--------------------------------------------------------------------------------
/stable-baselines3/scripts/build_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CPU_PARENT=ubuntu:16.04
4 | GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04
5 |
6 | TAG=stablebaselines/stable-baselines3
7 | VERSION=$(cat ./stable_baselines3/version.txt)
8 |
9 | if [[ ${USE_GPU} == "True" ]]; then
10 | PARENT=${GPU_PARENT}
11 | PYTORCH_DEPS="cudatoolkit=10.1"
12 | else
13 | PARENT=${CPU_PARENT}
14 | PYTORCH_DEPS="cpuonly"
15 | TAG="${TAG}-cpu"
16 | fi
17 |
18 | echo "docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} ."
19 | docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} .
20 | docker tag ${TAG}:${VERSION} ${TAG}:latest
21 |
22 | if [[ ${RELEASE} == "True" ]]; then
23 | docker push ${TAG}:${VERSION}
24 | docker push ${TAG}:latest
25 | fi
26 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/rl.rst:
--------------------------------------------------------------------------------
1 | .. _rl:
2 |
3 | ================================
4 | Reinforcement Learning Resources
5 | ================================
6 |
7 |
8 | Stable-Baselines3 assumes that you already understand the basic concepts of Reinforcement Learning (RL).
9 |
10 | However, if you want to learn about RL, there are several good resources to get started:
11 |
12 | - `OpenAI Spinning Up `_
13 | - `David Silver's course `_
14 | - `Lilian Weng's blog `_
15 | - `Berkeley's Deep RL Bootcamp `_
16 | - `Berkeley's Deep Reinforcement Learning course `_
17 | - `More resources `_
18 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=StableBaselines
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/quickstart.rst:
--------------------------------------------------------------------------------
1 | .. _quickstart:
2 |
3 | ===============
4 | Getting Started
5 | ===============
6 |
7 | Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
8 |
9 | Here is a quick example of how to train and run A2C on a CartPole environment:
10 |
11 | .. code-block:: python
12 |
13 | import gym
14 |
15 | from stable_baselines3 import A2C
16 |
17 | env = gym.make('CartPole-v1')
18 |
19 | model = A2C('MlpPolicy', env, verbose=1)
20 | model.learn(total_timesteps=10000)
21 |
22 | obs = env.reset()
23 | for i in range(1000):
24 | action, _state = model.predict(obs, deterministic=True)
25 | obs, reward, done, info = env.step(action)
26 | env.render()
27 | if done:
28 | obs = env.reset()
29 |
30 |
31 | Or just train a model with a one liner if
32 | `the environment is registered in Gym `_ and if
33 | the policy is registered:
34 |
35 | .. code-block:: python
36 |
37 | from stable_baselines3 import A2C
38 |
39 | model = A2C('MlpPolicy', 'CartPole-v1').learn(10000)
40 |
--------------------------------------------------------------------------------
/stable-baselines3/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2019 Antonin Raffin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/css/baselines_theme.css:
--------------------------------------------------------------------------------
1 | /* Main colors adapted from pytorch doc */
2 | :root{
3 | --main-bg-color: #343A40;
4 | --link-color: #FD7E14;
5 | }
6 |
7 | /* Header fonts y */
8 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
9 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
10 | }
11 |
12 |
13 | /* Docs background */
14 | .wy-side-nav-search{
15 | background-color: var(--main-bg-color);
16 | }
17 |
18 | /* Mobile version */
19 | .wy-nav-top{
20 | background-color: var(--main-bg-color);
21 | }
22 |
23 | /* Change link colors (except for the menu) */
24 | a {
25 | color: var(--link-color);
26 | }
27 |
28 | a:hover {
29 | color: #4F778F;
30 | }
31 |
32 | .wy-menu a {
33 | color: #b3b3b3;
34 | }
35 |
36 | .wy-menu a:hover {
37 | color: #b3b3b3;
38 | }
39 |
40 | a.icon.icon-home {
41 | color: #b3b3b3;
42 | }
43 |
44 | .version{
45 | color: var(--link-color) !important;
46 | }
47 |
48 |
49 | /* Make code blocks have a background */
50 | .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
51 | background: #f8f8f8;;
52 | }
53 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/common/distributions.rst:
--------------------------------------------------------------------------------
1 | .. _distributions:
2 |
3 | Probability Distributions
4 | =========================
5 |
6 | Probability distributions used for the different action spaces:
7 |
8 | - ``CategoricalDistribution`` -> Discrete
9 | - ``DiagGaussianDistribution`` -> Box (continuous actions)
10 | - ``StateDependentNoiseDistribution`` -> Box (continuous actions) when ``use_sde=True``
11 |
12 | .. - ``MultiCategoricalDistribution`` -> MultiDiscrete
13 | .. - ``BernoulliDistribution`` -> MultiBinary
14 |
15 | The policy networks output parameters for the distributions (named ``flat`` in the methods).
16 | Actions are then sampled from those distributions.
17 |
18 | For instance, in the case of discrete actions. The policy network outputs probability
19 | of taking each action. The ``CategoricalDistribution`` allows to sample from it,
20 | computes the entropy, the log probability (``log_prob``) and backpropagate the gradient.
21 |
22 | In the case of continuous actions, a Gaussian distribution is used. The policy network outputs
23 | mean and (log) std of the distribution (assumed to be a ``DiagGaussianDistribution``).
24 |
25 | .. automodule:: stable_baselines3.common.distributions
26 | :members:
27 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_tensorboard.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from stable_baselines3 import A2C, PPO, SAC, TD3
6 |
7 | MODEL_DICT = {
8 | "a2c": (A2C, "CartPole-v1"),
9 | "ppo": (PPO, "CartPole-v1"),
10 | "sac": (SAC, "Pendulum-v0"),
11 | "td3": (TD3, "Pendulum-v0"),
12 | }
13 |
14 | N_STEPS = 100
15 |
16 |
17 | @pytest.mark.parametrize("model_name", MODEL_DICT.keys())
18 | def test_tensorboard(tmp_path, model_name):
19 | # Skip if no tensorboard installed
20 | pytest.importorskip("tensorboard")
21 |
22 | logname = model_name.upper()
23 | algo, env_id = MODEL_DICT[model_name]
24 | model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
25 | model.learn(N_STEPS)
26 | model.learn(N_STEPS, reset_num_timesteps=False)
27 |
28 | assert os.path.isdir(tmp_path / str(logname + "_1"))
29 | assert not os.path.isdir(tmp_path / str(logname + "_2"))
30 |
31 | logname = "tb_multiple_runs_" + model_name
32 | model.learn(N_STEPS, tb_log_name=logname)
33 | model.learn(N_STEPS, tb_log_name=logname)
34 |
35 | assert os.path.isdir(tmp_path / str(logname + "_1"))
36 | # Check that the log dir name increments correctly
37 | assert os.path.isdir(tmp_path / str(logname + "_2"))
38 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/type_aliases.py:
--------------------------------------------------------------------------------
1 | """Common aliases for type hints"""
2 |
3 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
4 |
5 | import gym
6 | import numpy as np
7 | import torch as th
8 |
9 | from stable_baselines3.common.callbacks import BaseCallback
10 | from stable_baselines3.common.vec_env import VecEnv
11 |
12 | GymEnv = Union[gym.Env, VecEnv]
13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
14 | GymStepReturn = Tuple[GymObs, float, bool, Dict]
15 | TensorDict = Dict[str, th.Tensor]
16 | OptimizerStateDict = Dict[str, Any]
17 | MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
18 |
19 |
20 | class RolloutBufferSamples(NamedTuple):
21 | observations: th.Tensor
22 | actions: th.Tensor
23 | old_values: th.Tensor
24 | old_log_prob: th.Tensor
25 | advantages: th.Tensor
26 | returns: th.Tensor
27 |
28 |
29 | class ReplayBufferSamples(NamedTuple):
30 | observations: th.Tensor
31 | actions: th.Tensor
32 | next_observations: th.Tensor
33 | dones: th.Tensor
34 | rewards: th.Tensor
35 |
36 |
37 | class RolloutReturn(NamedTuple):
38 | episode_reward: float
39 | episode_timesteps: int
40 | n_episodes: int
41 | continue_training: bool
42 |
--------------------------------------------------------------------------------
/my_gym/envs/hanabi_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import error, spaces, utils
3 | from hanabi_learning_environment.rl_env import HanabiEnv
4 |
5 | class HanabiEnvWrapper(HanabiEnv, gym.Env):
6 | metadata = {'render.modes': ['human']}
7 | def __init__(self, config):
8 | self.config = config
9 | super(HanabiEnvWrapper, self).__init__(config=self.config)
10 |
11 | observation_shape = super().vectorized_observation_shape()
12 | self.observation_space = spaces.MultiBinary(observation_shape[0])
13 | self.action_space = spaces.Discrete(self.game.max_moves())
14 |
15 | def reset(self):
16 | obs = super().reset()
17 | obs = obs['player_observations'][obs['current_player']]['vectorized']
18 | return obs
19 |
20 | def step(self, action):
21 | # action is a integer from 0 to self.action_space
22 | # we map it to one of the legal moves
23 | # the legal move array may be too small in some cases, so just modulo action by the array length
24 | legal_moves = self.state.legal_moves()
25 | move = legal_moves[action % len(legal_moves)].to_dict()
26 |
27 | obs, reward, done, info = super().step(move)
28 | obs = obs['player_observations'][obs['current_player']]['vectorized']
29 | return obs, reward, done, info
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_deterministic.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
4 | from stable_baselines3.common.noise import NormalActionNoise
5 |
6 | N_STEPS_TRAINING = 3000
7 | SEED = 0
8 |
9 |
10 | @pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3])
11 | def test_deterministic_training_common(algo):
12 | results = [[], []]
13 | rewards = [[], []]
14 | # Smaller network
15 | kwargs = {"policy_kwargs": dict(net_arch=[64])}
16 | if algo in [TD3, SAC]:
17 | env_id = "Pendulum-v0"
18 | kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100})
19 | else:
20 | env_id = "CartPole-v1"
21 | if algo == DQN:
22 | kwargs.update({"learning_starts": 100})
23 |
24 | for i in range(2):
25 | model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
26 | model.learn(N_STEPS_TRAINING)
27 | env = model.get_env()
28 | obs = env.reset()
29 | for _ in range(100):
30 | action, _ = model.predict(obs, deterministic=False)
31 | obs, reward, _, _ = env.step(action)
32 | results[i].append(action)
33 | rewards[i].append(reward)
34 | assert sum(results[0]) == sum(results[1]), results
35 | assert sum(rewards[0]) == sum(rewards[1]), rewards
36 |
--------------------------------------------------------------------------------
/my_gym/envs/arms_human_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import error, spaces, utils
3 | from gym.utils import seeding
4 |
5 | import numpy as np
6 | import sys
7 |
8 | class ArmsHumanEnv(gym.Env):
9 | """
10 | Two player game.
11 | There are N (say 4) possible arms to pull.
12 | 0 1 2 3
13 | Some of the arms are forbidden (determined by the state)
14 | Goal is to pull same arm.
15 | """
16 |
17 | def __init__(self):
18 | super(ArmsHumanEnv, self).__init__()
19 | self.n = 3
20 | self.a = 4
21 | self.action_space = spaces.MultiDiscrete([self.a, self.a])
22 | self.observation_space = spaces.MultiDiscrete([self.n])
23 | self.reset()
24 |
25 | def step(self, a):
26 | context = self.state[0]
27 | match = (a[0] == a[1])
28 | green = [
29 | [1,0,0,0],
30 | [0,0,1,1],
31 | [0,1,0,1],
32 | ]
33 | correct = match and green[context][a[0]]
34 | self.reward = (int)(correct)
35 | return [self.state, self.reward, True, {}]
36 |
37 | def reset(self, state=None):
38 | if state:
39 | self.state = state
40 | else:
41 | self.state = [np.random.randint(self.n)]
42 |
43 | return self.state
44 |
45 | def render(self):
46 | print(self.n, self.state)
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_custom_policy.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch as th
3 |
4 | from stable_baselines3 import A2C, PPO, SAC, TD3
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "net_arch",
9 | [
10 | [12, dict(vf=[16], pi=[8])],
11 | [4],
12 | [],
13 | [4, 4],
14 | [12, dict(vf=[8, 4], pi=[8])],
15 | [12, dict(vf=[8], pi=[8, 4])],
16 | [12, dict(pi=[8])],
17 | ],
18 | )
19 | @pytest.mark.parametrize("model_class", [A2C, PPO])
20 | def test_flexible_mlp(model_class, net_arch):
21 | _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
22 |
23 |
24 | @pytest.mark.parametrize("net_arch", [[4], [4, 4],])
25 | @pytest.mark.parametrize("model_class", [SAC, TD3])
26 | def test_custom_offpolicy(model_class, net_arch):
27 | _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000)
28 |
29 |
30 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
31 | @pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
32 | def test_custom_optimizer(model_class, optimizer_kwargs):
33 | policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
34 | _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)
35 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
7 | from stable_baselines3.common.identity_env import FakeImageEnv
8 |
9 |
10 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
11 | def test_cnn(tmp_path, model_class):
12 | SAVE_NAME = "cnn_model.zip"
13 | # Fake grayscale with frameskip
14 | # Atari after preprocessing: 84x84x1, here we are using lower resolution
15 | # to check that the network handle it automatically
16 | env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
17 | if model_class in {A2C, PPO}:
18 | kwargs = dict(n_steps=100)
19 | else:
20 | # Avoid memory error when using replay buffer
21 | # Reduce the size of the features
22 | kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
23 | model = model_class("CnnPolicy", env, **kwargs).learn(250)
24 |
25 | obs = env.reset()
26 |
27 | action, _ = model.predict(obs, deterministic=True)
28 |
29 | model.save(tmp_path / SAVE_NAME)
30 | del model
31 |
32 | model = model_class.load(tmp_path / SAVE_NAME)
33 |
34 | # Check that the prediction is the same
35 | assert np.allclose(action, model.predict(obs, deterministic=True)[0])
36 |
37 | os.remove(str(tmp_path / SAVE_NAME))
38 |
--------------------------------------------------------------------------------
/stable-baselines3/NOTICE:
--------------------------------------------------------------------------------
1 | Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
2 | both licensed under the MIT License:
3 |
4 | before the fork (June 2018):
5 | Copyright (c) 2017 OpenAI (http://openai.com)
6 |
7 | after the fork (June 2018):
8 | Copyright (c) 2018-2019 Stable-Baselines Team
9 |
10 |
11 | Permission is hereby granted, free of charge, to any person obtaining a copy
12 | of this software and associated documentation files (the "Software"), to deal
13 | in the Software without restriction, including without limitation the rights
14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | copies of the Software, and to permit persons to whom the Software is
16 | furnished to do so, subject to the following conditions:
17 |
18 | The above copyright notice and this permission notice shall be included in
19 | all copies or substantial portions of the Software.
20 |
21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | THE SOFTWARE.
28 |
--------------------------------------------------------------------------------
/stable-baselines3/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PARENT_IMAGE
2 | FROM $PARENT_IMAGE
3 | ARG PYTORCH_DEPS=cpuonly
4 | ARG PYTHON_VERSION=3.6
5 |
6 | RUN apt-get update && apt-get install -y --no-install-recommends \
7 | build-essential \
8 | cmake \
9 | git \
10 | curl \
11 | ca-certificates \
12 | libjpeg-dev \
13 | libpng-dev \
14 | libglib2.0-0 && \
15 | rm -rf /var/lib/apt/lists/*
16 |
17 | # Install anaconda abd dependencies
18 | RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
19 | chmod +x ~/miniconda.sh && \
20 | ~/miniconda.sh -b -p /opt/conda && \
21 | rm ~/miniconda.sh && \
22 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \
23 | /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \
24 | /opt/conda/bin/conda clean -ya
25 | ENV PATH /opt/conda/bin:$PATH
26 |
27 | ENV CODE_DIR /root/code
28 |
29 | # Copy setup file only to install dependencies
30 | COPY ./setup.py ${CODE_DIR}/stable-baselines3/setup.py
31 | COPY ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt
32 |
33 | RUN \
34 | cd ${CODE_DIR}/stable-baselines3 3&& \
35 | pip install -e .[extra,tests,docs] && \
36 | # Use headless version for docker
37 | pip uninstall -y opencv-python && \
38 | pip install opencv-python-headless && \
39 | rm -rf $HOME/.cache/pip
40 |
41 | CMD /bin/bash
42 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/colab.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/stable-baselines3/Makefile:
--------------------------------------------------------------------------------
1 | SHELL=/bin/bash
2 | LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
3 |
4 | pytest:
5 | ./scripts/run_tests.sh
6 |
7 | type:
8 | pytype
9 |
10 | lint:
11 | # stop the build if there are Python syntax errors or undefined names
12 | # see https://lintlyci.github.io/Flake8Rules/
13 | flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
14 | # exit-zero treats all errors as warnings.
15 | flake8 ${LINT_PATHS} --count --exit-zero --statistics
16 |
17 | format:
18 | # Sort imports
19 | isort ${LINT_PATHS}
20 | # Reformat using black
21 | black -l 127 ${LINT_PATHS}
22 |
23 | check-codestyle:
24 | # Sort imports
25 | isort --check ${LINT_PATHS}
26 | # Reformat using black
27 | black --check -l 127 ${LINT_PATHS}
28 |
29 | commit-checks: format type lint
30 |
31 | doc:
32 | cd docs && make html
33 |
34 | spelling:
35 | cd docs && make spelling
36 |
37 | clean:
38 | cd docs && make clean
39 |
40 | # Build docker images
41 | # If you do export RELEASE=True, it will also push them
42 | docker: docker-cpu docker-gpu
43 |
44 | docker-cpu:
45 | ./scripts/build_docker.sh
46 |
47 | docker-gpu:
48 | USE_GPU=True ./scripts/build_docker.sh
49 |
50 | # PyPi package release
51 | release:
52 | python setup.py sdist
53 | python setup.py bdist_wheel
54 | twine upload dist/*
55 |
56 | # Test PyPi package release
57 | test-release:
58 | python setup.py sdist
59 | python setup.py bdist_wheel
60 | twine upload --repository-url https://test.pypi.org/legacy/ dist/*
61 |
62 | .PHONY: clean spelling doc lint format check-codestyle commit-checks
63 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/misc/projects.rst:
--------------------------------------------------------------------------------
1 | .. _projects:
2 |
3 | Projects
4 | =========
5 |
6 | This is a list of projects using stable-baselines3.
7 | Please tell us, if you want your project to appear on this page ;)
8 |
9 |
10 | .. RL Racing Robot
11 | .. --------------------------
12 | .. Implementation of reinforcement learning approach to make a donkey car learn to race.
13 | .. Uses SAC on autoencoder features
14 | ..
15 | .. | Author: Antonin Raffin (@araffin)
16 | .. | Github repo: https://github.com/araffin/RL-Racing-Robot
17 |
18 |
19 | Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics
20 | -----------------------------------------------------------------------------------
21 |
22 | An exploration method to train RL agent directly on real robots.
23 | It was the starting point of Stable-Baselines3.
24 |
25 | | Author: Antonin Raffin, Freek Stulp
26 | | Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
27 | | Paper: https://arxiv.org/abs/2005.05719
28 |
29 | Reacher
30 | -------
31 | A solution to the second project of the Udacity deep reinforcement learning course.
32 | It is an example of:
33 |
34 | - wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3
35 | - creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is https://github.com/DLR-RM/rl-baselines3-zoo)
36 | - generating several pre-trained models which solve the reacher environment
37 |
38 | | Author: Marios Koulakis
39 | | Github: https://github.com/koulakis/reacher-deep-reinforcement-learning
40 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/running_mean_std.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | class RunningMeanStd(object):
7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
8 | """
9 | Calulates the running mean and std of a data stream
10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
11 |
12 | :param epsilon: (float) helps with arithmetic issues
13 | :param shape: (tuple) the shape of the data stream's output
14 | """
15 | self.mean = np.zeros(shape, np.float64)
16 | self.var = np.ones(shape, np.float64)
17 | self.count = epsilon
18 |
19 | def update(self, arr: np.ndarray) -> None:
20 | batch_mean = np.mean(arr, axis=0)
21 | batch_var = np.var(arr, axis=0)
22 | batch_count = arr.shape[0]
23 | self.update_from_moments(batch_mean, batch_var, batch_count)
24 |
25 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
26 | delta = batch_mean - self.mean
27 | tot_count = self.count + batch_count
28 |
29 | new_mean = self.mean + delta * batch_count / tot_count
30 | m_a = self.var * self.count
31 | m_b = batch_var * batch_count
32 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
33 | new_var = m_2 / (self.count + batch_count)
34 |
35 | new_count = batch_count + self.count
36 |
37 | self.mean = new_mean
38 | self.var = new_var
39 | self.count = new_count
40 |
--------------------------------------------------------------------------------
/stable-baselines3/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 | # Skip CI if [ci skip] in the commit message
15 | if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
16 | runs-on: ubuntu-latest
17 | strategy:
18 | matrix:
19 | python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
20 |
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | # cpu version of pytorch
31 | pip install torch==1.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
32 | pip install .[extra,tests,docs]
33 | # Use headless version
34 | pip install opencv-python-headless
35 | - name: Build the doc
36 | run: |
37 | make doc
38 | - name: Type check
39 | run: |
40 | make type
41 | - name: Check codestyle
42 | run: |
43 | make check-codestyle
44 | - name: Lint with flake8
45 | run: |
46 | make lint
47 | - name: Test with pytest
48 | run: |
49 | make pytest
50 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_predict.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import pytest
3 |
4 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
5 | from stable_baselines3.common.vec_env import DummyVecEnv
6 |
7 | MODEL_LIST = [
8 | PPO,
9 | A2C,
10 | TD3,
11 | SAC,
12 | DQN,
13 | ]
14 |
15 |
16 | @pytest.mark.parametrize("model_class", MODEL_LIST)
17 | def test_auto_wrap(model_class):
18 | # test auto wrapping of env into a VecEnv
19 |
20 | # Use different environment for DQN
21 | if model_class is DQN:
22 | env_name = "CartPole-v0"
23 | else:
24 | env_name = "Pendulum-v0"
25 | env = gym.make(env_name)
26 | eval_env = gym.make(env_name)
27 | model = model_class("MlpPolicy", env)
28 | model.learn(100, eval_env=eval_env)
29 |
30 |
31 | @pytest.mark.parametrize("model_class", MODEL_LIST)
32 | @pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
33 | def test_predict(model_class, env_id):
34 | if env_id == "CartPole-v1":
35 | if model_class in [SAC, TD3]:
36 | return
37 | elif model_class in [DQN]:
38 | return
39 |
40 | # test detection of different shapes by the predict method
41 | model = model_class("MlpPolicy", env_id)
42 | env = gym.make(env_id)
43 | vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
44 |
45 | obs = env.reset()
46 | action, _ = model.predict(obs)
47 | assert action.shape == env.action_space.shape
48 | assert env.action_space.contains(action)
49 |
50 | vec_env_obs = vec_env.reset()
51 | action, _ = model.predict(vec_env_obs)
52 | assert action.shape[0] == vec_env_obs.shape[0]
53 |
--------------------------------------------------------------------------------
/tabular.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | def tabular_q_learning(env, discount_rate=0.01, step_size=0.01, eps=1e-4):
5 | action_space = env.action_space.nvec # multi-discrete
6 | state_space = env.observation_space.nvec # multi-discrete
7 | q_values = np.zeros( np.concatenate((state_space, action_space)) , np.float64)
8 |
9 | for _ in range(1000):
10 | for state_action, q_val in np.ndenumerate(q_values):
11 | state, action = state_action[:len(state_space)], state_action[len(state_space):]
12 | #print("state, action: ", state, action)
13 |
14 | env.reset(list(state))
15 | next_state, reward, done, _ = env.step(action)
16 |
17 | delta = reward - q_values[state_action]
18 | if not done:
19 | next_state = tuple(next_state)
20 | delta = reward + discount_rate * q_values[next_state].max() - q_values[state_action]
21 | q_values[state_action] += step_size * delta
22 |
23 | q_values = th.tensor(q_values)
24 |
25 | maxout_actions = th.max(th.max(q_values, dim=-2, keepdim=True).values, dim=-1, keepdim=True).values
26 | maxout_action2 = th.max(q_values, dim=-1, keepdim=True).values
27 | maxout_action1 = th.max(q_values, dim=-2, keepdim=True).values
28 |
29 | optimal_action1_mask = (maxout_action2 - maxout_actions > -eps).squeeze(-1)
30 | optimal_action2_mask = (maxout_action1 - maxout_actions > -eps).squeeze(-2)
31 |
32 | # print(q_values)
33 | print(optimal_action1_mask)
34 | print(optimal_action2_mask)
35 | return q_values, optimal_action1_mask, optimal_action2_mask
36 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import pytest
4 | from gym import spaces
5 |
6 | from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
7 |
8 |
9 | class NanAndInfEnv(gym.Env):
10 | """Custom Environment that raised NaNs and Infs"""
11 |
12 | metadata = {"render.modes": ["human"]}
13 |
14 | def __init__(self):
15 | super(NanAndInfEnv, self).__init__()
16 | self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
17 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
18 |
19 | @staticmethod
20 | def step(action):
21 | if np.all(np.array(action) > 0):
22 | obs = float("NaN")
23 | elif np.all(np.array(action) < 0):
24 | obs = float("inf")
25 | else:
26 | obs = 0
27 | return [obs], 0.0, False, {}
28 |
29 | @staticmethod
30 | def reset():
31 | return [0.0]
32 |
33 | def render(self, mode="human", close=False):
34 | pass
35 |
36 |
37 | def test_check_nan():
38 | """Test VecCheckNan Object"""
39 |
40 | env = DummyVecEnv([NanAndInfEnv])
41 | env = VecCheckNan(env, raise_exception=True)
42 |
43 | env.step([[0]])
44 |
45 | with pytest.raises(ValueError):
46 | env.step([[float("NaN")]])
47 |
48 | with pytest.raises(ValueError):
49 | env.step([[float("inf")]])
50 |
51 | with pytest.raises(ValueError):
52 | env.step([[-1]])
53 |
54 | with pytest.raises(ValueError):
55 | env.step([[1]])
56 |
57 | env.step(np.array([[0, 1], [0, 1]]))
58 |
59 | env.reset()
60 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/spelling_wordlist.txt:
--------------------------------------------------------------------------------
1 | py
2 | env
3 | atari
4 | argparse
5 | Argparse
6 | TensorFlow
7 | feedforward
8 | envs
9 | VecEnv
10 | pretrain
11 | petrained
12 | tf
13 | th
14 | nn
15 | np
16 | str
17 | mujoco
18 | cpu
19 | ndarray
20 | ndarrays
21 | timestep
22 | timesteps
23 | stepsize
24 | dataset
25 | adam
26 | fn
27 | normalisation
28 | Kullback
29 | Leibler
30 | boolean
31 | deserialized
32 | pretrained
33 | minibatch
34 | subprocesses
35 | ArgumentParser
36 | Tensorflow
37 | Gaussian
38 | approximator
39 | minibatches
40 | hyperparameters
41 | hyperparameter
42 | vectorized
43 | rl
44 | colab
45 | dataloader
46 | npz
47 | datasets
48 | vf
49 | logits
50 | num
51 | Utils
52 | backpropagate
53 | prepend
54 | NaN
55 | preprocessing
56 | Cloudpickle
57 | async
58 | multiprocess
59 | tensorflow
60 | mlp
61 | cnn
62 | neglogp
63 | tanh
64 | coef
65 | repo
66 | Huber
67 | params
68 | ppo
69 | arxiv
70 | Arxiv
71 | func
72 | DQN
73 | Uhlenbeck
74 | Ornstein
75 | multithread
76 | cancelled
77 | Tensorboard
78 | parallelize
79 | customising
80 | serializable
81 | Multiprocessed
82 | cartpole
83 | toolset
84 | lstm
85 | rescale
86 | ffmpeg
87 | avconv
88 | unnormalized
89 | Github
90 | pre
91 | preprocess
92 | backend
93 | attr
94 | preprocess
95 | Antonin
96 | Raffin
97 | araffin
98 | Homebrew
99 | Numpy
100 | Theano
101 | rollout
102 | kfac
103 | Piecewise
104 | csv
105 | nvidia
106 | visdom
107 | tensorboard
108 | preprocessed
109 | namespace
110 | sklearn
111 | GoalEnv
112 | Torchy
113 | pytorch
114 | dicts
115 | optimizers
116 | Deprecations
117 | forkserver
118 | cuda
119 | Polyak
120 | gSDE
121 | rollouts
122 |
--------------------------------------------------------------------------------
/my_gym/envs/arms_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import error, spaces, utils
3 | from gym.utils import seeding
4 |
5 | import numpy as np
6 | import sys
7 |
8 | class ArmsEnv(gym.Env):
9 | """
10 | Two player game.
11 | There are N (say 4) possible arms to pull.
12 | 0 1 2 3
13 | Some of the arms are forbidden (determined by the state)
14 | Goal is to pull same arm.
15 | """
16 |
17 | def __init__(self, n, m):
18 | super(ArmsEnv, self).__init__()
19 | self.n = n
20 | self.m = m # number of contexts with hard rules
21 | self.action_space = spaces.MultiDiscrete([2*n, 2*n])
22 | self.observation_space = spaces.MultiDiscrete([n])
23 | self.reset()
24 |
25 | self.invert = False
26 |
27 | def step(self, a):
28 | context = self.state[0]
29 | match = (a[0] == a[1])
30 |
31 | if context < self.m:
32 | if not self.invert:
33 | correct = match and a[0] == context
34 | else:
35 | correct = match and a[0] - self.n == context
36 | else:
37 | correct = match and a[0]%self.n == context # action mod n equals context
38 |
39 | self.reward = (int)(correct)
40 | return [self.state, self.reward, True, {}]
41 |
42 | def reset(self, state=None):
43 | if not hasattr(self, 'last'): self.last = 0
44 |
45 | self.rep = 0
46 | if state:
47 | self.state = state
48 | else:
49 | self.state = [np.random.randint(self.n)]
50 |
51 | return self.state
52 |
53 | def render(self):
54 | print(self.n, self.state)
55 |
56 | def set_invert(self, invert):
57 | self.invert = invert
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/a2c.rst:
--------------------------------------------------------------------------------
1 | .. _a2c:
2 |
3 | .. automodule:: stable_baselines3.a2c
4 |
5 |
6 | A2C
7 | ====
8 |
9 | A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3C) `_.
10 | It uses multiple workers to avoid the use of a replay buffer.
11 |
12 |
13 | Notes
14 | -----
15 |
16 | - Original paper: https://arxiv.org/abs/1602.01783
17 | - OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
18 |
19 |
20 | Can I use?
21 | ----------
22 |
23 | - Recurrent policies: ✔️
24 | - Multi processing: ✔️
25 | - Gym spaces:
26 |
27 |
28 | ============= ====== ===========
29 | Space Action Observation
30 | ============= ====== ===========
31 | Discrete ✔️ ✔️
32 | Box ✔️ ✔️
33 | MultiDiscrete ✔️ ✔️
34 | MultiBinary ✔️ ✔️
35 | ============= ====== ===========
36 |
37 |
38 | Example
39 | -------
40 |
41 | Train a A2C agent on ``CartPole-v1`` using 4 environments.
42 |
43 | .. code-block:: python
44 |
45 | import gym
46 |
47 | from stable_baselines3 import A2C
48 | from stable_baselines3.a2c import MlpPolicy
49 | from stable_baselines3.common.cmd_util import make_vec_env
50 |
51 | # Parallel environments
52 | env = make_vec_env('CartPole-v1', n_envs=4)
53 |
54 | model = A2C(MlpPolicy, env, verbose=1)
55 | model.learn(total_timesteps=25000)
56 | model.save("a2c_cartpole")
57 |
58 | del model # remove to demonstrate saving and loading
59 |
60 | model = A2C.load("a2c_cartpole")
61 |
62 | obs = env.reset()
63 | while True:
64 | action, _states = model.predict(obs)
65 | obs, rewards, dones, info = env.step(action)
66 | env.render()
67 |
68 | Parameters
69 | ----------
70 |
71 | .. autoclass:: A2C
72 | :members:
73 | :inherited-members:
74 |
--------------------------------------------------------------------------------
/stable-baselines3/.github/ISSUE_TEMPLATE/issue-template.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Issue Template
3 | about: How to create an issue for this repository
4 |
5 | ---
6 |
7 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
8 | Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case.
9 |
10 | If you have any questions, feel free to create an issue with the tag [question].
11 | If you wish to suggest an enhancement or feature request, add the tag [feature request].
12 | If you are submitting a bug report, please fill in the following details.
13 |
14 | If your issue is related to a custom gym environment, please check it first using:
15 |
16 | ```python
17 | from stable_baselines3.common.env_checker import check_env
18 |
19 | env = CustomEnv(arg1, ...)
20 | # It will check your custom environment and output additional warnings if needed
21 | check_env(env)
22 | ```
23 |
24 | **Describe the bug**
25 | A clear and concise description of what the bug is.
26 |
27 | **Code example**
28 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
29 |
30 | Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
31 | for both code and stack traces.
32 |
33 | ```python
34 | from stable_baselines3 import ...
35 |
36 | ```
37 |
38 | ```bash
39 | Traceback (most recent call last): File ...
40 |
41 | ```
42 |
43 | **System Info**
44 | Describe the characteristic of your environment:
45 | * Describe how the library was installed (pip, docker, source, ...)
46 | * GPU models and configuration
47 | * Python version
48 | * PyTorch version
49 | * Gym version
50 | * Versions of any other relevant libraries
51 |
52 | **Additional context**
53 | Add any other context about the problem here.
54 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_spaces.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import pytest
4 |
5 | from stable_baselines3 import DQN, SAC, TD3
6 | from stable_baselines3.common.evaluation import evaluate_policy
7 |
8 |
9 | class DummyMultiDiscreteSpace(gym.Env):
10 | def __init__(self, nvec):
11 | super(DummyMultiDiscreteSpace, self).__init__()
12 | self.observation_space = gym.spaces.MultiDiscrete(nvec)
13 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
14 |
15 | def reset(self):
16 | return self.observation_space.sample()
17 |
18 | def step(self, action):
19 | return self.observation_space.sample(), 0.0, False, {}
20 |
21 |
22 | class DummyMultiBinary(gym.Env):
23 | def __init__(self, n):
24 | super(DummyMultiBinary, self).__init__()
25 | self.observation_space = gym.spaces.MultiBinary(n)
26 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
27 |
28 | def reset(self):
29 | return self.observation_space.sample()
30 |
31 | def step(self, action):
32 | return self.observation_space.sample(), 0.0, False, {}
33 |
34 |
35 | @pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
36 | @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
37 | def test_identity_spaces(model_class, env):
38 | """
39 | Additional tests for DQ/SAC/TD3 to check observation space support
40 | for MultiDiscrete and MultiBinary.
41 | """
42 | # DQN only support discrete actions
43 | if model_class == DQN:
44 | env.action_space = gym.spaces.Discrete(4)
45 |
46 | env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
47 |
48 | model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
49 | model.learn(total_timesteps=500)
50 |
51 | evaluate_policy(model, env, n_eval_episodes=5)
52 |
--------------------------------------------------------------------------------
/bashfiles/hanabi_adapt_to_selfplay.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | thread=$2
7 |
8 | if (( ${thread} == 1 )); then python run_hanabi.py --run=$(($runid + 10)) --mreg=0.0; fi
9 | if (( ${thread} == 2 )); then python run_hanabi.py --run=$(($runid + 20)) --mreg=0.3; fi
10 | if (( ${thread} == 3 )); then python run_hanabi.py --run=$(($runid + 30)) --mreg=0.5; fi
11 |
12 | if (( ${thread} == 4 )); then python run_hanabi.py --run=$(($runid + 40)) --baseline; fi
13 | if (( ${thread} == 5 )); then python run_hanabi.py --run=$(($runid + 50)) --baseline --timesteps=250000; fi
14 | if (( ${thread} == 6 )); then python run_hanabi.py --run=$(($runid + 60)) --nomain; fi
15 | if (( ${thread} == 7 )); then python run_hanabi.py --run=$(($runid + 70)) --mreg=0.5 --latentz=50; fi
16 |
17 | for ((i=0;i<=3;i++))
18 | do
19 | if (( ${thread} == 1 )); then python run_hanabi.py --run=$(($runid + 10)) --mreg=0.0 --k=$i --testing | tee -a logs/hanabi$(($runid + 10 + $i)).txt; fi
20 | if (( ${thread} == 2 )); then python run_hanabi.py --run=$(($runid + 20)) --mreg=0.3 --k=$i --testing | tee -a logs/hanabi$(($runid + 20 + $i)).txt; fi
21 | if (( ${thread} == 3 )); then python run_hanabi.py --run=$(($runid + 30)) --mreg=0.5 --k=$i --testing | tee -a logs/hanabi$(($runid + 30 + $i)).txt; fi
22 |
23 | if (( ${thread} == 4 )); then python run_hanabi.py --run=$(($runid + 40)) --baseline --k=$i --testing | tee -a logs/hanabi$(($runid + 40 + $i)).txt; fi
24 | if (( ${thread} == 5 )); then python run_hanabi.py --run=$(($runid + 50)) --baseline --timesteps=250000 --k=$i --testing | tee -a logs/hanabi$(($runid + 50 + $i)).txt; fi
25 | if (( ${thread} == 6 )); then python run_hanabi.py --run=$(($runid + 60)) --nomain --k=$i --testing | tee -a logs/hanabi$(($runid + 60 + $i)).txt; fi
26 | if (( ${thread} == 7 )); then python run_hanabi.py --run=$(($runid + 70)) --mreg=0.5 --latentz=50 --k=$i --testing | tee -a logs/hanabi$(($runid + 70 + $i)).txt; fi
27 | done
--------------------------------------------------------------------------------
/stable-baselines3/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | # This includes the license file in the wheel.
3 | license_file = LICENSE
4 |
5 | [tool:pytest]
6 | # Deterministic ordering for tests; useful for pytest-xdist.
7 | env =
8 | PYTHONHASHSEED=0
9 | filterwarnings =
10 | # Tensorboard/Tensorflow warnings
11 | ignore:inspect.getargspec:DeprecationWarning:tensorflow
12 | ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
13 | ignore:The binary mode of fromstring is deprecated:DeprecationWarning
14 | ignore::FutureWarning:tensorflow
15 | # Gym warnings
16 | ignore:Parameters to load are deprecated.:DeprecationWarning
17 | ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
18 | ignore::UserWarning:gym
19 |
20 | [pytype]
21 | inputs = stable_baselines3
22 |
23 | [flake8]
24 | ignore = W503,W504,E203,E231 # line breaks before and after binary operators
25 | # Ignore import not used when aliases are defined
26 | per-file-ignores =
27 | ./stable_baselines3/__init__.py:F401
28 | ./stable_baselines3/common/__init__.py:F401
29 | ./stable_baselines3/a2c/__init__.py:F401
30 | ./stable_baselines3/ddpg/__init__.py:F401
31 | ./stable_baselines3/dqn/__init__.py:F401
32 | ./stable_baselines3/ppo/__init__.py:F401
33 | ./stable_baselines3/sac/__init__.py:F401
34 | ./stable_baselines3/td3/__init__.py:F401
35 | ./stable_baselines3/common/vec_env/__init__.py:F401
36 | exclude =
37 | # No need to traverse our git directory
38 | .git,
39 | # There's no value in checking cache directories
40 | __pycache__,
41 | # Don't check the doc
42 | docs/
43 | # This contains our built documentation
44 | build,
45 | # This contains builds of flake8 that we don't want to check
46 | dist
47 | *.egg-info
48 | max-complexity = 15
49 | # The GitHub editor is 127 chars wide
50 | max-line-length = 127
51 |
52 | [isort]
53 | profile = black
54 | line_length = 127
55 | src_paths = stable_baselines3
56 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_identity.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
5 | from stable_baselines3.common.evaluation import evaluate_policy
6 | from stable_baselines3.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
7 | from stable_baselines3.common.noise import NormalActionNoise
8 | from stable_baselines3.common.vec_env import DummyVecEnv
9 |
10 | DIM = 4
11 |
12 |
13 | @pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
14 | @pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
15 | def test_discrete(model_class, env):
16 | env_ = DummyVecEnv([lambda: env])
17 | kwargs = {}
18 | n_steps = 3000
19 | if model_class == DQN:
20 | kwargs = dict(learning_starts=0)
21 | n_steps = 4000
22 | # DQN only support discrete actions
23 | if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
24 | return
25 |
26 | model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps)
27 |
28 | evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
29 | obs = env.reset()
30 |
31 | assert np.shape(model.predict(obs)[0]) == np.shape(obs)
32 |
33 |
34 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3])
35 | def test_continuous(model_class):
36 | env = IdentityEnvBox(eps=0.5)
37 |
38 | n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class]
39 |
40 | kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95)
41 | if model_class in [TD3]:
42 | n_actions = 1
43 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
44 | kwargs["action_noise"] = action_noise
45 |
46 | model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)
47 |
48 | evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
49 |
--------------------------------------------------------------------------------
/bashfiles/arms_human_adapt_to_fixed.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | thread=$2
7 |
8 | if (( ${thread} == 1 )); then python run_arms_human.py --run=$(($runid + 10)) --mreg=0.0; fi
9 | if (( ${thread} == 2 )); then python run_arms_human.py --run=$(($runid + 20)) --mreg=0.3; fi
10 | if (( ${thread} == 3 )); then python run_arms_human.py --run=$(($runid + 30)) --mreg=0.5; fi
11 |
12 | if (( ${thread} == 4 )); then python run_arms_human.py --run=$(($runid + 40)) --baseline; fi
13 | if (( ${thread} == 5 )); then python run_arms_human.py --run=$(($runid + 50)) --baseline --timesteps=6000; fi
14 | if (( ${thread} == 6 )); then python run_arms_human.py --run=$(($runid + 60)) --nomain; fi
15 | if (( ${thread} == 7 )); then python run_arms_human.py --run=$(($runid + 70)) --mreg=0.5 --latentz=5; fi
16 |
17 | for ((i=0;i<=9;i++))
18 | do
19 | if (( ${thread} == 1 )); then python run_arms_human.py --run=$(($runid + 10)) --mreg=0.0 --k=$i --testing | tee -a logs/arms_human_$(($runid + 10 + $i)).txt; fi
20 | if (( ${thread} == 2 )); then python run_arms_human.py --run=$(($runid + 20)) --mreg=0.3 --k=$i --testing | tee -a logs/arms_human_$(($runid + 20 + $i)).txt; fi
21 | if (( ${thread} == 3 )); then python run_arms_human.py --run=$(($runid + 30)) --mreg=0.5 --k=$i --testing | tee -a logs/arms_human_$(($runid + 30 + $i)).txt; fi
22 |
23 | if (( ${thread} == 4 )); then python run_arms_human.py --run=$(($runid + 40)) --baseline --k=$i --testing | tee -a logs/arms_human_$(($runid + 40 + $i)).txt; fi
24 | if (( ${thread} == 5 )); then python run_arms_human.py --run=$(($runid + 50)) --baseline --timesteps=6000 --k=$i --testing | tee -a logs/arms_human_$(($runid + 50 + $i)).txt; fi
25 | if (( ${thread} == 6 )); then python run_arms_human.py --run=$(($runid + 60)) --nomain --k=$i --testing | tee -a logs/arms_human_$(($runid + 60 + $i)).txt; fi
26 | if (( ${thread} == 7 )); then python run_arms_human.py --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --k=$i --testing | tee -a logs/arms_human_$(($runid + 70 + $i)).txt; fi
27 | done
--------------------------------------------------------------------------------
/stable-baselines3/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Description
4 |
5 |
6 | ## Motivation and Context
7 |
8 |
9 |
10 | - [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
11 |
12 | ## Types of changes
13 |
14 | - [ ] Bug fix (non-breaking change which fixes an issue)
15 | - [ ] New feature (non-breaking change which adds functionality)
16 | - [ ] Breaking change (fix or feature that would cause existing functionality to change)
17 | - [ ] Documentation (update in the documentation)
18 |
19 | ## Checklist:
20 |
21 |
22 | - [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
23 | - [ ] I have updated the changelog accordingly (**required**).
24 | - [ ] My change requires a change to the documentation.
25 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
26 | - [ ] I have updated the documentation accordingly.
27 | - [ ] I have reformatted the code using `make format` (**required**)
28 | - [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
29 | - [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
30 |
31 |
32 | Note: we are using a maximum length of 127 characters per line
33 |
34 |
35 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa F401
2 | import typing
3 | from copy import deepcopy
4 | from typing import Optional, Union
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import (
7 | AlreadySteppingError,
8 | CloudpickleWrapper,
9 | NotSteppingError,
10 | VecEnv,
11 | VecEnvWrapper,
12 | )
13 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
14 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
15 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
16 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
17 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
18 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
19 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
20 |
21 | # Avoid circular import
22 | if typing.TYPE_CHECKING:
23 | from stable_baselines3.common.type_aliases import GymEnv
24 |
25 |
26 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
27 | """
28 | :param env: (gym.Env)
29 | :return: (VecNormalize)
30 | """
31 | env_tmp = env
32 | while isinstance(env_tmp, VecEnvWrapper):
33 | if isinstance(env_tmp, VecNormalize):
34 | return env_tmp
35 | env_tmp = env_tmp.venv
36 | return None
37 |
38 |
39 | # Define here to avoid circular import
40 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
41 | """
42 | Sync eval env and train env when using VecNormalize
43 |
44 | :param env: (GymEnv)
45 | :param eval_env: (GymEnv)
46 | """
47 | env_tmp, eval_env_tmp = env, eval_env
48 | while isinstance(env_tmp, VecEnvWrapper):
49 | if isinstance(env_tmp, VecNormalize):
50 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
51 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
52 | env_tmp = env_tmp.venv
53 | eval_env_tmp = eval_env_tmp.venv
54 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/dqn.rst:
--------------------------------------------------------------------------------
1 | .. _dqn:
2 |
3 | .. automodule:: stable_baselines3.dqn
4 |
5 |
6 | DQN
7 | ===
8 |
9 | `Deep Q Network (DQN) `_
10 |
11 | .. rubric:: Available Policies
12 |
13 | .. autosummary::
14 | :nosignatures:
15 |
16 | MlpPolicy
17 | CnnPolicy
18 |
19 |
20 | Notes
21 | -----
22 |
23 | - Original paper: https://arxiv.org/abs/1312.5602
24 | - Further reference: https://www.nature.com/articles/nature14236
25 |
26 | .. note::
27 | This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.
28 |
29 |
30 | Can I use?
31 | ----------
32 |
33 | - Recurrent policies: ❌
34 | - Multi processing: ❌
35 | - Gym spaces:
36 |
37 |
38 | ============= ====== ===========
39 | Space Action Observation
40 | ============= ====== ===========
41 | Discrete ✔ ✔
42 | Box ❌ ✔
43 | MultiDiscrete ❌ ✔
44 | MultiBinary ❌ ✔
45 | ============= ====== ===========
46 |
47 |
48 | Example
49 | -------
50 |
51 | .. code-block:: python
52 |
53 | import gym
54 | import numpy as np
55 |
56 | from stable_baselines3 import DQN
57 | from stable_baselines3.dqn import MlpPolicy
58 |
59 | env = gym.make('Pendulum-v0')
60 |
61 | model = DQN(MlpPolicy, env, verbose=1)
62 | model.learn(total_timesteps=10000, log_interval=4)
63 | model.save("dqn_pendulum")
64 |
65 | del model # remove to demonstrate saving and loading
66 |
67 | model = DQN.load("dqn_pendulum")
68 |
69 | obs = env.reset()
70 | while True:
71 | action, _states = model.predict(obs, deterministic=True)
72 | obs, reward, done, info = env.step(action)
73 | env.render()
74 | if done:
75 | obs = env.reset()
76 |
77 | Parameters
78 | ----------
79 |
80 | .. autoclass:: DQN
81 | :members:
82 | :inherited-members:
83 |
84 | .. _dqn_policies:
85 |
86 | DQN Policies
87 | -------------
88 |
89 | .. autoclass:: MlpPolicy
90 | :members:
91 | :inherited-members:
92 |
93 | .. autoclass:: CnnPolicy
94 | :members:
95 |
--------------------------------------------------------------------------------
/bashfiles/blocks_adapt_to_selfplay.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | thread=$2
7 |
8 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --ppopartners; fi
9 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --ppopartners; fi
10 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --ppopartners; fi
11 |
12 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --ppopartners; fi
13 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=1000000 --ppopartners; fi
14 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --ppopartners; fi
15 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --ppopartners; fi
16 |
17 | for ((i=0;i<=7;i++))
18 | do
19 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 10 + $i)).txt; fi
20 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 20 + $i)).txt; fi
21 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 30 + $i)).txt; fi
22 |
23 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 40 + $i)).txt; fi
24 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=1000000 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 50 + $i)).txt; fi
25 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 60 + $i)).txt; fi
26 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 70 + $i)).txt; fi
27 | done
28 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import gym
5 | import pytest
6 |
7 | from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
8 | from stable_baselines3.common.callbacks import (
9 | CallbackList,
10 | CheckpointCallback,
11 | EvalCallback,
12 | EveryNTimesteps,
13 | StopTrainingOnRewardThreshold,
14 | )
15 |
16 |
17 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
18 | def test_callbacks(tmp_path, model_class):
19 | log_folder = tmp_path / "logs/callbacks/"
20 |
21 | # Dyn only support discrete actions
22 | env_name = select_env(model_class)
23 | # Create RL model
24 | # Small network for fast test
25 | model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32]))
26 |
27 | checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
28 |
29 | eval_env = gym.make(env_name)
30 | # Stop training if the performance is good enough
31 | callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
32 |
33 | eval_callback = EvalCallback(
34 | eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100
35 | )
36 |
37 | # Equivalent to the `checkpoint_callback`
38 | # but here in an event-driven manner
39 | checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
40 | event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
41 |
42 | callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
43 |
44 | model.learn(500, callback=callback)
45 | model.learn(500, callback=None)
46 | # Transform callback into a callback list automatically
47 | model.learn(500, callback=[checkpoint_callback, eval_callback])
48 | # Automatic wrapping, old way of doing callbacks
49 | model.learn(500, callback=lambda _locals, _globals: True)
50 | if os.path.exists(log_folder):
51 | shutil.rmtree(log_folder)
52 |
53 |
54 | def select_env(model_class) -> str:
55 | if model_class is DQN:
56 | return "CartPole-v0"
57 | else:
58 | return "Pendulum-v0"
59 |
--------------------------------------------------------------------------------
/bashfiles/arms_adapt_to_selfplay.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | thread=$2
7 | m=$3
8 |
9 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --ppopartners; fi
10 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --ppopartners; fi
11 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --ppopartners; fi
12 |
13 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --ppopartners; fi
14 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --ppopartners; fi
15 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --ppopartners; fi
16 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --ppopartners; fi
17 |
18 | for ((i=0;i<=9;i++))
19 | do
20 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 10 + $i)).txt; fi
21 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 20 + $i)).txt; fi
22 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 30 + $i)).txt; fi
23 |
24 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 40 + $i)).txt; fi
25 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 50 + $i)).txt; fi
26 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 60 + $i)).txt; fi
27 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 70 + $i)).txt; fi
28 | done
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7 |
8 |
9 | class VecFrameStack(VecEnvWrapper):
10 | """
11 | Frame stacking wrapper for vectorized environment
12 |
13 | :param venv: the vectorized environment to wrap
14 | :param n_stack: Number of frames to stack
15 | """
16 |
17 | def __init__(self, venv: VecEnv, n_stack: int):
18 | self.venv = venv
19 | self.n_stack = n_stack
20 | wrapped_obs_space = venv.observation_space
21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
26 |
27 | def step_wait(self):
28 | observations, rewards, dones, infos = self.venv.step_wait()
29 | last_ax_size = observations.shape[-1]
30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
31 | for i, done in enumerate(dones):
32 | if done:
33 | if "terminal_observation" in infos[i]:
34 | old_terminal = infos[i]["terminal_observation"]
35 | new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
36 | infos[i]["terminal_observation"] = new_terminal
37 | else:
38 | warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
39 | self.stackedobs[i] = 0
40 | self.stackedobs[..., -observations.shape[-1] :] = observations
41 | return self.stackedobs, rewards, dones, infos
42 |
43 | def reset(self):
44 | """
45 | Reset all environments
46 | """
47 | obs = self.venv.reset()
48 | self.stackedobs[...] = 0
49 | self.stackedobs[..., -obs.shape[-1] :] = obs
50 | return self.stackedobs
51 |
52 | def close(self):
53 | self.venv.close()
54 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/ppo.rst:
--------------------------------------------------------------------------------
1 | .. _ppo2:
2 |
3 | .. automodule:: stable_baselines3.ppo
4 |
5 | PPO
6 | ===
7 |
8 | The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers)
9 | and TRPO (it uses a trust region to improve the actor).
10 |
11 | The main idea is that after an update, the new policy should be not too far form the old policy.
12 | For that, ppo uses clipping to avoid too large update.
13 |
14 |
15 | .. note::
16 |
17 | PPO contains several modifications from the original algorithm not documented
18 | by OpenAI: advantages are normalized and value function can be also clipped .
19 |
20 |
21 | Notes
22 | -----
23 |
24 | - Original paper: https://arxiv.org/abs/1707.06347
25 | - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
26 | - OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
27 | - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html
28 |
29 |
30 | Can I use?
31 | ----------
32 |
33 | - Recurrent policies: ❌
34 | - Multi processing: ✔️
35 | - Gym spaces:
36 |
37 |
38 | ============= ====== ===========
39 | Space Action Observation
40 | ============= ====== ===========
41 | Discrete ✔️ ✔️
42 | Box ✔️ ✔️
43 | MultiDiscrete ✔️ ✔️
44 | MultiBinary ✔️ ✔️
45 | ============= ====== ===========
46 |
47 | Example
48 | -------
49 |
50 | Train a PPO agent on ``Pendulum-v0`` using 4 environments.
51 |
52 | .. code-block:: python
53 |
54 | import gym
55 |
56 | from stable_baselines3 import PPO
57 | from stable_baselines3.ppo import MlpPolicy
58 | from stable_baselines3.common.cmd_util import make_vec_env
59 |
60 | # Parallel environments
61 | env = make_vec_env('CartPole-v1', n_envs=4)
62 |
63 | model = PPO(MlpPolicy, env, verbose=1)
64 | model.learn(total_timesteps=25000)
65 | model.save("ppo_cartpole")
66 |
67 | del model # remove to demonstrate saving and loading
68 |
69 | model = PPO.load("ppo_cartpole")
70 |
71 | obs = env.reset()
72 | while True:
73 | action, _states = model.predict(obs)
74 | obs, rewards, dones, info = env.step(action)
75 | env.render()
76 |
77 | Parameters
78 | ----------
79 |
80 | .. autoclass:: PPO
81 | :members:
82 | :inherited-members:
83 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/_static/img/colab-badge.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/vec_transpose.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | from stable_baselines3.common.preprocessing import is_image_space
7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
8 |
9 | if typing.TYPE_CHECKING:
10 | from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401
11 |
12 |
13 | class VecTransposeImage(VecEnvWrapper):
14 | """
15 | Re-order channels, from HxWxC to CxHxW.
16 | It is required for PyTorch convolution layers.
17 |
18 | :param venv: (VecEnv)
19 | """
20 |
21 | def __init__(self, venv: VecEnv):
22 | assert is_image_space(venv.observation_space), "The observation space must be an image"
23 |
24 | observation_space = self.transpose_space(venv.observation_space)
25 | super(VecTransposeImage, self).__init__(venv, observation_space=observation_space)
26 |
27 | @staticmethod
28 | def transpose_space(observation_space: spaces.Box) -> spaces.Box:
29 | """
30 | Transpose an observation space (re-order channels).
31 |
32 | :param observation_space: (spaces.Box)
33 | :return: (spaces.Box)
34 | """
35 | assert is_image_space(observation_space), "The observation space must be an image"
36 | width, height, channels = observation_space.shape
37 | new_shape = (channels, width, height)
38 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype)
39 |
40 | @staticmethod
41 | def transpose_image(image: np.ndarray) -> np.ndarray:
42 | """
43 | Transpose an image or batch of images (re-order channels).
44 |
45 | :param image: (np.ndarray)
46 | :return: (np.ndarray)
47 | """
48 | if len(image.shape) == 3:
49 | return np.transpose(image, (2, 0, 1))
50 | return np.transpose(image, (0, 3, 1, 2))
51 |
52 | def step_wait(self) -> "GymStepReturn":
53 | observations, rewards, dones, infos = self.venv.step_wait()
54 | return self.transpose_image(observations), rewards, dones, infos
55 |
56 | def reset(self) -> np.ndarray:
57 | """
58 | Reset all environments
59 | """
60 | return self.transpose_image(self.venv.reset())
61 |
62 | def close(self) -> None:
63 | self.venv.close()
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository contains code for the paper:
2 |
3 | [On the Critical Role of Conventions in Adaptive Human-AI Collaboration](https://openreview.net/pdf?id=8Ln-Bq0mZcy)
4 |
5 | ```
6 | "On the Critical Role of Conventions in Adaptive Human-AI Collaboration"
7 | Andy Shih, Arjun Sawhney, Jovana Kondic, Stefano Ermon, Dorsa Sadigh
8 | Proceedings of the 9th International Conference on Learning Representations (ICLR 2021)
9 |
10 | @inproceedings{SSKESiclr21,
11 | author = {Andy Shih and Arjun Sawhney and Jovana Kondic and Stefano Ermon and Dorsa Sadigh},
12 | title = {On the Critical Role of Conventions in Adaptive Human-AI Collaboration},
13 | booktitle = {Proceedings of the 9th International Conference on Learning Representations (ICLR)},
14 | month = {may},
15 | year = {2021},
16 | keywords = {conference}
17 | }
18 | ```
19 |
20 | # Instructions
21 |
22 | Install gym environment, hanabi environment, and stable baselines
23 | ```
24 | pip install .
25 |
26 | cd hanabi
27 | pip install .
28 | cd ../
29 |
30 | cd stable-baselines3
31 | pip install -e .
32 | cd ../
33 | ```
34 |
35 | Train partner agents
36 | ```
37 | bash bashfiles/arms_train_selfplay.bash 1230 2
38 | bash bashfiles/arms_train_selfplay.bash 1240 2
39 |
40 | for ((i=1230;i<=1235;i++))
41 | do
42 | bash bashfiles/blocks_train_selfplay.bash $i
43 | done
44 | for ((i=1240;i<=1245;i++))
45 | do
46 | bash bashfiles/blocks_train_selfplay.bash $i
47 | done
48 | for ((i=1240;i<=1247;i++))
49 | do
50 | bash bashfiles/hanabi_train_selfplay.bash $i
51 | done
52 | ```
53 |
54 | Run adaptation experiments.
55 | Choose from one of the settings:
56 | - t=1: modular, regularization lambda=0.0
57 | - t=2: modular, regularization lambda=0.3
58 | - t=3: modular, regularization lambda=0.5
59 | - t=4: baseline agg, aggregate gradients
60 | - t=5: baseline agg, aggregate gradients, early stopping
61 | - t=6: baseline modular, no main logits
62 | - t=7: low-dim z + modular, regularization lambda=0.5
63 |
64 | ```
65 | t=1
66 | runid=100
67 |
68 | bash bashfiles/arms_adapt_to_selfplay.bash $runid $t 2
69 | bash bashfiles/arms_adapt_to_fixed.bash $runid $t 2
70 | bash bashfiles/arms_human_adapt_to_fixed.bash $runid $t 2
71 |
72 | bash bashfiles/blocks_adapt_to_selfplay.bash $runid $t
73 | bash bashfiles/blocks_adapt_to_fixed.bash $runid $t
74 |
75 | bash bashfiles/hanabi_adapt_to_selfplay.bash $runid $t
76 |
77 | ```
--------------------------------------------------------------------------------
/my_gym/envs/generate_grid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 |
4 | def generate_grid(grid_size, blocks, blocks_per_color, colors):
5 | grid = [[0 for _ in range(grid_size)] for _ in range(grid_size)]
6 |
7 | for color in colors:
8 | arr = np.random.choice(len(blocks), blocks_per_color, replace=False)
9 | for i in arr:
10 | h, w = blocks[i]
11 |
12 | while True:
13 | r, c = np.random.randint(grid_size - h + 1), np.random.randint(grid_size - w + 1)
14 | valid = True
15 | for i in range(r,r+h):
16 | for j in range(c, c+w):
17 | if grid[i][j]:
18 | valid = False
19 | if not valid:
20 | continue
21 | for i in range(r,r+h):
22 | for j in range(c, c+w):
23 | grid[i][j] = color
24 | break;
25 |
26 | return grid
27 |
28 | def generate_mask(grid_size, percent_mask):
29 | mask = [[0 for _ in range(grid_size)] for _ in range(grid_size)]
30 | grid_sq = grid_size * grid_size
31 | arr = np.random.choice(grid_sq, int(grid_sq * percent_mask), replace=False)
32 |
33 | for a in arr:
34 | r = a // grid_size;
35 | c = a % grid_size;
36 | mask[r][c] = 1;
37 |
38 | return mask
39 |
40 | def make_grids(grid_size, vis1, vis2, blocks, blocks_per_color=1, colors=[2,3]):
41 | goal_grid = generate_grid(grid_size, blocks, blocks_per_color, colors)
42 |
43 | v, p = [vis1, vis2], [None, None]
44 | for i in range(2):
45 | if v[i] == 1:
46 | p[i] = generate_mask(grid_size, 0)
47 | elif v[i] == 2:
48 | p[i] = generate_mask(grid_size, 0.5)
49 | elif v[i] == 3:
50 | p[i] = generate_mask(grid_size, 1)
51 | elif v[i] == 4:
52 | p[i] = generate_mask(grid_size, float(np.random.uniform() < 0.5) )
53 | elif v[i] == 5:
54 | p[0] = generate_mask(grid_size, 0.5)
55 | p[1] = [[1 - x for x in r] for r in p[0]]
56 | else:
57 | raise ValueError('Visibility value is not an integer between [1,5].')
58 | p1_mask, p2_mask = p[0], p[1]
59 |
60 | goal_grid = np.array(goal_grid)
61 | p1_grid = (np.array(1)-p1_mask) * goal_grid + p1_mask # masked squares are denotes as 1
62 | p2_grid = (np.array(1)-p2_mask) * goal_grid + p2_mask
63 |
64 | return goal_grid, p1_grid, p2_grid
65 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from stable_baselines3.common.logger import (
5 | DEBUG,
6 | ScopedConfigure,
7 | configure,
8 | debug,
9 | dump,
10 | error,
11 | info,
12 | make_output_format,
13 | read_csv,
14 | read_json,
15 | record,
16 | record_dict,
17 | record_mean,
18 | reset,
19 | set_level,
20 | warn,
21 | )
22 |
23 | KEY_VALUES = {
24 | "test": 1,
25 | "b": -3.14,
26 | "8": 9.9,
27 | "l": [1, 2],
28 | "a": np.array([1, 2, 3]),
29 | "f": np.array(1),
30 | "g": np.array([[[1]]]),
31 | }
32 |
33 | KEY_EXCLUDED = {}
34 | for key in KEY_VALUES.keys():
35 | KEY_EXCLUDED[key] = None
36 |
37 |
38 | def test_main(tmp_path):
39 | """
40 | tests for the logger module
41 | """
42 | info("hi")
43 | debug("shouldn't appear")
44 | set_level(DEBUG)
45 | debug("should appear")
46 | configure(folder=str(tmp_path))
47 | record("a", 3)
48 | record("b", 2.5)
49 | dump()
50 | record("b", -2.5)
51 | record("a", 5.5)
52 | dump()
53 | info("^^^ should see a = 5.5")
54 | record_mean("b", -22.5)
55 | record_mean("b", -44.4)
56 | record("a", 5.5)
57 | dump()
58 | with ScopedConfigure(None, None):
59 | info("^^^ should see b = 33.3")
60 |
61 | with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]):
62 | record("b", -2.5)
63 | dump()
64 |
65 | reset()
66 | record("a", "longasslongasslongasslongasslongasslongassvalue")
67 | dump()
68 | warn("hey")
69 | error("oh")
70 | record_dict({"test": 1})
71 |
72 |
73 | @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
74 | def test_make_output(tmp_path, _format):
75 | """
76 | test make output
77 |
78 | :param _format: (str) output format
79 | """
80 | if _format == "tensorboard":
81 | # Skip if no tensorboard installed
82 | pytest.importorskip("tensorboard")
83 |
84 | writer = make_output_format(_format, tmp_path)
85 | writer.write(KEY_VALUES, KEY_EXCLUDED)
86 | if _format == "csv":
87 | read_csv(tmp_path / "progress.csv")
88 | elif _format == "json":
89 | read_json(tmp_path / "progress.json")
90 | writer.close()
91 |
92 |
93 | def test_make_output_fail(tmp_path):
94 | """
95 | test value error on logger
96 | """
97 | with pytest.raises(ValueError):
98 | make_output_format("dummy_format", tmp_path)
99 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/algos.rst:
--------------------------------------------------------------------------------
1 | RL Algorithms
2 | =============
3 |
4 | This table displays the rl algorithms that are implemented in the Stable Baselines3 project,
5 | along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
6 |
7 |
8 | ============ =========== ============ ================= =============== ================
9 | Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
10 | ============ =========== ============ ================= =============== ================
11 | A2C ✔️ ✔️ ✔️ ✔️ ✔️
12 | DDPG ✔️ ❌ ❌ ❌ ❌
13 | DQN ❌ ✔️ ❌ ❌ ❌
14 | PPO ✔️ ✔️ ✔️ ✔️ ✔️
15 | SAC ✔️ ❌ ❌ ❌ ❌
16 | TD3 ✔️ ❌ ❌ ❌ ❌
17 | ============ =========== ============ ================= =============== ================
18 |
19 |
20 | .. note::
21 | Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
22 |
23 | Actions ``gym.spaces``:
24 |
25 | - ``Box``: A N-dimensional box that contains every point in the action
26 | space.
27 | - ``Discrete``: A list of possible actions, where each timestep only
28 | one of the actions can be used.
29 | - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
30 | - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
31 |
32 |
33 | .. note::
34 |
35 | Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper
36 | See `Issue #339 `_ for more info.
37 |
38 |
39 | Reproducibility
40 | ---------------
41 |
42 | Completely reproducible results are not guaranteed across Tensorflow releases or different platforms.
43 | Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds.
44 |
45 | In order to make computations deterministics, on your specific problem on one specific platform,
46 | you need to pass a ``seed`` argument at the creation of a model.
47 | If you pass an environment to the model using ``set_env()``, then you also need to seed the environment first.
48 |
49 |
50 | Credit: part of the *Reproducibility* section comes from `PyTorch Documentation `_
51 |
--------------------------------------------------------------------------------
/bashfiles/blocks_adapt_to_fixed.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | thread=$2
7 |
8 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --fixedpartners; fi
9 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --fixedpartners; fi
10 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --fixedpartners; fi
11 |
12 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --fixedpartners; fi
13 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=700000 --fixedpartners; fi
14 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --fixedpartners; fi
15 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --fixedpartners; fi
16 |
17 | for ((i=0;i<=5;i++))
18 | do
19 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 10 + $i)).txt; fi
20 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 20 + $i)).txt; fi
21 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 30 + $i)).txt; fi
22 |
23 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 40 + $i)).txt; fi
24 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=700000 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 50 + $i)).txt; fi
25 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 60 + $i)).txt; fi
26 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 70 + $i)).txt; fi
27 | done
28 |
29 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --testing --zeroshot | tee -a logs/blockszero$(($runid + 10)).txt; fi
30 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --testing --zeroshot | tee -a logs/blockszero$(($runid + 30)).txt; fi
31 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --fixedpartners --testing --zeroshot | tee -a logs/blockszero$(($runid + 60)).txt; fi
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/ddpg.rst:
--------------------------------------------------------------------------------
1 | .. _ddpg:
2 |
3 | .. automodule:: stable_baselines3.ddpg
4 |
5 |
6 | DDPG
7 | ====
8 |
9 | `Deep Deterministic Policy Gradient (DDPG) `_ combines the
10 | trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions.
11 |
12 |
13 | .. rubric:: Available Policies
14 |
15 | .. autosummary::
16 | :nosignatures:
17 |
18 | MlpPolicy
19 |
20 |
21 | Notes
22 | -----
23 |
24 | - Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
25 | - DDPG Paper: https://arxiv.org/abs/1509.02971
26 | - OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
27 |
28 | .. note::
29 |
30 | The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation.
31 | to match the original paper
32 |
33 |
34 | Can I use?
35 | ----------
36 |
37 | - Recurrent policies: ❌
38 | - Multi processing: ❌
39 | - Gym spaces:
40 |
41 |
42 | ============= ====== ===========
43 | Space Action Observation
44 | ============= ====== ===========
45 | Discrete ❌ ✔️
46 | Box ✔️ ✔️
47 | MultiDiscrete ❌ ✔️
48 | MultiBinary ❌ ✔️
49 | ============= ====== ===========
50 |
51 |
52 | Example
53 | -------
54 |
55 | .. code-block:: python
56 |
57 | import gym
58 | import numpy as np
59 |
60 | from stable_baselines3 import DDPG
61 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
62 |
63 | env = gym.make('Pendulum-v0')
64 |
65 | # The noise objects for DDPG
66 | n_actions = env.action_space.shape[-1]
67 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
68 |
69 | model = DDPG('MlpPolicy', env, action_noise=action_noise, verbose=1)
70 | model.learn(total_timesteps=10000, log_interval=10)
71 | model.save("ddpg_pendulum")
72 | env = model.get_env()
73 |
74 | del model # remove to demonstrate saving and loading
75 |
76 | model = DDPG.load("ddpg_pendulum")
77 |
78 | obs = env.reset()
79 | while True:
80 | action, _states = model.predict(obs)
81 | obs, rewards, dones, info = env.step(action)
82 | env.render()
83 |
84 |
85 | Parameters
86 | ----------
87 |
88 | .. autoclass:: DDPG
89 | :members:
90 | :inherited-members:
91 |
92 | .. _ddpg_policies:
93 |
94 | DDPG Policies
95 | -------------
96 |
97 | .. autoclass:: MlpPolicy
98 | :members:
99 | :inherited-members:
100 |
101 |
102 | .. .. autoclass:: CnnPolicy
103 | .. :members:
104 | .. :inherited-members:
105 |
--------------------------------------------------------------------------------
/bashfiles/arms_adapt_to_fixed.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | runid=$1
6 | thread=$2
7 | m=$3
8 |
9 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --fixedpartners; fi
10 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --fixedpartners; fi
11 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --fixedpartners; fi
12 |
13 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --fixedpartners; fi
14 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --fixedpartners; fi
15 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --fixedpartners; fi
16 | if (( ${thread} == 8 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --fixedpartners; fi
17 |
18 | for ((i=0;i<=3;i++))
19 | do
20 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 10 + $i)).txt; fi
21 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 20 + $i)).txt; fi
22 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 30 + $i)).txt; fi
23 |
24 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 40 + $i)).txt; fi
25 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 50 + $i)).txt; fi
26 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 60 + $i)).txt; fi
27 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 70 + $i)).txt; fi
28 | done
29 |
30 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --testing --zeroshot | tee -a logs/armszero${m}_$(($runid + 10)).txt; fi
31 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --testing --zeroshot | tee -a logs/armszero${m}_$(($runid + 30)).txt; fi
32 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --fixedpartners --testing --zeroshot | tee -a logs/armszero${m}_$(($runid + 60)).txt; fi
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copied from stable_baselines
2 | import numpy as np
3 |
4 | from stable_baselines3.common.vec_env import VecEnv
5 |
6 |
7 | def evaluate_policy(
8 | model,
9 | env,
10 | partner_idx=0,
11 | n_eval_episodes=10,
12 | deterministic=True,
13 | render=False,
14 | callback=None,
15 | reward_threshold=None,
16 | return_episode_rewards=False,
17 | ):
18 | """
19 | Runs policy for ``n_eval_episodes`` episodes and returns average reward.
20 | This is made to work only with one env.
21 |
22 | :param model: (BaseAlgorithm) The RL agent you want to evaluate.
23 | :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv``
24 | this must contain only one environment.
25 | :param n_eval_episodes: (int) Number of episode to evaluate the agent
26 | :param deterministic: (bool) Whether to use deterministic or stochastic actions
27 | :param render: (bool) Whether to render the environment or not
28 | :param callback: (callable) callback function to do additional checks,
29 | called after each step.
30 | :param reward_threshold: (float) Minimum expected reward per episode,
31 | this will raise an error if the performance is not met
32 | :param return_episode_rewards: (bool) If True, a list of reward per episode
33 | will be returned instead of the mean.
34 | :return: (float, float) Mean reward per episode, std of reward per episode
35 | returns ([float], [int]) when ``return_episode_rewards`` is True
36 | """
37 | if isinstance(env, VecEnv):
38 | assert env.num_envs == 1, "You must pass only one environment when using this function"
39 |
40 | episode_rewards, episode_lengths = [], []
41 | for _ in range(n_eval_episodes):
42 | obs = env.reset()
43 | done, state = False, None
44 | episode_reward = 0.0
45 | episode_length = 0
46 | while not done:
47 | action, _ = model.predict(observation=obs, partner_idx=partner_idx, deterministic=deterministic)
48 | obs, reward, done, _info = env.step(action)
49 | episode_reward += reward
50 | if callback is not None:
51 | callback(locals(), globals())
52 | episode_length += 1
53 | if render:
54 | env.render()
55 | episode_rewards.append(episode_reward)
56 | episode_lengths.append(episode_length)
57 | mean_reward = np.mean(episode_rewards)
58 | std_reward = np.std(episode_rewards)
59 | if reward_threshold is not None:
60 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
61 | if return_episode_rewards:
62 | return episode_rewards, episode_lengths
63 | return mean_reward, std_reward
64 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_sde.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch as th
3 | from torch.distributions import Normal
4 |
5 | from stable_baselines3 import A2C, PPO, SAC
6 |
7 |
8 | def test_state_dependent_exploration_grad():
9 | """
10 | Check that the gradient correspond to the expected one
11 | """
12 | n_states = 2
13 | state_dim = 3
14 | action_dim = 10
15 | sigma_hat = th.ones(state_dim, action_dim, requires_grad=True)
16 | # Reduce the number of parameters
17 | # sigma_ = th.ones(state_dim, action_dim) * sigma_
18 | # weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
19 | th.manual_seed(2)
20 | weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat)
21 | weights = weights_dist.rsample()
22 |
23 | state = th.rand(n_states, state_dim)
24 | mu = th.ones(action_dim)
25 | noise = th.mm(state, weights)
26 |
27 | action = mu + noise
28 |
29 | variance = th.mm(state ** 2, sigma_hat ** 2)
30 | action_dist = Normal(mu, th.sqrt(variance))
31 |
32 | # Sum over the action dimension because we assume they are independent
33 | loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean()
34 | loss.backward()
35 |
36 | # From Rueckstiess paper: check that the computed gradient
37 | # correspond to the analytical form
38 | grad = th.zeros_like(sigma_hat)
39 | for j in range(action_dim):
40 | # sigma_hat is the std of the gaussian distribution of the noise matrix weights
41 | # sigma_j = sum_j(state_i **2 * sigma_hat_ij ** 2)
42 | # sigma_j is the standard deviation of the policy gaussian distribution
43 | sigma_j = th.sqrt(variance[:, j])
44 | for i in range(state_dim):
45 | # Derivative of the log probability of the jth component of the action
46 | # w.r.t. the standard deviation sigma_j
47 | d_log_policy_j = (noise[:, j] ** 2 - sigma_j ** 2) / sigma_j ** 3
48 | # Derivative of sigma_j w.r.t. sigma_hat_ij
49 | d_log_sigma_j = (state[:, i] ** 2 * sigma_hat[i, j]) / sigma_j
50 | # Chain rule, average over the minibatch
51 | grad[i, j] = (d_log_policy_j * d_log_sigma_j).mean()
52 |
53 | # sigma.grad should be equal to grad
54 | assert sigma_hat.grad.allclose(grad)
55 |
56 |
57 | @pytest.mark.parametrize("model_class", [SAC, A2C, PPO])
58 | @pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []])
59 | @pytest.mark.parametrize("use_expln", [False, True])
60 | def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln):
61 | model = model_class(
62 | "MlpPolicy",
63 | "Pendulum-v0",
64 | use_sde=True,
65 | seed=None,
66 | create_eval_env=True,
67 | verbose=1,
68 | policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln),
69 | )
70 | model.learn(total_timesteps=int(500), eval_freq=250)
71 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/tensorboard.rst:
--------------------------------------------------------------------------------
1 | .. _tensorboard:
2 |
3 | Tensorboard Integration
4 | =======================
5 |
6 | Basic Usage
7 | ------------
8 |
9 | To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent:
10 |
11 | .. code-block:: python
12 |
13 | from stable_baselines3 import A2C
14 |
15 | model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
16 | model.learn(total_timesteps=10000)
17 |
18 |
19 | You can also define custom logging name when training (by default it is the algorithm name)
20 |
21 | .. code-block:: python
22 |
23 | from stable_baselines3 import A2C
24 |
25 | model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
26 | model.learn(total_timesteps=10000, tb_log_name="first_run")
27 | # Pass reset_num_timesteps=False to continue the training curve in tensorboard
28 | # By default, it will create a new curve
29 | model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False)
30 | model.learn(total_timesteps=10000, tb_log_name="third_run", reset_num_timesteps=False)
31 |
32 |
33 | Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
34 |
35 | .. code-block:: bash
36 |
37 | tensorboard --logdir ./a2c_cartpole_tensorboard/
38 |
39 | you can also add past logging folders:
40 |
41 | .. code-block:: bash
42 |
43 | tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/
44 |
45 | It will display information such as the episode reward (when using a ``Monitor`` wrapper), the model losses and other parameter unique to some models.
46 |
47 | .. image:: ../_static/img/Tensorboard_example.png
48 | :width: 600
49 | :alt: plotting
50 |
51 | Logging More Values
52 | -------------------
53 |
54 | Using a callback, you can easily log more values with TensorBoard.
55 | Here is a simple example on how to log both additional tensor or arbitrary scalar value:
56 |
57 | .. code-block:: python
58 |
59 | import numpy as np
60 |
61 | from stable_baselines3 import SAC
62 | from stable_baselines3.common.callbacks import BaseCallback
63 |
64 | model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
65 |
66 |
67 | class TensorboardCallback(BaseCallback):
68 | """
69 | Custom callback for plotting additional values in tensorboard.
70 | """
71 |
72 | def __init__(self, verbose=0):
73 | super(TensorboardCallback, self).__init__(verbose)
74 |
75 | def _on_step(self) -> bool:
76 | # Log scalar value (here a random variable)
77 | value = np.random.random()
78 | self.logger.record('random_value', value)
79 | return True
80 |
81 |
82 | model.learn(50000, callback=TensorboardCallback())
83 |
--------------------------------------------------------------------------------
/partner.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple
3 | import torch as th
4 | import numpy as np
5 | from stable_baselines3 import PPO
6 |
7 | class PartnerPolicy(ABC):
8 | def __init__(self):
9 | pass
10 |
11 | @abstractmethod
12 | def forward(self, obs, deterministic=True) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
13 | pass
14 |
15 | class Partner:
16 | def __init__(self, policy : PartnerPolicy):
17 | self.policy = policy
18 |
19 | class PPOPartnerPolicy(PartnerPolicy):
20 | def __init__(self, model_path):
21 | super(PartnerPolicy, self).__init__()
22 | self.model = PPO.load(model_path)
23 | print("PPO Partner loaded successfully: %s" % model_path)
24 |
25 | def forward(self, obs, deterministic=True):
26 | return self.model.policy.forward(obs, partner_idx=0, deterministic=deterministic)
27 |
28 | class BlocksPermutationPartnerPolicy(PartnerPolicy):
29 | def __init__(self, perm, n=2):
30 | super(PartnerPolicy, self).__init__()
31 | self.perm = perm
32 | self.n = n
33 |
34 | self.action_index = [[i*n + j for j in range(n)] for i in range(n)]
35 |
36 | def forward(self, obs, deterministic=True):
37 | obs = obs[0]
38 | assert(2*self.n**2+1 == len(obs))
39 | goal_grid = obs[:self.n**2].reshape(self.n,self.n)
40 | working_grid = obs[self.n**2:2*(self.n**2)].reshape(self.n,self.n)
41 | turn = obs[-1]
42 |
43 | r, c = self.get_red_block_position(working_grid, self.n, self.n)
44 |
45 | #if r == None or turn >= 2:
46 | if r == None or turn >= 2:
47 | action = self.n**2+1 # pass turn
48 | else:
49 | action = self.perm[self.action_index[r][c]]
50 |
51 | return th.tensor([action]), th.tensor([0.0]), th.tensor([0.0])
52 |
53 | def get_block_position(self, grid, r, c, target):
54 | for i in range(r):
55 | for j in range(c):
56 | if grid[i][j] == target:
57 | return i, j
58 | return None, None
59 |
60 | def get_blue_block_position(self, grid, r, c):
61 | return self.get_block_position(grid, r, c, 3)
62 |
63 | def get_red_block_position(self, grid, r, c):
64 | return self.get_block_position(grid, r, c, 2)
65 |
66 | class ArmsPartnerPolicy(PartnerPolicy):
67 | def __init__(self, perm):
68 | super(PartnerPolicy, self).__init__()
69 | self.perm = th.tensor(perm)
70 |
71 | def forward(self, obs, deterministic=True):
72 | action = self.perm[obs]
73 | return th.cat((action, action), dim=1), th.tensor([0.0]), th.tensor([0.0])
74 |
75 | class LowRankPartnerPolicy(PartnerPolicy):
76 | def __init__(self, n):
77 | super(PartnerPolicy, self).__init__()
78 | self.n = n
79 |
80 | def forward(self, obs, deterministic=True):
81 | action = self.n
82 | return th.tensor([action]), th.tensor([0.0]), th.tensor([0.0])
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/td3.rst:
--------------------------------------------------------------------------------
1 | .. _td3:
2 |
3 | .. automodule:: stable_baselines3.td3
4 |
5 |
6 | TD3
7 | ===
8 |
9 | `Twin Delayed DDPG (TD3) `_ Addressing Function Approximation Error in Actor-Critic Methods.
10 |
11 | TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
12 | We recommend reading `OpenAI Spinning guide on TD3 `_ to learn more about those.
13 |
14 |
15 | .. rubric:: Available Policies
16 |
17 | .. autosummary::
18 | :nosignatures:
19 |
20 | MlpPolicy
21 |
22 |
23 | Notes
24 | -----
25 |
26 | - Original paper: https://arxiv.org/pdf/1802.09477.pdf
27 | - OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
28 | - Original Implementation: https://github.com/sfujim/TD3
29 |
30 | .. note::
31 |
32 | The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
33 | to match the original paper
34 |
35 |
36 | Can I use?
37 | ----------
38 |
39 | - Recurrent policies: ❌
40 | - Multi processing: ❌
41 | - Gym spaces:
42 |
43 |
44 | ============= ====== ===========
45 | Space Action Observation
46 | ============= ====== ===========
47 | Discrete ❌ ✔️
48 | Box ✔️ ✔️
49 | MultiDiscrete ❌ ✔️
50 | MultiBinary ❌ ✔️
51 | ============= ====== ===========
52 |
53 |
54 | Example
55 | -------
56 |
57 | .. code-block:: python
58 |
59 | import gym
60 | import numpy as np
61 |
62 | from stable_baselines3 import TD3
63 | from stable_baselines3.td3.policies import MlpPolicy
64 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
65 |
66 | env = gym.make('Pendulum-v0')
67 |
68 | # The noise objects for TD3
69 | n_actions = env.action_space.shape[-1]
70 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
71 |
72 | model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)
73 | model.learn(total_timesteps=10000, log_interval=10)
74 | model.save("td3_pendulum")
75 | env = model.get_env()
76 |
77 | del model # remove to demonstrate saving and loading
78 |
79 | model = TD3.load("td3_pendulum")
80 |
81 | obs = env.reset()
82 | while True:
83 | action, _states = model.predict(obs)
84 | obs, rewards, dones, info = env.step(action)
85 | env.render()
86 |
87 |
88 | Parameters
89 | ----------
90 |
91 | .. autoclass:: TD3
92 | :members:
93 | :inherited-members:
94 |
95 | .. _td3_policies:
96 |
97 | TD3 Policies
98 | -------------
99 |
100 | .. autoclass:: MlpPolicy
101 | :members:
102 | :inherited-members:
103 |
104 |
105 | .. .. autoclass:: CnnPolicy
106 | .. :members:
107 | .. :inherited-members:
108 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Stable Baselines3 documentation master file, created by
2 | sphinx-quickstart on Thu Sep 26 11:06:54 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Stable Baselines3 docs! - RL Baselines Made Easy
7 | ===========================================================
8 |
9 | `Stable Baselines3 `_ is a set of improved implementations of reinforcement learning algorithms in PyTorch.
10 | It is the next major version of `Stable Baselines `_.
11 |
12 |
13 | Github repository: https://github.com/DLR-RM/stable-baselines3
14 |
15 | RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/rl-baselines3-zoo
16 |
17 | RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
18 |
19 |
20 | Main Features
21 | --------------
22 |
23 | - Unified structure for all algorithms
24 | - PEP8 compliant (unified code style)
25 | - Documented functions and classes
26 | - Tests, high code coverage and type hints
27 | - Clean code
28 | - Tensorboard support
29 |
30 |
31 | .. toctree::
32 | :maxdepth: 2
33 | :caption: User Guide
34 |
35 | guide/install
36 | guide/quickstart
37 | guide/rl_tips
38 | guide/rl
39 | guide/algos
40 | guide/examples
41 | guide/vec_envs
42 | guide/custom_env
43 | guide/custom_policy
44 | guide/callbacks
45 | guide/tensorboard
46 | guide/rl_zoo
47 | guide/migration
48 | guide/checking_nan
49 | guide/developer
50 |
51 |
52 | .. toctree::
53 | :maxdepth: 1
54 | :caption: RL Algorithms
55 |
56 | modules/base
57 | modules/a2c
58 | modules/ddpg
59 | modules/dqn
60 | modules/ppo
61 | modules/sac
62 | modules/td3
63 |
64 | .. toctree::
65 | :maxdepth: 1
66 | :caption: Common
67 |
68 | common/atari_wrappers
69 | common/cmd_util
70 | common/distributions
71 | common/evaluation
72 | common/env_checker
73 | common/monitor
74 | common/logger
75 | common/noise
76 | common/utils
77 |
78 | .. toctree::
79 | :maxdepth: 1
80 | :caption: Misc
81 |
82 | misc/changelog
83 | misc/projects
84 |
85 |
86 | Citing Stable Baselines3
87 | ------------------------
88 | To cite this project in publications:
89 |
90 | .. code-block:: bibtex
91 |
92 | @misc{stable-baselines3,
93 | author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
94 | title = {Stable Baselines3},
95 | year = {2019},
96 | publisher = {GitHub},
97 | journal = {GitHub repository},
98 | howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
99 | }
100 |
101 | Indices and tables
102 | -------------------
103 |
104 | * :ref:`genindex`
105 | * :ref:`search`
106 | * :ref:`modindex`
107 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for dealing with vectorized environments.
3 | """
4 |
5 | from collections import OrderedDict
6 |
7 | import gym
8 | import numpy as np
9 |
10 |
11 | def copy_obs_dict(obs):
12 | """
13 | Deep-copy a dict of numpy arrays.
14 |
15 | :param obs: (OrderedDict): a dict of numpy arrays.
16 | :return (OrderedDict) a dict of copied numpy arrays.
17 | """
18 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
20 |
21 |
22 | def dict_to_obs(space, obs_dict):
23 | """
24 | Convert an internal representation raw_obs into the appropriate type
25 | specified by space.
26 |
27 | :param space: (gym.spaces.Space) an observation space.
28 | :param obs_dict: (OrderedDict) a dict of numpy arrays.
29 | :return (ndarray, tuple or dict): returns an observation
30 | of the same type as space. If space is Dict, function is identity;
31 | if space is Tuple, converts dict to Tuple; otherwise, space is
32 | unstructured and returns the value raw_obs[None].
33 | """
34 | if isinstance(space, gym.spaces.Dict):
35 | return obs_dict
36 | elif isinstance(space, gym.spaces.Tuple):
37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space"
38 | return tuple((obs_dict[i] for i in range(len(space.spaces))))
39 | else:
40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
41 | return obs_dict[None]
42 |
43 |
44 | def obs_space_info(obs_space):
45 | """
46 | Get dict-structured information about a gym.Space.
47 |
48 | Dict spaces are represented directly by their dict of subspaces.
49 | Tuple spaces are converted into a dict with keys indexing into the tuple.
50 | Unstructured spaces are represented by {None: obs_space}.
51 |
52 | :param obs_space: (gym.spaces.Space) an observation space
53 | :return (tuple) A tuple (keys, shapes, dtypes):
54 | keys: a list of dict keys.
55 | shapes: a dict mapping keys to shapes.
56 | dtypes: a dict mapping keys to dtypes.
57 | """
58 | if isinstance(obs_space, gym.spaces.Dict):
59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
60 | subspaces = obs_space.spaces
61 | elif isinstance(obs_space, gym.spaces.Tuple):
62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
63 | else:
64 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
65 | subspaces = {None: obs_space}
66 | keys = []
67 | shapes = {}
68 | dtypes = {}
69 | for key, box in subspaces.items():
70 | keys.append(key)
71 | shapes[key] = box.shape
72 | dtypes[key] = box.dtype
73 | return keys, shapes, dtypes
74 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/custom_env.rst:
--------------------------------------------------------------------------------
1 | .. _custom_env:
2 |
3 | Using Custom Environments
4 | ==========================
5 |
6 | To use the rl baselines with custom environments, they just need to follow the *gym* interface.
7 | That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class):
8 |
9 |
10 | .. note::
11 | If you are using images as input, the input values must be in [0, 255] as the observation
12 | is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies.
13 |
14 |
15 |
16 | .. code-block:: python
17 |
18 | import gym
19 | from gym import spaces
20 |
21 | class CustomEnv(gym.Env):
22 | """Custom Environment that follows gym interface"""
23 | metadata = {'render.modes': ['human']}
24 |
25 | def __init__(self, arg1, arg2, ...):
26 | super(CustomEnv, self).__init__()
27 | # Define action and observation space
28 | # They must be gym.spaces objects
29 | # Example when using discrete actions:
30 | self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
31 | # Example for using image as input:
32 | self.observation_space = spaces.Box(low=0, high=255,
33 | shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
34 |
35 | def step(self, action):
36 | ...
37 | return observation, reward, done, info
38 | def reset(self):
39 | ...
40 | return observation # reward, done, info can't be included
41 | def render(self, mode='human'):
42 | ...
43 | def close (self):
44 | ...
45 |
46 |
47 | Then you can define and train a RL agent with:
48 |
49 | .. code-block:: python
50 |
51 | # Instantiate the env
52 | env = CustomEnv(arg1, ...)
53 | # Define and Train the agent
54 | model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
55 |
56 |
57 | To check that your environment follows the gym interface, please use:
58 |
59 | .. code-block:: python
60 |
61 | from stable_baselines3.common.env_checker import check_env
62 |
63 | env = CustomEnv(arg1, ...)
64 | # It will check your custom environment and output additional warnings if needed
65 | check_env(env)
66 |
67 |
68 |
69 | We have created a `colab notebook `_ for
70 | a concrete example of creating a custom environment.
71 |
72 | You can also find a `complete guide online `_
73 | on creating a custom Gym environment.
74 |
75 |
76 | Optionally, you can also register the environment with gym,
77 | that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env).
78 |
79 |
80 | In the project, for testing purposes, we use a custom environment named ``IdentityEnv``
81 | defined `in this file `_.
82 | An example of how to use it can be found `here `_.
83 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/modules/sac.rst:
--------------------------------------------------------------------------------
1 | .. _sac:
2 |
3 | .. automodule:: stable_baselines3.sac
4 |
5 |
6 | SAC
7 | ===
8 |
9 | `Soft Actor Critic (SAC) `_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
10 |
11 | SAC is the successor of `Soft Q-Learning SQL `_ and incorporates the double Q-learning trick from TD3.
12 | A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
13 |
14 |
15 | .. rubric:: Available Policies
16 |
17 | .. autosummary::
18 | :nosignatures:
19 |
20 | MlpPolicy
21 | CnnPolicy
22 |
23 |
24 | Notes
25 | -----
26 |
27 | - Original paper: https://arxiv.org/abs/1801.01290
28 | - OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
29 | - Original Implementation: https://github.com/haarnoja/sac
30 | - Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/
31 |
32 | .. note::
33 | In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon),
34 | which is the equivalent to the inverse of reward scale in the original SAC paper.
35 | The main reason is that it avoids having too high errors when updating the Q functions.
36 |
37 |
38 | .. note::
39 |
40 | The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
41 | to match the original paper
42 |
43 |
44 | Can I use?
45 | ----------
46 |
47 | - Recurrent policies: ❌
48 | - Multi processing: ❌
49 | - Gym spaces:
50 |
51 |
52 | ============= ====== ===========
53 | Space Action Observation
54 | ============= ====== ===========
55 | Discrete ❌ ✔️
56 | Box ✔️ ✔️
57 | MultiDiscrete ❌ ✔️
58 | MultiBinary ❌ ✔️
59 | ============= ====== ===========
60 |
61 |
62 | Example
63 | -------
64 |
65 | .. code-block:: python
66 |
67 | import gym
68 | import numpy as np
69 |
70 | from stable_baselines3 import SAC
71 | from stable_baselines3.sac import MlpPolicy
72 |
73 | env = gym.make('Pendulum-v0')
74 |
75 | model = SAC(MlpPolicy, env, verbose=1)
76 | model.learn(total_timesteps=10000, log_interval=4)
77 | model.save("sac_pendulum")
78 |
79 | del model # remove to demonstrate saving and loading
80 |
81 | model = SAC.load("sac_pendulum")
82 |
83 | obs = env.reset()
84 | while True:
85 | action, _states = model.predict(obs)
86 | obs, reward, done, info = env.step(action)
87 | env.render()
88 | if done:
89 | obs = env.reset()
90 |
91 | Parameters
92 | ----------
93 |
94 | .. autoclass:: SAC
95 | :members:
96 | :inherited-members:
97 |
98 | .. _sac_policies:
99 |
100 | SAC Policies
101 | -------------
102 |
103 | .. autoclass:: MlpPolicy
104 | :members:
105 | :inherited-members:
106 |
107 | .. .. autoclass:: CnnPolicy
108 | .. :members:
109 | .. :inherited-members:
110 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/vec_envs.rst:
--------------------------------------------------------------------------------
1 | .. _vec_env:
2 |
3 | .. automodule:: stable_baselines3.common.vec_env
4 |
5 | Vectorized Environments
6 | =======================
7 |
8 | Vectorized Environments are a method for stacking multiple independent environments into a single environment.
9 | Instead of training an RL agent on 1 environment per step, it allows us to train it on ``n`` environments per step.
10 | Because of this, ``actions`` passed to the environment are now a vector (of dimension ``n``).
11 | It is the same for ``observations``, ``rewards`` and end of episode signals (``dones``).
12 | In the case of non-array observation spaces such as ``Dict`` or ``Tuple``, where different sub-spaces
13 | may have different shapes, the sub-observations are vectors (of dimension ``n``).
14 |
15 | ============= ======= ============ ======== ========= ================
16 | Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing
17 | ============= ======= ============ ======== ========= ================
18 | DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️
19 | SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
20 | ============= ======= ============ ======== ========= ================
21 |
22 | .. note::
23 |
24 | Vectorized environments are required when using wrappers for frame-stacking or normalization.
25 |
26 | .. note::
27 |
28 | When using vectorized environments, the environments are automatically reset at the end of each episode.
29 | Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
30 | You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
31 |
32 | .. warning::
33 |
34 | When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
35 | On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
36 |
37 | For more information, see Python's `multiprocessing guidelines `_.
38 |
39 | VecEnv
40 | ------
41 |
42 | .. autoclass:: VecEnv
43 | :members:
44 |
45 | DummyVecEnv
46 | -----------
47 |
48 | .. autoclass:: DummyVecEnv
49 | :members:
50 |
51 | SubprocVecEnv
52 | -------------
53 |
54 | .. autoclass:: SubprocVecEnv
55 | :members:
56 |
57 | Wrappers
58 | --------
59 |
60 | VecFrameStack
61 | ~~~~~~~~~~~~~
62 |
63 | .. autoclass:: VecFrameStack
64 | :members:
65 |
66 |
67 | VecNormalize
68 | ~~~~~~~~~~~~
69 |
70 | .. autoclass:: VecNormalize
71 | :members:
72 |
73 |
74 | VecVideoRecorder
75 | ~~~~~~~~~~~~~~~~
76 |
77 | .. autoclass:: VecVideoRecorder
78 | :members:
79 |
80 |
81 | VecCheckNan
82 | ~~~~~~~~~~~~~~~~
83 |
84 | .. autoclass:: VecCheckNan
85 | :members:
86 |
87 |
88 | VecTransposeImage
89 | ~~~~~~~~~~~~~~~~~
90 |
91 | .. autoclass:: VecTransposeImage
92 | :members:
93 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/rl_zoo.rst:
--------------------------------------------------------------------------------
1 | .. _rl_zoo:
2 |
3 | ==================
4 | RL Baselines3 Zoo
5 | ==================
6 |
7 | `RL Baselines3 Zoo `_. is a collection of pre-trained Reinforcement Learning agents using
8 | Stable-Baselines3.
9 | It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos.
10 |
11 | Goals of this repository:
12 |
13 | 1. Provide a simple interface to train and enjoy RL agents
14 | 2. Benchmark the different Reinforcement Learning algorithms
15 | 3. Provide tuned hyperparameters for each environment and RL algorithm
16 | 4. Have fun with the trained agents!
17 |
18 | Installation
19 | ------------
20 |
21 | 1. Clone the repository:
22 |
23 | ::
24 |
25 | git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo
26 | cd rl-baselines3-zoo/
27 |
28 |
29 | .. note::
30 |
31 | You can remove the ``--recursive`` option if you don't want to download the trained agents
32 |
33 |
34 | 2. Install dependencies
35 | ::
36 |
37 | apt-get install swig cmake ffmpeg
38 | pip install -r requirements.txt
39 |
40 |
41 | Train an Agent
42 | --------------
43 |
44 | The hyperparameters for each environment are defined in
45 | ``hyperparameters/algo_name.yml``.
46 |
47 | If the environment exists in this file, then you can train an agent
48 | using:
49 |
50 | ::
51 |
52 | python train.py --algo algo_name --env env_id
53 |
54 | For example (with evaluation and checkpoints):
55 |
56 | ::
57 |
58 | python train.py --algo ppo2 --env CartPole-v1 --eval-freq 10000 --save-freq 50000
59 |
60 |
61 | Continue training (here, load pretrained agent for Breakout and continue
62 | training for 5000 steps):
63 |
64 | ::
65 |
66 | python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
67 |
68 |
69 | Enjoy a Trained Agent
70 | ---------------------
71 |
72 | If the trained agent exists, then you can see it in action using:
73 |
74 | ::
75 |
76 | python enjoy.py --algo algo_name --env env_id
77 |
78 | For example, enjoy A2C on Breakout during 5000 timesteps:
79 |
80 | ::
81 |
82 | python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
83 |
84 |
85 | Hyperparameter Optimization
86 | ---------------------------
87 |
88 | We use `Optuna `_ for optimizing the hyperparameters.
89 |
90 |
91 | Tune the hyperparameters for PPO, using a random sampler and median pruner, 2 parallels jobs,
92 | with a budget of 1000 trials and a maximum of 50000 steps:
93 |
94 | ::
95 |
96 | python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
97 | --sampler random --pruner median
98 |
99 |
100 | Colab Notebook: Try it Online!
101 | ------------------------------
102 |
103 | You can train agents online using Google `colab notebook `_.
104 |
105 |
106 | .. note::
107 |
108 | You can find more information about the rl baselines3 zoo in the repo `README `_. For instance, how to record a video of a trained agent.
109 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 |
5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper
6 |
7 |
8 | class VecCheckNan(VecEnvWrapper):
9 | """
10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
11 | allowing you to know from what the NaN of inf originated from.
12 |
13 | :param venv: (VecEnv) the vectorized environment to wrap
14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning
15 | :param warn_once: (bool) Whether or not to only warn once.
16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well
17 | """
18 |
19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True):
20 | VecEnvWrapper.__init__(self, venv)
21 | self.raise_exception = raise_exception
22 | self.warn_once = warn_once
23 | self.check_inf = check_inf
24 | self._actions = None
25 | self._observations = None
26 | self._user_warned = False
27 |
28 | def step_async(self, actions):
29 | self._check_val(async_step=True, actions=actions)
30 |
31 | self._actions = actions
32 | self.venv.step_async(actions)
33 |
34 | def step_wait(self):
35 | observations, rewards, news, infos = self.venv.step_wait()
36 |
37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
38 |
39 | self._observations = observations
40 | return observations, rewards, news, infos
41 |
42 | def reset(self):
43 | observations = self.venv.reset()
44 | self._actions = None
45 |
46 | self._check_val(async_step=False, observations=observations)
47 |
48 | self._observations = observations
49 | return observations
50 |
51 | def _check_val(self, *, async_step, **kwargs):
52 | # if warn and warn once and have warned once: then stop checking
53 | if not self.raise_exception and self.warn_once and self._user_warned:
54 | return
55 |
56 | found = []
57 | for name, val in kwargs.items():
58 | has_nan = np.any(np.isnan(val))
59 | has_inf = self.check_inf and np.any(np.isinf(val))
60 | if has_inf:
61 | found.append((name, "inf"))
62 | if has_nan:
63 | found.append((name, "nan"))
64 |
65 | if found:
66 | self._user_warned = True
67 | msg = ""
68 | for i, (name, type_val) in enumerate(found):
69 | msg += "found {} in {}".format(type_val, name)
70 | if i != len(found) - 1:
71 | msg += ", "
72 |
73 | msg += ".\r\nOriginated from the "
74 |
75 | if not async_step:
76 | if self._actions is None:
77 | msg += "environment observation (at reset)"
78 | else:
79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions)
80 | else:
81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations)
82 |
83 | if self.raise_exception:
84 | raise ValueError(msg)
85 | else:
86 | warnings.warn(msg, UserWarning)
87 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_run.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
5 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
6 |
7 | normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
8 |
9 |
10 | @pytest.mark.parametrize("model_class", [TD3, DDPG])
11 | @pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
12 | def test_deterministic_pg(model_class, action_noise):
13 | """
14 | Test for DDPG and variants (TD3).
15 | """
16 | model = model_class(
17 | "MlpPolicy",
18 | "Pendulum-v0",
19 | policy_kwargs=dict(net_arch=[64, 64]),
20 | learning_starts=100,
21 | verbose=1,
22 | create_eval_env=True,
23 | action_noise=action_noise,
24 | )
25 | model.learn(total_timesteps=1000, eval_freq=500)
26 |
27 |
28 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
29 | def test_a2c(env_id):
30 | model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
31 | model.learn(total_timesteps=1000, eval_freq=500)
32 |
33 |
34 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
35 | @pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
36 | def test_ppo(env_id, clip_range_vf):
37 | if clip_range_vf is not None and clip_range_vf < 0:
38 | # Should throw an error
39 | with pytest.raises(AssertionError):
40 | model = PPO(
41 | "MlpPolicy",
42 | env_id,
43 | seed=0,
44 | policy_kwargs=dict(net_arch=[16]),
45 | verbose=1,
46 | create_eval_env=True,
47 | clip_range_vf=clip_range_vf,
48 | )
49 | else:
50 | model = PPO(
51 | "MlpPolicy",
52 | env_id,
53 | seed=0,
54 | policy_kwargs=dict(net_arch=[16]),
55 | verbose=1,
56 | create_eval_env=True,
57 | clip_range_vf=clip_range_vf,
58 | )
59 | model.learn(total_timesteps=1000, eval_freq=500)
60 |
61 |
62 | @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
63 | def test_sac(ent_coef):
64 | model = SAC(
65 | "MlpPolicy",
66 | "Pendulum-v0",
67 | policy_kwargs=dict(net_arch=[64, 64]),
68 | learning_starts=100,
69 | verbose=1,
70 | create_eval_env=True,
71 | ent_coef=ent_coef,
72 | action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
73 | )
74 | model.learn(total_timesteps=1000, eval_freq=500)
75 |
76 |
77 | @pytest.mark.parametrize("n_critics", [1, 3])
78 | def test_n_critics(n_critics):
79 | # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
80 | model = SAC(
81 | "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
82 | )
83 | model.learn(total_timesteps=1000)
84 |
85 |
86 | def test_dqn():
87 | model = DQN(
88 | "MlpPolicy",
89 | "CartPole-v1",
90 | policy_kwargs=dict(net_arch=[64, 64]),
91 | learning_starts=500,
92 | buffer_size=500,
93 | learning_rate=3e-4,
94 | verbose=1,
95 | create_eval_env=True,
96 | )
97 | model.learn(total_timesteps=1000, eval_freq=500)
98 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/custom_policy.rst:
--------------------------------------------------------------------------------
1 | .. _custom_policy:
2 |
3 | Custom Policy Network
4 | ---------------------
5 |
6 | Stable Baselines3 provides policy networks for images (CnnPolicies)
7 | and other type of input features (MlpPolicies).
8 |
9 | One way of customising the policy network architecture is to pass arguments when creating the model,
10 | using ``policy_kwargs`` parameter:
11 |
12 | .. code-block:: python
13 |
14 | import gym
15 | import torch as th
16 |
17 | from stable_baselines3 import PPO
18 |
19 | # Custom MLP policy of two layers of size 32 each with tanh activation function
20 | policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[32, 32])
21 | # Create the agent
22 | model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
23 | # Retrieve the environment
24 | env = model.get_env()
25 | # Train the agent
26 | model.learn(total_timesteps=100000)
27 | # Save the agent
28 | model.save("ppo-cartpole")
29 |
30 | del model
31 | # the policy_kwargs are automatically loaded
32 | model = PPO.load("ppo-cartpole")
33 |
34 |
35 | You can also easily define a custom architecture for the policy (or value) network:
36 |
37 | .. note::
38 |
39 | Defining a custom policy class is equivalent to passing ``policy_kwargs``.
40 | However, it lets you name the policy and so usually makes the code clearer.
41 | ``policy_kwargs`` is particularly useful when doing hyperparameter search.
42 |
43 |
44 |
45 | The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many
46 | of them are shared between the policy network and the value network. It is assumed to be a list with the following
47 | structure:
48 |
49 | 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
50 | If the number of ints is zero, there will be no shared layers.
51 | 2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
52 | It is formatted like ``dict(vf=[], pi=[])``.
53 | If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
54 |
55 | In short: ``[, dict(vf=[], pi=[])]``.
56 |
57 | Examples
58 | ~~~~~~~~
59 |
60 | Two shared layers of size 128: ``net_arch=[128, 128]``
61 |
62 |
63 | .. code-block:: none
64 |
65 | obs
66 | |
67 | <128>
68 | |
69 | <128>
70 | / \
71 | action value
72 |
73 |
74 | Value network deeper than policy network, first layer shared: ``net_arch=[128, dict(vf=[256, 256])]``
75 |
76 | .. code-block:: none
77 |
78 | obs
79 | |
80 | <128>
81 | / \
82 | action <256>
83 | |
84 | <256>
85 | |
86 | value
87 |
88 |
89 | Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
90 |
91 | .. code-block:: none
92 |
93 | obs
94 | |
95 | <128>
96 | / \
97 | <16> <256>
98 | | |
99 | action value
100 |
101 |
102 |
103 | If your task requires even more granular control over the policy architecture, you can redefine the policy directly.
104 |
105 | **TODO**
106 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_monitor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import uuid
4 |
5 | import gym
6 | import pandas
7 |
8 | from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results
9 |
10 |
11 | def test_monitor(tmp_path):
12 | """
13 | Test the monitor wrapper
14 | """
15 | env = gym.make("CartPole-v1")
16 | env.seed(0)
17 | monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
18 | monitor_env = Monitor(env, monitor_file)
19 | monitor_env.reset()
20 | total_steps = 1000
21 | ep_rewards = []
22 | ep_lengths = []
23 | ep_len, ep_reward = 0, 0
24 | for _ in range(total_steps):
25 | _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample())
26 | ep_len += 1
27 | ep_reward += reward
28 | if done:
29 | ep_rewards.append(ep_reward)
30 | ep_lengths.append(ep_len)
31 | monitor_env.reset()
32 | ep_len, ep_reward = 0, 0
33 |
34 | monitor_env.close()
35 | assert monitor_env.get_total_steps() == total_steps
36 | assert sum(ep_lengths) == sum(monitor_env.get_episode_lengths())
37 | assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
38 | _ = monitor_env.get_episode_times()
39 |
40 | with open(monitor_file, "rt") as file_handler:
41 | first_line = file_handler.readline()
42 | assert first_line.startswith("#")
43 | metadata = json.loads(first_line[1:])
44 | assert metadata["env_id"] == "CartPole-v1"
45 | assert set(metadata.keys()) == {"env_id", "t_start"}, "Incorrect keys in monitor metadata"
46 |
47 | last_logline = pandas.read_csv(file_handler, index_col=None)
48 | assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
49 | os.remove(monitor_file)
50 |
51 |
52 | def test_monitor_load_results(tmp_path):
53 | """
54 | test load_results on log files produced by the monitor wrapper
55 | """
56 | tmp_path = str(tmp_path)
57 | env1 = gym.make("CartPole-v1")
58 | env1.seed(0)
59 | monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
60 | monitor_env1 = Monitor(env1, monitor_file1)
61 |
62 | monitor_files = get_monitor_files(tmp_path)
63 | assert len(monitor_files) == 1
64 | assert monitor_file1 in monitor_files
65 |
66 | monitor_env1.reset()
67 | episode_count1 = 0
68 | for _ in range(1000):
69 | _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample())
70 | if done:
71 | episode_count1 += 1
72 | monitor_env1.reset()
73 |
74 | results_size1 = len(load_results(os.path.join(tmp_path)).index)
75 | assert results_size1 == episode_count1
76 |
77 | env2 = gym.make("CartPole-v1")
78 | env2.seed(0)
79 | monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
80 | monitor_env2 = Monitor(env2, monitor_file2)
81 | monitor_files = get_monitor_files(tmp_path)
82 | assert len(monitor_files) == 2
83 | assert monitor_file1 in monitor_files
84 | assert monitor_file2 in monitor_files
85 |
86 | monitor_env2.reset()
87 | episode_count2 = 0
88 | for _ in range(1000):
89 | _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample())
90 | if done:
91 | episode_count2 += 1
92 | monitor_env2.reset()
93 |
94 | results_size2 = len(load_results(os.path.join(tmp_path)).index)
95 |
96 | assert results_size2 == (results_size1 + episode_count2)
97 |
98 | os.remove(monitor_file1)
99 | os.remove(monitor_file2)
100 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/vec_video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from gym.wrappers.monitoring import video_recorder
4 |
5 | from stable_baselines3.common import logger
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper
7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
9 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
10 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
11 |
12 |
13 | class VecVideoRecorder(VecEnvWrapper):
14 | """
15 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
16 | It requires ffmpeg or avconv to be installed on the machine.
17 |
18 | :param venv: (VecEnv or VecEnvWrapper)
19 | :param video_folder: (str) Where to save videos
20 | :param record_video_trigger: (func) Function that defines when to start recording.
21 | The function takes the current number of step,
22 | and returns whether we should start recording or not.
23 | :param video_length: (int) Length of recorded videos
24 | :param name_prefix: (str) Prefix to the video name
25 | """
26 |
27 | def __init__(self, venv, video_folder, record_video_trigger, video_length=200, name_prefix="rl-video"):
28 |
29 | VecEnvWrapper.__init__(self, venv)
30 |
31 | self.env = venv
32 | # Temp variable to retrieve metadata
33 | temp_env = venv
34 |
35 | # Unwrap to retrieve metadata dict
36 | # that will be used by gym recorder
37 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack):
38 | temp_env = temp_env.venv
39 |
40 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
41 | metadata = temp_env.get_attr("metadata")[0]
42 | else:
43 | metadata = temp_env.metadata
44 |
45 | self.env.metadata = metadata
46 |
47 | self.record_video_trigger = record_video_trigger
48 | self.video_recorder = None
49 |
50 | self.video_folder = os.path.abspath(video_folder)
51 | # Create output folder if needed
52 | os.makedirs(self.video_folder, exist_ok=True)
53 |
54 | self.name_prefix = name_prefix
55 | self.step_id = 0
56 | self.video_length = video_length
57 |
58 | self.recording = False
59 | self.recorded_frames = 0
60 |
61 | def reset(self):
62 | obs = self.venv.reset()
63 | self.start_video_recorder()
64 | return obs
65 |
66 | def start_video_recorder(self):
67 | self.close_video_recorder()
68 |
69 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
70 | base_path = os.path.join(self.video_folder, video_name)
71 | self.video_recorder = video_recorder.VideoRecorder(
72 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
73 | )
74 |
75 | self.video_recorder.capture_frame()
76 | self.recorded_frames = 1
77 | self.recording = True
78 |
79 | def _video_enabled(self):
80 | return self.record_video_trigger(self.step_id)
81 |
82 | def step_wait(self):
83 | obs, rews, dones, infos = self.venv.step_wait()
84 |
85 | self.step_id += 1
86 | if self.recording:
87 | self.video_recorder.capture_frame()
88 | self.recorded_frames += 1
89 | if self.recorded_frames > self.video_length:
90 | logger.info("Saving video to ", self.video_recorder.path)
91 | self.close_video_recorder()
92 | elif self._video_enabled():
93 | self.start_video_recorder()
94 |
95 | return obs, rews, dones, infos
96 |
97 | def close_video_recorder(self):
98 | if self.recording:
99 | self.video_recorder.close()
100 | self.recording = False
101 | self.recorded_frames = 1
102 |
103 | def close(self):
104 | VecEnvWrapper.close(self)
105 | self.close_video_recorder()
106 |
107 | def __del__(self):
108 | self.close()
109 |
--------------------------------------------------------------------------------
/stable-baselines3/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing to Stable-Baselines3
2 |
3 | If you are interested in contributing to Stable-Baselines, your contributions will fall
4 | into two categories:
5 | 1. You want to propose a new Feature and implement it
6 | - Create an issue about your intended feature, and we shall discuss the design and
7 | implementation. Once we agree that the plan looks good, go ahead and implement it.
8 | 2. You want to implement a feature or bug-fix for an outstanding issue
9 | - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues
10 | - Pick an issue or feature and comment on the task that you want to work on this feature.
11 | - If you need more context on a particular issue, please ask and we shall provide.
12 |
13 | Once you finish implementing a feature or bug-fix, please send a Pull Request to
14 | https://github.com/DLR-RM/stable-baselines3
15 |
16 |
17 | If you are not familiar with creating a Pull Request, here are some guides:
18 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
19 | - https://help.github.com/articles/creating-a-pull-request/
20 |
21 |
22 | ## Developing Stable-Baselines3
23 |
24 | To develop Stable-Baselines3 on your machine, here are some tips:
25 |
26 | 1. Clone a copy of Stable-Baselines3 from source:
27 |
28 | ```bash
29 | git clone https://github.com/DLR-RM/stable-baselines3
30 | cd stable-baselines3/
31 | ```
32 |
33 | 2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:
34 |
35 | ```bash
36 | pip install -e .[docs,tests,extra]
37 | ```
38 |
39 | ## Codestyle
40 |
41 | We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
42 |
43 | **Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
44 |
45 | Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template:
46 |
47 | ```python
48 |
49 | def my_function(arg1: type1, arg2: type2) -> returntype:
50 | """
51 | Short description of the function.
52 |
53 | :param arg1: (type1) describe what is arg1
54 | :param arg2: (type2) describe what is arg2
55 | :return: (returntype) describe what is returned
56 | """
57 | ...
58 | return my_variable
59 | ```
60 |
61 | ## Pull Request (PR)
62 |
63 | Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
64 |
65 | Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli).
66 | A PR must pass the Continuous Integration tests to be merged with the master branch.
67 |
68 |
69 | ## Tests
70 |
71 | All new features must add tests in the `tests/` folder ensuring that everything works fine.
72 | We use [pytest](https://pytest.org/).
73 | Also, when a bug fix is proposed, tests should be added to avoid regression.
74 |
75 | To run tests with `pytest`:
76 |
77 | ```
78 | make pytest
79 | ```
80 |
81 | Type checking with `pytype`:
82 |
83 | ```
84 | make type
85 | ```
86 |
87 | Codestyle check with `black`, `isort` and `flake8`:
88 |
89 | ```
90 | make check-codestyle
91 | make lint
92 | ```
93 |
94 | To run `pytype`, `format` and `lint` in one command:
95 | ```
96 | make commit-checks
97 | ```
98 |
99 | Build the documentation:
100 |
101 | ```
102 | make doc
103 | ```
104 |
105 | Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that):
106 |
107 | ```
108 | make spelling
109 | ```
110 |
111 |
112 | ## Changelog and Documentation
113 |
114 | Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed.
115 | You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too.
116 | A README is present in the `docs/` folder for instructions on how to build the documentation.
117 |
118 |
119 | Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
120 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_distributions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch as th
3 |
4 | from stable_baselines3 import A2C, PPO
5 | from stable_baselines3.common.distributions import (
6 | BernoulliDistribution,
7 | CategoricalDistribution,
8 | DiagGaussianDistribution,
9 | MultiCategoricalDistribution,
10 | SquashedDiagGaussianDistribution,
11 | StateDependentNoiseDistribution,
12 | TanhBijector,
13 | )
14 | from stable_baselines3.common.utils import set_random_seed
15 |
16 | N_ACTIONS = 2
17 | N_FEATURES = 3
18 | N_SAMPLES = int(5e6)
19 |
20 |
21 | def test_bijector():
22 | """
23 | Test TanhBijector
24 | """
25 | actions = th.ones(5) * 2.0
26 | bijector = TanhBijector()
27 |
28 | squashed_actions = bijector.forward(actions)
29 | # Check that the boundaries are not violated
30 | assert th.max(th.abs(squashed_actions)) <= 1.0
31 | # Check the inverse method
32 | assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
33 |
34 |
35 | @pytest.mark.parametrize("model_class", [A2C, PPO])
36 | def test_squashed_gaussian(model_class):
37 | """
38 | Test run with squashed Gaussian (notably entropy computation)
39 | """
40 | model = model_class("MlpPolicy", "Pendulum-v0", use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
41 | model.learn(500)
42 |
43 | gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
44 | dist = SquashedDiagGaussianDistribution(N_ACTIONS)
45 | _, log_std = dist.proba_distribution_net(N_FEATURES)
46 | dist = dist.proba_distribution(gaussian_mean, log_std)
47 | actions = dist.get_actions()
48 | assert th.max(th.abs(actions)) <= 1.0
49 |
50 |
51 | def test_sde_distribution():
52 | n_actions = 1
53 | deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1
54 | state = th.ones(N_SAMPLES, N_FEATURES) * 0.3
55 | dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False)
56 |
57 | set_random_seed(1)
58 | _, log_std = dist.proba_distribution_net(N_FEATURES)
59 | dist.sample_weights(log_std, batch_size=N_SAMPLES)
60 |
61 | dist = dist.proba_distribution(deterministic_actions, log_std, state)
62 | actions = dist.get_actions()
63 |
64 | assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=2e-3)
65 | assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=2e-3)
66 |
67 |
68 | # TODO: analytical form for squashed Gaussian?
69 | @pytest.mark.parametrize(
70 | "dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),]
71 | )
72 | def test_entropy(dist):
73 | # The entropy can be approximated by averaging the negative log likelihood
74 | # mean negative log likelihood == differential entropy
75 | set_random_seed(1)
76 | state = th.rand(N_SAMPLES, N_FEATURES)
77 | deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS)
78 | _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
79 |
80 | if isinstance(dist, DiagGaussianDistribution):
81 | dist = dist.proba_distribution(deterministic_actions, log_std)
82 | else:
83 | dist.sample_weights(log_std, batch_size=N_SAMPLES)
84 | dist = dist.proba_distribution(deterministic_actions, log_std, state)
85 |
86 | actions = dist.get_actions()
87 | entropy = dist.entropy()
88 | log_prob = dist.log_prob(actions)
89 | assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
90 |
91 |
92 | categorical_params = [
93 | (CategoricalDistribution(N_ACTIONS), N_ACTIONS),
94 | (MultiCategoricalDistribution([2, 3]), sum([2, 3])),
95 | (BernoulliDistribution(N_ACTIONS), N_ACTIONS),
96 | ]
97 |
98 |
99 | @pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params)
100 | def test_categorical(dist, CAT_ACTIONS):
101 | # The entropy can be approximated by averaging the negative log likelihood
102 | # mean negative log likelihood == entropy
103 | set_random_seed(1)
104 | action_logits = th.rand(N_SAMPLES, CAT_ACTIONS)
105 | dist = dist.proba_distribution(action_logits)
106 | actions = dist.get_actions()
107 | entropy = dist.entropy()
108 | log_prob = dist.log_prob(actions)
109 | assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
110 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/developer.rst:
--------------------------------------------------------------------------------
1 | .. _developer:
2 |
3 | ================
4 | Developer Guide
5 | ================
6 |
7 | This guide is meant for those who want to understand the internals and the design choices of Stable-Baselines3.
8 |
9 |
10 | At first, you should read the two issues where the design choices were discussed:
11 |
12 | - https://github.com/hill-a/stable-baselines/issues/576
13 | - https://github.com/hill-a/stable-baselines/issues/733
14 |
15 |
16 | The library is not meant to be modular, although inheritance is used to reduce code duplication.
17 |
18 |
19 | Algorithms Structure
20 | ====================
21 |
22 |
23 | Each algorithm (on-policy and off-policy ones) follows a common structure.
24 | Policy contains code for acting in the environment, and algorithm updates this policy.
25 | There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (``policies.py``).
26 |
27 | Each algorithm has two main methods:
28 |
29 | - ``.collect_rollouts()`` which defines how new samples are collected, usually inherited from the base class. Those samples are then stored in a ``RolloutBuffer`` (discarded after the gradient update) or ``ReplayBuffer``
30 |
31 | - ``.train()`` which updates the parameters using samples from the buffer
32 |
33 |
34 | Where to start?
35 | ===============
36 |
37 | The first thing you need to read and understand are the base classes in the ``common/`` folder:
38 |
39 | - ``BaseAlgorithm`` in ``base_class.py`` which defines how an RL class should look like.
40 | It contains also all the "glue code" for saving/loading and the common operations (wrapping environments)
41 |
42 | - ``BasePolicy`` in ``policies.py`` which defines how a policy class should look like.
43 | It contains also all the magic for the ``.predict()`` method, to handle as many spaces/cases as possible
44 |
45 | - ``OffPolicyAlgorithm`` in ``off_policy_algorithm.py`` that contains the implementation of ``collect_rollouts()`` for the off-policy algorithms,
46 | and similarly ``OnPolicyAlgorithm`` in ``on_policy_algorithm.py``.
47 |
48 |
49 | All the environments handled internally are assumed to be ``VecEnv`` (``gym.Env`` are automatically wrapped).
50 |
51 |
52 | Pre-Processing
53 | ==============
54 |
55 | To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
56 | Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``.
57 |
58 | For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention.
59 |
60 |
61 | Policy Structure
62 | ================
63 |
64 | When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
65 | In SB3, "policy" refers to the class that handles all the networks useful for training,
66 | so not only the network used to predict actions (the "learned controller").
67 | For instance, the ``TD3`` policy contains the actor, the critic and the target networks.
68 |
69 | To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use ``ActorCriticPolicy``),
70 | SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks,
71 | respectively. Importing ``[algorithm]/policies.py`` registers an appropriate policy for that algorithm under those names.
72 |
73 | Probability distributions
74 | =========================
75 |
76 | When needed, the policies handle the different probability distributions.
77 | All distributions are located in ``common/distributions.py`` and follow the same interface.
78 | Each distribution corresponds to a type of action space (e.g. ``Categorical`` is the one used for discrete actions.
79 | For continuous actions, we can use multiple distributions ("DiagGaussian", "SquashedGaussian" or "StateDependentDistribution")
80 |
81 | State-Dependent Exploration
82 | ===========================
83 |
84 | State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots,
85 | that was the starting point for the Stable-Baselines3 library.
86 | I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3): https://arxiv.org/abs/2005.05719
87 |
88 | Misc
89 | ====
90 |
91 | The rest of the ``common/`` is composed of helpers (e.g. evaluation helpers) or basic components (like the callbacks).
92 | The ``type_aliases.py`` file contains common type hint aliases like ``GymStepReturn``.
93 |
94 | Et voilà?
95 |
96 | After reading this guide and the mentioned files, you should be now able to understand the design logic behind the library ;)
97 |
--------------------------------------------------------------------------------
/run_arms_human.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import argparse
3 | import os, time
4 |
5 | import gym
6 | import my_gym
7 |
8 | from stable_baselines3 import PPO
9 | from stable_baselines3.common.env_checker import check_env
10 | from stable_baselines3.common import make_vec_env
11 | from stable_baselines3.common.vec_env import DummyVecEnv
12 | from stable_baselines3.common.evaluation import evaluate_policy
13 |
14 | import torch as th
15 | import torch.nn as nn
16 | from interactive_policy import ArmsPolicy
17 | from partner_config import get_arms_human_partners
18 | from util import check_optimal, learn, load_model
19 | from util import adapt_task, adapt_partner_baseline, adapt_partner_modular, adapt_partner_scratch
20 |
21 | warnings.simplefilter(action='ignore', category=FutureWarning)
22 |
23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24 | parser.add_argument('--run', type=int, default=0, help="Run ID. In case you want to run replicates")
25 | parser.add_argument('--netsz', type=int, default=30, help="Size of policy network")
26 | parser.add_argument('--latentz', type=int, default=30, help="Size of latent z dimension")
27 |
28 | parser.add_argument('--mreg', type=float, default=0.0, help="Marginal regularization.")
29 | parser.add_argument('--baseline', action='store_true', default=False, help="Baseline: no modular separation.")
30 | parser.add_argument('--nomain', action='store_true', default=False, help="Baseline: don't use main logits.")
31 |
32 | parser.add_argument('--timesteps', type=int, default=10000, help="Number of timesteps to train for")
33 | parser.add_argument('--testing', action='store_true', default=False, help="Testing.")
34 |
35 | parser.add_argument('--k', type=int, default=0, help="When fixedpartner=True, k is the index of the test partner")
36 |
37 | args = parser.parse_args()
38 | print(args)
39 |
40 | def get_model_name_and_path(run, mreg=0.00):
41 | layout = [
42 | ('run={:04d}', run),
43 | ('netsz={:03d}', args.netsz),
44 | ('mreg={:.2f}', mreg),
45 | ]
46 |
47 | m_name = '_'.join([t.format(v) for (t, v) in layout])
48 | m_path = 'output/armshuman_' + m_name
49 | return m_name, m_path
50 |
51 | model_name, model_path = get_model_name_and_path(args.run, mreg=args.mreg)
52 |
53 | HP = {
54 | 'n_steps': 64,
55 | 'n_steps_testing': 16,
56 | 'batch_size': 16,
57 | 'n_epochs': 20,
58 | 'n_epochs_testing': 50,
59 | 'mreg': args.mreg,
60 | }
61 |
62 | setting, partner_type = "", "fixed"
63 | TRAIN_PARTNERS, TEST_PARTNERS = get_arms_human_partners(setting, partner_type)
64 | PARTNERS = [ TEST_PARTNERS[args.k % len(TEST_PARTNERS)] ] if args.testing else TRAIN_PARTNERS
65 |
66 | def main():
67 | global PARTNERS
68 | env = gym.make('arms-human-v0')
69 | num_partners = len(PARTNERS) if PARTNERS is not None else 1
70 |
71 | print("model path: ", model_path)
72 | net_arch = [args.netsz,args.latentz]
73 | partner_net_arch = [args.netsz,args.netsz]
74 | policy_kwargs = dict(activation_fn=nn.ReLU,
75 | net_arch=[dict(vf=net_arch, pi=net_arch)],
76 | partner_net_arch=[dict(vf=partner_net_arch, pi=partner_net_arch)],
77 | num_partners=num_partners,
78 | baseline=args.baseline,
79 | nomain=args.nomain,
80 | )
81 |
82 | def load_model_fn(partners, testing, try_load=True):
83 | return load_model(model_path=model_path, policy_class=ArmsPolicy, policy_kwargs=policy_kwargs, env=env, hp=HP, partners=partners, testing=testing, try_load=try_load)
84 |
85 | def learn_model_fn(model, timesteps, save, period):
86 | return learn(model, model_name=model_name, model_path=model_path, timesteps=timesteps, save=save, period=period, save_thresh=None)
87 |
88 | # TRAINING
89 | if not args.testing:
90 | print("#section Training")
91 | model = load_model_fn(partners=PARTNERS, testing=False)
92 | learn_model_fn(model, timesteps=args.timesteps, save=True, period=200)
93 |
94 | ts, period = 240, HP['n_steps_testing']
95 |
96 | # TESTING
97 | if args.testing:
98 | if args.baseline: adapt_partner_baseline(load_model_fn, learn_model_fn, partners=PARTNERS, timesteps=ts, period=period, do_optimal=True)
99 | else: adapt_partner_modular(load_model_fn, learn_model_fn, partners=PARTNERS, timesteps=ts, period=period, do_optimal=True)
100 | adapt_partner_scratch(load_model_fn, learn_model_fn, partners=PARTNERS, timesteps=ts, period=period, do_optimal=True)
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/stable-baselines3/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import find_packages, setup
4 |
5 | with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler:
6 | __version__ = file_handler.read().strip()
7 |
8 |
9 | long_description = """
10 |
11 | # Stable Baselines3
12 |
13 | Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
14 |
15 | These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
16 |
17 |
18 | ## Links
19 |
20 | Repository:
21 | https://github.com/DLR-RM/stable-baselines3
22 |
23 | Medium article:
24 | https://medium.com/@araffin/df87c4b2fc82
25 |
26 | Documentation:
27 | https://stable-baselines3.readthedocs.io/en/master/
28 |
29 | RL Baselines3 Zoo:
30 | https://github.com/DLR-RM/rl-baselines3-zoo
31 |
32 | ## Quick example
33 |
34 | Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym.
35 |
36 | Here is a quick example of how to train and run PPO on a cartpole environment:
37 |
38 | ```python
39 | import gym
40 |
41 | from stable_baselines3 import PPO
42 |
43 | env = gym.make('CartPole-v1')
44 |
45 | model = PPO('MlpPolicy', env, verbose=1)
46 | model.learn(total_timesteps=10000)
47 |
48 | obs = env.reset()
49 | for i in range(1000):
50 | action, _states = model.predict(obs, deterministic=True)
51 | obs, reward, done, info = env.step(action)
52 | env.render()
53 | if done:
54 | obs = env.reset()
55 | ```
56 |
57 | Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
58 |
59 | ```python
60 | from stable_baselines3 import PPO
61 |
62 | model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
63 | ```
64 |
65 | """ # noqa:E501
66 |
67 |
68 | setup(
69 | name="stable_baselines3",
70 | packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
71 | package_data={"stable_baselines3": ["py.typed", "version.txt"]},
72 | install_requires=[
73 | "gym>=0.17",
74 | "numpy",
75 | "torch>=1.4.0",
76 | # For saving models
77 | "cloudpickle",
78 | # For reading logs
79 | "pandas",
80 | # Plotting learning curves
81 | "matplotlib",
82 | ],
83 | extras_require={
84 | "tests": [
85 | # Run tests and coverage
86 | "pytest",
87 | "pytest-cov",
88 | "pytest-env",
89 | "pytest-xdist",
90 | # Type check
91 | "pytype",
92 | # Lint code
93 | "flake8>=3.8",
94 | # Sort imports
95 | "isort>=5.0",
96 | # Reformat
97 | "black",
98 | ],
99 | "docs": [
100 | "sphinx",
101 | "sphinx-autobuild",
102 | "sphinx-rtd-theme",
103 | # For spelling
104 | "sphinxcontrib.spelling",
105 | # Type hints support
106 | # 'sphinx-autodoc-typehints'
107 | ],
108 | "extra": [
109 | # For render
110 | "opencv-python",
111 | # For atari games,
112 | "atari_py~=0.2.0",
113 | "pillow",
114 | # Tensorboard support
115 | "tensorboard",
116 | # Checking memory taken by replay buffer
117 | "psutil",
118 | ],
119 | },
120 | description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
121 | author="Antonin Raffin",
122 | url="https://github.com/DLR-RM/stable-baselines3",
123 | author_email="antonin.raffin@dlr.de",
124 | keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
125 | "gym openai stable baselines toolbox python data-science",
126 | license="MIT",
127 | long_description=long_description,
128 | long_description_content_type="text/markdown",
129 | version=__version__,
130 | )
131 |
132 | # python setup.py sdist
133 | # python setup.py bdist_wheel
134 | # twine upload --repository-url https://test.pypi.org/legacy/ dist/*
135 | # twine upload dist/*
136 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/results_plotter.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | # import matplotlib
7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
8 | from matplotlib import pyplot as plt
9 |
10 | from stable_baselines3.common.monitor import load_results
11 |
12 | X_TIMESTEPS = "timesteps"
13 | X_EPISODES = "episodes"
14 | X_WALLTIME = "walltime_hrs"
15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
16 | EPISODES_WINDOW = 100
17 |
18 |
19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
20 | """
21 | Apply a rolling window to a np.ndarray
22 |
23 | :param array: (np.ndarray) the input Array
24 | :param window: (int) length of the rolling window
25 | :return: (np.ndarray) rolling window on the input array
26 | """
27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
28 | strides = array.strides + (array.strides[-1],)
29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
30 |
31 |
32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
33 | """
34 | Apply a function to the rolling window of 2 arrays
35 |
36 | :param var_1: (np.ndarray) variable 1
37 | :param var_2: (np.ndarray) variable 2
38 | :param window: (int) length of the rolling window
39 | :param func: (numpy function) function to apply on the rolling window on variable 2 (such as np.mean)
40 | :return: (Tuple[np.ndarray, np.ndarray]) the rolling output with applied function
41 | """
42 | var_2_window = rolling_window(var_2, window)
43 | function_on_var2 = func(var_2_window, axis=-1)
44 | return var_1[window - 1 :], function_on_var2
45 |
46 |
47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
48 | """
49 | Decompose a data frame variable to x ans ys
50 |
51 | :param data_frame: (pd.DataFrame) the input data
52 | :param x_axis: (str) the axis for the x and y output
53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
54 | :return: (Tuple[np.ndarray, np.ndarray]) the x and y output
55 | """
56 | if x_axis == X_TIMESTEPS:
57 | x_var = np.cumsum(data_frame.l.values)
58 | y_var = data_frame.r.values
59 | elif x_axis == X_EPISODES:
60 | x_var = np.arange(len(data_frame))
61 | y_var = data_frame.r.values
62 | elif x_axis == X_WALLTIME:
63 | # Convert to hours
64 | x_var = data_frame.t.values / 3600.0
65 | y_var = data_frame.r.values
66 | else:
67 | raise NotImplementedError
68 | return x_var, y_var
69 |
70 |
71 | def plot_curves(
72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
73 | ) -> None:
74 | """
75 | plot the curves
76 |
77 | :param xy_list: (List[Tuple[np.ndarray, np.ndarray]]) the x and y coordinates to plot
78 | :param x_axis: (str) the axis for the x and y output
79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
80 | :param title: (str) the title of the plot
81 | :param figsize: (Tuple[int, int]) Size of the figure (width, height)
82 | """
83 |
84 | plt.figure(title, figsize=figsize)
85 | max_x = max(xy[0][-1] for xy in xy_list)
86 | min_x = 0
87 | for (i, (x, y)) in enumerate(xy_list):
88 | plt.scatter(x, y, s=2)
89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size.
90 | if x.shape[0] >= EPISODES_WINDOW:
91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW
92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
93 | plt.plot(x, y_mean)
94 | plt.xlim(min_x, max_x)
95 | plt.title(title)
96 | plt.xlabel(x_axis)
97 | plt.ylabel("Episode Rewards")
98 | plt.tight_layout()
99 |
100 |
101 | def plot_results(
102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
103 | ) -> None:
104 | """
105 | Plot the results using csv files from ``Monitor`` wrapper.
106 |
107 | :param dirs: ([str]) the save location of the results to plot
108 | :param num_timesteps: (int or None) only plot the points below this value
109 | :param x_axis: (str) the axis for the x and y output
110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
111 | :param task_name: (str) the title of the task to plot
112 | :param figsize: (Tuple[int, int]) Size of the figure (width, height)
113 | """
114 |
115 | data_frames = []
116 | for folder in dirs:
117 | data_frame = load_results(folder)
118 | if num_timesteps is not None:
119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
120 | data_frames.append(data_frame)
121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
122 | plot_curves(xy_list, x_axis, task_name, figsize)
123 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/vec_env/dummy_vec_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from copy import deepcopy
3 | from typing import Sequence
4 |
5 | import numpy as np
6 |
7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv
8 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
9 |
10 |
11 | class DummyVecEnv(VecEnv):
12 | """
13 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
14 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
15 | as the overhead of multiprocess or multithread outweighs the environment computation time.
16 | This can also be used for RL methods that
17 | require a vectorized environment, but that you want a single environments to train with.
18 |
19 | :param env_fns: ([Gym Environment]) the list of environments to vectorize
20 | """
21 |
22 | def __init__(self, env_fns):
23 | self.envs = [fn() for fn in env_fns]
24 | env = self.envs[0]
25 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
26 | obs_space = env.observation_space
27 | self.keys, shapes, dtypes = obs_space_info(obs_space)
28 |
29 | self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
30 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
31 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
32 | self.buf_infos = [{} for _ in range(self.num_envs)]
33 | self.actions = None
34 | self.metadata = env.metadata
35 |
36 | def step_async(self, actions):
37 | self.actions = actions
38 |
39 | def step_wait(self):
40 | for env_idx in range(self.num_envs):
41 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
42 | self.actions[env_idx]
43 | )
44 | if self.buf_dones[env_idx]:
45 | # save final observation where user can get it, then reset
46 | self.buf_infos[env_idx]["terminal_observation"] = obs
47 | obs = self.envs[env_idx].reset()
48 | self._save_obs(env_idx, obs)
49 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
50 |
51 | def seed(self, seed=None):
52 | seeds = list()
53 | for idx, env in enumerate(self.envs):
54 | seeds.append(env.seed(seed + idx))
55 | return seeds
56 |
57 | def reset(self):
58 | for env_idx in range(self.num_envs):
59 | obs = self.envs[env_idx].reset()
60 | self._save_obs(env_idx, obs)
61 | return self._obs_from_buf()
62 |
63 | def close(self):
64 | for env in self.envs:
65 | env.close()
66 |
67 | def get_images(self) -> Sequence[np.ndarray]:
68 | return [env.render(mode="rgb_array") for env in self.envs]
69 |
70 | def render(self, mode: str = "human"):
71 | """
72 | Gym environment rendering. If there are multiple environments then
73 | they are tiled together in one image via ``BaseVecEnv.render()``.
74 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
75 | underlying environment.
76 |
77 | Therefore, some arguments such as ``mode`` will have values that are valid
78 | only when ``num_envs == 1``.
79 |
80 | :param mode: The rendering type.
81 | """
82 | if self.num_envs == 1:
83 | return self.envs[0].render(mode=mode)
84 | else:
85 | return super().render(mode=mode)
86 |
87 | def _save_obs(self, env_idx, obs):
88 | for key in self.keys:
89 | if key is None:
90 | self.buf_obs[key][env_idx] = obs
91 | else:
92 | self.buf_obs[key][env_idx] = obs[key]
93 |
94 | def _obs_from_buf(self):
95 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
96 |
97 | def get_attr(self, attr_name, indices=None):
98 | """Return attribute from vectorized environment (see base class)."""
99 | target_envs = self._get_target_envs(indices)
100 | return [getattr(env_i, attr_name) for env_i in target_envs]
101 |
102 | def set_attr(self, attr_name, value, indices=None):
103 | """Set attribute inside vectorized environments (see base class)."""
104 | target_envs = self._get_target_envs(indices)
105 | for env_i in target_envs:
106 | setattr(env_i, attr_name, value)
107 |
108 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
109 | """Call instance methods of vectorized environments."""
110 | target_envs = self._get_target_envs(indices)
111 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
112 |
113 | def _get_target_envs(self, indices):
114 | indices = self._get_indices(indices)
115 | return [self.envs[i] for i in indices]
116 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/bit_flipping_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Optional, Union
3 |
4 | import numpy as np
5 | from gym import GoalEnv, spaces
6 |
7 | from stable_baselines3.common.type_aliases import GymStepReturn
8 |
9 |
10 | class BitFlippingEnv(GoalEnv):
11 | """
12 | Simple bit flipping env, useful to test HER.
13 | The goal is to flip all the bits to get a vector of ones.
14 | In the continuous variant, if the ith action component has a value > 0,
15 | then the ith bit will be flipped.
16 |
17 | :param n_bits: (int) Number of bits to flip
18 | :param continuous: (bool) Whether to use the continuous actions version or not,
19 | by default, it uses the discrete one
20 | :param max_steps: (Optional[int]) Max number of steps, by default, equal to n_bits
21 | :param discrete_obs_space: (bool) Whether to use the discrete observation
22 | version or not, by default, it uses the MultiBinary one
23 | """
24 |
25 | def __init__(
26 | self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False
27 | ):
28 | super(BitFlippingEnv, self).__init__()
29 | # The achieved goal is determined by the current state
30 | # here, it is a special where they are equal
31 | if discrete_obs_space:
32 | # In the discrete case, the agent act on the binary
33 | # representation of the observation
34 | self.observation_space = spaces.Dict(
35 | {
36 | "observation": spaces.Discrete(2 ** n_bits - 1),
37 | "achieved_goal": spaces.Discrete(2 ** n_bits - 1),
38 | "desired_goal": spaces.Discrete(2 ** n_bits - 1),
39 | }
40 | )
41 | else:
42 | self.observation_space = spaces.Dict(
43 | {
44 | "observation": spaces.MultiBinary(n_bits),
45 | "achieved_goal": spaces.MultiBinary(n_bits),
46 | "desired_goal": spaces.MultiBinary(n_bits),
47 | }
48 | )
49 |
50 | self.obs_space = spaces.MultiBinary(n_bits)
51 |
52 | if continuous:
53 | self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
54 | else:
55 | self.action_space = spaces.Discrete(n_bits)
56 | self.continuous = continuous
57 | self.discrete_obs_space = discrete_obs_space
58 | self.state = None
59 | self.desired_goal = np.ones((n_bits,))
60 | if max_steps is None:
61 | max_steps = n_bits
62 | self.max_steps = max_steps
63 | self.current_step = 0
64 | self.reset()
65 |
66 | def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
67 | """
68 | Convert to discrete space if needed.
69 |
70 | :param state: (np.ndarray)
71 | :return: (np.ndarray or int)
72 | """
73 | if self.discrete_obs_space:
74 | # The internal state is the binary representation of the
75 | # observed one
76 | return int(sum([state[i] * 2 ** i for i in range(len(state))]))
77 | return state
78 |
79 | def _get_obs(self) -> OrderedDict:
80 | """
81 | Helper to create the observation.
82 |
83 | :return: (OrderedDict)
84 | """
85 | return OrderedDict(
86 | [
87 | ("observation", self.convert_if_needed(self.state.copy())),
88 | ("achieved_goal", self.convert_if_needed(self.state.copy())),
89 | ("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
90 | ]
91 | )
92 |
93 | def reset(self) -> OrderedDict:
94 | self.current_step = 0
95 | self.state = self.obs_space.sample()
96 | return self._get_obs()
97 |
98 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
99 | if self.continuous:
100 | self.state[action > 0] = 1 - self.state[action > 0]
101 | else:
102 | self.state[action] = 1 - self.state[action]
103 | obs = self._get_obs()
104 | reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None)
105 | done = reward == 0
106 | self.current_step += 1
107 | # Episode terminate when we reached the goal or the max number of steps
108 | info = {"is_success": done}
109 | done = done or self.current_step >= self.max_steps
110 | return obs, reward, done, info
111 |
112 | def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> float:
113 | # Deceptive reward: it is positive only when the goal is achieved
114 | if self.discrete_obs_space:
115 | return 0.0 if achieved_goal == desired_goal else -1.0
116 | return 0.0 if (achieved_goal == desired_goal).all() else -1.0
117 |
118 | def render(self, mode: str = "human") -> Optional[np.ndarray]:
119 | if mode == "rgb_array":
120 | return self.state.copy()
121 | print(self.state)
122 |
123 | def close(self) -> None:
124 | pass
125 |
--------------------------------------------------------------------------------
/stable-baselines3/docs/guide/install.rst:
--------------------------------------------------------------------------------
1 | .. _install:
2 |
3 | Installation
4 | ============
5 |
6 | Prerequisites
7 | -------------
8 |
9 | Stable-Baselines3 requires python 3.6+.
10 |
11 | Windows 10
12 | ~~~~~~~~~~
13 |
14 | We recommend using `Anaconda `_ for Windows users for easier installation of Python packages and required libraries. You need an environment with Python version 3.6 or above.
15 |
16 | For a quick start you can move straight to installing Stable-Baselines3 in the next step.
17 |
18 | .. note::
19 |
20 | Trying to create Atari environments may result to vague errors related to missing DLL files and modules. This is an
21 | issue with atari-py package. `See this discussion for more information `_.
22 |
23 |
24 | Stable Release
25 | ~~~~~~~~~~~~~~
26 | To install Stable Baselines3 with pip, execute:
27 |
28 | .. code-block:: bash
29 |
30 | pip install stable-baselines3[extra]
31 |
32 | This includes an optional dependencies like Tensorboard, OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use:
33 |
34 | .. code-block:: bash
35 |
36 | pip install stable-baselines3
37 |
38 |
39 | Bleeding-edge version
40 | ---------------------
41 |
42 | .. code-block:: bash
43 |
44 | pip install git+https://github.com/DLR-RM/stable-baselines3
45 |
46 |
47 | Development version
48 | -------------------
49 |
50 | To contribute to Stable-Baselines3, with support for running tests and building the documentation.
51 |
52 | .. code-block:: bash
53 |
54 | git clone https://github.com/DLR-RM/stable-baselines3 && cd stable-baselines3
55 | pip install -e .[docs,tests,extra]
56 |
57 |
58 | Using Docker Images
59 | -------------------
60 |
61 | If you are looking for docker images with stable-baselines already installed in it,
62 | we recommend using images from `RL Baselines3 Zoo `_.
63 |
64 | Otherwise, the following images contained all the dependencies for stable-baselines3 but not the stable-baselines3 package itself.
65 | They are made for development.
66 |
67 | Use Built Images
68 | ~~~~~~~~~~~~~~~~
69 |
70 | GPU image (requires `nvidia-docker`_):
71 |
72 | .. code-block:: bash
73 |
74 | docker pull stablebaselines/stable-baselines3
75 |
76 | CPU only:
77 |
78 | .. code-block:: bash
79 |
80 | docker pull stablebaselines/stable-baselines3-cpu
81 |
82 | Build the Docker Images
83 | ~~~~~~~~~~~~~~~~~~~~~~~~
84 |
85 | Build GPU image (with nvidia-docker):
86 |
87 | .. code-block:: bash
88 |
89 | make docker-gpu
90 |
91 | Build CPU image:
92 |
93 | .. code-block:: bash
94 |
95 | make docker-cpu
96 |
97 | Note: if you are using a proxy, you need to pass extra params during
98 | build and do some `tweaks`_:
99 |
100 | .. code-block:: bash
101 |
102 | --network=host --build-arg HTTP_PROXY=http://your.proxy.fr:8080/ --build-arg http_proxy=http://your.proxy.fr:8080/ --build-arg HTTPS_PROXY=https://your.proxy.fr:8080/ --build-arg https_proxy=https://your.proxy.fr:8080/
103 |
104 | Run the images (CPU/GPU)
105 | ~~~~~~~~~~~~~~~~~~~~~~~~
106 |
107 | Run the nvidia-docker GPU image
108 |
109 | .. code-block:: bash
110 |
111 | docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /root/code/stable-baselines3/ && pytest tests/'
112 |
113 | Or, with the shell file:
114 |
115 | .. code-block:: bash
116 |
117 | ./scripts/run_docker_gpu.sh pytest tests/
118 |
119 | Run the docker CPU image
120 |
121 | .. code-block:: bash
122 |
123 | docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /root/code/stable-baselines3/ && pytest tests/'
124 |
125 | Or, with the shell file:
126 |
127 | .. code-block:: bash
128 |
129 | ./scripts/run_docker_cpu.sh pytest tests/
130 |
131 | Explanation of the docker command:
132 |
133 | - ``docker run -it`` create an instance of an image (=container), and
134 | run it interactively (so ctrl+c will work)
135 | - ``--rm`` option means to remove the container once it exits/stops
136 | (otherwise, you will have to use ``docker rm``)
137 | - ``--network host`` don't use network isolation, this allow to use
138 | tensorboard/visdom on host machine
139 | - ``--ipc=host`` Use the host system’s IPC namespace. IPC (POSIX/SysV IPC) namespace provides
140 | separation of named shared memory segments, semaphores and message
141 | queues.
142 | - ``--name test`` give explicitly the name ``test`` to the container,
143 | otherwise it will be assigned a random name
144 | - ``--mount src=...`` give access of the local directory (``pwd``
145 | command) to the container (it will be map to ``/root/code/stable-baselines``), so
146 | all the logs created in the container in this folder will be kept
147 | - ``bash -c '...'`` Run command inside the docker image, here run the tests
148 | (``pytest tests/``)
149 |
150 | .. _nvidia-docker: https://github.com/NVIDIA/nvidia-docker
151 | .. _tweaks: https://stackoverflow.com/questions/23111631/cannot-download-docker-images-behind-a-proxy
152 |
--------------------------------------------------------------------------------
/stable-baselines3/tests/test_envs.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import pytest
4 | from gym import spaces
5 |
6 | from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
7 | from stable_baselines3.common.env_checker import check_env
8 | from stable_baselines3.common.identity_env import (
9 | FakeImageEnv,
10 | IdentityEnv,
11 | IdentityEnvBox,
12 | IdentityEnvMultiBinary,
13 | IdentityEnvMultiDiscrete,
14 | )
15 |
16 | ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv]
17 |
18 |
19 | @pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"])
20 | def test_env(env_id):
21 | """
22 | Check that environmnent integrated in Gym pass the test.
23 |
24 | :param env_id: (str)
25 | """
26 | env = gym.make(env_id)
27 | with pytest.warns(None) as record:
28 | check_env(env)
29 |
30 | # Pendulum-v0 will produce a warning because the action space is
31 | # in [-2, 2] and not [-1, 1]
32 | if env_id == "Pendulum-v0":
33 | assert len(record) == 1
34 | else:
35 | # The other environments must pass without warning
36 | assert len(record) == 0
37 |
38 |
39 | @pytest.mark.parametrize("env_class", ENV_CLASSES)
40 | def test_custom_envs(env_class):
41 | env = env_class()
42 | check_env(env)
43 |
44 |
45 | def test_high_dimension_action_space():
46 | """
47 | Test for continuous action space
48 | with more than one action.
49 | """
50 | env = FakeImageEnv()
51 | # Patch the action space
52 | env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32)
53 |
54 | # Patch to avoid error
55 | def patched_step(_action):
56 | return env.observation_space.sample(), 0.0, False, {}
57 |
58 | env.step = patched_step
59 | check_env(env)
60 |
61 |
62 | @pytest.mark.parametrize(
63 | "new_obs_space",
64 | [
65 | # Small image
66 | spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
67 | # Range not in [0, 255]
68 | spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
69 | # Wrong dtype
70 | spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
71 | # Not an image, it should be a 1D vector
72 | spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
73 | # Tuple space is not supported by SB
74 | spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]),
75 | # Dict space is not supported by SB when env is not a GoalEnv
76 | spaces.Dict({"position": spaces.Discrete(5)}),
77 | ],
78 | )
79 | def test_non_default_spaces(new_obs_space):
80 | env = FakeImageEnv()
81 | env.observation_space = new_obs_space
82 | # Patch methods to avoid errors
83 | env.reset = new_obs_space.sample
84 |
85 | def patched_step(_action):
86 | return new_obs_space.sample(), 0.0, False, {}
87 |
88 | env.step = patched_step
89 | with pytest.warns(UserWarning):
90 | check_env(env)
91 |
92 |
93 | def check_reset_assert_error(env, new_reset_return):
94 | """
95 | Helper to check that the error is caught.
96 | :param env: (gym.Env)
97 | :param new_reset_return: (Any)
98 | """
99 |
100 | def wrong_reset():
101 | return new_reset_return
102 |
103 | # Patch the reset method with a wrong one
104 | env.reset = wrong_reset
105 | with pytest.raises(AssertionError):
106 | check_env(env)
107 |
108 |
109 | def test_common_failures_reset():
110 | """
111 | Test that common failure cases of the `reset_method` are caught
112 | """
113 | env = IdentityEnvBox()
114 | # Return an observation that does not match the observation_space
115 | check_reset_assert_error(env, np.ones((3,)))
116 | # The observation is not a numpy array
117 | check_reset_assert_error(env, 1)
118 |
119 | # Return not only the observation
120 | check_reset_assert_error(env, (env.observation_space.sample(), False))
121 |
122 |
123 | def check_step_assert_error(env, new_step_return=()):
124 | """
125 | Helper to check that the error is caught.
126 | :param env: (gym.Env)
127 | :param new_step_return: (tuple)
128 | """
129 |
130 | def wrong_step(_action):
131 | return new_step_return
132 |
133 | # Patch the step method with a wrong one
134 | env.step = wrong_step
135 | with pytest.raises(AssertionError):
136 | check_env(env)
137 |
138 |
139 | def test_common_failures_step():
140 | """
141 | Test that common failure cases of the `step` method are caught
142 | """
143 | env = IdentityEnvBox()
144 |
145 | # Wrong shape for the observation
146 | check_step_assert_error(env, (np.ones((4,)), 1.0, False, {}))
147 | # Obs is not a numpy array
148 | check_step_assert_error(env, (1, 1.0, False, {}))
149 |
150 | # Return a wrong reward
151 | check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {}))
152 |
153 | # Info dict is not returned
154 | check_step_assert_error(env, (env.observation_space.sample(), 0.0, False))
155 |
156 | # Done is not a boolean
157 | check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {}))
158 | check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))
159 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/preprocessing.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import torch as th
5 | from gym import spaces
6 | from torch.nn import functional as F
7 |
8 |
9 | def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool:
10 | """
11 | Check if a observation space has the shape, limits and dtype
12 | of a valid image.
13 | The check is conservative, so that it returns False
14 | if there is a doubt.
15 |
16 | Valid images: RGB, RGBD, GrayScale with values in [0, 255]
17 |
18 | :param observation_space: (spaces.Space)
19 | :param channels_last: (bool)
20 | :param check_channels: (bool) Whether to do or not the check for the number of channels.
21 | e.g., with frame-stacking, the observation space may have more channels than expected.
22 | :return: (bool)
23 | """
24 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
25 | # Check the type
26 | if observation_space.dtype != np.uint8:
27 | return False
28 |
29 | # Check the value range
30 | if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
31 | return False
32 |
33 | # Skip channels check
34 | if not check_channels:
35 | return True
36 | # Check the number of channels
37 | if channels_last:
38 | n_channels = observation_space.shape[-1]
39 | else:
40 | n_channels = observation_space.shape[0]
41 | # RGB, RGBD, GrayScale
42 | return n_channels in [1, 3, 4]
43 | return False
44 |
45 |
46 | def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor:
47 | """
48 | Preprocess observation to be to a neural network.
49 | For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
50 | For discrete observations, it create a one hot vector.
51 |
52 | :param obs: (th.Tensor) Observation
53 | :param observation_space: (spaces.Space)
54 | :param normalize_images: (bool) Whether to normalize images or not
55 | (True by default)
56 | :return: (th.Tensor)
57 | """
58 | if isinstance(observation_space, spaces.Box):
59 | if is_image_space(observation_space) and normalize_images:
60 | return obs.float() / 255.0
61 | return obs.float()
62 |
63 | elif isinstance(observation_space, spaces.Discrete):
64 | # One hot encoding and convert to float to avoid errors
65 | return F.one_hot(obs.long(), num_classes=observation_space.n).float()
66 |
67 | elif isinstance(observation_space, spaces.MultiDiscrete):
68 | # Tensor concatenation of one hot encodings of each Categorical sub-space
69 | return th.cat(
70 | [
71 | F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
72 | for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
73 | ],
74 | dim=-1,
75 | ).view(obs.shape[0], sum(observation_space.nvec))
76 |
77 | elif isinstance(observation_space, spaces.MultiBinary):
78 | return obs.float()
79 |
80 | else:
81 | raise NotImplementedError()
82 |
83 |
84 | def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
85 | """
86 | Get the shape of the observation (useful for the buffers).
87 |
88 | :param observation_space: (spaces.Space)
89 | :return: (Tuple[int, ...])
90 | """
91 | if isinstance(observation_space, spaces.Box):
92 | return observation_space.shape
93 | elif isinstance(observation_space, spaces.Discrete):
94 | # Observation is an int
95 | return (1,)
96 | elif isinstance(observation_space, spaces.MultiDiscrete):
97 | # Number of discrete features
98 | return (int(len(observation_space.nvec)),)
99 | elif isinstance(observation_space, spaces.MultiBinary):
100 | # Number of binary features
101 | return (int(observation_space.n),)
102 | else:
103 | raise NotImplementedError()
104 |
105 |
106 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
107 | """
108 | Get the dimension of the observation space when flattened.
109 | It does not apply to image observation space.
110 |
111 | :param observation_space: (spaces.Space)
112 | :return: (int)
113 | """
114 | # See issue https://github.com/openai/gym/issues/1915
115 | # it may be a problem for Dict/Tuple spaces too...
116 | if isinstance(observation_space, spaces.MultiDiscrete):
117 | return sum(observation_space.nvec)
118 | else:
119 | # Use Gym internal method
120 | return spaces.utils.flatdim(observation_space)
121 |
122 |
123 | def get_action_dim(action_space: spaces.Space) -> int:
124 | """
125 | Get the dimension of the action space.
126 |
127 | :param action_space: (spaces.Space)
128 | :return: (int)
129 | """
130 | if isinstance(action_space, spaces.Box):
131 | return int(np.prod(action_space.shape))
132 | elif isinstance(action_space, spaces.Discrete):
133 | # Action is an int
134 | return 1
135 | elif isinstance(action_space, spaces.MultiDiscrete):
136 | # Number of discrete actions
137 | return int(len(action_space.nvec))
138 | elif isinstance(action_space, spaces.MultiBinary):
139 | # Number of binary actions
140 | return int(action_space.n)
141 | else:
142 | raise NotImplementedError()
143 |
--------------------------------------------------------------------------------
/stable-baselines3/stable_baselines3/common/identity_env.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import numpy as np
4 | from gym import Env, Space
5 | from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
6 |
7 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
8 |
9 |
10 | class IdentityEnv(Env):
11 | def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
12 | """
13 | Identity environment for testing purposes
14 |
15 | :param dim: the size of the action and observation dimension you want
16 | to learn. Provide at most one of ``dim`` and ``space``. If both are
17 | None, then initialization proceeds with ``dim=1`` and ``space=None``.
18 | :param space: the action and observation space. Provide at most one of
19 | ``dim`` and ``space``.
20 | :param ep_length: the length of each episode in timesteps
21 | """
22 | if space is None:
23 | if dim is None:
24 | dim = 1
25 | space = Discrete(dim)
26 | else:
27 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
28 |
29 | self.action_space = self.observation_space = space
30 | self.ep_length = ep_length
31 | self.current_step = 0
32 | self.num_resets = -1 # Becomes 0 after __init__ exits.
33 | self.reset()
34 |
35 | def reset(self) -> GymObs:
36 | self.current_step = 0
37 | self.num_resets += 1
38 | self._choose_next_state()
39 | return self.state
40 |
41 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
42 | reward = self._get_reward(action)
43 | self._choose_next_state()
44 | self.current_step += 1
45 | done = self.current_step >= self.ep_length
46 | return self.state, reward, done, {}
47 |
48 | def _choose_next_state(self) -> None:
49 | self.state = self.action_space.sample()
50 |
51 | def _get_reward(self, action: Union[int, np.ndarray]) -> float:
52 | return 1.0 if np.all(self.state == action) else 0.0
53 |
54 | def render(self, mode: str = "human") -> None:
55 | pass
56 |
57 |
58 | class IdentityEnvBox(IdentityEnv):
59 | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
60 | """
61 | Identity environment for testing purposes
62 |
63 | :param low: (float) the lower bound of the box dim
64 | :param high: (float) the upper bound of the box dim
65 | :param eps: (float) the epsilon bound for correct value
66 | :param ep_length: (int) the length of each episode in timesteps
67 | """
68 | space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
69 | super().__init__(ep_length=ep_length, space=space)
70 | self.eps = eps
71 |
72 | def step(self, action: np.ndarray) -> GymStepReturn:
73 | reward = self._get_reward(action)
74 | self._choose_next_state()
75 | self.current_step += 1
76 | done = self.current_step >= self.ep_length
77 | return self.state, reward, done, {}
78 |
79 | def _get_reward(self, action: np.ndarray) -> float:
80 | return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
81 |
82 |
83 | class IdentityEnvMultiDiscrete(IdentityEnv):
84 | def __init__(self, dim: int = 1, ep_length: int = 100):
85 | """
86 | Identity environment for testing purposes
87 |
88 | :param dim: (int) the size of the dimensions you want to learn
89 | :param ep_length: (int) the length of each episode in timesteps
90 | """
91 | space = MultiDiscrete([dim, dim])
92 | super().__init__(ep_length=ep_length, space=space)
93 |
94 |
95 | class IdentityEnvMultiBinary(IdentityEnv):
96 | def __init__(self, dim: int = 1, ep_length: int = 100):
97 | """
98 | Identity environment for testing purposes
99 |
100 | :param dim: (int) the size of the dimensions you want to learn
101 | :param ep_length: (int) the length of each episode in timesteps
102 | """
103 | space = MultiBinary(dim)
104 | super().__init__(ep_length=ep_length, space=space)
105 |
106 |
107 | class FakeImageEnv(Env):
108 | """
109 | Fake image environment for testing purposes, it mimics Atari games.
110 |
111 | :param action_dim: (int) Number of discrete actions
112 | :param screen_height: (int) Height of the image
113 | :param screen_width: (int) Width of the image
114 | :param n_channels: (int) Number of color channels
115 | :param discrete: (bool)
116 | """
117 |
118 | def __init__(
119 | self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True
120 | ):
121 |
122 | self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8)
123 | if discrete:
124 | self.action_space = Discrete(action_dim)
125 | else:
126 | self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32)
127 | self.ep_length = 10
128 | self.current_step = 0
129 |
130 | def reset(self) -> np.ndarray:
131 | self.current_step = 0
132 | return self.observation_space.sample()
133 |
134 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
135 | reward = 0.0
136 | self.current_step += 1
137 | done = self.current_step >= self.ep_length
138 | return self.observation_space.sample(), reward, done, {}
139 |
140 | def render(self, mode: str = "human") -> None:
141 | pass
142 |
--------------------------------------------------------------------------------