├── logs └── .keep ├── output └── .keep ├── stable-baselines3 ├── tests │ ├── __init__.py │ ├── test_tensorboard.py │ ├── test_deterministic.py │ ├── test_custom_policy.py │ ├── test_cnn.py │ ├── test_predict.py │ ├── test_vec_check_nan.py │ ├── test_spaces.py │ ├── test_identity.py │ ├── test_callbacks.py │ ├── test_logger.py │ ├── test_sde.py │ ├── test_run.py │ ├── test_monitor.py │ ├── test_distributions.py │ └── test_envs.py ├── .dockerignore ├── stable_baselines3 │ ├── py.typed │ ├── version.txt │ ├── a2c │ │ ├── __init__.py │ │ └── policies.py │ ├── dqn │ │ └── __init__.py │ ├── ppo │ │ ├── __init__.py │ │ └── policies.py │ ├── sac │ │ └── __init__.py │ ├── td3 │ │ └── __init__.py │ ├── ddpg │ │ ├── __init__.py │ │ └── policies.py │ ├── common │ │ ├── __init__.py │ │ ├── type_aliases.py │ │ ├── running_mean_std.py │ │ ├── vec_env │ │ │ ├── __init__.py │ │ │ ├── vec_frame_stack.py │ │ │ ├── vec_transpose.py │ │ │ ├── util.py │ │ │ ├── vec_check_nan.py │ │ │ ├── vec_video_recorder.py │ │ │ └── dummy_vec_env.py │ │ ├── evaluation.py │ │ ├── results_plotter.py │ │ ├── bit_flipping_env.py │ │ ├── preprocessing.py │ │ └── identity_env.py │ └── __init__.py ├── scripts │ ├── run_tests.sh │ ├── run_docker_cpu.sh │ ├── run_docker_gpu.sh │ └── build_docker.sh ├── docs │ ├── common │ │ ├── utils.rst │ │ ├── logger.rst │ │ ├── noise.rst │ │ ├── monitor.rst │ │ ├── cmd_util.rst │ │ ├── evaluation.rst │ │ ├── atari_wrappers.rst │ │ ├── env_checker.rst │ │ └── distributions.rst │ ├── _static │ │ ├── img │ │ │ ├── logo.png │ │ │ ├── mistake.png │ │ │ ├── try_it.png │ │ │ ├── breakout.gif │ │ │ ├── Tensorboard_example.png │ │ │ ├── colab.svg │ │ │ └── colab-badge.svg │ │ └── css │ │ │ └── baselines_theme.css │ ├── guide │ │ ├── migration.rst │ │ ├── rl.rst │ │ ├── quickstart.rst │ │ ├── algos.rst │ │ ├── tensorboard.rst │ │ ├── custom_env.rst │ │ ├── vec_envs.rst │ │ ├── rl_zoo.rst │ │ ├── custom_policy.rst │ │ ├── developer.rst │ │ └── install.rst │ ├── conda_env.yml │ ├── README.md │ ├── Makefile │ ├── modules │ │ ├── base.rst │ │ ├── a2c.rst │ │ ├── dqn.rst │ │ ├── ppo.rst │ │ ├── ddpg.rst │ │ ├── td3.rst │ │ └── sac.rst │ ├── make.bat │ ├── misc │ │ └── projects.rst │ ├── spelling_wordlist.txt │ └── index.rst ├── .coveragerc ├── .gitlab-ci.yml ├── .readthedocs.yml ├── .gitignore ├── LICENSE ├── NOTICE ├── Dockerfile ├── Makefile ├── .github │ ├── workflows │ │ └── ci.yml │ ├── ISSUE_TEMPLATE │ │ └── issue-template.md │ └── PULL_REQUEST_TEMPLATE.md ├── setup.cfg ├── CONTRIBUTING.md └── setup.py ├── .gitmodules ├── bashfiles ├── blocks_train_selfplay.bash ├── hanabi_train_selfplay.bash ├── arms_train_selfplay.bash ├── hanabi_adapt_to_selfplay.bash ├── arms_human_adapt_to_fixed.bash ├── blocks_adapt_to_selfplay.bash ├── arms_adapt_to_selfplay.bash ├── blocks_adapt_to_fixed.bash └── arms_adapt_to_fixed.bash ├── setup.py ├── my_gym ├── envs │ ├── __init__.py │ ├── hanabi_env.py │ ├── arms_human_env.py │ ├── arms_env.py │ └── generate_grid.py └── __init__.py ├── tabular.py ├── README.md ├── partner.py └── run_arms_human.py /logs/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable-baselines3/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable-baselines3/.dockerignore: -------------------------------------------------------------------------------- 1 | .gitignore -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/version.txt: -------------------------------------------------------------------------------- 1 | 0.8.0a5 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hanabi"] 2 | path = hanabi 3 | url = https://github.com/deepmind/hanabi-learning-environment 4 | -------------------------------------------------------------------------------- /bashfiles/blocks_train_selfplay.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | python run_blocks.py --run=$runid --selfplay -------------------------------------------------------------------------------- /bashfiles/hanabi_train_selfplay.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | python run_hanabi.py --run=$runid --selfplay -------------------------------------------------------------------------------- /stable-baselines3/scripts/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v 3 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | Utils 4 | ===== 5 | 6 | .. automodule:: stable_baselines3.common.utils 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/logger.rst: -------------------------------------------------------------------------------- 1 | .. _logger: 2 | 3 | Logger 4 | ====== 5 | 6 | .. automodule:: stable_baselines3.common.logger 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/logo.png -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/mistake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/mistake.png -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/try_it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/try_it.png -------------------------------------------------------------------------------- /stable-baselines3/docs/common/noise.rst: -------------------------------------------------------------------------------- 1 | .. _noise: 2 | 3 | Action Noise 4 | ============= 5 | 6 | .. automodule:: stable_baselines3.common.noise 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.a2c.a2c import A2C 2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy 3 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.dqn.dqn import DQN 2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy 3 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy 2 | from stable_baselines3.ppo.ppo import PPO 3 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy 2 | from stable_baselines3.sac.sac import SAC 3 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy 2 | from stable_baselines3.td3.td3 import TD3 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='my_gym', 4 | version='0.0.1', 5 | install_requires=['gym']#And any other dependencies required 6 | ) -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/breakout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/breakout.gif -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ddpg.ddpg import DDPG 2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy 3 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/monitor.rst: -------------------------------------------------------------------------------- 1 | .. _monitor: 2 | 3 | Monitor Wrapper 4 | =============== 5 | 6 | .. automodule:: stable_baselines3.common.monitor 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/ddpg/policies.py: -------------------------------------------------------------------------------- 1 | # DDPG can be view as a special case of TD3 2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy # noqa:F401 3 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/cmd_util.rst: -------------------------------------------------------------------------------- 1 | .. _cmd_util: 2 | 3 | Command Utils 4 | ========================= 5 | 6 | .. automodule:: stable_baselines3.common.cmd_util 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _eval: 2 | 3 | Evaluation Helper 4 | ================= 5 | 6 | .. automodule:: stable_baselines3.common.evaluation 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/Tensorboard_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stanford-ILIAD/Conventions-ModularPolicy/HEAD/stable-baselines3/docs/_static/img/Tensorboard_example.png -------------------------------------------------------------------------------- /stable-baselines3/docs/common/atari_wrappers.rst: -------------------------------------------------------------------------------- 1 | .. _atari_wrapper: 2 | 3 | Atari Wrappers 4 | ============== 5 | 6 | .. automodule:: stable_baselines3.common.atari_wrappers 7 | :members: 8 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env 2 | from stable_baselines3.common.utils import set_random_seed 3 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/env_checker.rst: -------------------------------------------------------------------------------- 1 | .. _env_checker: 2 | 3 | Gym Environment Checker 4 | ======================== 5 | 6 | .. automodule:: stable_baselines3.common.env_checker 7 | :members: 8 | -------------------------------------------------------------------------------- /bashfiles/arms_train_selfplay.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | m=$2 7 | 8 | for i in {0..9} 9 | do 10 | python run_arms.py --run=$(($runid + $i)) --selfplay --m=${m} 11 | done -------------------------------------------------------------------------------- /my_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from my_gym.envs.blocks_env import BlocksEnv 2 | from my_gym.envs.arms_env import ArmsEnv 3 | from my_gym.envs.hanabi_env import HanabiEnvWrapper 4 | from my_gym.envs.arms_human_env import ArmsHumanEnv -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/migration.rst: -------------------------------------------------------------------------------- 1 | .. _migration: 2 | 3 | ================================ 4 | Migrating from Stable-Baselines 5 | ================================ 6 | 7 | 8 | This is a guide to migrate from Stable-Baselines to Stable-Baselines3. 9 | 10 | It also references the main changes. 11 | 12 | **TODO** 13 | -------------------------------------------------------------------------------- /stable-baselines3/docs/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: root 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - cpuonly=1.0=0 7 | - pip=20.0 8 | - python=3.6 9 | - pytorch=1.5.0=py3.6_cpu_0 10 | - pip: 11 | - gym==0.17.2 12 | - cloudpickle 13 | - opencv-python-headless 14 | - pandas 15 | - numpy 16 | - matplotlib 17 | -------------------------------------------------------------------------------- /stable-baselines3/.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = False 3 | omit = 4 | tests/* 5 | setup.py 6 | # Require graphical interface 7 | stable_baselines3/common/results_plotter.py 8 | # Require ffmpeg 9 | stable_baselines3/common/vec_env/vec_video_recorder.py 10 | 11 | [report] 12 | exclude_lines = 13 | pragma: no cover 14 | raise NotImplementedError() 15 | if typing.TYPE_CHECKING: 16 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/a2c/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for A2C 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | 8 | register_policy("MlpPolicy", ActorCriticPolicy) 9 | register_policy("CnnPolicy", ActorCriticCnnPolicy) 10 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/ppo/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for PPO 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | 8 | register_policy("MlpPolicy", ActorCriticPolicy) 9 | register_policy("CnnPolicy", ActorCriticCnnPolicy) 10 | -------------------------------------------------------------------------------- /stable-baselines3/scripts/run_docker_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Launch an experiment using the docker cpu image 3 | 4 | cmd_line="$@" 5 | 6 | echo "Executing in the docker (cpu image):" 7 | echo $cmd_line 8 | 9 | docker run -it --rm --network host --ipc=host \ 10 | --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ 11 | bash -c "cd /root/code/stable-baselines3/ && $cmd_line" 12 | -------------------------------------------------------------------------------- /stable-baselines3/.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: stablebaselines/stable-baselines3-cpu:0.8.0a4 2 | 3 | type-check: 4 | script: 5 | - make type 6 | 7 | pytest: 8 | script: 9 | # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error 10 | - MKL_THREADING_LAYER=GNU make pytest 11 | 12 | doc-build: 13 | script: 14 | - make doc 15 | 16 | lint-check: 17 | script: 18 | - make check-codestyle 19 | - make lint 20 | -------------------------------------------------------------------------------- /my_gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='blocks-v0', 5 | entry_point='my_gym.envs:BlocksEnv', 6 | ) 7 | 8 | register( 9 | id='arms-v0', 10 | entry_point='my_gym.envs:ArmsEnv', 11 | ) 12 | 13 | register( 14 | id='hanabi-v0', 15 | entry_point='my_gym.envs:HanabiEnvWrapper', 16 | ) 17 | 18 | register( 19 | id='arms-human-v0', 20 | entry_point='my_gym.envs:ArmsHumanEnv', 21 | ) -------------------------------------------------------------------------------- /stable-baselines3/.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | # Optionally build your docs in additional formats such as PDF and ePub 12 | formats: all 13 | 14 | # Set requirements using conda env 15 | conda: 16 | environment: docs/conda_env.yml 17 | -------------------------------------------------------------------------------- /stable-baselines3/docs/README.md: -------------------------------------------------------------------------------- 1 | ## Stable Baselines3 Documentation 2 | 3 | This folder contains documentation for the RL baselines. 4 | 5 | 6 | ### Build the Documentation 7 | 8 | #### Install Sphinx and Theme 9 | 10 | ``` 11 | pip install sphinx sphinx-autobuild sphinx-rtd-theme 12 | ``` 13 | 14 | #### Building the Docs 15 | 16 | In the `docs/` folder: 17 | ``` 18 | make html 19 | ``` 20 | 21 | if you want to building each time a file is changed: 22 | 23 | ``` 24 | sphinx-autobuild . _build/html 25 | ``` 26 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from stable_baselines3.a2c import A2C 4 | from stable_baselines3.ddpg import DDPG 5 | from stable_baselines3.dqn import DQN 6 | from stable_baselines3.ppo import PPO 7 | from stable_baselines3.sac import SAC 8 | from stable_baselines3.td3 import TD3 9 | 10 | # Read version from file 11 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 12 | with open(version_file, "r") as file_handler: 13 | __version__ = file_handler.read().strip() 14 | -------------------------------------------------------------------------------- /stable-baselines3/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.pkl 4 | *.py~ 5 | *.bak 6 | .pytest_cache 7 | .DS_Store 8 | .idea 9 | .coverage 10 | .coverage.* 11 | __pycache__/ 12 | _build/ 13 | *.npz 14 | *.pth 15 | .pytype/ 16 | git_rewrite_commit_history.sh 17 | 18 | # Setuptools distribution and build folders. 19 | /dist/ 20 | /build 21 | keys/ 22 | 23 | # Virtualenv 24 | /env 25 | /venv 26 | 27 | 28 | *.sublime-project 29 | *.sublime-workspace 30 | 31 | .idea 32 | 33 | logs/ 34 | 35 | .ipynb_checkpoints 36 | ghostdriver.log 37 | 38 | htmlcov 39 | 40 | junk 41 | src 42 | 43 | *.egg-info 44 | .cache 45 | *.lprof 46 | *.prof 47 | 48 | MUJOCO_LOG.TXT 49 | -------------------------------------------------------------------------------- /stable-baselines3/scripts/run_docker_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Launch an experiment using the docker gpu image 3 | 4 | cmd_line="$@" 5 | 6 | echo "Executing in the docker (gpu image):" 7 | echo $cmd_line 8 | 9 | # TODO: always use new-style once sufficiently widely used (probably 2021 onwards) 10 | if [ -x "$(which nvidia-docker)" ]; then 11 | # old-style nvidia-docker2 12 | NVIDIA_ARG="--runtime=nvidia" 13 | else 14 | NVIDIA_ARG="--gpus all" 15 | fi 16 | 17 | docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \ 18 | --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ 19 | bash -c "cd /root/code/stable-baselines3/ && $cmd_line" 20 | -------------------------------------------------------------------------------- /stable-baselines3/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W # make warnings fatal 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = StableBaselines 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/base.rst: -------------------------------------------------------------------------------- 1 | .. _base_algo: 2 | 3 | .. automodule:: stable_baselines3.common.base_class 4 | 5 | 6 | Base RL Class 7 | ============= 8 | 9 | Common interface for all the RL algorithms 10 | 11 | .. autoclass:: BaseAlgorithm 12 | :members: 13 | 14 | 15 | .. automodule:: stable_baselines3.common.off_policy_algorithm 16 | 17 | 18 | Base Off-Policy Class 19 | ^^^^^^^^^^^^^^^^^^^^^ 20 | 21 | The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3) 22 | 23 | .. autoclass:: OffPolicyAlgorithm 24 | :members: 25 | 26 | 27 | .. automodule:: stable_baselines3.common.on_policy_algorithm 28 | 29 | 30 | Base On-Policy Class 31 | ^^^^^^^^^^^^^^^^^^^^^ 32 | 33 | The base RL algorithm for On-Policy algorithm (ex: A2C/PPO) 34 | 35 | .. autoclass:: OnPolicyAlgorithm 36 | :members: 37 | -------------------------------------------------------------------------------- /stable-baselines3/scripts/build_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CPU_PARENT=ubuntu:16.04 4 | GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04 5 | 6 | TAG=stablebaselines/stable-baselines3 7 | VERSION=$(cat ./stable_baselines3/version.txt) 8 | 9 | if [[ ${USE_GPU} == "True" ]]; then 10 | PARENT=${GPU_PARENT} 11 | PYTORCH_DEPS="cudatoolkit=10.1" 12 | else 13 | PARENT=${CPU_PARENT} 14 | PYTORCH_DEPS="cpuonly" 15 | TAG="${TAG}-cpu" 16 | fi 17 | 18 | echo "docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} ." 19 | docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} . 20 | docker tag ${TAG}:${VERSION} ${TAG}:latest 21 | 22 | if [[ ${RELEASE} == "True" ]]; then 23 | docker push ${TAG}:${VERSION} 24 | docker push ${TAG}:latest 25 | fi 26 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/rl.rst: -------------------------------------------------------------------------------- 1 | .. _rl: 2 | 3 | ================================ 4 | Reinforcement Learning Resources 5 | ================================ 6 | 7 | 8 | Stable-Baselines3 assumes that you already understand the basic concepts of Reinforcement Learning (RL). 9 | 10 | However, if you want to learn about RL, there are several good resources to get started: 11 | 12 | - `OpenAI Spinning Up `_ 13 | - `David Silver's course `_ 14 | - `Lilian Weng's blog `_ 15 | - `Berkeley's Deep RL Bootcamp `_ 16 | - `Berkeley's Deep Reinforcement Learning course `_ 17 | - `More resources `_ 18 | -------------------------------------------------------------------------------- /stable-baselines3/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=StableBaselines 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/quickstart.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | =============== 4 | Getting Started 5 | =============== 6 | 7 | Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. 8 | 9 | Here is a quick example of how to train and run A2C on a CartPole environment: 10 | 11 | .. code-block:: python 12 | 13 | import gym 14 | 15 | from stable_baselines3 import A2C 16 | 17 | env = gym.make('CartPole-v1') 18 | 19 | model = A2C('MlpPolicy', env, verbose=1) 20 | model.learn(total_timesteps=10000) 21 | 22 | obs = env.reset() 23 | for i in range(1000): 24 | action, _state = model.predict(obs, deterministic=True) 25 | obs, reward, done, info = env.step(action) 26 | env.render() 27 | if done: 28 | obs = env.reset() 29 | 30 | 31 | Or just train a model with a one liner if 32 | `the environment is registered in Gym `_ and if 33 | the policy is registered: 34 | 35 | .. code-block:: python 36 | 37 | from stable_baselines3 import A2C 38 | 39 | model = A2C('MlpPolicy', 'CartPole-v1').learn(10000) 40 | -------------------------------------------------------------------------------- /stable-baselines3/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019 Antonin Raffin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/css/baselines_theme.css: -------------------------------------------------------------------------------- 1 | /* Main colors adapted from pytorch doc */ 2 | :root{ 3 | --main-bg-color: #343A40; 4 | --link-color: #FD7E14; 5 | } 6 | 7 | /* Header fonts y */ 8 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 9 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 10 | } 11 | 12 | 13 | /* Docs background */ 14 | .wy-side-nav-search{ 15 | background-color: var(--main-bg-color); 16 | } 17 | 18 | /* Mobile version */ 19 | .wy-nav-top{ 20 | background-color: var(--main-bg-color); 21 | } 22 | 23 | /* Change link colors (except for the menu) */ 24 | a { 25 | color: var(--link-color); 26 | } 27 | 28 | a:hover { 29 | color: #4F778F; 30 | } 31 | 32 | .wy-menu a { 33 | color: #b3b3b3; 34 | } 35 | 36 | .wy-menu a:hover { 37 | color: #b3b3b3; 38 | } 39 | 40 | a.icon.icon-home { 41 | color: #b3b3b3; 42 | } 43 | 44 | .version{ 45 | color: var(--link-color) !important; 46 | } 47 | 48 | 49 | /* Make code blocks have a background */ 50 | .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { 51 | background: #f8f8f8;; 52 | } 53 | -------------------------------------------------------------------------------- /stable-baselines3/docs/common/distributions.rst: -------------------------------------------------------------------------------- 1 | .. _distributions: 2 | 3 | Probability Distributions 4 | ========================= 5 | 6 | Probability distributions used for the different action spaces: 7 | 8 | - ``CategoricalDistribution`` -> Discrete 9 | - ``DiagGaussianDistribution`` -> Box (continuous actions) 10 | - ``StateDependentNoiseDistribution`` -> Box (continuous actions) when ``use_sde=True`` 11 | 12 | .. - ``MultiCategoricalDistribution`` -> MultiDiscrete 13 | .. - ``BernoulliDistribution`` -> MultiBinary 14 | 15 | The policy networks output parameters for the distributions (named ``flat`` in the methods). 16 | Actions are then sampled from those distributions. 17 | 18 | For instance, in the case of discrete actions. The policy network outputs probability 19 | of taking each action. The ``CategoricalDistribution`` allows to sample from it, 20 | computes the entropy, the log probability (``log_prob``) and backpropagate the gradient. 21 | 22 | In the case of continuous actions, a Gaussian distribution is used. The policy network outputs 23 | mean and (log) std of the distribution (assumed to be a ``DiagGaussianDistribution``). 24 | 25 | .. automodule:: stable_baselines3.common.distributions 26 | :members: 27 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from stable_baselines3 import A2C, PPO, SAC, TD3 6 | 7 | MODEL_DICT = { 8 | "a2c": (A2C, "CartPole-v1"), 9 | "ppo": (PPO, "CartPole-v1"), 10 | "sac": (SAC, "Pendulum-v0"), 11 | "td3": (TD3, "Pendulum-v0"), 12 | } 13 | 14 | N_STEPS = 100 15 | 16 | 17 | @pytest.mark.parametrize("model_name", MODEL_DICT.keys()) 18 | def test_tensorboard(tmp_path, model_name): 19 | # Skip if no tensorboard installed 20 | pytest.importorskip("tensorboard") 21 | 22 | logname = model_name.upper() 23 | algo, env_id = MODEL_DICT[model_name] 24 | model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path) 25 | model.learn(N_STEPS) 26 | model.learn(N_STEPS, reset_num_timesteps=False) 27 | 28 | assert os.path.isdir(tmp_path / str(logname + "_1")) 29 | assert not os.path.isdir(tmp_path / str(logname + "_2")) 30 | 31 | logname = "tb_multiple_runs_" + model_name 32 | model.learn(N_STEPS, tb_log_name=logname) 33 | model.learn(N_STEPS, tb_log_name=logname) 34 | 35 | assert os.path.isdir(tmp_path / str(logname + "_1")) 36 | # Check that the log dir name increments correctly 37 | assert os.path.isdir(tmp_path / str(logname + "_2")) 38 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/type_aliases.py: -------------------------------------------------------------------------------- 1 | """Common aliases for type hints""" 2 | 3 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union 4 | 5 | import gym 6 | import numpy as np 7 | import torch as th 8 | 9 | from stable_baselines3.common.callbacks import BaseCallback 10 | from stable_baselines3.common.vec_env import VecEnv 11 | 12 | GymEnv = Union[gym.Env, VecEnv] 13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] 14 | GymStepReturn = Tuple[GymObs, float, bool, Dict] 15 | TensorDict = Dict[str, th.Tensor] 16 | OptimizerStateDict = Dict[str, Any] 17 | MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback] 18 | 19 | 20 | class RolloutBufferSamples(NamedTuple): 21 | observations: th.Tensor 22 | actions: th.Tensor 23 | old_values: th.Tensor 24 | old_log_prob: th.Tensor 25 | advantages: th.Tensor 26 | returns: th.Tensor 27 | 28 | 29 | class ReplayBufferSamples(NamedTuple): 30 | observations: th.Tensor 31 | actions: th.Tensor 32 | next_observations: th.Tensor 33 | dones: th.Tensor 34 | rewards: th.Tensor 35 | 36 | 37 | class RolloutReturn(NamedTuple): 38 | episode_reward: float 39 | episode_timesteps: int 40 | n_episodes: int 41 | continue_training: bool 42 | -------------------------------------------------------------------------------- /my_gym/envs/hanabi_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error, spaces, utils 3 | from hanabi_learning_environment.rl_env import HanabiEnv 4 | 5 | class HanabiEnvWrapper(HanabiEnv, gym.Env): 6 | metadata = {'render.modes': ['human']} 7 | def __init__(self, config): 8 | self.config = config 9 | super(HanabiEnvWrapper, self).__init__(config=self.config) 10 | 11 | observation_shape = super().vectorized_observation_shape() 12 | self.observation_space = spaces.MultiBinary(observation_shape[0]) 13 | self.action_space = spaces.Discrete(self.game.max_moves()) 14 | 15 | def reset(self): 16 | obs = super().reset() 17 | obs = obs['player_observations'][obs['current_player']]['vectorized'] 18 | return obs 19 | 20 | def step(self, action): 21 | # action is a integer from 0 to self.action_space 22 | # we map it to one of the legal moves 23 | # the legal move array may be too small in some cases, so just modulo action by the array length 24 | legal_moves = self.state.legal_moves() 25 | move = legal_moves[action % len(legal_moves)].to_dict() 26 | 27 | obs, reward, done, info = super().step(move) 28 | obs = obs['player_observations'][obs['current_player']]['vectorized'] 29 | return obs, reward, done, info -------------------------------------------------------------------------------- /stable-baselines3/tests/test_deterministic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 4 | from stable_baselines3.common.noise import NormalActionNoise 5 | 6 | N_STEPS_TRAINING = 3000 7 | SEED = 0 8 | 9 | 10 | @pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3]) 11 | def test_deterministic_training_common(algo): 12 | results = [[], []] 13 | rewards = [[], []] 14 | # Smaller network 15 | kwargs = {"policy_kwargs": dict(net_arch=[64])} 16 | if algo in [TD3, SAC]: 17 | env_id = "Pendulum-v0" 18 | kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100}) 19 | else: 20 | env_id = "CartPole-v1" 21 | if algo == DQN: 22 | kwargs.update({"learning_starts": 100}) 23 | 24 | for i in range(2): 25 | model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) 26 | model.learn(N_STEPS_TRAINING) 27 | env = model.get_env() 28 | obs = env.reset() 29 | for _ in range(100): 30 | action, _ = model.predict(obs, deterministic=False) 31 | obs, reward, _, _ = env.step(action) 32 | results[i].append(action) 33 | rewards[i].append(reward) 34 | assert sum(results[0]) == sum(results[1]), results 35 | assert sum(rewards[0]) == sum(rewards[1]), rewards 36 | -------------------------------------------------------------------------------- /my_gym/envs/arms_human_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error, spaces, utils 3 | from gym.utils import seeding 4 | 5 | import numpy as np 6 | import sys 7 | 8 | class ArmsHumanEnv(gym.Env): 9 | """ 10 | Two player game. 11 | There are N (say 4) possible arms to pull. 12 | 0 1 2 3 13 | Some of the arms are forbidden (determined by the state) 14 | Goal is to pull same arm. 15 | """ 16 | 17 | def __init__(self): 18 | super(ArmsHumanEnv, self).__init__() 19 | self.n = 3 20 | self.a = 4 21 | self.action_space = spaces.MultiDiscrete([self.a, self.a]) 22 | self.observation_space = spaces.MultiDiscrete([self.n]) 23 | self.reset() 24 | 25 | def step(self, a): 26 | context = self.state[0] 27 | match = (a[0] == a[1]) 28 | green = [ 29 | [1,0,0,0], 30 | [0,0,1,1], 31 | [0,1,0,1], 32 | ] 33 | correct = match and green[context][a[0]] 34 | self.reward = (int)(correct) 35 | return [self.state, self.reward, True, {}] 36 | 37 | def reset(self, state=None): 38 | if state: 39 | self.state = state 40 | else: 41 | self.state = [np.random.randint(self.n)] 42 | 43 | return self.state 44 | 45 | def render(self): 46 | print(self.n, self.state) -------------------------------------------------------------------------------- /stable-baselines3/tests/test_custom_policy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch as th 3 | 4 | from stable_baselines3 import A2C, PPO, SAC, TD3 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "net_arch", 9 | [ 10 | [12, dict(vf=[16], pi=[8])], 11 | [4], 12 | [], 13 | [4, 4], 14 | [12, dict(vf=[8, 4], pi=[8])], 15 | [12, dict(vf=[8], pi=[8, 4])], 16 | [12, dict(pi=[8])], 17 | ], 18 | ) 19 | @pytest.mark.parametrize("model_class", [A2C, PPO]) 20 | def test_flexible_mlp(model_class, net_arch): 21 | _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) 22 | 23 | 24 | @pytest.mark.parametrize("net_arch", [[4], [4, 4],]) 25 | @pytest.mark.parametrize("model_class", [SAC, TD3]) 26 | def test_custom_offpolicy(model_class, net_arch): 27 | _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000) 28 | 29 | 30 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) 31 | @pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)]) 32 | def test_custom_optimizer(model_class, optimizer_kwargs): 33 | policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) 34 | _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000) 35 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 7 | from stable_baselines3.common.identity_env import FakeImageEnv 8 | 9 | 10 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) 11 | def test_cnn(tmp_path, model_class): 12 | SAVE_NAME = "cnn_model.zip" 13 | # Fake grayscale with frameskip 14 | # Atari after preprocessing: 84x84x1, here we are using lower resolution 15 | # to check that the network handle it automatically 16 | env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3}) 17 | if model_class in {A2C, PPO}: 18 | kwargs = dict(n_steps=100) 19 | else: 20 | # Avoid memory error when using replay buffer 21 | # Reduce the size of the features 22 | kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) 23 | model = model_class("CnnPolicy", env, **kwargs).learn(250) 24 | 25 | obs = env.reset() 26 | 27 | action, _ = model.predict(obs, deterministic=True) 28 | 29 | model.save(tmp_path / SAVE_NAME) 30 | del model 31 | 32 | model = model_class.load(tmp_path / SAVE_NAME) 33 | 34 | # Check that the prediction is the same 35 | assert np.allclose(action, model.predict(obs, deterministic=True)[0]) 36 | 37 | os.remove(str(tmp_path / SAVE_NAME)) 38 | -------------------------------------------------------------------------------- /stable-baselines3/NOTICE: -------------------------------------------------------------------------------- 1 | Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines, 2 | both licensed under the MIT License: 3 | 4 | before the fork (June 2018): 5 | Copyright (c) 2017 OpenAI (http://openai.com) 6 | 7 | after the fork (June 2018): 8 | Copyright (c) 2018-2019 Stable-Baselines Team 9 | 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in 19 | all copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | THE SOFTWARE. 28 | -------------------------------------------------------------------------------- /stable-baselines3/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PARENT_IMAGE 2 | FROM $PARENT_IMAGE 3 | ARG PYTORCH_DEPS=cpuonly 4 | ARG PYTHON_VERSION=3.6 5 | 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | cmake \ 9 | git \ 10 | curl \ 11 | ca-certificates \ 12 | libjpeg-dev \ 13 | libpng-dev \ 14 | libglib2.0-0 && \ 15 | rm -rf /var/lib/apt/lists/* 16 | 17 | # Install anaconda abd dependencies 18 | RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 19 | chmod +x ~/miniconda.sh && \ 20 | ~/miniconda.sh -b -p /opt/conda && \ 21 | rm ~/miniconda.sh && \ 22 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \ 23 | /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \ 24 | /opt/conda/bin/conda clean -ya 25 | ENV PATH /opt/conda/bin:$PATH 26 | 27 | ENV CODE_DIR /root/code 28 | 29 | # Copy setup file only to install dependencies 30 | COPY ./setup.py ${CODE_DIR}/stable-baselines3/setup.py 31 | COPY ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt 32 | 33 | RUN \ 34 | cd ${CODE_DIR}/stable-baselines3 3&& \ 35 | pip install -e .[extra,tests,docs] && \ 36 | # Use headless version for docker 37 | pip uninstall -y opencv-python && \ 38 | pip install opencv-python-headless && \ 39 | rm -rf $HOME/.cache/pip 40 | 41 | CMD /bin/bash 42 | -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/colab.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /stable-baselines3/Makefile: -------------------------------------------------------------------------------- 1 | SHELL=/bin/bash 2 | LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py 3 | 4 | pytest: 5 | ./scripts/run_tests.sh 6 | 7 | type: 8 | pytype 9 | 10 | lint: 11 | # stop the build if there are Python syntax errors or undefined names 12 | # see https://lintlyci.github.io/Flake8Rules/ 13 | flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics 14 | # exit-zero treats all errors as warnings. 15 | flake8 ${LINT_PATHS} --count --exit-zero --statistics 16 | 17 | format: 18 | # Sort imports 19 | isort ${LINT_PATHS} 20 | # Reformat using black 21 | black -l 127 ${LINT_PATHS} 22 | 23 | check-codestyle: 24 | # Sort imports 25 | isort --check ${LINT_PATHS} 26 | # Reformat using black 27 | black --check -l 127 ${LINT_PATHS} 28 | 29 | commit-checks: format type lint 30 | 31 | doc: 32 | cd docs && make html 33 | 34 | spelling: 35 | cd docs && make spelling 36 | 37 | clean: 38 | cd docs && make clean 39 | 40 | # Build docker images 41 | # If you do export RELEASE=True, it will also push them 42 | docker: docker-cpu docker-gpu 43 | 44 | docker-cpu: 45 | ./scripts/build_docker.sh 46 | 47 | docker-gpu: 48 | USE_GPU=True ./scripts/build_docker.sh 49 | 50 | # PyPi package release 51 | release: 52 | python setup.py sdist 53 | python setup.py bdist_wheel 54 | twine upload dist/* 55 | 56 | # Test PyPi package release 57 | test-release: 58 | python setup.py sdist 59 | python setup.py bdist_wheel 60 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 61 | 62 | .PHONY: clean spelling doc lint format check-codestyle commit-checks 63 | -------------------------------------------------------------------------------- /stable-baselines3/docs/misc/projects.rst: -------------------------------------------------------------------------------- 1 | .. _projects: 2 | 3 | Projects 4 | ========= 5 | 6 | This is a list of projects using stable-baselines3. 7 | Please tell us, if you want your project to appear on this page ;) 8 | 9 | 10 | .. RL Racing Robot 11 | .. -------------------------- 12 | .. Implementation of reinforcement learning approach to make a donkey car learn to race. 13 | .. Uses SAC on autoencoder features 14 | .. 15 | .. | Author: Antonin Raffin (@araffin) 16 | .. | Github repo: https://github.com/araffin/RL-Racing-Robot 17 | 18 | 19 | Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics 20 | ----------------------------------------------------------------------------------- 21 | 22 | An exploration method to train RL agent directly on real robots. 23 | It was the starting point of Stable-Baselines3. 24 | 25 | | Author: Antonin Raffin, Freek Stulp 26 | | Github: https://github.com/DLR-RM/stable-baselines3/tree/sde 27 | | Paper: https://arxiv.org/abs/2005.05719 28 | 29 | Reacher 30 | ------- 31 | A solution to the second project of the Udacity deep reinforcement learning course. 32 | It is an example of: 33 | 34 | - wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3 35 | - creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is https://github.com/DLR-RM/rl-baselines3-zoo) 36 | - generating several pre-trained models which solve the reacher environment 37 | 38 | | Author: Marios Koulakis 39 | | Github: https://github.com/koulakis/reacher-deep-reinforcement-learning 40 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | class RunningMeanStd(object): 7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): 8 | """ 9 | Calulates the running mean and std of a data stream 10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 11 | 12 | :param epsilon: (float) helps with arithmetic issues 13 | :param shape: (tuple) the shape of the data stream's output 14 | """ 15 | self.mean = np.zeros(shape, np.float64) 16 | self.var = np.ones(shape, np.float64) 17 | self.count = epsilon 18 | 19 | def update(self, arr: np.ndarray) -> None: 20 | batch_mean = np.mean(arr, axis=0) 21 | batch_var = np.var(arr, axis=0) 22 | batch_count = arr.shape[0] 23 | self.update_from_moments(batch_mean, batch_var, batch_count) 24 | 25 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None: 26 | delta = batch_mean - self.mean 27 | tot_count = self.count + batch_count 28 | 29 | new_mean = self.mean + delta * batch_count / tot_count 30 | m_a = self.var * self.count 31 | m_b = batch_var * batch_count 32 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 33 | new_var = m_2 / (self.count + batch_count) 34 | 35 | new_count = batch_count + self.count 36 | 37 | self.mean = new_mean 38 | self.var = new_var 39 | self.count = new_count 40 | -------------------------------------------------------------------------------- /stable-baselines3/.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | # Skip CI if [ci skip] in the commit message 15 | if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: [3.6, 3.7] # 3.8 not supported yet by pytype 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | # cpu version of pytorch 31 | pip install torch==1.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 32 | pip install .[extra,tests,docs] 33 | # Use headless version 34 | pip install opencv-python-headless 35 | - name: Build the doc 36 | run: | 37 | make doc 38 | - name: Type check 39 | run: | 40 | make type 41 | - name: Check codestyle 42 | run: | 43 | make check-codestyle 44 | - name: Lint with flake8 45 | run: | 46 | make lint 47 | - name: Test with pytest 48 | run: | 49 | make pytest 50 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_predict.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | 4 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 5 | from stable_baselines3.common.vec_env import DummyVecEnv 6 | 7 | MODEL_LIST = [ 8 | PPO, 9 | A2C, 10 | TD3, 11 | SAC, 12 | DQN, 13 | ] 14 | 15 | 16 | @pytest.mark.parametrize("model_class", MODEL_LIST) 17 | def test_auto_wrap(model_class): 18 | # test auto wrapping of env into a VecEnv 19 | 20 | # Use different environment for DQN 21 | if model_class is DQN: 22 | env_name = "CartPole-v0" 23 | else: 24 | env_name = "Pendulum-v0" 25 | env = gym.make(env_name) 26 | eval_env = gym.make(env_name) 27 | model = model_class("MlpPolicy", env) 28 | model.learn(100, eval_env=eval_env) 29 | 30 | 31 | @pytest.mark.parametrize("model_class", MODEL_LIST) 32 | @pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) 33 | def test_predict(model_class, env_id): 34 | if env_id == "CartPole-v1": 35 | if model_class in [SAC, TD3]: 36 | return 37 | elif model_class in [DQN]: 38 | return 39 | 40 | # test detection of different shapes by the predict method 41 | model = model_class("MlpPolicy", env_id) 42 | env = gym.make(env_id) 43 | vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) 44 | 45 | obs = env.reset() 46 | action, _ = model.predict(obs) 47 | assert action.shape == env.action_space.shape 48 | assert env.action_space.contains(action) 49 | 50 | vec_env_obs = vec_env.reset() 51 | action, _ = model.predict(vec_env_obs) 52 | assert action.shape[0] == vec_env_obs.shape[0] 53 | -------------------------------------------------------------------------------- /tabular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | def tabular_q_learning(env, discount_rate=0.01, step_size=0.01, eps=1e-4): 5 | action_space = env.action_space.nvec # multi-discrete 6 | state_space = env.observation_space.nvec # multi-discrete 7 | q_values = np.zeros( np.concatenate((state_space, action_space)) , np.float64) 8 | 9 | for _ in range(1000): 10 | for state_action, q_val in np.ndenumerate(q_values): 11 | state, action = state_action[:len(state_space)], state_action[len(state_space):] 12 | #print("state, action: ", state, action) 13 | 14 | env.reset(list(state)) 15 | next_state, reward, done, _ = env.step(action) 16 | 17 | delta = reward - q_values[state_action] 18 | if not done: 19 | next_state = tuple(next_state) 20 | delta = reward + discount_rate * q_values[next_state].max() - q_values[state_action] 21 | q_values[state_action] += step_size * delta 22 | 23 | q_values = th.tensor(q_values) 24 | 25 | maxout_actions = th.max(th.max(q_values, dim=-2, keepdim=True).values, dim=-1, keepdim=True).values 26 | maxout_action2 = th.max(q_values, dim=-1, keepdim=True).values 27 | maxout_action1 = th.max(q_values, dim=-2, keepdim=True).values 28 | 29 | optimal_action1_mask = (maxout_action2 - maxout_actions > -eps).squeeze(-1) 30 | optimal_action2_mask = (maxout_action1 - maxout_actions > -eps).squeeze(-2) 31 | 32 | # print(q_values) 33 | print(optimal_action1_mask) 34 | print(optimal_action2_mask) 35 | return q_values, optimal_action1_mask, optimal_action2_mask 36 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import pytest 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan 7 | 8 | 9 | class NanAndInfEnv(gym.Env): 10 | """Custom Environment that raised NaNs and Infs""" 11 | 12 | metadata = {"render.modes": ["human"]} 13 | 14 | def __init__(self): 15 | super(NanAndInfEnv, self).__init__() 16 | self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) 17 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) 18 | 19 | @staticmethod 20 | def step(action): 21 | if np.all(np.array(action) > 0): 22 | obs = float("NaN") 23 | elif np.all(np.array(action) < 0): 24 | obs = float("inf") 25 | else: 26 | obs = 0 27 | return [obs], 0.0, False, {} 28 | 29 | @staticmethod 30 | def reset(): 31 | return [0.0] 32 | 33 | def render(self, mode="human", close=False): 34 | pass 35 | 36 | 37 | def test_check_nan(): 38 | """Test VecCheckNan Object""" 39 | 40 | env = DummyVecEnv([NanAndInfEnv]) 41 | env = VecCheckNan(env, raise_exception=True) 42 | 43 | env.step([[0]]) 44 | 45 | with pytest.raises(ValueError): 46 | env.step([[float("NaN")]]) 47 | 48 | with pytest.raises(ValueError): 49 | env.step([[float("inf")]]) 50 | 51 | with pytest.raises(ValueError): 52 | env.step([[-1]]) 53 | 54 | with pytest.raises(ValueError): 55 | env.step([[1]]) 56 | 57 | env.step(np.array([[0, 1], [0, 1]])) 58 | 59 | env.reset() 60 | -------------------------------------------------------------------------------- /stable-baselines3/docs/spelling_wordlist.txt: -------------------------------------------------------------------------------- 1 | py 2 | env 3 | atari 4 | argparse 5 | Argparse 6 | TensorFlow 7 | feedforward 8 | envs 9 | VecEnv 10 | pretrain 11 | petrained 12 | tf 13 | th 14 | nn 15 | np 16 | str 17 | mujoco 18 | cpu 19 | ndarray 20 | ndarrays 21 | timestep 22 | timesteps 23 | stepsize 24 | dataset 25 | adam 26 | fn 27 | normalisation 28 | Kullback 29 | Leibler 30 | boolean 31 | deserialized 32 | pretrained 33 | minibatch 34 | subprocesses 35 | ArgumentParser 36 | Tensorflow 37 | Gaussian 38 | approximator 39 | minibatches 40 | hyperparameters 41 | hyperparameter 42 | vectorized 43 | rl 44 | colab 45 | dataloader 46 | npz 47 | datasets 48 | vf 49 | logits 50 | num 51 | Utils 52 | backpropagate 53 | prepend 54 | NaN 55 | preprocessing 56 | Cloudpickle 57 | async 58 | multiprocess 59 | tensorflow 60 | mlp 61 | cnn 62 | neglogp 63 | tanh 64 | coef 65 | repo 66 | Huber 67 | params 68 | ppo 69 | arxiv 70 | Arxiv 71 | func 72 | DQN 73 | Uhlenbeck 74 | Ornstein 75 | multithread 76 | cancelled 77 | Tensorboard 78 | parallelize 79 | customising 80 | serializable 81 | Multiprocessed 82 | cartpole 83 | toolset 84 | lstm 85 | rescale 86 | ffmpeg 87 | avconv 88 | unnormalized 89 | Github 90 | pre 91 | preprocess 92 | backend 93 | attr 94 | preprocess 95 | Antonin 96 | Raffin 97 | araffin 98 | Homebrew 99 | Numpy 100 | Theano 101 | rollout 102 | kfac 103 | Piecewise 104 | csv 105 | nvidia 106 | visdom 107 | tensorboard 108 | preprocessed 109 | namespace 110 | sklearn 111 | GoalEnv 112 | Torchy 113 | pytorch 114 | dicts 115 | optimizers 116 | Deprecations 117 | forkserver 118 | cuda 119 | Polyak 120 | gSDE 121 | rollouts 122 | -------------------------------------------------------------------------------- /my_gym/envs/arms_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error, spaces, utils 3 | from gym.utils import seeding 4 | 5 | import numpy as np 6 | import sys 7 | 8 | class ArmsEnv(gym.Env): 9 | """ 10 | Two player game. 11 | There are N (say 4) possible arms to pull. 12 | 0 1 2 3 13 | Some of the arms are forbidden (determined by the state) 14 | Goal is to pull same arm. 15 | """ 16 | 17 | def __init__(self, n, m): 18 | super(ArmsEnv, self).__init__() 19 | self.n = n 20 | self.m = m # number of contexts with hard rules 21 | self.action_space = spaces.MultiDiscrete([2*n, 2*n]) 22 | self.observation_space = spaces.MultiDiscrete([n]) 23 | self.reset() 24 | 25 | self.invert = False 26 | 27 | def step(self, a): 28 | context = self.state[0] 29 | match = (a[0] == a[1]) 30 | 31 | if context < self.m: 32 | if not self.invert: 33 | correct = match and a[0] == context 34 | else: 35 | correct = match and a[0] - self.n == context 36 | else: 37 | correct = match and a[0]%self.n == context # action mod n equals context 38 | 39 | self.reward = (int)(correct) 40 | return [self.state, self.reward, True, {}] 41 | 42 | def reset(self, state=None): 43 | if not hasattr(self, 'last'): self.last = 0 44 | 45 | self.rep = 0 46 | if state: 47 | self.state = state 48 | else: 49 | self.state = [np.random.randint(self.n)] 50 | 51 | return self.state 52 | 53 | def render(self): 54 | print(self.n, self.state) 55 | 56 | def set_invert(self, invert): 57 | self.invert = invert -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/a2c.rst: -------------------------------------------------------------------------------- 1 | .. _a2c: 2 | 3 | .. automodule:: stable_baselines3.a2c 4 | 5 | 6 | A2C 7 | ==== 8 | 9 | A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3C) `_. 10 | It uses multiple workers to avoid the use of a replay buffer. 11 | 12 | 13 | Notes 14 | ----- 15 | 16 | - Original paper: https://arxiv.org/abs/1602.01783 17 | - OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/ 18 | 19 | 20 | Can I use? 21 | ---------- 22 | 23 | - Recurrent policies: ✔️ 24 | - Multi processing: ✔️ 25 | - Gym spaces: 26 | 27 | 28 | ============= ====== =========== 29 | Space Action Observation 30 | ============= ====== =========== 31 | Discrete ✔️ ✔️ 32 | Box ✔️ ✔️ 33 | MultiDiscrete ✔️ ✔️ 34 | MultiBinary ✔️ ✔️ 35 | ============= ====== =========== 36 | 37 | 38 | Example 39 | ------- 40 | 41 | Train a A2C agent on ``CartPole-v1`` using 4 environments. 42 | 43 | .. code-block:: python 44 | 45 | import gym 46 | 47 | from stable_baselines3 import A2C 48 | from stable_baselines3.a2c import MlpPolicy 49 | from stable_baselines3.common.cmd_util import make_vec_env 50 | 51 | # Parallel environments 52 | env = make_vec_env('CartPole-v1', n_envs=4) 53 | 54 | model = A2C(MlpPolicy, env, verbose=1) 55 | model.learn(total_timesteps=25000) 56 | model.save("a2c_cartpole") 57 | 58 | del model # remove to demonstrate saving and loading 59 | 60 | model = A2C.load("a2c_cartpole") 61 | 62 | obs = env.reset() 63 | while True: 64 | action, _states = model.predict(obs) 65 | obs, rewards, dones, info = env.step(action) 66 | env.render() 67 | 68 | Parameters 69 | ---------- 70 | 71 | .. autoclass:: A2C 72 | :members: 73 | :inherited-members: 74 | -------------------------------------------------------------------------------- /stable-baselines3/.github/ISSUE_TEMPLATE/issue-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Issue Template 3 | about: How to create an issue for this repository 4 | 5 | --- 6 | 7 | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. 8 | Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case. 9 | 10 | If you have any questions, feel free to create an issue with the tag [question]. 11 | If you wish to suggest an enhancement or feature request, add the tag [feature request]. 12 | If you are submitting a bug report, please fill in the following details. 13 | 14 | If your issue is related to a custom gym environment, please check it first using: 15 | 16 | ```python 17 | from stable_baselines3.common.env_checker import check_env 18 | 19 | env = CustomEnv(arg1, ...) 20 | # It will check your custom environment and output additional warnings if needed 21 | check_env(env) 22 | ``` 23 | 24 | **Describe the bug** 25 | A clear and concise description of what the bug is. 26 | 27 | **Code example** 28 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. 29 | 30 | Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) 31 | for both code and stack traces. 32 | 33 | ```python 34 | from stable_baselines3 import ... 35 | 36 | ``` 37 | 38 | ```bash 39 | Traceback (most recent call last): File ... 40 | 41 | ``` 42 | 43 | **System Info** 44 | Describe the characteristic of your environment: 45 | * Describe how the library was installed (pip, docker, source, ...) 46 | * GPU models and configuration 47 | * Python version 48 | * PyTorch version 49 | * Gym version 50 | * Versions of any other relevant libraries 51 | 52 | **Additional context** 53 | Add any other context about the problem here. 54 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_spaces.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import pytest 4 | 5 | from stable_baselines3 import DQN, SAC, TD3 6 | from stable_baselines3.common.evaluation import evaluate_policy 7 | 8 | 9 | class DummyMultiDiscreteSpace(gym.Env): 10 | def __init__(self, nvec): 11 | super(DummyMultiDiscreteSpace, self).__init__() 12 | self.observation_space = gym.spaces.MultiDiscrete(nvec) 13 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) 14 | 15 | def reset(self): 16 | return self.observation_space.sample() 17 | 18 | def step(self, action): 19 | return self.observation_space.sample(), 0.0, False, {} 20 | 21 | 22 | class DummyMultiBinary(gym.Env): 23 | def __init__(self, n): 24 | super(DummyMultiBinary, self).__init__() 25 | self.observation_space = gym.spaces.MultiBinary(n) 26 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) 27 | 28 | def reset(self): 29 | return self.observation_space.sample() 30 | 31 | def step(self, action): 32 | return self.observation_space.sample(), 0.0, False, {} 33 | 34 | 35 | @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) 36 | @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) 37 | def test_identity_spaces(model_class, env): 38 | """ 39 | Additional tests for DQ/SAC/TD3 to check observation space support 40 | for MultiDiscrete and MultiBinary. 41 | """ 42 | # DQN only support discrete actions 43 | if model_class == DQN: 44 | env.action_space = gym.spaces.Discrete(4) 45 | 46 | env = gym.wrappers.TimeLimit(env, max_episode_steps=100) 47 | 48 | model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) 49 | model.learn(total_timesteps=500) 50 | 51 | evaluate_policy(model, env, n_eval_episodes=5) 52 | -------------------------------------------------------------------------------- /bashfiles/hanabi_adapt_to_selfplay.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | thread=$2 7 | 8 | if (( ${thread} == 1 )); then python run_hanabi.py --run=$(($runid + 10)) --mreg=0.0; fi 9 | if (( ${thread} == 2 )); then python run_hanabi.py --run=$(($runid + 20)) --mreg=0.3; fi 10 | if (( ${thread} == 3 )); then python run_hanabi.py --run=$(($runid + 30)) --mreg=0.5; fi 11 | 12 | if (( ${thread} == 4 )); then python run_hanabi.py --run=$(($runid + 40)) --baseline; fi 13 | if (( ${thread} == 5 )); then python run_hanabi.py --run=$(($runid + 50)) --baseline --timesteps=250000; fi 14 | if (( ${thread} == 6 )); then python run_hanabi.py --run=$(($runid + 60)) --nomain; fi 15 | if (( ${thread} == 7 )); then python run_hanabi.py --run=$(($runid + 70)) --mreg=0.5 --latentz=50; fi 16 | 17 | for ((i=0;i<=3;i++)) 18 | do 19 | if (( ${thread} == 1 )); then python run_hanabi.py --run=$(($runid + 10)) --mreg=0.0 --k=$i --testing | tee -a logs/hanabi$(($runid + 10 + $i)).txt; fi 20 | if (( ${thread} == 2 )); then python run_hanabi.py --run=$(($runid + 20)) --mreg=0.3 --k=$i --testing | tee -a logs/hanabi$(($runid + 20 + $i)).txt; fi 21 | if (( ${thread} == 3 )); then python run_hanabi.py --run=$(($runid + 30)) --mreg=0.5 --k=$i --testing | tee -a logs/hanabi$(($runid + 30 + $i)).txt; fi 22 | 23 | if (( ${thread} == 4 )); then python run_hanabi.py --run=$(($runid + 40)) --baseline --k=$i --testing | tee -a logs/hanabi$(($runid + 40 + $i)).txt; fi 24 | if (( ${thread} == 5 )); then python run_hanabi.py --run=$(($runid + 50)) --baseline --timesteps=250000 --k=$i --testing | tee -a logs/hanabi$(($runid + 50 + $i)).txt; fi 25 | if (( ${thread} == 6 )); then python run_hanabi.py --run=$(($runid + 60)) --nomain --k=$i --testing | tee -a logs/hanabi$(($runid + 60 + $i)).txt; fi 26 | if (( ${thread} == 7 )); then python run_hanabi.py --run=$(($runid + 70)) --mreg=0.5 --latentz=50 --k=$i --testing | tee -a logs/hanabi$(($runid + 70 + $i)).txt; fi 27 | done -------------------------------------------------------------------------------- /stable-baselines3/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file in the wheel. 3 | license_file = LICENSE 4 | 5 | [tool:pytest] 6 | # Deterministic ordering for tests; useful for pytest-xdist. 7 | env = 8 | PYTHONHASHSEED=0 9 | filterwarnings = 10 | # Tensorboard/Tensorflow warnings 11 | ignore:inspect.getargspec:DeprecationWarning:tensorflow 12 | ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning 13 | ignore:The binary mode of fromstring is deprecated:DeprecationWarning 14 | ignore::FutureWarning:tensorflow 15 | # Gym warnings 16 | ignore:Parameters to load are deprecated.:DeprecationWarning 17 | ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning 18 | ignore::UserWarning:gym 19 | 20 | [pytype] 21 | inputs = stable_baselines3 22 | 23 | [flake8] 24 | ignore = W503,W504,E203,E231 # line breaks before and after binary operators 25 | # Ignore import not used when aliases are defined 26 | per-file-ignores = 27 | ./stable_baselines3/__init__.py:F401 28 | ./stable_baselines3/common/__init__.py:F401 29 | ./stable_baselines3/a2c/__init__.py:F401 30 | ./stable_baselines3/ddpg/__init__.py:F401 31 | ./stable_baselines3/dqn/__init__.py:F401 32 | ./stable_baselines3/ppo/__init__.py:F401 33 | ./stable_baselines3/sac/__init__.py:F401 34 | ./stable_baselines3/td3/__init__.py:F401 35 | ./stable_baselines3/common/vec_env/__init__.py:F401 36 | exclude = 37 | # No need to traverse our git directory 38 | .git, 39 | # There's no value in checking cache directories 40 | __pycache__, 41 | # Don't check the doc 42 | docs/ 43 | # This contains our built documentation 44 | build, 45 | # This contains builds of flake8 that we don't want to check 46 | dist 47 | *.egg-info 48 | max-complexity = 15 49 | # The GitHub editor is 127 chars wide 50 | max-line-length = 127 51 | 52 | [isort] 53 | profile = black 54 | line_length = 127 55 | src_paths = stable_baselines3 56 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_identity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 5 | from stable_baselines3.common.evaluation import evaluate_policy 6 | from stable_baselines3.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete 7 | from stable_baselines3.common.noise import NormalActionNoise 8 | from stable_baselines3.common.vec_env import DummyVecEnv 9 | 10 | DIM = 4 11 | 12 | 13 | @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) 14 | @pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) 15 | def test_discrete(model_class, env): 16 | env_ = DummyVecEnv([lambda: env]) 17 | kwargs = {} 18 | n_steps = 3000 19 | if model_class == DQN: 20 | kwargs = dict(learning_starts=0) 21 | n_steps = 4000 22 | # DQN only support discrete actions 23 | if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): 24 | return 25 | 26 | model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps) 27 | 28 | evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90) 29 | obs = env.reset() 30 | 31 | assert np.shape(model.predict(obs)[0]) == np.shape(obs) 32 | 33 | 34 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3]) 35 | def test_continuous(model_class): 36 | env = IdentityEnvBox(eps=0.5) 37 | 38 | n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class] 39 | 40 | kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) 41 | if model_class in [TD3]: 42 | n_actions = 1 43 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) 44 | kwargs["action_noise"] = action_noise 45 | 46 | model = model_class("MlpPolicy", env, **kwargs).learn(n_steps) 47 | 48 | evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) 49 | -------------------------------------------------------------------------------- /bashfiles/arms_human_adapt_to_fixed.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | thread=$2 7 | 8 | if (( ${thread} == 1 )); then python run_arms_human.py --run=$(($runid + 10)) --mreg=0.0; fi 9 | if (( ${thread} == 2 )); then python run_arms_human.py --run=$(($runid + 20)) --mreg=0.3; fi 10 | if (( ${thread} == 3 )); then python run_arms_human.py --run=$(($runid + 30)) --mreg=0.5; fi 11 | 12 | if (( ${thread} == 4 )); then python run_arms_human.py --run=$(($runid + 40)) --baseline; fi 13 | if (( ${thread} == 5 )); then python run_arms_human.py --run=$(($runid + 50)) --baseline --timesteps=6000; fi 14 | if (( ${thread} == 6 )); then python run_arms_human.py --run=$(($runid + 60)) --nomain; fi 15 | if (( ${thread} == 7 )); then python run_arms_human.py --run=$(($runid + 70)) --mreg=0.5 --latentz=5; fi 16 | 17 | for ((i=0;i<=9;i++)) 18 | do 19 | if (( ${thread} == 1 )); then python run_arms_human.py --run=$(($runid + 10)) --mreg=0.0 --k=$i --testing | tee -a logs/arms_human_$(($runid + 10 + $i)).txt; fi 20 | if (( ${thread} == 2 )); then python run_arms_human.py --run=$(($runid + 20)) --mreg=0.3 --k=$i --testing | tee -a logs/arms_human_$(($runid + 20 + $i)).txt; fi 21 | if (( ${thread} == 3 )); then python run_arms_human.py --run=$(($runid + 30)) --mreg=0.5 --k=$i --testing | tee -a logs/arms_human_$(($runid + 30 + $i)).txt; fi 22 | 23 | if (( ${thread} == 4 )); then python run_arms_human.py --run=$(($runid + 40)) --baseline --k=$i --testing | tee -a logs/arms_human_$(($runid + 40 + $i)).txt; fi 24 | if (( ${thread} == 5 )); then python run_arms_human.py --run=$(($runid + 50)) --baseline --timesteps=6000 --k=$i --testing | tee -a logs/arms_human_$(($runid + 50 + $i)).txt; fi 25 | if (( ${thread} == 6 )); then python run_arms_human.py --run=$(($runid + 60)) --nomain --k=$i --testing | tee -a logs/arms_human_$(($runid + 60 + $i)).txt; fi 26 | if (( ${thread} == 7 )); then python run_arms_human.py --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --k=$i --testing | tee -a logs/arms_human_$(($runid + 70 + $i)).txt; fi 27 | done -------------------------------------------------------------------------------- /stable-baselines3/.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | ## Motivation and Context 7 | 8 | 9 | 10 | - [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes) 11 | 12 | ## Types of changes 13 | 14 | - [ ] Bug fix (non-breaking change which fixes an issue) 15 | - [ ] New feature (non-breaking change which adds functionality) 16 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 17 | - [ ] Documentation (update in the documentation) 18 | 19 | ## Checklist: 20 | 21 | 22 | - [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**) 23 | - [ ] I have updated the changelog accordingly (**required**). 24 | - [ ] My change requires a change to the documentation. 25 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). 26 | - [ ] I have updated the documentation accordingly. 27 | - [ ] I have reformatted the code using `make format` (**required**) 28 | - [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**) 29 | - [ ] I have ensured `make pytest` and `make type` both pass. (**required**) 30 | 31 | 32 | Note: we are using a maximum length of 127 characters per line 33 | 34 | 35 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F401 2 | import typing 3 | from copy import deepcopy 4 | from typing import Optional, Union 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import ( 7 | AlreadySteppingError, 8 | CloudpickleWrapper, 9 | NotSteppingError, 10 | VecEnv, 11 | VecEnvWrapper, 12 | ) 13 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 14 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 15 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan 16 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack 17 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize 18 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage 19 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder 20 | 21 | # Avoid circular import 22 | if typing.TYPE_CHECKING: 23 | from stable_baselines3.common.type_aliases import GymEnv 24 | 25 | 26 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: 27 | """ 28 | :param env: (gym.Env) 29 | :return: (VecNormalize) 30 | """ 31 | env_tmp = env 32 | while isinstance(env_tmp, VecEnvWrapper): 33 | if isinstance(env_tmp, VecNormalize): 34 | return env_tmp 35 | env_tmp = env_tmp.venv 36 | return None 37 | 38 | 39 | # Define here to avoid circular import 40 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: 41 | """ 42 | Sync eval env and train env when using VecNormalize 43 | 44 | :param env: (GymEnv) 45 | :param eval_env: (GymEnv) 46 | """ 47 | env_tmp, eval_env_tmp = env, eval_env 48 | while isinstance(env_tmp, VecEnvWrapper): 49 | if isinstance(env_tmp, VecNormalize): 50 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 51 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) 52 | env_tmp = env_tmp.venv 53 | eval_env_tmp = eval_env_tmp.venv 54 | -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/dqn.rst: -------------------------------------------------------------------------------- 1 | .. _dqn: 2 | 3 | .. automodule:: stable_baselines3.dqn 4 | 5 | 6 | DQN 7 | === 8 | 9 | `Deep Q Network (DQN) `_ 10 | 11 | .. rubric:: Available Policies 12 | 13 | .. autosummary:: 14 | :nosignatures: 15 | 16 | MlpPolicy 17 | CnnPolicy 18 | 19 | 20 | Notes 21 | ----- 22 | 23 | - Original paper: https://arxiv.org/abs/1312.5602 24 | - Further reference: https://www.nature.com/articles/nature14236 25 | 26 | .. note:: 27 | This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. 28 | 29 | 30 | Can I use? 31 | ---------- 32 | 33 | - Recurrent policies: ❌ 34 | - Multi processing: ❌ 35 | - Gym spaces: 36 | 37 | 38 | ============= ====== =========== 39 | Space Action Observation 40 | ============= ====== =========== 41 | Discrete ✔ ✔ 42 | Box ❌ ✔ 43 | MultiDiscrete ❌ ✔ 44 | MultiBinary ❌ ✔ 45 | ============= ====== =========== 46 | 47 | 48 | Example 49 | ------- 50 | 51 | .. code-block:: python 52 | 53 | import gym 54 | import numpy as np 55 | 56 | from stable_baselines3 import DQN 57 | from stable_baselines3.dqn import MlpPolicy 58 | 59 | env = gym.make('Pendulum-v0') 60 | 61 | model = DQN(MlpPolicy, env, verbose=1) 62 | model.learn(total_timesteps=10000, log_interval=4) 63 | model.save("dqn_pendulum") 64 | 65 | del model # remove to demonstrate saving and loading 66 | 67 | model = DQN.load("dqn_pendulum") 68 | 69 | obs = env.reset() 70 | while True: 71 | action, _states = model.predict(obs, deterministic=True) 72 | obs, reward, done, info = env.step(action) 73 | env.render() 74 | if done: 75 | obs = env.reset() 76 | 77 | Parameters 78 | ---------- 79 | 80 | .. autoclass:: DQN 81 | :members: 82 | :inherited-members: 83 | 84 | .. _dqn_policies: 85 | 86 | DQN Policies 87 | ------------- 88 | 89 | .. autoclass:: MlpPolicy 90 | :members: 91 | :inherited-members: 92 | 93 | .. autoclass:: CnnPolicy 94 | :members: 95 | -------------------------------------------------------------------------------- /bashfiles/blocks_adapt_to_selfplay.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | thread=$2 7 | 8 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --ppopartners; fi 9 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --ppopartners; fi 10 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --ppopartners; fi 11 | 12 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --ppopartners; fi 13 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=1000000 --ppopartners; fi 14 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --ppopartners; fi 15 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --ppopartners; fi 16 | 17 | for ((i=0;i<=7;i++)) 18 | do 19 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 10 + $i)).txt; fi 20 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 20 + $i)).txt; fi 21 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 30 + $i)).txt; fi 22 | 23 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 40 + $i)).txt; fi 24 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=1000000 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 50 + $i)).txt; fi 25 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 60 + $i)).txt; fi 26 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --ppopartners --k=$i --testing | tee -a logs/blocksppo$(($runid + 70 + $i)).txt; fi 27 | done 28 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import gym 5 | import pytest 6 | 7 | from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 8 | from stable_baselines3.common.callbacks import ( 9 | CallbackList, 10 | CheckpointCallback, 11 | EvalCallback, 12 | EveryNTimesteps, 13 | StopTrainingOnRewardThreshold, 14 | ) 15 | 16 | 17 | @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) 18 | def test_callbacks(tmp_path, model_class): 19 | log_folder = tmp_path / "logs/callbacks/" 20 | 21 | # Dyn only support discrete actions 22 | env_name = select_env(model_class) 23 | # Create RL model 24 | # Small network for fast test 25 | model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32])) 26 | 27 | checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder) 28 | 29 | eval_env = gym.make(env_name) 30 | # Stop training if the performance is good enough 31 | callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) 32 | 33 | eval_callback = EvalCallback( 34 | eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100 35 | ) 36 | 37 | # Equivalent to the `checkpoint_callback` 38 | # but here in an event-driven manner 39 | checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event") 40 | event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) 41 | 42 | callback = CallbackList([checkpoint_callback, eval_callback, event_callback]) 43 | 44 | model.learn(500, callback=callback) 45 | model.learn(500, callback=None) 46 | # Transform callback into a callback list automatically 47 | model.learn(500, callback=[checkpoint_callback, eval_callback]) 48 | # Automatic wrapping, old way of doing callbacks 49 | model.learn(500, callback=lambda _locals, _globals: True) 50 | if os.path.exists(log_folder): 51 | shutil.rmtree(log_folder) 52 | 53 | 54 | def select_env(model_class) -> str: 55 | if model_class is DQN: 56 | return "CartPole-v0" 57 | else: 58 | return "Pendulum-v0" 59 | -------------------------------------------------------------------------------- /bashfiles/arms_adapt_to_selfplay.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | thread=$2 7 | m=$3 8 | 9 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --ppopartners; fi 10 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --ppopartners; fi 11 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --ppopartners; fi 12 | 13 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --ppopartners; fi 14 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --ppopartners; fi 15 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --ppopartners; fi 16 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --ppopartners; fi 17 | 18 | for ((i=0;i<=9;i++)) 19 | do 20 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 10 + $i)).txt; fi 21 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 20 + $i)).txt; fi 22 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 30 + $i)).txt; fi 23 | 24 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 40 + $i)).txt; fi 25 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 50 + $i)).txt; fi 26 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 60 + $i)).txt; fi 27 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --ppopartners --k=$i --testing | tee -a logs/armsppo${m}_$(($runid + 70 + $i)).txt; fi 28 | done -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 7 | 8 | 9 | class VecFrameStack(VecEnvWrapper): 10 | """ 11 | Frame stacking wrapper for vectorized environment 12 | 13 | :param venv: the vectorized environment to wrap 14 | :param n_stack: Number of frames to stack 15 | """ 16 | 17 | def __init__(self, venv: VecEnv, n_stack: int): 18 | self.venv = venv 19 | self.n_stack = n_stack 20 | wrapped_obs_space = venv.observation_space 21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) 22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) 23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 26 | 27 | def step_wait(self): 28 | observations, rewards, dones, infos = self.venv.step_wait() 29 | last_ax_size = observations.shape[-1] 30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) 31 | for i, done in enumerate(dones): 32 | if done: 33 | if "terminal_observation" in infos[i]: 34 | old_terminal = infos[i]["terminal_observation"] 35 | new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) 36 | infos[i]["terminal_observation"] = new_terminal 37 | else: 38 | warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") 39 | self.stackedobs[i] = 0 40 | self.stackedobs[..., -observations.shape[-1] :] = observations 41 | return self.stackedobs, rewards, dones, infos 42 | 43 | def reset(self): 44 | """ 45 | Reset all environments 46 | """ 47 | obs = self.venv.reset() 48 | self.stackedobs[...] = 0 49 | self.stackedobs[..., -obs.shape[-1] :] = obs 50 | return self.stackedobs 51 | 52 | def close(self): 53 | self.venv.close() 54 | -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/ppo.rst: -------------------------------------------------------------------------------- 1 | .. _ppo2: 2 | 3 | .. automodule:: stable_baselines3.ppo 4 | 5 | PPO 6 | === 7 | 8 | The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers) 9 | and TRPO (it uses a trust region to improve the actor). 10 | 11 | The main idea is that after an update, the new policy should be not too far form the old policy. 12 | For that, ppo uses clipping to avoid too large update. 13 | 14 | 15 | .. note:: 16 | 17 | PPO contains several modifications from the original algorithm not documented 18 | by OpenAI: advantages are normalized and value function can be also clipped . 19 | 20 | 21 | Notes 22 | ----- 23 | 24 | - Original paper: https://arxiv.org/abs/1707.06347 25 | - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 26 | - OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ 27 | - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html 28 | 29 | 30 | Can I use? 31 | ---------- 32 | 33 | - Recurrent policies: ❌ 34 | - Multi processing: ✔️ 35 | - Gym spaces: 36 | 37 | 38 | ============= ====== =========== 39 | Space Action Observation 40 | ============= ====== =========== 41 | Discrete ✔️ ✔️ 42 | Box ✔️ ✔️ 43 | MultiDiscrete ✔️ ✔️ 44 | MultiBinary ✔️ ✔️ 45 | ============= ====== =========== 46 | 47 | Example 48 | ------- 49 | 50 | Train a PPO agent on ``Pendulum-v0`` using 4 environments. 51 | 52 | .. code-block:: python 53 | 54 | import gym 55 | 56 | from stable_baselines3 import PPO 57 | from stable_baselines3.ppo import MlpPolicy 58 | from stable_baselines3.common.cmd_util import make_vec_env 59 | 60 | # Parallel environments 61 | env = make_vec_env('CartPole-v1', n_envs=4) 62 | 63 | model = PPO(MlpPolicy, env, verbose=1) 64 | model.learn(total_timesteps=25000) 65 | model.save("ppo_cartpole") 66 | 67 | del model # remove to demonstrate saving and loading 68 | 69 | model = PPO.load("ppo_cartpole") 70 | 71 | obs = env.reset() 72 | while True: 73 | action, _states = model.predict(obs) 74 | obs, rewards, dones, info = env.step(action) 75 | env.render() 76 | 77 | Parameters 78 | ---------- 79 | 80 | .. autoclass:: PPO 81 | :members: 82 | :inherited-members: 83 | -------------------------------------------------------------------------------- /stable-baselines3/docs/_static/img/colab-badge.svg: -------------------------------------------------------------------------------- 1 | Open in ColabOpen in Colab 2 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/vec_transpose.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.preprocessing import is_image_space 7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 8 | 9 | if typing.TYPE_CHECKING: 10 | from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401 11 | 12 | 13 | class VecTransposeImage(VecEnvWrapper): 14 | """ 15 | Re-order channels, from HxWxC to CxHxW. 16 | It is required for PyTorch convolution layers. 17 | 18 | :param venv: (VecEnv) 19 | """ 20 | 21 | def __init__(self, venv: VecEnv): 22 | assert is_image_space(venv.observation_space), "The observation space must be an image" 23 | 24 | observation_space = self.transpose_space(venv.observation_space) 25 | super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) 26 | 27 | @staticmethod 28 | def transpose_space(observation_space: spaces.Box) -> spaces.Box: 29 | """ 30 | Transpose an observation space (re-order channels). 31 | 32 | :param observation_space: (spaces.Box) 33 | :return: (spaces.Box) 34 | """ 35 | assert is_image_space(observation_space), "The observation space must be an image" 36 | width, height, channels = observation_space.shape 37 | new_shape = (channels, width, height) 38 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) 39 | 40 | @staticmethod 41 | def transpose_image(image: np.ndarray) -> np.ndarray: 42 | """ 43 | Transpose an image or batch of images (re-order channels). 44 | 45 | :param image: (np.ndarray) 46 | :return: (np.ndarray) 47 | """ 48 | if len(image.shape) == 3: 49 | return np.transpose(image, (2, 0, 1)) 50 | return np.transpose(image, (0, 3, 1, 2)) 51 | 52 | def step_wait(self) -> "GymStepReturn": 53 | observations, rewards, dones, infos = self.venv.step_wait() 54 | return self.transpose_image(observations), rewards, dones, infos 55 | 56 | def reset(self) -> np.ndarray: 57 | """ 58 | Reset all environments 59 | """ 60 | return self.transpose_image(self.venv.reset()) 61 | 62 | def close(self) -> None: 63 | self.venv.close() 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for the paper: 2 | 3 | [On the Critical Role of Conventions in Adaptive Human-AI Collaboration](https://openreview.net/pdf?id=8Ln-Bq0mZcy) 4 | 5 | ``` 6 | "On the Critical Role of Conventions in Adaptive Human-AI Collaboration" 7 | Andy Shih, Arjun Sawhney, Jovana Kondic, Stefano Ermon, Dorsa Sadigh 8 | Proceedings of the 9th International Conference on Learning Representations (ICLR 2021) 9 | 10 | @inproceedings{SSKESiclr21, 11 | author = {Andy Shih and Arjun Sawhney and Jovana Kondic and Stefano Ermon and Dorsa Sadigh}, 12 | title = {On the Critical Role of Conventions in Adaptive Human-AI Collaboration}, 13 | booktitle = {Proceedings of the 9th International Conference on Learning Representations (ICLR)}, 14 | month = {may}, 15 | year = {2021}, 16 | keywords = {conference} 17 | } 18 | ``` 19 | 20 | # Instructions 21 | 22 | Install gym environment, hanabi environment, and stable baselines 23 | ``` 24 | pip install . 25 | 26 | cd hanabi 27 | pip install . 28 | cd ../ 29 | 30 | cd stable-baselines3 31 | pip install -e . 32 | cd ../ 33 | ``` 34 | 35 | Train partner agents 36 | ``` 37 | bash bashfiles/arms_train_selfplay.bash 1230 2 38 | bash bashfiles/arms_train_selfplay.bash 1240 2 39 | 40 | for ((i=1230;i<=1235;i++)) 41 | do 42 | bash bashfiles/blocks_train_selfplay.bash $i 43 | done 44 | for ((i=1240;i<=1245;i++)) 45 | do 46 | bash bashfiles/blocks_train_selfplay.bash $i 47 | done 48 | for ((i=1240;i<=1247;i++)) 49 | do 50 | bash bashfiles/hanabi_train_selfplay.bash $i 51 | done 52 | ``` 53 | 54 | Run adaptation experiments. 55 | Choose from one of the settings: 56 | - t=1: modular, regularization lambda=0.0 57 | - t=2: modular, regularization lambda=0.3 58 | - t=3: modular, regularization lambda=0.5 59 | - t=4: baseline agg, aggregate gradients 60 | - t=5: baseline agg, aggregate gradients, early stopping 61 | - t=6: baseline modular, no main logits 62 | - t=7: low-dim z + modular, regularization lambda=0.5 63 | 64 | ``` 65 | t=1 66 | runid=100 67 | 68 | bash bashfiles/arms_adapt_to_selfplay.bash $runid $t 2 69 | bash bashfiles/arms_adapt_to_fixed.bash $runid $t 2 70 | bash bashfiles/arms_human_adapt_to_fixed.bash $runid $t 2 71 | 72 | bash bashfiles/blocks_adapt_to_selfplay.bash $runid $t 73 | bash bashfiles/blocks_adapt_to_fixed.bash $runid $t 74 | 75 | bash bashfiles/hanabi_adapt_to_selfplay.bash $runid $t 76 | 77 | ``` -------------------------------------------------------------------------------- /my_gym/envs/generate_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | 4 | def generate_grid(grid_size, blocks, blocks_per_color, colors): 5 | grid = [[0 for _ in range(grid_size)] for _ in range(grid_size)] 6 | 7 | for color in colors: 8 | arr = np.random.choice(len(blocks), blocks_per_color, replace=False) 9 | for i in arr: 10 | h, w = blocks[i] 11 | 12 | while True: 13 | r, c = np.random.randint(grid_size - h + 1), np.random.randint(grid_size - w + 1) 14 | valid = True 15 | for i in range(r,r+h): 16 | for j in range(c, c+w): 17 | if grid[i][j]: 18 | valid = False 19 | if not valid: 20 | continue 21 | for i in range(r,r+h): 22 | for j in range(c, c+w): 23 | grid[i][j] = color 24 | break; 25 | 26 | return grid 27 | 28 | def generate_mask(grid_size, percent_mask): 29 | mask = [[0 for _ in range(grid_size)] for _ in range(grid_size)] 30 | grid_sq = grid_size * grid_size 31 | arr = np.random.choice(grid_sq, int(grid_sq * percent_mask), replace=False) 32 | 33 | for a in arr: 34 | r = a // grid_size; 35 | c = a % grid_size; 36 | mask[r][c] = 1; 37 | 38 | return mask 39 | 40 | def make_grids(grid_size, vis1, vis2, blocks, blocks_per_color=1, colors=[2,3]): 41 | goal_grid = generate_grid(grid_size, blocks, blocks_per_color, colors) 42 | 43 | v, p = [vis1, vis2], [None, None] 44 | for i in range(2): 45 | if v[i] == 1: 46 | p[i] = generate_mask(grid_size, 0) 47 | elif v[i] == 2: 48 | p[i] = generate_mask(grid_size, 0.5) 49 | elif v[i] == 3: 50 | p[i] = generate_mask(grid_size, 1) 51 | elif v[i] == 4: 52 | p[i] = generate_mask(grid_size, float(np.random.uniform() < 0.5) ) 53 | elif v[i] == 5: 54 | p[0] = generate_mask(grid_size, 0.5) 55 | p[1] = [[1 - x for x in r] for r in p[0]] 56 | else: 57 | raise ValueError('Visibility value is not an integer between [1,5].') 58 | p1_mask, p2_mask = p[0], p[1] 59 | 60 | goal_grid = np.array(goal_grid) 61 | p1_grid = (np.array(1)-p1_mask) * goal_grid + p1_mask # masked squares are denotes as 1 62 | p2_grid = (np.array(1)-p2_mask) * goal_grid + p2_mask 63 | 64 | return goal_grid, p1_grid, p2_grid 65 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from stable_baselines3.common.logger import ( 5 | DEBUG, 6 | ScopedConfigure, 7 | configure, 8 | debug, 9 | dump, 10 | error, 11 | info, 12 | make_output_format, 13 | read_csv, 14 | read_json, 15 | record, 16 | record_dict, 17 | record_mean, 18 | reset, 19 | set_level, 20 | warn, 21 | ) 22 | 23 | KEY_VALUES = { 24 | "test": 1, 25 | "b": -3.14, 26 | "8": 9.9, 27 | "l": [1, 2], 28 | "a": np.array([1, 2, 3]), 29 | "f": np.array(1), 30 | "g": np.array([[[1]]]), 31 | } 32 | 33 | KEY_EXCLUDED = {} 34 | for key in KEY_VALUES.keys(): 35 | KEY_EXCLUDED[key] = None 36 | 37 | 38 | def test_main(tmp_path): 39 | """ 40 | tests for the logger module 41 | """ 42 | info("hi") 43 | debug("shouldn't appear") 44 | set_level(DEBUG) 45 | debug("should appear") 46 | configure(folder=str(tmp_path)) 47 | record("a", 3) 48 | record("b", 2.5) 49 | dump() 50 | record("b", -2.5) 51 | record("a", 5.5) 52 | dump() 53 | info("^^^ should see a = 5.5") 54 | record_mean("b", -22.5) 55 | record_mean("b", -44.4) 56 | record("a", 5.5) 57 | dump() 58 | with ScopedConfigure(None, None): 59 | info("^^^ should see b = 33.3") 60 | 61 | with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]): 62 | record("b", -2.5) 63 | dump() 64 | 65 | reset() 66 | record("a", "longasslongasslongasslongasslongasslongassvalue") 67 | dump() 68 | warn("hey") 69 | error("oh") 70 | record_dict({"test": 1}) 71 | 72 | 73 | @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) 74 | def test_make_output(tmp_path, _format): 75 | """ 76 | test make output 77 | 78 | :param _format: (str) output format 79 | """ 80 | if _format == "tensorboard": 81 | # Skip if no tensorboard installed 82 | pytest.importorskip("tensorboard") 83 | 84 | writer = make_output_format(_format, tmp_path) 85 | writer.write(KEY_VALUES, KEY_EXCLUDED) 86 | if _format == "csv": 87 | read_csv(tmp_path / "progress.csv") 88 | elif _format == "json": 89 | read_json(tmp_path / "progress.json") 90 | writer.close() 91 | 92 | 93 | def test_make_output_fail(tmp_path): 94 | """ 95 | test value error on logger 96 | """ 97 | with pytest.raises(ValueError): 98 | make_output_format("dummy_format", tmp_path) 99 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/algos.rst: -------------------------------------------------------------------------------- 1 | RL Algorithms 2 | ============= 3 | 4 | This table displays the rl algorithms that are implemented in the Stable Baselines3 project, 5 | along with some useful characteristics: support for discrete/continuous actions, multiprocessing. 6 | 7 | 8 | ============ =========== ============ ================= =============== ================ 9 | Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing 10 | ============ =========== ============ ================= =============== ================ 11 | A2C ✔️ ✔️ ✔️ ✔️ ✔️ 12 | DDPG ✔️ ❌ ❌ ❌ ❌ 13 | DQN ❌ ✔️ ❌ ❌ ❌ 14 | PPO ✔️ ✔️ ✔️ ✔️ ✔️ 15 | SAC ✔️ ❌ ❌ ❌ ❌ 16 | TD3 ✔️ ❌ ❌ ❌ ❌ 17 | ============ =========== ============ ================= =============== ================ 18 | 19 | 20 | .. note:: 21 | Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. 22 | 23 | Actions ``gym.spaces``: 24 | 25 | - ``Box``: A N-dimensional box that contains every point in the action 26 | space. 27 | - ``Discrete``: A list of possible actions, where each timestep only 28 | one of the actions can be used. 29 | - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used. 30 | - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. 31 | 32 | 33 | .. note:: 34 | 35 | Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper 36 | See `Issue #339 `_ for more info. 37 | 38 | 39 | Reproducibility 40 | --------------- 41 | 42 | Completely reproducible results are not guaranteed across Tensorflow releases or different platforms. 43 | Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds. 44 | 45 | In order to make computations deterministics, on your specific problem on one specific platform, 46 | you need to pass a ``seed`` argument at the creation of a model. 47 | If you pass an environment to the model using ``set_env()``, then you also need to seed the environment first. 48 | 49 | 50 | Credit: part of the *Reproducibility* section comes from `PyTorch Documentation `_ 51 | -------------------------------------------------------------------------------- /bashfiles/blocks_adapt_to_fixed.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | thread=$2 7 | 8 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --fixedpartners; fi 9 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --fixedpartners; fi 10 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --fixedpartners; fi 11 | 12 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --fixedpartners; fi 13 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=700000 --fixedpartners; fi 14 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --fixedpartners; fi 15 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --fixedpartners; fi 16 | 17 | for ((i=0;i<=5;i++)) 18 | do 19 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 10 + $i)).txt; fi 20 | if (( ${thread} == 2 )); then python run_blocks.py --run=$(($runid + 20)) --mreg=0.3 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 20 + $i)).txt; fi 21 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 30 + $i)).txt; fi 22 | 23 | if (( ${thread} == 4 )); then python run_blocks.py --run=$(($runid + 40)) --baseline --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 40 + $i)).txt; fi 24 | if (( ${thread} == 5 )); then python run_blocks.py --run=$(($runid + 50)) --baseline --timesteps=700000 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 50 + $i)).txt; fi 25 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 60 + $i)).txt; fi 26 | if (( ${thread} == 7 )); then python run_blocks.py --run=$(($runid + 70)) --mreg=0.5 --latentz=20 --fixedpartners --k=$i --testing | tee -a logs/blocks$(($runid + 70 + $i)).txt; fi 27 | done 28 | 29 | if (( ${thread} == 1 )); then python run_blocks.py --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --testing --zeroshot | tee -a logs/blockszero$(($runid + 10)).txt; fi 30 | if (( ${thread} == 3 )); then python run_blocks.py --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --testing --zeroshot | tee -a logs/blockszero$(($runid + 30)).txt; fi 31 | if (( ${thread} == 6 )); then python run_blocks.py --run=$(($runid + 60)) --nomain --fixedpartners --testing --zeroshot | tee -a logs/blockszero$(($runid + 60)).txt; fi -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/ddpg.rst: -------------------------------------------------------------------------------- 1 | .. _ddpg: 2 | 3 | .. automodule:: stable_baselines3.ddpg 4 | 5 | 6 | DDPG 7 | ==== 8 | 9 | `Deep Deterministic Policy Gradient (DDPG) `_ combines the 10 | trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions. 11 | 12 | 13 | .. rubric:: Available Policies 14 | 15 | .. autosummary:: 16 | :nosignatures: 17 | 18 | MlpPolicy 19 | 20 | 21 | Notes 22 | ----- 23 | 24 | - Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf 25 | - DDPG Paper: https://arxiv.org/abs/1509.02971 26 | - OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html 27 | 28 | .. note:: 29 | 30 | The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation. 31 | to match the original paper 32 | 33 | 34 | Can I use? 35 | ---------- 36 | 37 | - Recurrent policies: ❌ 38 | - Multi processing: ❌ 39 | - Gym spaces: 40 | 41 | 42 | ============= ====== =========== 43 | Space Action Observation 44 | ============= ====== =========== 45 | Discrete ❌ ✔️ 46 | Box ✔️ ✔️ 47 | MultiDiscrete ❌ ✔️ 48 | MultiBinary ❌ ✔️ 49 | ============= ====== =========== 50 | 51 | 52 | Example 53 | ------- 54 | 55 | .. code-block:: python 56 | 57 | import gym 58 | import numpy as np 59 | 60 | from stable_baselines3 import DDPG 61 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 62 | 63 | env = gym.make('Pendulum-v0') 64 | 65 | # The noise objects for DDPG 66 | n_actions = env.action_space.shape[-1] 67 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) 68 | 69 | model = DDPG('MlpPolicy', env, action_noise=action_noise, verbose=1) 70 | model.learn(total_timesteps=10000, log_interval=10) 71 | model.save("ddpg_pendulum") 72 | env = model.get_env() 73 | 74 | del model # remove to demonstrate saving and loading 75 | 76 | model = DDPG.load("ddpg_pendulum") 77 | 78 | obs = env.reset() 79 | while True: 80 | action, _states = model.predict(obs) 81 | obs, rewards, dones, info = env.step(action) 82 | env.render() 83 | 84 | 85 | Parameters 86 | ---------- 87 | 88 | .. autoclass:: DDPG 89 | :members: 90 | :inherited-members: 91 | 92 | .. _ddpg_policies: 93 | 94 | DDPG Policies 95 | ------------- 96 | 97 | .. autoclass:: MlpPolicy 98 | :members: 99 | :inherited-members: 100 | 101 | 102 | .. .. autoclass:: CnnPolicy 103 | .. :members: 104 | .. :inherited-members: 105 | -------------------------------------------------------------------------------- /bashfiles/arms_adapt_to_fixed.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | runid=$1 6 | thread=$2 7 | m=$3 8 | 9 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --fixedpartners; fi 10 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --fixedpartners; fi 11 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --fixedpartners; fi 12 | 13 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --fixedpartners; fi 14 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --fixedpartners; fi 15 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --fixedpartners; fi 16 | if (( ${thread} == 8 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --fixedpartners; fi 17 | 18 | for ((i=0;i<=3;i++)) 19 | do 20 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 10 + $i)).txt; fi 21 | if (( ${thread} == 2 )); then python run_arms.py --m=${m} --run=$(($runid + 20)) --mreg=0.3 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 20 + $i)).txt; fi 22 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 30 + $i)).txt; fi 23 | 24 | if (( ${thread} == 4 )); then python run_arms.py --m=${m} --run=$(($runid + 40)) --baseline --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 40 + $i)).txt; fi 25 | if (( ${thread} == 5 )); then python run_arms.py --m=${m} --run=$(($runid + 50)) --baseline --timesteps=6000 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 50 + $i)).txt; fi 26 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 60 + $i)).txt; fi 27 | if (( ${thread} == 7 )); then python run_arms.py --m=${m} --run=$(($runid + 70)) --mreg=0.5 --latentz=5 --fixedpartners --k=$i --testing | tee -a logs/arms${m}_$(($runid + 70 + $i)).txt; fi 28 | done 29 | 30 | if (( ${thread} == 1 )); then python run_arms.py --m=${m} --run=$(($runid + 10)) --mreg=0.0 --fixedpartners --testing --zeroshot | tee -a logs/armszero${m}_$(($runid + 10)).txt; fi 31 | if (( ${thread} == 3 )); then python run_arms.py --m=${m} --run=$(($runid + 30)) --mreg=0.5 --fixedpartners --testing --zeroshot | tee -a logs/armszero${m}_$(($runid + 30)).txt; fi 32 | if (( ${thread} == 6 )); then python run_arms.py --m=${m} --run=$(($runid + 60)) --nomain --fixedpartners --testing --zeroshot | tee -a logs/armszero${m}_$(($runid + 60)).txt; fi -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copied from stable_baselines 2 | import numpy as np 3 | 4 | from stable_baselines3.common.vec_env import VecEnv 5 | 6 | 7 | def evaluate_policy( 8 | model, 9 | env, 10 | partner_idx=0, 11 | n_eval_episodes=10, 12 | deterministic=True, 13 | render=False, 14 | callback=None, 15 | reward_threshold=None, 16 | return_episode_rewards=False, 17 | ): 18 | """ 19 | Runs policy for ``n_eval_episodes`` episodes and returns average reward. 20 | This is made to work only with one env. 21 | 22 | :param model: (BaseAlgorithm) The RL agent you want to evaluate. 23 | :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv`` 24 | this must contain only one environment. 25 | :param n_eval_episodes: (int) Number of episode to evaluate the agent 26 | :param deterministic: (bool) Whether to use deterministic or stochastic actions 27 | :param render: (bool) Whether to render the environment or not 28 | :param callback: (callable) callback function to do additional checks, 29 | called after each step. 30 | :param reward_threshold: (float) Minimum expected reward per episode, 31 | this will raise an error if the performance is not met 32 | :param return_episode_rewards: (bool) If True, a list of reward per episode 33 | will be returned instead of the mean. 34 | :return: (float, float) Mean reward per episode, std of reward per episode 35 | returns ([float], [int]) when ``return_episode_rewards`` is True 36 | """ 37 | if isinstance(env, VecEnv): 38 | assert env.num_envs == 1, "You must pass only one environment when using this function" 39 | 40 | episode_rewards, episode_lengths = [], [] 41 | for _ in range(n_eval_episodes): 42 | obs = env.reset() 43 | done, state = False, None 44 | episode_reward = 0.0 45 | episode_length = 0 46 | while not done: 47 | action, _ = model.predict(observation=obs, partner_idx=partner_idx, deterministic=deterministic) 48 | obs, reward, done, _info = env.step(action) 49 | episode_reward += reward 50 | if callback is not None: 51 | callback(locals(), globals()) 52 | episode_length += 1 53 | if render: 54 | env.render() 55 | episode_rewards.append(episode_reward) 56 | episode_lengths.append(episode_length) 57 | mean_reward = np.mean(episode_rewards) 58 | std_reward = np.std(episode_rewards) 59 | if reward_threshold is not None: 60 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" 61 | if return_episode_rewards: 62 | return episode_rewards, episode_lengths 63 | return mean_reward, std_reward 64 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_sde.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch as th 3 | from torch.distributions import Normal 4 | 5 | from stable_baselines3 import A2C, PPO, SAC 6 | 7 | 8 | def test_state_dependent_exploration_grad(): 9 | """ 10 | Check that the gradient correspond to the expected one 11 | """ 12 | n_states = 2 13 | state_dim = 3 14 | action_dim = 10 15 | sigma_hat = th.ones(state_dim, action_dim, requires_grad=True) 16 | # Reduce the number of parameters 17 | # sigma_ = th.ones(state_dim, action_dim) * sigma_ 18 | # weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma)) 19 | th.manual_seed(2) 20 | weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat) 21 | weights = weights_dist.rsample() 22 | 23 | state = th.rand(n_states, state_dim) 24 | mu = th.ones(action_dim) 25 | noise = th.mm(state, weights) 26 | 27 | action = mu + noise 28 | 29 | variance = th.mm(state ** 2, sigma_hat ** 2) 30 | action_dist = Normal(mu, th.sqrt(variance)) 31 | 32 | # Sum over the action dimension because we assume they are independent 33 | loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean() 34 | loss.backward() 35 | 36 | # From Rueckstiess paper: check that the computed gradient 37 | # correspond to the analytical form 38 | grad = th.zeros_like(sigma_hat) 39 | for j in range(action_dim): 40 | # sigma_hat is the std of the gaussian distribution of the noise matrix weights 41 | # sigma_j = sum_j(state_i **2 * sigma_hat_ij ** 2) 42 | # sigma_j is the standard deviation of the policy gaussian distribution 43 | sigma_j = th.sqrt(variance[:, j]) 44 | for i in range(state_dim): 45 | # Derivative of the log probability of the jth component of the action 46 | # w.r.t. the standard deviation sigma_j 47 | d_log_policy_j = (noise[:, j] ** 2 - sigma_j ** 2) / sigma_j ** 3 48 | # Derivative of sigma_j w.r.t. sigma_hat_ij 49 | d_log_sigma_j = (state[:, i] ** 2 * sigma_hat[i, j]) / sigma_j 50 | # Chain rule, average over the minibatch 51 | grad[i, j] = (d_log_policy_j * d_log_sigma_j).mean() 52 | 53 | # sigma.grad should be equal to grad 54 | assert sigma_hat.grad.allclose(grad) 55 | 56 | 57 | @pytest.mark.parametrize("model_class", [SAC, A2C, PPO]) 58 | @pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []]) 59 | @pytest.mark.parametrize("use_expln", [False, True]) 60 | def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln): 61 | model = model_class( 62 | "MlpPolicy", 63 | "Pendulum-v0", 64 | use_sde=True, 65 | seed=None, 66 | create_eval_env=True, 67 | verbose=1, 68 | policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln), 69 | ) 70 | model.learn(total_timesteps=int(500), eval_freq=250) 71 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/tensorboard.rst: -------------------------------------------------------------------------------- 1 | .. _tensorboard: 2 | 3 | Tensorboard Integration 4 | ======================= 5 | 6 | Basic Usage 7 | ------------ 8 | 9 | To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent: 10 | 11 | .. code-block:: python 12 | 13 | from stable_baselines3 import A2C 14 | 15 | model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/") 16 | model.learn(total_timesteps=10000) 17 | 18 | 19 | You can also define custom logging name when training (by default it is the algorithm name) 20 | 21 | .. code-block:: python 22 | 23 | from stable_baselines3 import A2C 24 | 25 | model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/") 26 | model.learn(total_timesteps=10000, tb_log_name="first_run") 27 | # Pass reset_num_timesteps=False to continue the training curve in tensorboard 28 | # By default, it will create a new curve 29 | model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False) 30 | model.learn(total_timesteps=10000, tb_log_name="third_run", reset_num_timesteps=False) 31 | 32 | 33 | Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command: 34 | 35 | .. code-block:: bash 36 | 37 | tensorboard --logdir ./a2c_cartpole_tensorboard/ 38 | 39 | you can also add past logging folders: 40 | 41 | .. code-block:: bash 42 | 43 | tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/ 44 | 45 | It will display information such as the episode reward (when using a ``Monitor`` wrapper), the model losses and other parameter unique to some models. 46 | 47 | .. image:: ../_static/img/Tensorboard_example.png 48 | :width: 600 49 | :alt: plotting 50 | 51 | Logging More Values 52 | ------------------- 53 | 54 | Using a callback, you can easily log more values with TensorBoard. 55 | Here is a simple example on how to log both additional tensor or arbitrary scalar value: 56 | 57 | .. code-block:: python 58 | 59 | import numpy as np 60 | 61 | from stable_baselines3 import SAC 62 | from stable_baselines3.common.callbacks import BaseCallback 63 | 64 | model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1) 65 | 66 | 67 | class TensorboardCallback(BaseCallback): 68 | """ 69 | Custom callback for plotting additional values in tensorboard. 70 | """ 71 | 72 | def __init__(self, verbose=0): 73 | super(TensorboardCallback, self).__init__(verbose) 74 | 75 | def _on_step(self) -> bool: 76 | # Log scalar value (here a random variable) 77 | value = np.random.random() 78 | self.logger.record('random_value', value) 79 | return True 80 | 81 | 82 | model.learn(50000, callback=TensorboardCallback()) 83 | -------------------------------------------------------------------------------- /partner.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | import torch as th 4 | import numpy as np 5 | from stable_baselines3 import PPO 6 | 7 | class PartnerPolicy(ABC): 8 | def __init__(self): 9 | pass 10 | 11 | @abstractmethod 12 | def forward(self, obs, deterministic=True) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: 13 | pass 14 | 15 | class Partner: 16 | def __init__(self, policy : PartnerPolicy): 17 | self.policy = policy 18 | 19 | class PPOPartnerPolicy(PartnerPolicy): 20 | def __init__(self, model_path): 21 | super(PartnerPolicy, self).__init__() 22 | self.model = PPO.load(model_path) 23 | print("PPO Partner loaded successfully: %s" % model_path) 24 | 25 | def forward(self, obs, deterministic=True): 26 | return self.model.policy.forward(obs, partner_idx=0, deterministic=deterministic) 27 | 28 | class BlocksPermutationPartnerPolicy(PartnerPolicy): 29 | def __init__(self, perm, n=2): 30 | super(PartnerPolicy, self).__init__() 31 | self.perm = perm 32 | self.n = n 33 | 34 | self.action_index = [[i*n + j for j in range(n)] for i in range(n)] 35 | 36 | def forward(self, obs, deterministic=True): 37 | obs = obs[0] 38 | assert(2*self.n**2+1 == len(obs)) 39 | goal_grid = obs[:self.n**2].reshape(self.n,self.n) 40 | working_grid = obs[self.n**2:2*(self.n**2)].reshape(self.n,self.n) 41 | turn = obs[-1] 42 | 43 | r, c = self.get_red_block_position(working_grid, self.n, self.n) 44 | 45 | #if r == None or turn >= 2: 46 | if r == None or turn >= 2: 47 | action = self.n**2+1 # pass turn 48 | else: 49 | action = self.perm[self.action_index[r][c]] 50 | 51 | return th.tensor([action]), th.tensor([0.0]), th.tensor([0.0]) 52 | 53 | def get_block_position(self, grid, r, c, target): 54 | for i in range(r): 55 | for j in range(c): 56 | if grid[i][j] == target: 57 | return i, j 58 | return None, None 59 | 60 | def get_blue_block_position(self, grid, r, c): 61 | return self.get_block_position(grid, r, c, 3) 62 | 63 | def get_red_block_position(self, grid, r, c): 64 | return self.get_block_position(grid, r, c, 2) 65 | 66 | class ArmsPartnerPolicy(PartnerPolicy): 67 | def __init__(self, perm): 68 | super(PartnerPolicy, self).__init__() 69 | self.perm = th.tensor(perm) 70 | 71 | def forward(self, obs, deterministic=True): 72 | action = self.perm[obs] 73 | return th.cat((action, action), dim=1), th.tensor([0.0]), th.tensor([0.0]) 74 | 75 | class LowRankPartnerPolicy(PartnerPolicy): 76 | def __init__(self, n): 77 | super(PartnerPolicy, self).__init__() 78 | self.n = n 79 | 80 | def forward(self, obs, deterministic=True): 81 | action = self.n 82 | return th.tensor([action]), th.tensor([0.0]), th.tensor([0.0]) -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/td3.rst: -------------------------------------------------------------------------------- 1 | .. _td3: 2 | 3 | .. automodule:: stable_baselines3.td3 4 | 5 | 6 | TD3 7 | === 8 | 9 | `Twin Delayed DDPG (TD3) `_ Addressing Function Approximation Error in Actor-Critic Methods. 10 | 11 | TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. 12 | We recommend reading `OpenAI Spinning guide on TD3 `_ to learn more about those. 13 | 14 | 15 | .. rubric:: Available Policies 16 | 17 | .. autosummary:: 18 | :nosignatures: 19 | 20 | MlpPolicy 21 | 22 | 23 | Notes 24 | ----- 25 | 26 | - Original paper: https://arxiv.org/pdf/1802.09477.pdf 27 | - OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html 28 | - Original Implementation: https://github.com/sfujim/TD3 29 | 30 | .. note:: 31 | 32 | The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, 33 | to match the original paper 34 | 35 | 36 | Can I use? 37 | ---------- 38 | 39 | - Recurrent policies: ❌ 40 | - Multi processing: ❌ 41 | - Gym spaces: 42 | 43 | 44 | ============= ====== =========== 45 | Space Action Observation 46 | ============= ====== =========== 47 | Discrete ❌ ✔️ 48 | Box ✔️ ✔️ 49 | MultiDiscrete ❌ ✔️ 50 | MultiBinary ❌ ✔️ 51 | ============= ====== =========== 52 | 53 | 54 | Example 55 | ------- 56 | 57 | .. code-block:: python 58 | 59 | import gym 60 | import numpy as np 61 | 62 | from stable_baselines3 import TD3 63 | from stable_baselines3.td3.policies import MlpPolicy 64 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 65 | 66 | env = gym.make('Pendulum-v0') 67 | 68 | # The noise objects for TD3 69 | n_actions = env.action_space.shape[-1] 70 | action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) 71 | 72 | model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1) 73 | model.learn(total_timesteps=10000, log_interval=10) 74 | model.save("td3_pendulum") 75 | env = model.get_env() 76 | 77 | del model # remove to demonstrate saving and loading 78 | 79 | model = TD3.load("td3_pendulum") 80 | 81 | obs = env.reset() 82 | while True: 83 | action, _states = model.predict(obs) 84 | obs, rewards, dones, info = env.step(action) 85 | env.render() 86 | 87 | 88 | Parameters 89 | ---------- 90 | 91 | .. autoclass:: TD3 92 | :members: 93 | :inherited-members: 94 | 95 | .. _td3_policies: 96 | 97 | TD3 Policies 98 | ------------- 99 | 100 | .. autoclass:: MlpPolicy 101 | :members: 102 | :inherited-members: 103 | 104 | 105 | .. .. autoclass:: CnnPolicy 106 | .. :members: 107 | .. :inherited-members: 108 | -------------------------------------------------------------------------------- /stable-baselines3/docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Stable Baselines3 documentation master file, created by 2 | sphinx-quickstart on Thu Sep 26 11:06:54 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Stable Baselines3 docs! - RL Baselines Made Easy 7 | =========================================================== 8 | 9 | `Stable Baselines3 `_ is a set of improved implementations of reinforcement learning algorithms in PyTorch. 10 | It is the next major version of `Stable Baselines `_. 11 | 12 | 13 | Github repository: https://github.com/DLR-RM/stable-baselines3 14 | 15 | RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/rl-baselines3-zoo 16 | 17 | RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning. 18 | 19 | 20 | Main Features 21 | -------------- 22 | 23 | - Unified structure for all algorithms 24 | - PEP8 compliant (unified code style) 25 | - Documented functions and classes 26 | - Tests, high code coverage and type hints 27 | - Clean code 28 | - Tensorboard support 29 | 30 | 31 | .. toctree:: 32 | :maxdepth: 2 33 | :caption: User Guide 34 | 35 | guide/install 36 | guide/quickstart 37 | guide/rl_tips 38 | guide/rl 39 | guide/algos 40 | guide/examples 41 | guide/vec_envs 42 | guide/custom_env 43 | guide/custom_policy 44 | guide/callbacks 45 | guide/tensorboard 46 | guide/rl_zoo 47 | guide/migration 48 | guide/checking_nan 49 | guide/developer 50 | 51 | 52 | .. toctree:: 53 | :maxdepth: 1 54 | :caption: RL Algorithms 55 | 56 | modules/base 57 | modules/a2c 58 | modules/ddpg 59 | modules/dqn 60 | modules/ppo 61 | modules/sac 62 | modules/td3 63 | 64 | .. toctree:: 65 | :maxdepth: 1 66 | :caption: Common 67 | 68 | common/atari_wrappers 69 | common/cmd_util 70 | common/distributions 71 | common/evaluation 72 | common/env_checker 73 | common/monitor 74 | common/logger 75 | common/noise 76 | common/utils 77 | 78 | .. toctree:: 79 | :maxdepth: 1 80 | :caption: Misc 81 | 82 | misc/changelog 83 | misc/projects 84 | 85 | 86 | Citing Stable Baselines3 87 | ------------------------ 88 | To cite this project in publications: 89 | 90 | .. code-block:: bibtex 91 | 92 | @misc{stable-baselines3, 93 | author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, 94 | title = {Stable Baselines3}, 95 | year = {2019}, 96 | publisher = {GitHub}, 97 | journal = {GitHub repository}, 98 | howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}}, 99 | } 100 | 101 | Indices and tables 102 | ------------------- 103 | 104 | * :ref:`genindex` 105 | * :ref:`search` 106 | * :ref:`modindex` 107 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy a dict of numpy arrays. 14 | 15 | :param obs: (OrderedDict): a dict of numpy arrays. 16 | :return (OrderedDict) a dict of copied numpy arrays. 17 | """ 18 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" 19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 20 | 21 | 22 | def dict_to_obs(space, obs_dict): 23 | """ 24 | Convert an internal representation raw_obs into the appropriate type 25 | specified by space. 26 | 27 | :param space: (gym.spaces.Space) an observation space. 28 | :param obs_dict: (OrderedDict) a dict of numpy arrays. 29 | :return (ndarray, tuple or dict): returns an observation 30 | of the same type as space. If space is Dict, function is identity; 31 | if space is Tuple, converts dict to Tuple; otherwise, space is 32 | unstructured and returns the value raw_obs[None]. 33 | """ 34 | if isinstance(space, gym.spaces.Dict): 35 | return obs_dict 36 | elif isinstance(space, gym.spaces.Tuple): 37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space" 38 | return tuple((obs_dict[i] for i in range(len(space.spaces)))) 39 | else: 40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 41 | return obs_dict[None] 42 | 43 | 44 | def obs_space_info(obs_space): 45 | """ 46 | Get dict-structured information about a gym.Space. 47 | 48 | Dict spaces are represented directly by their dict of subspaces. 49 | Tuple spaces are converted into a dict with keys indexing into the tuple. 50 | Unstructured spaces are represented by {None: obs_space}. 51 | 52 | :param obs_space: (gym.spaces.Space) an observation space 53 | :return (tuple) A tuple (keys, shapes, dtypes): 54 | keys: a list of dict keys. 55 | shapes: a dict mapping keys to shapes. 56 | dtypes: a dict mapping keys to dtypes. 57 | """ 58 | if isinstance(obs_space, gym.spaces.Dict): 59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 60 | subspaces = obs_space.spaces 61 | elif isinstance(obs_space, gym.spaces.Tuple): 62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 63 | else: 64 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" 65 | subspaces = {None: obs_space} 66 | keys = [] 67 | shapes = {} 68 | dtypes = {} 69 | for key, box in subspaces.items(): 70 | keys.append(key) 71 | shapes[key] = box.shape 72 | dtypes[key] = box.dtype 73 | return keys, shapes, dtypes 74 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/custom_env.rst: -------------------------------------------------------------------------------- 1 | .. _custom_env: 2 | 3 | Using Custom Environments 4 | ========================== 5 | 6 | To use the rl baselines with custom environments, they just need to follow the *gym* interface. 7 | That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class): 8 | 9 | 10 | .. note:: 11 | If you are using images as input, the input values must be in [0, 255] as the observation 12 | is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. 13 | 14 | 15 | 16 | .. code-block:: python 17 | 18 | import gym 19 | from gym import spaces 20 | 21 | class CustomEnv(gym.Env): 22 | """Custom Environment that follows gym interface""" 23 | metadata = {'render.modes': ['human']} 24 | 25 | def __init__(self, arg1, arg2, ...): 26 | super(CustomEnv, self).__init__() 27 | # Define action and observation space 28 | # They must be gym.spaces objects 29 | # Example when using discrete actions: 30 | self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS) 31 | # Example for using image as input: 32 | self.observation_space = spaces.Box(low=0, high=255, 33 | shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8) 34 | 35 | def step(self, action): 36 | ... 37 | return observation, reward, done, info 38 | def reset(self): 39 | ... 40 | return observation # reward, done, info can't be included 41 | def render(self, mode='human'): 42 | ... 43 | def close (self): 44 | ... 45 | 46 | 47 | Then you can define and train a RL agent with: 48 | 49 | .. code-block:: python 50 | 51 | # Instantiate the env 52 | env = CustomEnv(arg1, ...) 53 | # Define and Train the agent 54 | model = A2C('CnnPolicy', env).learn(total_timesteps=1000) 55 | 56 | 57 | To check that your environment follows the gym interface, please use: 58 | 59 | .. code-block:: python 60 | 61 | from stable_baselines3.common.env_checker import check_env 62 | 63 | env = CustomEnv(arg1, ...) 64 | # It will check your custom environment and output additional warnings if needed 65 | check_env(env) 66 | 67 | 68 | 69 | We have created a `colab notebook `_ for 70 | a concrete example of creating a custom environment. 71 | 72 | You can also find a `complete guide online `_ 73 | on creating a custom Gym environment. 74 | 75 | 76 | Optionally, you can also register the environment with gym, 77 | that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env). 78 | 79 | 80 | In the project, for testing purposes, we use a custom environment named ``IdentityEnv`` 81 | defined `in this file `_. 82 | An example of how to use it can be found `here `_. 83 | -------------------------------------------------------------------------------- /stable-baselines3/docs/modules/sac.rst: -------------------------------------------------------------------------------- 1 | .. _sac: 2 | 3 | .. automodule:: stable_baselines3.sac 4 | 5 | 6 | SAC 7 | === 8 | 9 | `Soft Actor Critic (SAC) `_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. 10 | 11 | SAC is the successor of `Soft Q-Learning SQL `_ and incorporates the double Q-learning trick from TD3. 12 | A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy. 13 | 14 | 15 | .. rubric:: Available Policies 16 | 17 | .. autosummary:: 18 | :nosignatures: 19 | 20 | MlpPolicy 21 | CnnPolicy 22 | 23 | 24 | Notes 25 | ----- 26 | 27 | - Original paper: https://arxiv.org/abs/1801.01290 28 | - OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html 29 | - Original Implementation: https://github.com/haarnoja/sac 30 | - Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/ 31 | 32 | .. note:: 33 | In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), 34 | which is the equivalent to the inverse of reward scale in the original SAC paper. 35 | The main reason is that it avoids having too high errors when updating the Q functions. 36 | 37 | 38 | .. note:: 39 | 40 | The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, 41 | to match the original paper 42 | 43 | 44 | Can I use? 45 | ---------- 46 | 47 | - Recurrent policies: ❌ 48 | - Multi processing: ❌ 49 | - Gym spaces: 50 | 51 | 52 | ============= ====== =========== 53 | Space Action Observation 54 | ============= ====== =========== 55 | Discrete ❌ ✔️ 56 | Box ✔️ ✔️ 57 | MultiDiscrete ❌ ✔️ 58 | MultiBinary ❌ ✔️ 59 | ============= ====== =========== 60 | 61 | 62 | Example 63 | ------- 64 | 65 | .. code-block:: python 66 | 67 | import gym 68 | import numpy as np 69 | 70 | from stable_baselines3 import SAC 71 | from stable_baselines3.sac import MlpPolicy 72 | 73 | env = gym.make('Pendulum-v0') 74 | 75 | model = SAC(MlpPolicy, env, verbose=1) 76 | model.learn(total_timesteps=10000, log_interval=4) 77 | model.save("sac_pendulum") 78 | 79 | del model # remove to demonstrate saving and loading 80 | 81 | model = SAC.load("sac_pendulum") 82 | 83 | obs = env.reset() 84 | while True: 85 | action, _states = model.predict(obs) 86 | obs, reward, done, info = env.step(action) 87 | env.render() 88 | if done: 89 | obs = env.reset() 90 | 91 | Parameters 92 | ---------- 93 | 94 | .. autoclass:: SAC 95 | :members: 96 | :inherited-members: 97 | 98 | .. _sac_policies: 99 | 100 | SAC Policies 101 | ------------- 102 | 103 | .. autoclass:: MlpPolicy 104 | :members: 105 | :inherited-members: 106 | 107 | .. .. autoclass:: CnnPolicy 108 | .. :members: 109 | .. :inherited-members: 110 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/vec_envs.rst: -------------------------------------------------------------------------------- 1 | .. _vec_env: 2 | 3 | .. automodule:: stable_baselines3.common.vec_env 4 | 5 | Vectorized Environments 6 | ======================= 7 | 8 | Vectorized Environments are a method for stacking multiple independent environments into a single environment. 9 | Instead of training an RL agent on 1 environment per step, it allows us to train it on ``n`` environments per step. 10 | Because of this, ``actions`` passed to the environment are now a vector (of dimension ``n``). 11 | It is the same for ``observations``, ``rewards`` and end of episode signals (``dones``). 12 | In the case of non-array observation spaces such as ``Dict`` or ``Tuple``, where different sub-spaces 13 | may have different shapes, the sub-observations are vectors (of dimension ``n``). 14 | 15 | ============= ======= ============ ======== ========= ================ 16 | Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing 17 | ============= ======= ============ ======== ========= ================ 18 | DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️ 19 | SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ 20 | ============= ======= ============ ======== ========= ================ 21 | 22 | .. note:: 23 | 24 | Vectorized environments are required when using wrappers for frame-stacking or normalization. 25 | 26 | .. note:: 27 | 28 | When using vectorized environments, the environments are automatically reset at the end of each episode. 29 | Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated. 30 | You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv. 31 | 32 | .. warning:: 33 | 34 | When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows). 35 | On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks. 36 | 37 | For more information, see Python's `multiprocessing guidelines `_. 38 | 39 | VecEnv 40 | ------ 41 | 42 | .. autoclass:: VecEnv 43 | :members: 44 | 45 | DummyVecEnv 46 | ----------- 47 | 48 | .. autoclass:: DummyVecEnv 49 | :members: 50 | 51 | SubprocVecEnv 52 | ------------- 53 | 54 | .. autoclass:: SubprocVecEnv 55 | :members: 56 | 57 | Wrappers 58 | -------- 59 | 60 | VecFrameStack 61 | ~~~~~~~~~~~~~ 62 | 63 | .. autoclass:: VecFrameStack 64 | :members: 65 | 66 | 67 | VecNormalize 68 | ~~~~~~~~~~~~ 69 | 70 | .. autoclass:: VecNormalize 71 | :members: 72 | 73 | 74 | VecVideoRecorder 75 | ~~~~~~~~~~~~~~~~ 76 | 77 | .. autoclass:: VecVideoRecorder 78 | :members: 79 | 80 | 81 | VecCheckNan 82 | ~~~~~~~~~~~~~~~~ 83 | 84 | .. autoclass:: VecCheckNan 85 | :members: 86 | 87 | 88 | VecTransposeImage 89 | ~~~~~~~~~~~~~~~~~ 90 | 91 | .. autoclass:: VecTransposeImage 92 | :members: 93 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/rl_zoo.rst: -------------------------------------------------------------------------------- 1 | .. _rl_zoo: 2 | 3 | ================== 4 | RL Baselines3 Zoo 5 | ================== 6 | 7 | `RL Baselines3 Zoo `_. is a collection of pre-trained Reinforcement Learning agents using 8 | Stable-Baselines3. 9 | It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos. 10 | 11 | Goals of this repository: 12 | 13 | 1. Provide a simple interface to train and enjoy RL agents 14 | 2. Benchmark the different Reinforcement Learning algorithms 15 | 3. Provide tuned hyperparameters for each environment and RL algorithm 16 | 4. Have fun with the trained agents! 17 | 18 | Installation 19 | ------------ 20 | 21 | 1. Clone the repository: 22 | 23 | :: 24 | 25 | git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo 26 | cd rl-baselines3-zoo/ 27 | 28 | 29 | .. note:: 30 | 31 | You can remove the ``--recursive`` option if you don't want to download the trained agents 32 | 33 | 34 | 2. Install dependencies 35 | :: 36 | 37 | apt-get install swig cmake ffmpeg 38 | pip install -r requirements.txt 39 | 40 | 41 | Train an Agent 42 | -------------- 43 | 44 | The hyperparameters for each environment are defined in 45 | ``hyperparameters/algo_name.yml``. 46 | 47 | If the environment exists in this file, then you can train an agent 48 | using: 49 | 50 | :: 51 | 52 | python train.py --algo algo_name --env env_id 53 | 54 | For example (with evaluation and checkpoints): 55 | 56 | :: 57 | 58 | python train.py --algo ppo2 --env CartPole-v1 --eval-freq 10000 --save-freq 50000 59 | 60 | 61 | Continue training (here, load pretrained agent for Breakout and continue 62 | training for 5000 steps): 63 | 64 | :: 65 | 66 | python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000 67 | 68 | 69 | Enjoy a Trained Agent 70 | --------------------- 71 | 72 | If the trained agent exists, then you can see it in action using: 73 | 74 | :: 75 | 76 | python enjoy.py --algo algo_name --env env_id 77 | 78 | For example, enjoy A2C on Breakout during 5000 timesteps: 79 | 80 | :: 81 | 82 | python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000 83 | 84 | 85 | Hyperparameter Optimization 86 | --------------------------- 87 | 88 | We use `Optuna `_ for optimizing the hyperparameters. 89 | 90 | 91 | Tune the hyperparameters for PPO, using a random sampler and median pruner, 2 parallels jobs, 92 | with a budget of 1000 trials and a maximum of 50000 steps: 93 | 94 | :: 95 | 96 | python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \ 97 | --sampler random --pruner median 98 | 99 | 100 | Colab Notebook: Try it Online! 101 | ------------------------------ 102 | 103 | You can train agents online using Google `colab notebook `_. 104 | 105 | 106 | .. note:: 107 | 108 | You can find more information about the rl baselines3 zoo in the repo `README `_. For instance, how to record a video of a trained agent. 109 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: (bool) Whether or not to only warn once. 16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions): 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self): 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self): 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step, **kwargs): 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += "found {} in {}".format(type_val, name) 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions) 80 | else: 81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations) 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 5 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 6 | 7 | normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) 8 | 9 | 10 | @pytest.mark.parametrize("model_class", [TD3, DDPG]) 11 | @pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))]) 12 | def test_deterministic_pg(model_class, action_noise): 13 | """ 14 | Test for DDPG and variants (TD3). 15 | """ 16 | model = model_class( 17 | "MlpPolicy", 18 | "Pendulum-v0", 19 | policy_kwargs=dict(net_arch=[64, 64]), 20 | learning_starts=100, 21 | verbose=1, 22 | create_eval_env=True, 23 | action_noise=action_noise, 24 | ) 25 | model.learn(total_timesteps=1000, eval_freq=500) 26 | 27 | 28 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) 29 | def test_a2c(env_id): 30 | model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) 31 | model.learn(total_timesteps=1000, eval_freq=500) 32 | 33 | 34 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) 35 | @pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2]) 36 | def test_ppo(env_id, clip_range_vf): 37 | if clip_range_vf is not None and clip_range_vf < 0: 38 | # Should throw an error 39 | with pytest.raises(AssertionError): 40 | model = PPO( 41 | "MlpPolicy", 42 | env_id, 43 | seed=0, 44 | policy_kwargs=dict(net_arch=[16]), 45 | verbose=1, 46 | create_eval_env=True, 47 | clip_range_vf=clip_range_vf, 48 | ) 49 | else: 50 | model = PPO( 51 | "MlpPolicy", 52 | env_id, 53 | seed=0, 54 | policy_kwargs=dict(net_arch=[16]), 55 | verbose=1, 56 | create_eval_env=True, 57 | clip_range_vf=clip_range_vf, 58 | ) 59 | model.learn(total_timesteps=1000, eval_freq=500) 60 | 61 | 62 | @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) 63 | def test_sac(ent_coef): 64 | model = SAC( 65 | "MlpPolicy", 66 | "Pendulum-v0", 67 | policy_kwargs=dict(net_arch=[64, 64]), 68 | learning_starts=100, 69 | verbose=1, 70 | create_eval_env=True, 71 | ent_coef=ent_coef, 72 | action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), 73 | ) 74 | model.learn(total_timesteps=1000, eval_freq=500) 75 | 76 | 77 | @pytest.mark.parametrize("n_critics", [1, 3]) 78 | def test_n_critics(n_critics): 79 | # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG 80 | model = SAC( 81 | "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 82 | ) 83 | model.learn(total_timesteps=1000) 84 | 85 | 86 | def test_dqn(): 87 | model = DQN( 88 | "MlpPolicy", 89 | "CartPole-v1", 90 | policy_kwargs=dict(net_arch=[64, 64]), 91 | learning_starts=500, 92 | buffer_size=500, 93 | learning_rate=3e-4, 94 | verbose=1, 95 | create_eval_env=True, 96 | ) 97 | model.learn(total_timesteps=1000, eval_freq=500) 98 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/custom_policy.rst: -------------------------------------------------------------------------------- 1 | .. _custom_policy: 2 | 3 | Custom Policy Network 4 | --------------------- 5 | 6 | Stable Baselines3 provides policy networks for images (CnnPolicies) 7 | and other type of input features (MlpPolicies). 8 | 9 | One way of customising the policy network architecture is to pass arguments when creating the model, 10 | using ``policy_kwargs`` parameter: 11 | 12 | .. code-block:: python 13 | 14 | import gym 15 | import torch as th 16 | 17 | from stable_baselines3 import PPO 18 | 19 | # Custom MLP policy of two layers of size 32 each with tanh activation function 20 | policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[32, 32]) 21 | # Create the agent 22 | model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) 23 | # Retrieve the environment 24 | env = model.get_env() 25 | # Train the agent 26 | model.learn(total_timesteps=100000) 27 | # Save the agent 28 | model.save("ppo-cartpole") 29 | 30 | del model 31 | # the policy_kwargs are automatically loaded 32 | model = PPO.load("ppo-cartpole") 33 | 34 | 35 | You can also easily define a custom architecture for the policy (or value) network: 36 | 37 | .. note:: 38 | 39 | Defining a custom policy class is equivalent to passing ``policy_kwargs``. 40 | However, it lets you name the policy and so usually makes the code clearer. 41 | ``policy_kwargs`` is particularly useful when doing hyperparameter search. 42 | 43 | 44 | 45 | The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many 46 | of them are shared between the policy network and the value network. It is assumed to be a list with the following 47 | structure: 48 | 49 | 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer. 50 | If the number of ints is zero, there will be no shared layers. 51 | 2. An optional dict, to specify the following non-shared layers for the value network and the policy network. 52 | It is formatted like ``dict(vf=[], pi=[])``. 53 | If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. 54 | 55 | In short: ``[, dict(vf=[], pi=[])]``. 56 | 57 | Examples 58 | ~~~~~~~~ 59 | 60 | Two shared layers of size 128: ``net_arch=[128, 128]`` 61 | 62 | 63 | .. code-block:: none 64 | 65 | obs 66 | | 67 | <128> 68 | | 69 | <128> 70 | / \ 71 | action value 72 | 73 | 74 | Value network deeper than policy network, first layer shared: ``net_arch=[128, dict(vf=[256, 256])]`` 75 | 76 | .. code-block:: none 77 | 78 | obs 79 | | 80 | <128> 81 | / \ 82 | action <256> 83 | | 84 | <256> 85 | | 86 | value 87 | 88 | 89 | Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]`` 90 | 91 | .. code-block:: none 92 | 93 | obs 94 | | 95 | <128> 96 | / \ 97 | <16> <256> 98 | | | 99 | action value 100 | 101 | 102 | 103 | If your task requires even more granular control over the policy architecture, you can redefine the policy directly. 104 | 105 | **TODO** 106 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_monitor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import uuid 4 | 5 | import gym 6 | import pandas 7 | 8 | from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results 9 | 10 | 11 | def test_monitor(tmp_path): 12 | """ 13 | Test the monitor wrapper 14 | """ 15 | env = gym.make("CartPole-v1") 16 | env.seed(0) 17 | monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) 18 | monitor_env = Monitor(env, monitor_file) 19 | monitor_env.reset() 20 | total_steps = 1000 21 | ep_rewards = [] 22 | ep_lengths = [] 23 | ep_len, ep_reward = 0, 0 24 | for _ in range(total_steps): 25 | _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample()) 26 | ep_len += 1 27 | ep_reward += reward 28 | if done: 29 | ep_rewards.append(ep_reward) 30 | ep_lengths.append(ep_len) 31 | monitor_env.reset() 32 | ep_len, ep_reward = 0, 0 33 | 34 | monitor_env.close() 35 | assert monitor_env.get_total_steps() == total_steps 36 | assert sum(ep_lengths) == sum(monitor_env.get_episode_lengths()) 37 | assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards) 38 | _ = monitor_env.get_episode_times() 39 | 40 | with open(monitor_file, "rt") as file_handler: 41 | first_line = file_handler.readline() 42 | assert first_line.startswith("#") 43 | metadata = json.loads(first_line[1:]) 44 | assert metadata["env_id"] == "CartPole-v1" 45 | assert set(metadata.keys()) == {"env_id", "t_start"}, "Incorrect keys in monitor metadata" 46 | 47 | last_logline = pandas.read_csv(file_handler, index_col=None) 48 | assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" 49 | os.remove(monitor_file) 50 | 51 | 52 | def test_monitor_load_results(tmp_path): 53 | """ 54 | test load_results on log files produced by the monitor wrapper 55 | """ 56 | tmp_path = str(tmp_path) 57 | env1 = gym.make("CartPole-v1") 58 | env1.seed(0) 59 | monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) 60 | monitor_env1 = Monitor(env1, monitor_file1) 61 | 62 | monitor_files = get_monitor_files(tmp_path) 63 | assert len(monitor_files) == 1 64 | assert monitor_file1 in monitor_files 65 | 66 | monitor_env1.reset() 67 | episode_count1 = 0 68 | for _ in range(1000): 69 | _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample()) 70 | if done: 71 | episode_count1 += 1 72 | monitor_env1.reset() 73 | 74 | results_size1 = len(load_results(os.path.join(tmp_path)).index) 75 | assert results_size1 == episode_count1 76 | 77 | env2 = gym.make("CartPole-v1") 78 | env2.seed(0) 79 | monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) 80 | monitor_env2 = Monitor(env2, monitor_file2) 81 | monitor_files = get_monitor_files(tmp_path) 82 | assert len(monitor_files) == 2 83 | assert monitor_file1 in monitor_files 84 | assert monitor_file2 in monitor_files 85 | 86 | monitor_env2.reset() 87 | episode_count2 = 0 88 | for _ in range(1000): 89 | _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample()) 90 | if done: 91 | episode_count2 += 1 92 | monitor_env2.reset() 93 | 94 | results_size2 = len(load_results(os.path.join(tmp_path)).index) 95 | 96 | assert results_size2 == (results_size1 + episode_count2) 97 | 98 | os.remove(monitor_file1) 99 | os.remove(monitor_file2) 100 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.wrappers.monitoring import video_recorder 4 | 5 | from stable_baselines3.common import logger 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack 10 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize 11 | 12 | 13 | class VecVideoRecorder(VecEnvWrapper): 14 | """ 15 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 16 | It requires ffmpeg or avconv to be installed on the machine. 17 | 18 | :param venv: (VecEnv or VecEnvWrapper) 19 | :param video_folder: (str) Where to save videos 20 | :param record_video_trigger: (func) Function that defines when to start recording. 21 | The function takes the current number of step, 22 | and returns whether we should start recording or not. 23 | :param video_length: (int) Length of recorded videos 24 | :param name_prefix: (str) Prefix to the video name 25 | """ 26 | 27 | def __init__(self, venv, video_folder, record_video_trigger, video_length=200, name_prefix="rl-video"): 28 | 29 | VecEnvWrapper.__init__(self, venv) 30 | 31 | self.env = venv 32 | # Temp variable to retrieve metadata 33 | temp_env = venv 34 | 35 | # Unwrap to retrieve metadata dict 36 | # that will be used by gym recorder 37 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack): 38 | temp_env = temp_env.venv 39 | 40 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 41 | metadata = temp_env.get_attr("metadata")[0] 42 | else: 43 | metadata = temp_env.metadata 44 | 45 | self.env.metadata = metadata 46 | 47 | self.record_video_trigger = record_video_trigger 48 | self.video_recorder = None 49 | 50 | self.video_folder = os.path.abspath(video_folder) 51 | # Create output folder if needed 52 | os.makedirs(self.video_folder, exist_ok=True) 53 | 54 | self.name_prefix = name_prefix 55 | self.step_id = 0 56 | self.video_length = video_length 57 | 58 | self.recording = False 59 | self.recorded_frames = 0 60 | 61 | def reset(self): 62 | obs = self.venv.reset() 63 | self.start_video_recorder() 64 | return obs 65 | 66 | def start_video_recorder(self): 67 | self.close_video_recorder() 68 | 69 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" 70 | base_path = os.path.join(self.video_folder, video_name) 71 | self.video_recorder = video_recorder.VideoRecorder( 72 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id} 73 | ) 74 | 75 | self.video_recorder.capture_frame() 76 | self.recorded_frames = 1 77 | self.recording = True 78 | 79 | def _video_enabled(self): 80 | return self.record_video_trigger(self.step_id) 81 | 82 | def step_wait(self): 83 | obs, rews, dones, infos = self.venv.step_wait() 84 | 85 | self.step_id += 1 86 | if self.recording: 87 | self.video_recorder.capture_frame() 88 | self.recorded_frames += 1 89 | if self.recorded_frames > self.video_length: 90 | logger.info("Saving video to ", self.video_recorder.path) 91 | self.close_video_recorder() 92 | elif self._video_enabled(): 93 | self.start_video_recorder() 94 | 95 | return obs, rews, dones, infos 96 | 97 | def close_video_recorder(self): 98 | if self.recording: 99 | self.video_recorder.close() 100 | self.recording = False 101 | self.recorded_frames = 1 102 | 103 | def close(self): 104 | VecEnvWrapper.close(self) 105 | self.close_video_recorder() 106 | 107 | def __del__(self): 108 | self.close() 109 | -------------------------------------------------------------------------------- /stable-baselines3/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to Stable-Baselines3 2 | 3 | If you are interested in contributing to Stable-Baselines, your contributions will fall 4 | into two categories: 5 | 1. You want to propose a new Feature and implement it 6 | - Create an issue about your intended feature, and we shall discuss the design and 7 | implementation. Once we agree that the plan looks good, go ahead and implement it. 8 | 2. You want to implement a feature or bug-fix for an outstanding issue 9 | - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues 10 | - Pick an issue or feature and comment on the task that you want to work on this feature. 11 | - If you need more context on a particular issue, please ask and we shall provide. 12 | 13 | Once you finish implementing a feature or bug-fix, please send a Pull Request to 14 | https://github.com/DLR-RM/stable-baselines3 15 | 16 | 17 | If you are not familiar with creating a Pull Request, here are some guides: 18 | - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request 19 | - https://help.github.com/articles/creating-a-pull-request/ 20 | 21 | 22 | ## Developing Stable-Baselines3 23 | 24 | To develop Stable-Baselines3 on your machine, here are some tips: 25 | 26 | 1. Clone a copy of Stable-Baselines3 from source: 27 | 28 | ```bash 29 | git clone https://github.com/DLR-RM/stable-baselines3 30 | cd stable-baselines3/ 31 | ``` 32 | 33 | 2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests: 34 | 35 | ```bash 36 | pip install -e .[docs,tests,extra] 37 | ``` 38 | 39 | ## Codestyle 40 | 41 | We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports. 42 | 43 | **Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`. 44 | 45 | Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template: 46 | 47 | ```python 48 | 49 | def my_function(arg1: type1, arg2: type2) -> returntype: 50 | """ 51 | Short description of the function. 52 | 53 | :param arg1: (type1) describe what is arg1 54 | :param arg2: (type2) describe what is arg2 55 | :return: (returntype) describe what is returned 56 | """ 57 | ... 58 | return my_variable 59 | ``` 60 | 61 | ## Pull Request (PR) 62 | 63 | Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process. 64 | 65 | Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli). 66 | A PR must pass the Continuous Integration tests to be merged with the master branch. 67 | 68 | 69 | ## Tests 70 | 71 | All new features must add tests in the `tests/` folder ensuring that everything works fine. 72 | We use [pytest](https://pytest.org/). 73 | Also, when a bug fix is proposed, tests should be added to avoid regression. 74 | 75 | To run tests with `pytest`: 76 | 77 | ``` 78 | make pytest 79 | ``` 80 | 81 | Type checking with `pytype`: 82 | 83 | ``` 84 | make type 85 | ``` 86 | 87 | Codestyle check with `black`, `isort` and `flake8`: 88 | 89 | ``` 90 | make check-codestyle 91 | make lint 92 | ``` 93 | 94 | To run `pytype`, `format` and `lint` in one command: 95 | ``` 96 | make commit-checks 97 | ``` 98 | 99 | Build the documentation: 100 | 101 | ``` 102 | make doc 103 | ``` 104 | 105 | Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that): 106 | 107 | ``` 108 | make spelling 109 | ``` 110 | 111 | 112 | ## Changelog and Documentation 113 | 114 | Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed. 115 | You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too. 116 | A README is present in the `docs/` folder for instructions on how to build the documentation. 117 | 118 | 119 | Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one. 120 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_distributions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch as th 3 | 4 | from stable_baselines3 import A2C, PPO 5 | from stable_baselines3.common.distributions import ( 6 | BernoulliDistribution, 7 | CategoricalDistribution, 8 | DiagGaussianDistribution, 9 | MultiCategoricalDistribution, 10 | SquashedDiagGaussianDistribution, 11 | StateDependentNoiseDistribution, 12 | TanhBijector, 13 | ) 14 | from stable_baselines3.common.utils import set_random_seed 15 | 16 | N_ACTIONS = 2 17 | N_FEATURES = 3 18 | N_SAMPLES = int(5e6) 19 | 20 | 21 | def test_bijector(): 22 | """ 23 | Test TanhBijector 24 | """ 25 | actions = th.ones(5) * 2.0 26 | bijector = TanhBijector() 27 | 28 | squashed_actions = bijector.forward(actions) 29 | # Check that the boundaries are not violated 30 | assert th.max(th.abs(squashed_actions)) <= 1.0 31 | # Check the inverse method 32 | assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all() 33 | 34 | 35 | @pytest.mark.parametrize("model_class", [A2C, PPO]) 36 | def test_squashed_gaussian(model_class): 37 | """ 38 | Test run with squashed Gaussian (notably entropy computation) 39 | """ 40 | model = model_class("MlpPolicy", "Pendulum-v0", use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) 41 | model.learn(500) 42 | 43 | gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS) 44 | dist = SquashedDiagGaussianDistribution(N_ACTIONS) 45 | _, log_std = dist.proba_distribution_net(N_FEATURES) 46 | dist = dist.proba_distribution(gaussian_mean, log_std) 47 | actions = dist.get_actions() 48 | assert th.max(th.abs(actions)) <= 1.0 49 | 50 | 51 | def test_sde_distribution(): 52 | n_actions = 1 53 | deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1 54 | state = th.ones(N_SAMPLES, N_FEATURES) * 0.3 55 | dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False) 56 | 57 | set_random_seed(1) 58 | _, log_std = dist.proba_distribution_net(N_FEATURES) 59 | dist.sample_weights(log_std, batch_size=N_SAMPLES) 60 | 61 | dist = dist.proba_distribution(deterministic_actions, log_std, state) 62 | actions = dist.get_actions() 63 | 64 | assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=2e-3) 65 | assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=2e-3) 66 | 67 | 68 | # TODO: analytical form for squashed Gaussian? 69 | @pytest.mark.parametrize( 70 | "dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),] 71 | ) 72 | def test_entropy(dist): 73 | # The entropy can be approximated by averaging the negative log likelihood 74 | # mean negative log likelihood == differential entropy 75 | set_random_seed(1) 76 | state = th.rand(N_SAMPLES, N_FEATURES) 77 | deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS) 78 | _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) 79 | 80 | if isinstance(dist, DiagGaussianDistribution): 81 | dist = dist.proba_distribution(deterministic_actions, log_std) 82 | else: 83 | dist.sample_weights(log_std, batch_size=N_SAMPLES) 84 | dist = dist.proba_distribution(deterministic_actions, log_std, state) 85 | 86 | actions = dist.get_actions() 87 | entropy = dist.entropy() 88 | log_prob = dist.log_prob(actions) 89 | assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) 90 | 91 | 92 | categorical_params = [ 93 | (CategoricalDistribution(N_ACTIONS), N_ACTIONS), 94 | (MultiCategoricalDistribution([2, 3]), sum([2, 3])), 95 | (BernoulliDistribution(N_ACTIONS), N_ACTIONS), 96 | ] 97 | 98 | 99 | @pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params) 100 | def test_categorical(dist, CAT_ACTIONS): 101 | # The entropy can be approximated by averaging the negative log likelihood 102 | # mean negative log likelihood == entropy 103 | set_random_seed(1) 104 | action_logits = th.rand(N_SAMPLES, CAT_ACTIONS) 105 | dist = dist.proba_distribution(action_logits) 106 | actions = dist.get_actions() 107 | entropy = dist.entropy() 108 | log_prob = dist.log_prob(actions) 109 | assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) 110 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/developer.rst: -------------------------------------------------------------------------------- 1 | .. _developer: 2 | 3 | ================ 4 | Developer Guide 5 | ================ 6 | 7 | This guide is meant for those who want to understand the internals and the design choices of Stable-Baselines3. 8 | 9 | 10 | At first, you should read the two issues where the design choices were discussed: 11 | 12 | - https://github.com/hill-a/stable-baselines/issues/576 13 | - https://github.com/hill-a/stable-baselines/issues/733 14 | 15 | 16 | The library is not meant to be modular, although inheritance is used to reduce code duplication. 17 | 18 | 19 | Algorithms Structure 20 | ==================== 21 | 22 | 23 | Each algorithm (on-policy and off-policy ones) follows a common structure. 24 | Policy contains code for acting in the environment, and algorithm updates this policy. 25 | There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (``policies.py``). 26 | 27 | Each algorithm has two main methods: 28 | 29 | - ``.collect_rollouts()`` which defines how new samples are collected, usually inherited from the base class. Those samples are then stored in a ``RolloutBuffer`` (discarded after the gradient update) or ``ReplayBuffer`` 30 | 31 | - ``.train()`` which updates the parameters using samples from the buffer 32 | 33 | 34 | Where to start? 35 | =============== 36 | 37 | The first thing you need to read and understand are the base classes in the ``common/`` folder: 38 | 39 | - ``BaseAlgorithm`` in ``base_class.py`` which defines how an RL class should look like. 40 | It contains also all the "glue code" for saving/loading and the common operations (wrapping environments) 41 | 42 | - ``BasePolicy`` in ``policies.py`` which defines how a policy class should look like. 43 | It contains also all the magic for the ``.predict()`` method, to handle as many spaces/cases as possible 44 | 45 | - ``OffPolicyAlgorithm`` in ``off_policy_algorithm.py`` that contains the implementation of ``collect_rollouts()`` for the off-policy algorithms, 46 | and similarly ``OnPolicyAlgorithm`` in ``on_policy_algorithm.py``. 47 | 48 | 49 | All the environments handled internally are assumed to be ``VecEnv`` (``gym.Env`` are automatically wrapped). 50 | 51 | 52 | Pre-Processing 53 | ============== 54 | 55 | To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation). 56 | Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``. 57 | 58 | For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention. 59 | 60 | 61 | Policy Structure 62 | ================ 63 | 64 | When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology. 65 | In SB3, "policy" refers to the class that handles all the networks useful for training, 66 | so not only the network used to predict actions (the "learned controller"). 67 | For instance, the ``TD3`` policy contains the actor, the critic and the target networks. 68 | 69 | To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use ``ActorCriticPolicy``), 70 | SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks, 71 | respectively. Importing ``[algorithm]/policies.py`` registers an appropriate policy for that algorithm under those names. 72 | 73 | Probability distributions 74 | ========================= 75 | 76 | When needed, the policies handle the different probability distributions. 77 | All distributions are located in ``common/distributions.py`` and follow the same interface. 78 | Each distribution corresponds to a type of action space (e.g. ``Categorical`` is the one used for discrete actions. 79 | For continuous actions, we can use multiple distributions ("DiagGaussian", "SquashedGaussian" or "StateDependentDistribution") 80 | 81 | State-Dependent Exploration 82 | =========================== 83 | 84 | State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots, 85 | that was the starting point for the Stable-Baselines3 library. 86 | I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3): https://arxiv.org/abs/2005.05719 87 | 88 | Misc 89 | ==== 90 | 91 | The rest of the ``common/`` is composed of helpers (e.g. evaluation helpers) or basic components (like the callbacks). 92 | The ``type_aliases.py`` file contains common type hint aliases like ``GymStepReturn``. 93 | 94 | Et voilà? 95 | 96 | After reading this guide and the mentioned files, you should be now able to understand the design logic behind the library ;) 97 | -------------------------------------------------------------------------------- /run_arms_human.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import argparse 3 | import os, time 4 | 5 | import gym 6 | import my_gym 7 | 8 | from stable_baselines3 import PPO 9 | from stable_baselines3.common.env_checker import check_env 10 | from stable_baselines3.common import make_vec_env 11 | from stable_baselines3.common.vec_env import DummyVecEnv 12 | from stable_baselines3.common.evaluation import evaluate_policy 13 | 14 | import torch as th 15 | import torch.nn as nn 16 | from interactive_policy import ArmsPolicy 17 | from partner_config import get_arms_human_partners 18 | from util import check_optimal, learn, load_model 19 | from util import adapt_task, adapt_partner_baseline, adapt_partner_modular, adapt_partner_scratch 20 | 21 | warnings.simplefilter(action='ignore', category=FutureWarning) 22 | 23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument('--run', type=int, default=0, help="Run ID. In case you want to run replicates") 25 | parser.add_argument('--netsz', type=int, default=30, help="Size of policy network") 26 | parser.add_argument('--latentz', type=int, default=30, help="Size of latent z dimension") 27 | 28 | parser.add_argument('--mreg', type=float, default=0.0, help="Marginal regularization.") 29 | parser.add_argument('--baseline', action='store_true', default=False, help="Baseline: no modular separation.") 30 | parser.add_argument('--nomain', action='store_true', default=False, help="Baseline: don't use main logits.") 31 | 32 | parser.add_argument('--timesteps', type=int, default=10000, help="Number of timesteps to train for") 33 | parser.add_argument('--testing', action='store_true', default=False, help="Testing.") 34 | 35 | parser.add_argument('--k', type=int, default=0, help="When fixedpartner=True, k is the index of the test partner") 36 | 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | def get_model_name_and_path(run, mreg=0.00): 41 | layout = [ 42 | ('run={:04d}', run), 43 | ('netsz={:03d}', args.netsz), 44 | ('mreg={:.2f}', mreg), 45 | ] 46 | 47 | m_name = '_'.join([t.format(v) for (t, v) in layout]) 48 | m_path = 'output/armshuman_' + m_name 49 | return m_name, m_path 50 | 51 | model_name, model_path = get_model_name_and_path(args.run, mreg=args.mreg) 52 | 53 | HP = { 54 | 'n_steps': 64, 55 | 'n_steps_testing': 16, 56 | 'batch_size': 16, 57 | 'n_epochs': 20, 58 | 'n_epochs_testing': 50, 59 | 'mreg': args.mreg, 60 | } 61 | 62 | setting, partner_type = "", "fixed" 63 | TRAIN_PARTNERS, TEST_PARTNERS = get_arms_human_partners(setting, partner_type) 64 | PARTNERS = [ TEST_PARTNERS[args.k % len(TEST_PARTNERS)] ] if args.testing else TRAIN_PARTNERS 65 | 66 | def main(): 67 | global PARTNERS 68 | env = gym.make('arms-human-v0') 69 | num_partners = len(PARTNERS) if PARTNERS is not None else 1 70 | 71 | print("model path: ", model_path) 72 | net_arch = [args.netsz,args.latentz] 73 | partner_net_arch = [args.netsz,args.netsz] 74 | policy_kwargs = dict(activation_fn=nn.ReLU, 75 | net_arch=[dict(vf=net_arch, pi=net_arch)], 76 | partner_net_arch=[dict(vf=partner_net_arch, pi=partner_net_arch)], 77 | num_partners=num_partners, 78 | baseline=args.baseline, 79 | nomain=args.nomain, 80 | ) 81 | 82 | def load_model_fn(partners, testing, try_load=True): 83 | return load_model(model_path=model_path, policy_class=ArmsPolicy, policy_kwargs=policy_kwargs, env=env, hp=HP, partners=partners, testing=testing, try_load=try_load) 84 | 85 | def learn_model_fn(model, timesteps, save, period): 86 | return learn(model, model_name=model_name, model_path=model_path, timesteps=timesteps, save=save, period=period, save_thresh=None) 87 | 88 | # TRAINING 89 | if not args.testing: 90 | print("#section Training") 91 | model = load_model_fn(partners=PARTNERS, testing=False) 92 | learn_model_fn(model, timesteps=args.timesteps, save=True, period=200) 93 | 94 | ts, period = 240, HP['n_steps_testing'] 95 | 96 | # TESTING 97 | if args.testing: 98 | if args.baseline: adapt_partner_baseline(load_model_fn, learn_model_fn, partners=PARTNERS, timesteps=ts, period=period, do_optimal=True) 99 | else: adapt_partner_modular(load_model_fn, learn_model_fn, partners=PARTNERS, timesteps=ts, period=period, do_optimal=True) 100 | adapt_partner_scratch(load_model_fn, learn_model_fn, partners=PARTNERS, timesteps=ts, period=period, do_optimal=True) 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /stable-baselines3/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler: 6 | __version__ = file_handler.read().strip() 7 | 8 | 9 | long_description = """ 10 | 11 | # Stable Baselines3 12 | 13 | Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). 14 | 15 | These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details. 16 | 17 | 18 | ## Links 19 | 20 | Repository: 21 | https://github.com/DLR-RM/stable-baselines3 22 | 23 | Medium article: 24 | https://medium.com/@araffin/df87c4b2fc82 25 | 26 | Documentation: 27 | https://stable-baselines3.readthedocs.io/en/master/ 28 | 29 | RL Baselines3 Zoo: 30 | https://github.com/DLR-RM/rl-baselines3-zoo 31 | 32 | ## Quick example 33 | 34 | Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym. 35 | 36 | Here is a quick example of how to train and run PPO on a cartpole environment: 37 | 38 | ```python 39 | import gym 40 | 41 | from stable_baselines3 import PPO 42 | 43 | env = gym.make('CartPole-v1') 44 | 45 | model = PPO('MlpPolicy', env, verbose=1) 46 | model.learn(total_timesteps=10000) 47 | 48 | obs = env.reset() 49 | for i in range(1000): 50 | action, _states = model.predict(obs, deterministic=True) 51 | obs, reward, done, info = env.step(action) 52 | env.render() 53 | if done: 54 | obs = env.reset() 55 | ``` 56 | 57 | Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): 58 | 59 | ```python 60 | from stable_baselines3 import PPO 61 | 62 | model = PPO('MlpPolicy', 'CartPole-v1').learn(10000) 63 | ``` 64 | 65 | """ # noqa:E501 66 | 67 | 68 | setup( 69 | name="stable_baselines3", 70 | packages=[package for package in find_packages() if package.startswith("stable_baselines3")], 71 | package_data={"stable_baselines3": ["py.typed", "version.txt"]}, 72 | install_requires=[ 73 | "gym>=0.17", 74 | "numpy", 75 | "torch>=1.4.0", 76 | # For saving models 77 | "cloudpickle", 78 | # For reading logs 79 | "pandas", 80 | # Plotting learning curves 81 | "matplotlib", 82 | ], 83 | extras_require={ 84 | "tests": [ 85 | # Run tests and coverage 86 | "pytest", 87 | "pytest-cov", 88 | "pytest-env", 89 | "pytest-xdist", 90 | # Type check 91 | "pytype", 92 | # Lint code 93 | "flake8>=3.8", 94 | # Sort imports 95 | "isort>=5.0", 96 | # Reformat 97 | "black", 98 | ], 99 | "docs": [ 100 | "sphinx", 101 | "sphinx-autobuild", 102 | "sphinx-rtd-theme", 103 | # For spelling 104 | "sphinxcontrib.spelling", 105 | # Type hints support 106 | # 'sphinx-autodoc-typehints' 107 | ], 108 | "extra": [ 109 | # For render 110 | "opencv-python", 111 | # For atari games, 112 | "atari_py~=0.2.0", 113 | "pillow", 114 | # Tensorboard support 115 | "tensorboard", 116 | # Checking memory taken by replay buffer 117 | "psutil", 118 | ], 119 | }, 120 | description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", 121 | author="Antonin Raffin", 122 | url="https://github.com/DLR-RM/stable-baselines3", 123 | author_email="antonin.raffin@dlr.de", 124 | keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " 125 | "gym openai stable baselines toolbox python data-science", 126 | license="MIT", 127 | long_description=long_description, 128 | long_description_content_type="text/markdown", 129 | version=__version__, 130 | ) 131 | 132 | # python setup.py sdist 133 | # python setup.py bdist_wheel 134 | # twine upload --repository-url https://test.pypi.org/legacy/ dist/* 135 | # twine upload dist/* 136 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/results_plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | # import matplotlib 7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode 8 | from matplotlib import pyplot as plt 9 | 10 | from stable_baselines3.common.monitor import load_results 11 | 12 | X_TIMESTEPS = "timesteps" 13 | X_EPISODES = "episodes" 14 | X_WALLTIME = "walltime_hrs" 15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] 16 | EPISODES_WINDOW = 100 17 | 18 | 19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray: 20 | """ 21 | Apply a rolling window to a np.ndarray 22 | 23 | :param array: (np.ndarray) the input Array 24 | :param window: (int) length of the rolling window 25 | :return: (np.ndarray) rolling window on the input array 26 | """ 27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) 28 | strides = array.strides + (array.strides[-1],) 29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) 30 | 31 | 32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: 33 | """ 34 | Apply a function to the rolling window of 2 arrays 35 | 36 | :param var_1: (np.ndarray) variable 1 37 | :param var_2: (np.ndarray) variable 2 38 | :param window: (int) length of the rolling window 39 | :param func: (numpy function) function to apply on the rolling window on variable 2 (such as np.mean) 40 | :return: (Tuple[np.ndarray, np.ndarray]) the rolling output with applied function 41 | """ 42 | var_2_window = rolling_window(var_2, window) 43 | function_on_var2 = func(var_2_window, axis=-1) 44 | return var_1[window - 1 :], function_on_var2 45 | 46 | 47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: 48 | """ 49 | Decompose a data frame variable to x ans ys 50 | 51 | :param data_frame: (pd.DataFrame) the input data 52 | :param x_axis: (str) the axis for the x and y output 53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 54 | :return: (Tuple[np.ndarray, np.ndarray]) the x and y output 55 | """ 56 | if x_axis == X_TIMESTEPS: 57 | x_var = np.cumsum(data_frame.l.values) 58 | y_var = data_frame.r.values 59 | elif x_axis == X_EPISODES: 60 | x_var = np.arange(len(data_frame)) 61 | y_var = data_frame.r.values 62 | elif x_axis == X_WALLTIME: 63 | # Convert to hours 64 | x_var = data_frame.t.values / 3600.0 65 | y_var = data_frame.r.values 66 | else: 67 | raise NotImplementedError 68 | return x_var, y_var 69 | 70 | 71 | def plot_curves( 72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) 73 | ) -> None: 74 | """ 75 | plot the curves 76 | 77 | :param xy_list: (List[Tuple[np.ndarray, np.ndarray]]) the x and y coordinates to plot 78 | :param x_axis: (str) the axis for the x and y output 79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 80 | :param title: (str) the title of the plot 81 | :param figsize: (Tuple[int, int]) Size of the figure (width, height) 82 | """ 83 | 84 | plt.figure(title, figsize=figsize) 85 | max_x = max(xy[0][-1] for xy in xy_list) 86 | min_x = 0 87 | for (i, (x, y)) in enumerate(xy_list): 88 | plt.scatter(x, y, s=2) 89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size. 90 | if x.shape[0] >= EPISODES_WINDOW: 91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW 92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) 93 | plt.plot(x, y_mean) 94 | plt.xlim(min_x, max_x) 95 | plt.title(title) 96 | plt.xlabel(x_axis) 97 | plt.ylabel("Episode Rewards") 98 | plt.tight_layout() 99 | 100 | 101 | def plot_results( 102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) 103 | ) -> None: 104 | """ 105 | Plot the results using csv files from ``Monitor`` wrapper. 106 | 107 | :param dirs: ([str]) the save location of the results to plot 108 | :param num_timesteps: (int or None) only plot the points below this value 109 | :param x_axis: (str) the axis for the x and y output 110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 111 | :param task_name: (str) the title of the task to plot 112 | :param figsize: (Tuple[int, int]) Size of the figure (width, height) 113 | """ 114 | 115 | data_frames = [] 116 | for folder in dirs: 117 | data_frame = load_results(folder) 118 | if num_timesteps is not None: 119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps] 120 | data_frames.append(data_frame) 121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames] 122 | plot_curves(xy_list, x_axis, task_name, figsize) 123 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import Sequence 4 | 5 | import numpy as np 6 | 7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv 8 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info 9 | 10 | 11 | class DummyVecEnv(VecEnv): 12 | """ 13 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 14 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, 15 | as the overhead of multiprocess or multithread outweighs the environment computation time. 16 | This can also be used for RL methods that 17 | require a vectorized environment, but that you want a single environments to train with. 18 | 19 | :param env_fns: ([Gym Environment]) the list of environments to vectorize 20 | """ 21 | 22 | def __init__(self, env_fns): 23 | self.envs = [fn() for fn in env_fns] 24 | env = self.envs[0] 25 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 26 | obs_space = env.observation_space 27 | self.keys, shapes, dtypes = obs_space_info(obs_space) 28 | 29 | self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]) 30 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) 31 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 32 | self.buf_infos = [{} for _ in range(self.num_envs)] 33 | self.actions = None 34 | self.metadata = env.metadata 35 | 36 | def step_async(self, actions): 37 | self.actions = actions 38 | 39 | def step_wait(self): 40 | for env_idx in range(self.num_envs): 41 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( 42 | self.actions[env_idx] 43 | ) 44 | if self.buf_dones[env_idx]: 45 | # save final observation where user can get it, then reset 46 | self.buf_infos[env_idx]["terminal_observation"] = obs 47 | obs = self.envs[env_idx].reset() 48 | self._save_obs(env_idx, obs) 49 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 50 | 51 | def seed(self, seed=None): 52 | seeds = list() 53 | for idx, env in enumerate(self.envs): 54 | seeds.append(env.seed(seed + idx)) 55 | return seeds 56 | 57 | def reset(self): 58 | for env_idx in range(self.num_envs): 59 | obs = self.envs[env_idx].reset() 60 | self._save_obs(env_idx, obs) 61 | return self._obs_from_buf() 62 | 63 | def close(self): 64 | for env in self.envs: 65 | env.close() 66 | 67 | def get_images(self) -> Sequence[np.ndarray]: 68 | return [env.render(mode="rgb_array") for env in self.envs] 69 | 70 | def render(self, mode: str = "human"): 71 | """ 72 | Gym environment rendering. If there are multiple environments then 73 | they are tiled together in one image via ``BaseVecEnv.render()``. 74 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the 75 | underlying environment. 76 | 77 | Therefore, some arguments such as ``mode`` will have values that are valid 78 | only when ``num_envs == 1``. 79 | 80 | :param mode: The rendering type. 81 | """ 82 | if self.num_envs == 1: 83 | return self.envs[0].render(mode=mode) 84 | else: 85 | return super().render(mode=mode) 86 | 87 | def _save_obs(self, env_idx, obs): 88 | for key in self.keys: 89 | if key is None: 90 | self.buf_obs[key][env_idx] = obs 91 | else: 92 | self.buf_obs[key][env_idx] = obs[key] 93 | 94 | def _obs_from_buf(self): 95 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 96 | 97 | def get_attr(self, attr_name, indices=None): 98 | """Return attribute from vectorized environment (see base class).""" 99 | target_envs = self._get_target_envs(indices) 100 | return [getattr(env_i, attr_name) for env_i in target_envs] 101 | 102 | def set_attr(self, attr_name, value, indices=None): 103 | """Set attribute inside vectorized environments (see base class).""" 104 | target_envs = self._get_target_envs(indices) 105 | for env_i in target_envs: 106 | setattr(env_i, attr_name, value) 107 | 108 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 109 | """Call instance methods of vectorized environments.""" 110 | target_envs = self._get_target_envs(indices) 111 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 112 | 113 | def _get_target_envs(self, indices): 114 | indices = self._get_indices(indices) 115 | return [self.envs[i] for i in indices] 116 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/bit_flipping_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Optional, Union 3 | 4 | import numpy as np 5 | from gym import GoalEnv, spaces 6 | 7 | from stable_baselines3.common.type_aliases import GymStepReturn 8 | 9 | 10 | class BitFlippingEnv(GoalEnv): 11 | """ 12 | Simple bit flipping env, useful to test HER. 13 | The goal is to flip all the bits to get a vector of ones. 14 | In the continuous variant, if the ith action component has a value > 0, 15 | then the ith bit will be flipped. 16 | 17 | :param n_bits: (int) Number of bits to flip 18 | :param continuous: (bool) Whether to use the continuous actions version or not, 19 | by default, it uses the discrete one 20 | :param max_steps: (Optional[int]) Max number of steps, by default, equal to n_bits 21 | :param discrete_obs_space: (bool) Whether to use the discrete observation 22 | version or not, by default, it uses the MultiBinary one 23 | """ 24 | 25 | def __init__( 26 | self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False 27 | ): 28 | super(BitFlippingEnv, self).__init__() 29 | # The achieved goal is determined by the current state 30 | # here, it is a special where they are equal 31 | if discrete_obs_space: 32 | # In the discrete case, the agent act on the binary 33 | # representation of the observation 34 | self.observation_space = spaces.Dict( 35 | { 36 | "observation": spaces.Discrete(2 ** n_bits - 1), 37 | "achieved_goal": spaces.Discrete(2 ** n_bits - 1), 38 | "desired_goal": spaces.Discrete(2 ** n_bits - 1), 39 | } 40 | ) 41 | else: 42 | self.observation_space = spaces.Dict( 43 | { 44 | "observation": spaces.MultiBinary(n_bits), 45 | "achieved_goal": spaces.MultiBinary(n_bits), 46 | "desired_goal": spaces.MultiBinary(n_bits), 47 | } 48 | ) 49 | 50 | self.obs_space = spaces.MultiBinary(n_bits) 51 | 52 | if continuous: 53 | self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) 54 | else: 55 | self.action_space = spaces.Discrete(n_bits) 56 | self.continuous = continuous 57 | self.discrete_obs_space = discrete_obs_space 58 | self.state = None 59 | self.desired_goal = np.ones((n_bits,)) 60 | if max_steps is None: 61 | max_steps = n_bits 62 | self.max_steps = max_steps 63 | self.current_step = 0 64 | self.reset() 65 | 66 | def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: 67 | """ 68 | Convert to discrete space if needed. 69 | 70 | :param state: (np.ndarray) 71 | :return: (np.ndarray or int) 72 | """ 73 | if self.discrete_obs_space: 74 | # The internal state is the binary representation of the 75 | # observed one 76 | return int(sum([state[i] * 2 ** i for i in range(len(state))])) 77 | return state 78 | 79 | def _get_obs(self) -> OrderedDict: 80 | """ 81 | Helper to create the observation. 82 | 83 | :return: (OrderedDict) 84 | """ 85 | return OrderedDict( 86 | [ 87 | ("observation", self.convert_if_needed(self.state.copy())), 88 | ("achieved_goal", self.convert_if_needed(self.state.copy())), 89 | ("desired_goal", self.convert_if_needed(self.desired_goal.copy())), 90 | ] 91 | ) 92 | 93 | def reset(self) -> OrderedDict: 94 | self.current_step = 0 95 | self.state = self.obs_space.sample() 96 | return self._get_obs() 97 | 98 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: 99 | if self.continuous: 100 | self.state[action > 0] = 1 - self.state[action > 0] 101 | else: 102 | self.state[action] = 1 - self.state[action] 103 | obs = self._get_obs() 104 | reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None) 105 | done = reward == 0 106 | self.current_step += 1 107 | # Episode terminate when we reached the goal or the max number of steps 108 | info = {"is_success": done} 109 | done = done or self.current_step >= self.max_steps 110 | return obs, reward, done, info 111 | 112 | def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> float: 113 | # Deceptive reward: it is positive only when the goal is achieved 114 | if self.discrete_obs_space: 115 | return 0.0 if achieved_goal == desired_goal else -1.0 116 | return 0.0 if (achieved_goal == desired_goal).all() else -1.0 117 | 118 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 119 | if mode == "rgb_array": 120 | return self.state.copy() 121 | print(self.state) 122 | 123 | def close(self) -> None: 124 | pass 125 | -------------------------------------------------------------------------------- /stable-baselines3/docs/guide/install.rst: -------------------------------------------------------------------------------- 1 | .. _install: 2 | 3 | Installation 4 | ============ 5 | 6 | Prerequisites 7 | ------------- 8 | 9 | Stable-Baselines3 requires python 3.6+. 10 | 11 | Windows 10 12 | ~~~~~~~~~~ 13 | 14 | We recommend using `Anaconda `_ for Windows users for easier installation of Python packages and required libraries. You need an environment with Python version 3.6 or above. 15 | 16 | For a quick start you can move straight to installing Stable-Baselines3 in the next step. 17 | 18 | .. note:: 19 | 20 | Trying to create Atari environments may result to vague errors related to missing DLL files and modules. This is an 21 | issue with atari-py package. `See this discussion for more information `_. 22 | 23 | 24 | Stable Release 25 | ~~~~~~~~~~~~~~ 26 | To install Stable Baselines3 with pip, execute: 27 | 28 | .. code-block:: bash 29 | 30 | pip install stable-baselines3[extra] 31 | 32 | This includes an optional dependencies like Tensorboard, OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use: 33 | 34 | .. code-block:: bash 35 | 36 | pip install stable-baselines3 37 | 38 | 39 | Bleeding-edge version 40 | --------------------- 41 | 42 | .. code-block:: bash 43 | 44 | pip install git+https://github.com/DLR-RM/stable-baselines3 45 | 46 | 47 | Development version 48 | ------------------- 49 | 50 | To contribute to Stable-Baselines3, with support for running tests and building the documentation. 51 | 52 | .. code-block:: bash 53 | 54 | git clone https://github.com/DLR-RM/stable-baselines3 && cd stable-baselines3 55 | pip install -e .[docs,tests,extra] 56 | 57 | 58 | Using Docker Images 59 | ------------------- 60 | 61 | If you are looking for docker images with stable-baselines already installed in it, 62 | we recommend using images from `RL Baselines3 Zoo `_. 63 | 64 | Otherwise, the following images contained all the dependencies for stable-baselines3 but not the stable-baselines3 package itself. 65 | They are made for development. 66 | 67 | Use Built Images 68 | ~~~~~~~~~~~~~~~~ 69 | 70 | GPU image (requires `nvidia-docker`_): 71 | 72 | .. code-block:: bash 73 | 74 | docker pull stablebaselines/stable-baselines3 75 | 76 | CPU only: 77 | 78 | .. code-block:: bash 79 | 80 | docker pull stablebaselines/stable-baselines3-cpu 81 | 82 | Build the Docker Images 83 | ~~~~~~~~~~~~~~~~~~~~~~~~ 84 | 85 | Build GPU image (with nvidia-docker): 86 | 87 | .. code-block:: bash 88 | 89 | make docker-gpu 90 | 91 | Build CPU image: 92 | 93 | .. code-block:: bash 94 | 95 | make docker-cpu 96 | 97 | Note: if you are using a proxy, you need to pass extra params during 98 | build and do some `tweaks`_: 99 | 100 | .. code-block:: bash 101 | 102 | --network=host --build-arg HTTP_PROXY=http://your.proxy.fr:8080/ --build-arg http_proxy=http://your.proxy.fr:8080/ --build-arg HTTPS_PROXY=https://your.proxy.fr:8080/ --build-arg https_proxy=https://your.proxy.fr:8080/ 103 | 104 | Run the images (CPU/GPU) 105 | ~~~~~~~~~~~~~~~~~~~~~~~~ 106 | 107 | Run the nvidia-docker GPU image 108 | 109 | .. code-block:: bash 110 | 111 | docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /root/code/stable-baselines3/ && pytest tests/' 112 | 113 | Or, with the shell file: 114 | 115 | .. code-block:: bash 116 | 117 | ./scripts/run_docker_gpu.sh pytest tests/ 118 | 119 | Run the docker CPU image 120 | 121 | .. code-block:: bash 122 | 123 | docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /root/code/stable-baselines3/ && pytest tests/' 124 | 125 | Or, with the shell file: 126 | 127 | .. code-block:: bash 128 | 129 | ./scripts/run_docker_cpu.sh pytest tests/ 130 | 131 | Explanation of the docker command: 132 | 133 | - ``docker run -it`` create an instance of an image (=container), and 134 | run it interactively (so ctrl+c will work) 135 | - ``--rm`` option means to remove the container once it exits/stops 136 | (otherwise, you will have to use ``docker rm``) 137 | - ``--network host`` don't use network isolation, this allow to use 138 | tensorboard/visdom on host machine 139 | - ``--ipc=host`` Use the host system’s IPC namespace. IPC (POSIX/SysV IPC) namespace provides 140 | separation of named shared memory segments, semaphores and message 141 | queues. 142 | - ``--name test`` give explicitly the name ``test`` to the container, 143 | otherwise it will be assigned a random name 144 | - ``--mount src=...`` give access of the local directory (``pwd`` 145 | command) to the container (it will be map to ``/root/code/stable-baselines``), so 146 | all the logs created in the container in this folder will be kept 147 | - ``bash -c '...'`` Run command inside the docker image, here run the tests 148 | (``pytest tests/``) 149 | 150 | .. _nvidia-docker: https://github.com/NVIDIA/nvidia-docker 151 | .. _tweaks: https://stackoverflow.com/questions/23111631/cannot-download-docker-images-behind-a-proxy 152 | -------------------------------------------------------------------------------- /stable-baselines3/tests/test_envs.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import pytest 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.bit_flipping_env import BitFlippingEnv 7 | from stable_baselines3.common.env_checker import check_env 8 | from stable_baselines3.common.identity_env import ( 9 | FakeImageEnv, 10 | IdentityEnv, 11 | IdentityEnvBox, 12 | IdentityEnvMultiBinary, 13 | IdentityEnvMultiDiscrete, 14 | ) 15 | 16 | ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv] 17 | 18 | 19 | @pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"]) 20 | def test_env(env_id): 21 | """ 22 | Check that environmnent integrated in Gym pass the test. 23 | 24 | :param env_id: (str) 25 | """ 26 | env = gym.make(env_id) 27 | with pytest.warns(None) as record: 28 | check_env(env) 29 | 30 | # Pendulum-v0 will produce a warning because the action space is 31 | # in [-2, 2] and not [-1, 1] 32 | if env_id == "Pendulum-v0": 33 | assert len(record) == 1 34 | else: 35 | # The other environments must pass without warning 36 | assert len(record) == 0 37 | 38 | 39 | @pytest.mark.parametrize("env_class", ENV_CLASSES) 40 | def test_custom_envs(env_class): 41 | env = env_class() 42 | check_env(env) 43 | 44 | 45 | def test_high_dimension_action_space(): 46 | """ 47 | Test for continuous action space 48 | with more than one action. 49 | """ 50 | env = FakeImageEnv() 51 | # Patch the action space 52 | env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32) 53 | 54 | # Patch to avoid error 55 | def patched_step(_action): 56 | return env.observation_space.sample(), 0.0, False, {} 57 | 58 | env.step = patched_step 59 | check_env(env) 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "new_obs_space", 64 | [ 65 | # Small image 66 | spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), 67 | # Range not in [0, 255] 68 | spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8), 69 | # Wrong dtype 70 | spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32), 71 | # Not an image, it should be a 1D vector 72 | spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), 73 | # Tuple space is not supported by SB 74 | spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), 75 | # Dict space is not supported by SB when env is not a GoalEnv 76 | spaces.Dict({"position": spaces.Discrete(5)}), 77 | ], 78 | ) 79 | def test_non_default_spaces(new_obs_space): 80 | env = FakeImageEnv() 81 | env.observation_space = new_obs_space 82 | # Patch methods to avoid errors 83 | env.reset = new_obs_space.sample 84 | 85 | def patched_step(_action): 86 | return new_obs_space.sample(), 0.0, False, {} 87 | 88 | env.step = patched_step 89 | with pytest.warns(UserWarning): 90 | check_env(env) 91 | 92 | 93 | def check_reset_assert_error(env, new_reset_return): 94 | """ 95 | Helper to check that the error is caught. 96 | :param env: (gym.Env) 97 | :param new_reset_return: (Any) 98 | """ 99 | 100 | def wrong_reset(): 101 | return new_reset_return 102 | 103 | # Patch the reset method with a wrong one 104 | env.reset = wrong_reset 105 | with pytest.raises(AssertionError): 106 | check_env(env) 107 | 108 | 109 | def test_common_failures_reset(): 110 | """ 111 | Test that common failure cases of the `reset_method` are caught 112 | """ 113 | env = IdentityEnvBox() 114 | # Return an observation that does not match the observation_space 115 | check_reset_assert_error(env, np.ones((3,))) 116 | # The observation is not a numpy array 117 | check_reset_assert_error(env, 1) 118 | 119 | # Return not only the observation 120 | check_reset_assert_error(env, (env.observation_space.sample(), False)) 121 | 122 | 123 | def check_step_assert_error(env, new_step_return=()): 124 | """ 125 | Helper to check that the error is caught. 126 | :param env: (gym.Env) 127 | :param new_step_return: (tuple) 128 | """ 129 | 130 | def wrong_step(_action): 131 | return new_step_return 132 | 133 | # Patch the step method with a wrong one 134 | env.step = wrong_step 135 | with pytest.raises(AssertionError): 136 | check_env(env) 137 | 138 | 139 | def test_common_failures_step(): 140 | """ 141 | Test that common failure cases of the `step` method are caught 142 | """ 143 | env = IdentityEnvBox() 144 | 145 | # Wrong shape for the observation 146 | check_step_assert_error(env, (np.ones((4,)), 1.0, False, {})) 147 | # Obs is not a numpy array 148 | check_step_assert_error(env, (1, 1.0, False, {})) 149 | 150 | # Return a wrong reward 151 | check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {})) 152 | 153 | # Info dict is not returned 154 | check_step_assert_error(env, (env.observation_space.sample(), 0.0, False)) 155 | 156 | # Done is not a boolean 157 | check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) 158 | check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) 159 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/preprocessing.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch as th 5 | from gym import spaces 6 | from torch.nn import functional as F 7 | 8 | 9 | def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool: 10 | """ 11 | Check if a observation space has the shape, limits and dtype 12 | of a valid image. 13 | The check is conservative, so that it returns False 14 | if there is a doubt. 15 | 16 | Valid images: RGB, RGBD, GrayScale with values in [0, 255] 17 | 18 | :param observation_space: (spaces.Space) 19 | :param channels_last: (bool) 20 | :param check_channels: (bool) Whether to do or not the check for the number of channels. 21 | e.g., with frame-stacking, the observation space may have more channels than expected. 22 | :return: (bool) 23 | """ 24 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: 25 | # Check the type 26 | if observation_space.dtype != np.uint8: 27 | return False 28 | 29 | # Check the value range 30 | if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): 31 | return False 32 | 33 | # Skip channels check 34 | if not check_channels: 35 | return True 36 | # Check the number of channels 37 | if channels_last: 38 | n_channels = observation_space.shape[-1] 39 | else: 40 | n_channels = observation_space.shape[0] 41 | # RGB, RGBD, GrayScale 42 | return n_channels in [1, 3, 4] 43 | return False 44 | 45 | 46 | def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor: 47 | """ 48 | Preprocess observation to be to a neural network. 49 | For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) 50 | For discrete observations, it create a one hot vector. 51 | 52 | :param obs: (th.Tensor) Observation 53 | :param observation_space: (spaces.Space) 54 | :param normalize_images: (bool) Whether to normalize images or not 55 | (True by default) 56 | :return: (th.Tensor) 57 | """ 58 | if isinstance(observation_space, spaces.Box): 59 | if is_image_space(observation_space) and normalize_images: 60 | return obs.float() / 255.0 61 | return obs.float() 62 | 63 | elif isinstance(observation_space, spaces.Discrete): 64 | # One hot encoding and convert to float to avoid errors 65 | return F.one_hot(obs.long(), num_classes=observation_space.n).float() 66 | 67 | elif isinstance(observation_space, spaces.MultiDiscrete): 68 | # Tensor concatenation of one hot encodings of each Categorical sub-space 69 | return th.cat( 70 | [ 71 | F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() 72 | for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1)) 73 | ], 74 | dim=-1, 75 | ).view(obs.shape[0], sum(observation_space.nvec)) 76 | 77 | elif isinstance(observation_space, spaces.MultiBinary): 78 | return obs.float() 79 | 80 | else: 81 | raise NotImplementedError() 82 | 83 | 84 | def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: 85 | """ 86 | Get the shape of the observation (useful for the buffers). 87 | 88 | :param observation_space: (spaces.Space) 89 | :return: (Tuple[int, ...]) 90 | """ 91 | if isinstance(observation_space, spaces.Box): 92 | return observation_space.shape 93 | elif isinstance(observation_space, spaces.Discrete): 94 | # Observation is an int 95 | return (1,) 96 | elif isinstance(observation_space, spaces.MultiDiscrete): 97 | # Number of discrete features 98 | return (int(len(observation_space.nvec)),) 99 | elif isinstance(observation_space, spaces.MultiBinary): 100 | # Number of binary features 101 | return (int(observation_space.n),) 102 | else: 103 | raise NotImplementedError() 104 | 105 | 106 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int: 107 | """ 108 | Get the dimension of the observation space when flattened. 109 | It does not apply to image observation space. 110 | 111 | :param observation_space: (spaces.Space) 112 | :return: (int) 113 | """ 114 | # See issue https://github.com/openai/gym/issues/1915 115 | # it may be a problem for Dict/Tuple spaces too... 116 | if isinstance(observation_space, spaces.MultiDiscrete): 117 | return sum(observation_space.nvec) 118 | else: 119 | # Use Gym internal method 120 | return spaces.utils.flatdim(observation_space) 121 | 122 | 123 | def get_action_dim(action_space: spaces.Space) -> int: 124 | """ 125 | Get the dimension of the action space. 126 | 127 | :param action_space: (spaces.Space) 128 | :return: (int) 129 | """ 130 | if isinstance(action_space, spaces.Box): 131 | return int(np.prod(action_space.shape)) 132 | elif isinstance(action_space, spaces.Discrete): 133 | # Action is an int 134 | return 1 135 | elif isinstance(action_space, spaces.MultiDiscrete): 136 | # Number of discrete actions 137 | return int(len(action_space.nvec)) 138 | elif isinstance(action_space, spaces.MultiBinary): 139 | # Number of binary actions 140 | return int(action_space.n) 141 | else: 142 | raise NotImplementedError() 143 | -------------------------------------------------------------------------------- /stable-baselines3/stable_baselines3/common/identity_env.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | from gym import Env, Space 5 | from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete 6 | 7 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn 8 | 9 | 10 | class IdentityEnv(Env): 11 | def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100): 12 | """ 13 | Identity environment for testing purposes 14 | 15 | :param dim: the size of the action and observation dimension you want 16 | to learn. Provide at most one of ``dim`` and ``space``. If both are 17 | None, then initialization proceeds with ``dim=1`` and ``space=None``. 18 | :param space: the action and observation space. Provide at most one of 19 | ``dim`` and ``space``. 20 | :param ep_length: the length of each episode in timesteps 21 | """ 22 | if space is None: 23 | if dim is None: 24 | dim = 1 25 | space = Discrete(dim) 26 | else: 27 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed" 28 | 29 | self.action_space = self.observation_space = space 30 | self.ep_length = ep_length 31 | self.current_step = 0 32 | self.num_resets = -1 # Becomes 0 after __init__ exits. 33 | self.reset() 34 | 35 | def reset(self) -> GymObs: 36 | self.current_step = 0 37 | self.num_resets += 1 38 | self._choose_next_state() 39 | return self.state 40 | 41 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: 42 | reward = self._get_reward(action) 43 | self._choose_next_state() 44 | self.current_step += 1 45 | done = self.current_step >= self.ep_length 46 | return self.state, reward, done, {} 47 | 48 | def _choose_next_state(self) -> None: 49 | self.state = self.action_space.sample() 50 | 51 | def _get_reward(self, action: Union[int, np.ndarray]) -> float: 52 | return 1.0 if np.all(self.state == action) else 0.0 53 | 54 | def render(self, mode: str = "human") -> None: 55 | pass 56 | 57 | 58 | class IdentityEnvBox(IdentityEnv): 59 | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100): 60 | """ 61 | Identity environment for testing purposes 62 | 63 | :param low: (float) the lower bound of the box dim 64 | :param high: (float) the upper bound of the box dim 65 | :param eps: (float) the epsilon bound for correct value 66 | :param ep_length: (int) the length of each episode in timesteps 67 | """ 68 | space = Box(low=low, high=high, shape=(1,), dtype=np.float32) 69 | super().__init__(ep_length=ep_length, space=space) 70 | self.eps = eps 71 | 72 | def step(self, action: np.ndarray) -> GymStepReturn: 73 | reward = self._get_reward(action) 74 | self._choose_next_state() 75 | self.current_step += 1 76 | done = self.current_step >= self.ep_length 77 | return self.state, reward, done, {} 78 | 79 | def _get_reward(self, action: np.ndarray) -> float: 80 | return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 81 | 82 | 83 | class IdentityEnvMultiDiscrete(IdentityEnv): 84 | def __init__(self, dim: int = 1, ep_length: int = 100): 85 | """ 86 | Identity environment for testing purposes 87 | 88 | :param dim: (int) the size of the dimensions you want to learn 89 | :param ep_length: (int) the length of each episode in timesteps 90 | """ 91 | space = MultiDiscrete([dim, dim]) 92 | super().__init__(ep_length=ep_length, space=space) 93 | 94 | 95 | class IdentityEnvMultiBinary(IdentityEnv): 96 | def __init__(self, dim: int = 1, ep_length: int = 100): 97 | """ 98 | Identity environment for testing purposes 99 | 100 | :param dim: (int) the size of the dimensions you want to learn 101 | :param ep_length: (int) the length of each episode in timesteps 102 | """ 103 | space = MultiBinary(dim) 104 | super().__init__(ep_length=ep_length, space=space) 105 | 106 | 107 | class FakeImageEnv(Env): 108 | """ 109 | Fake image environment for testing purposes, it mimics Atari games. 110 | 111 | :param action_dim: (int) Number of discrete actions 112 | :param screen_height: (int) Height of the image 113 | :param screen_width: (int) Width of the image 114 | :param n_channels: (int) Number of color channels 115 | :param discrete: (bool) 116 | """ 117 | 118 | def __init__( 119 | self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True 120 | ): 121 | 122 | self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8) 123 | if discrete: 124 | self.action_space = Discrete(action_dim) 125 | else: 126 | self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32) 127 | self.ep_length = 10 128 | self.current_step = 0 129 | 130 | def reset(self) -> np.ndarray: 131 | self.current_step = 0 132 | return self.observation_space.sample() 133 | 134 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: 135 | reward = 0.0 136 | self.current_step += 1 137 | done = self.current_step >= self.ep_length 138 | return self.observation_space.sample(), reward, done, {} 139 | 140 | def render(self, mode: str = "human") -> None: 141 | pass 142 | --------------------------------------------------------------------------------