├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bsuite ├── __init__.py ├── _metadata.py ├── analysis │ └── results.ipynb ├── baselines │ ├── README.md │ ├── __init__.py │ ├── base.py │ ├── experiment.py │ ├── jax │ │ ├── __init__.py │ │ ├── actor_critic │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ │ ├── actor_critic_rnn │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ │ ├── boot_dqn │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ │ └── dqn │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ ├── random │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── run.py │ │ └── run_test.py │ ├── tf │ │ ├── __init__.py │ │ ├── actor_critic │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ │ ├── actor_critic_rnn │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ │ ├── boot_dqn │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ │ └── dqn │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── run.py │ │ │ └── run_test.py │ ├── third_party │ │ ├── __init__.py │ │ ├── dopamine_dqn │ │ │ ├── __init__.py │ │ │ └── run.py │ │ ├── openai_dqn │ │ │ ├── __init__.py │ │ │ └── run.py │ │ └── openai_ppo │ │ │ ├── __init__.py │ │ │ └── run.py │ └── utils │ │ ├── __init__.py │ │ ├── pool.py │ │ ├── replay.py │ │ ├── replay_test.py │ │ ├── sequence.py │ │ └── sequence_test.py ├── bsuite.py ├── environments │ ├── README.md │ ├── __init__.py │ ├── bandit.py │ ├── bandit_test.py │ ├── base.py │ ├── cartpole.py │ ├── cartpole_test.py │ ├── catch.py │ ├── catch_test.py │ ├── deep_sea.py │ ├── deep_sea_test.py │ ├── discounting_chain.py │ ├── discounting_chain_test.py │ ├── memory_chain.py │ ├── memory_chain_test.py │ ├── mnist.py │ ├── mnist_test.py │ ├── mountain_car.py │ ├── mountain_car_test.py │ ├── umbrella_chain.py │ └── umbrella_chain_test.py ├── experiments │ ├── README.md │ ├── __init__.py │ ├── bandit │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── bandit.py │ │ └── sweep.py │ ├── bandit_noise │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── bandit_noise.py │ │ ├── bandit_noise_test.py │ │ └── sweep.py │ ├── bandit_scale │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── bandit_scale.py │ │ ├── bandit_scale_test.py │ │ └── sweep.py │ ├── cartpole │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── cartpole.py │ │ └── sweep.py │ ├── cartpole_noise │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── cartpole_noise.py │ │ ├── cartpole_noise_test.py │ │ └── sweep.py │ ├── cartpole_scale │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── cartpole_scale.py │ │ ├── cartpole_scale_test.py │ │ └── sweep.py │ ├── cartpole_swingup │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── cartpole_swingup.py │ │ ├── cartpole_swingup_test.py │ │ └── sweep.py │ ├── catch │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── catch.py │ │ └── sweep.py │ ├── catch_noise │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── catch_noise.py │ │ ├── catch_noise_test.py │ │ └── sweep.py │ ├── catch_scale │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── catch_scale.py │ │ ├── catch_scale_test.py │ │ └── sweep.py │ ├── deep_sea │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── deep_sea.py │ │ └── sweep.py │ ├── deep_sea_stochastic │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── deep_sea_stochastic.py │ │ ├── deep_sea_stochastic_test.py │ │ └── sweep.py │ ├── discounting_chain │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── discounting_chain.py │ │ └── sweep.py │ ├── memory_len │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── memory_len.py │ │ └── sweep.py │ ├── memory_size │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── memory_size.py │ │ └── sweep.py │ ├── mnist │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── mnist.py │ │ └── sweep.py │ ├── mnist_noise │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── mnist_noise.py │ │ ├── mnist_noise_test.py │ │ └── sweep.py │ ├── mnist_scale │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── mnist_scale.py │ │ ├── mnist_scale_test.py │ │ └── sweep.py │ ├── mountain_car │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── mountain_car.py │ │ └── sweep.py │ ├── mountain_car_noise │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── mountain_car_noise.py │ │ ├── mountain_car_noise_test.py │ │ └── sweep.py │ ├── mountain_car_scale │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── mountain_car_scale.py │ │ ├── mountain_car_scale_test.py │ │ └── sweep.py │ ├── summary_analysis.py │ ├── summary_analysis_test.py │ ├── umbrella_distract │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── sweep.py │ │ └── umbrella_distract.py │ └── umbrella_length │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── sweep.py │ │ └── umbrella_length.py ├── logging │ ├── __init__.py │ ├── base.py │ ├── csv_load.py │ ├── csv_load_test.py │ ├── csv_logging.py │ ├── logging_utils.py │ ├── sqlite_load.py │ ├── sqlite_load_test.py │ ├── sqlite_logging.py │ ├── sqlite_logging_test.py │ └── terminal_logging.py ├── sweep.py ├── sweep_test.py ├── tests │ ├── __init__.py │ ├── environments_test.py │ └── sweep_test.py └── utils │ ├── __init__.py │ ├── datasets.py │ ├── gym_wrapper.py │ ├── gym_wrapper_test.py │ ├── plotting.py │ ├── smoothers.py │ ├── wrappers.py │ └── wrappers_test.py ├── reports ├── README.md ├── bsuite_appendix.tex ├── bsuite_preamble.tex ├── iclr_2019 │ ├── fancyhdr.sty │ ├── iclr2019_conference.bib │ ├── iclr2019_conference.bst │ ├── iclr2019_conference.sty │ ├── iclr2019_conference.tex │ ├── images │ │ ├── bar_plot.png │ │ └── radar_plot.png │ ├── math_commands.tex │ └── natbib.sty ├── icml_2019 │ ├── algorithm.sty │ ├── algorithmic.sty │ ├── example_paper.bib │ ├── example_paper.tex │ ├── fancyhdr.sty │ ├── icml2019.bst │ ├── icml2019.sty │ ├── icml_numpapers.eps │ └── images │ │ ├── bar_plot.png │ │ └── radar_plot.png ├── neurips_2019 │ ├── images │ │ ├── bar_plot.png │ │ └── radar_plot.png │ ├── neurips_2019.sty │ ├── neurips_2019.tex │ └── references.bib └── standalone │ ├── images │ ├── bar_plot.png │ └── radar_plot.png │ ├── references.bib │ └── standalone.tex ├── run_on_gcp.sh ├── setup.py └── test.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test-ubuntu: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [3.6, 3.7] 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | pip install --upgrade pip setuptools 20 | pip install . 21 | pip install .[baselines_jax] 22 | pip install .[baselines] 23 | pip install .[testing] 24 | - name: Check types with pytype 25 | run: | 26 | pytype -j "$(grep -c ^processor /proc/cpuinfo)" bsuite 27 | - name: Test with pytest 28 | run: | 29 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" bsuite 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | pip install --upgrade pip setuptools twine 19 | - name: Build and publish 20 | env: 21 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 22 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 23 | run: | 24 | python setup.py sdist 25 | twine upload dist/* 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Virtual environment files. 2 | bin/ 3 | lib/ 4 | lib64/ 5 | share/ 6 | pyvenv.cfg 7 | 8 | # IDE files. 9 | .vscode 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /bsuite/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Behaviour Suite for Reinforcement Learning.""" 17 | 18 | from . import bsuite as _bsuite 19 | from bsuite._metadata import __version__ 20 | 21 | load = _bsuite.load 22 | load_from_id = _bsuite.load_from_id 23 | load_and_record = _bsuite.load_and_record 24 | load_and_record_to_sqlite = _bsuite.load_and_record_to_sqlite 25 | load_and_record_to_csv = _bsuite.load_and_record_to_csv 26 | -------------------------------------------------------------------------------- /bsuite/_metadata.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Package metadata for bsuite. 17 | 18 | This is kept in a separate module so that it can be imported from setup.py, at 19 | a time when bsuite's dependencies may not have been installed yet. 20 | """ 21 | 22 | __version__ = '0.3.5' 23 | -------------------------------------------------------------------------------- /bsuite/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/base.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple agent interface.""" 17 | 18 | import abc 19 | import dm_env 20 | 21 | Action = int # Only discrete-action agents for now. 22 | 23 | 24 | class Agent(abc.ABC): 25 | """An agent consists of an action-selection mechanism and an update rule.""" 26 | 27 | @abc.abstractmethod 28 | def select_action(self, timestep: dm_env.TimeStep) -> Action: 29 | """Takes in a timestep, samples from agent's policy, returns an action.""" 30 | 31 | @abc.abstractmethod 32 | def update( 33 | self, 34 | timestep: dm_env.TimeStep, 35 | action: Action, 36 | new_timestep: dm_env.TimeStep, 37 | ) -> None: 38 | """Updates the agent given a transition.""" 39 | -------------------------------------------------------------------------------- /bsuite/baselines/experiment.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple agent-environment training loop.""" 17 | 18 | from bsuite.baselines import base 19 | from bsuite.logging import terminal_logging 20 | 21 | import dm_env 22 | 23 | 24 | def run(agent: base.Agent, 25 | environment: dm_env.Environment, 26 | num_episodes: int, 27 | verbose: bool = False) -> None: 28 | """Runs an agent on an environment. 29 | 30 | Note that for bsuite environments, logging is handled internally. 31 | 32 | Args: 33 | agent: The agent to train and evaluate. 34 | environment: The environment to train on. 35 | num_episodes: Number of episodes to train for. 36 | verbose: Whether to also log to terminal. 37 | """ 38 | 39 | if verbose: 40 | environment = terminal_logging.wrap_environment( 41 | environment, log_every=True) # pytype: disable=wrong-arg-types 42 | 43 | for _ in range(num_episodes): 44 | # Run an episode. 45 | timestep = environment.reset() 46 | while not timestep.last(): 47 | # Generate an action from the agent's policy. 48 | action = agent.select_action(timestep) 49 | 50 | # Step the environment. 51 | new_timestep = environment.step(action) 52 | 53 | # Tell the agent about what just happened. 54 | agent.update(timestep, action, new_timestep) 55 | 56 | # Book-keeping. 57 | timestep = new_timestep 58 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple actor-critic implementation in JAX.""" 17 | 18 | from bsuite.baselines.jax.actor_critic.agent import ActorCritic 19 | from bsuite.baselines.jax.actor_critic.agent import default_agent 20 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/actor_critic/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.jax import actor_critic 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = actor_critic.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/actor_critic_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple actor-critic implementation in JAX.""" 17 | 18 | from bsuite.baselines.jax.actor_critic_rnn.agent import ActorCriticRNN 19 | from bsuite.baselines.jax.actor_critic_rnn.agent import default_agent 20 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/actor_critic_rnn/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.jax import actor_critic_rnn 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = actor_critic_rnn.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/boot_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple DQN agent implemented in JAX.""" 17 | 18 | from bsuite.baselines.jax.boot_dqn.agent import BootstrappedDqn 19 | from bsuite.baselines.jax.boot_dqn.agent import default_agent 20 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/boot_dqn/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.jax import boot_dqn 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = boot_dqn.default_agent( 34 | env.observation_spec(), env.action_spec(), num_ensemble=2) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple DQN agent implemented in JAX.""" 17 | 18 | from bsuite.baselines.jax.dqn.agent import default_agent 19 | from bsuite.baselines.jax.dqn.agent import DQN 20 | -------------------------------------------------------------------------------- /bsuite/baselines/jax/dqn/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.jax import dqn 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = dqn.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/random/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """An agent that takes uniformly random actions.""" 17 | 18 | from bsuite.baselines.random.agent import default_agent 19 | from bsuite.baselines.random.agent import Random 20 | -------------------------------------------------------------------------------- /bsuite/baselines/random/agent.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """An agent that takes uniformly random actions.""" 17 | 18 | from typing import Optional 19 | 20 | from bsuite.baselines import base 21 | import dm_env 22 | from dm_env import specs 23 | import numpy as np 24 | 25 | 26 | class Random(base.Agent): 27 | """A random agent.""" 28 | 29 | def __init__(self, 30 | action_spec: specs.DiscreteArray, 31 | seed: Optional[int] = None): 32 | self._num_actions = action_spec.num_values 33 | self._rng = np.random.RandomState(seed) 34 | 35 | def select_action(self, timestep: dm_env.TimeStep) -> base.Action: 36 | del timestep 37 | return self._rng.randint(self._num_actions) 38 | 39 | def update(self, 40 | timestep: dm_env.TimeStep, 41 | action: base.Action, 42 | new_timestep: dm_env.TimeStep) -> None: 43 | del timestep 44 | del action 45 | del new_timestep 46 | 47 | 48 | def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, 49 | **kwargs) -> Random: 50 | del obs_spec # for compatibility 51 | return Random(action_spec=action_spec, **kwargs) 52 | -------------------------------------------------------------------------------- /bsuite/baselines/random/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines import random 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = random.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple TensorFlow 2-based implementation of the actor-critic algorithm.""" 17 | 18 | from bsuite.baselines.tf.actor_critic.agent import ActorCritic 19 | from bsuite.baselines.tf.actor_critic.agent import default_agent 20 | from bsuite.baselines.tf.actor_critic.agent import PolicyValueNet 21 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/actor_critic/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.tf import actor_critic 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = actor_critic.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/actor_critic_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple TensorFlow 2-based implementation of a recurrent actor-critic.""" 17 | 18 | from bsuite.baselines.tf.actor_critic_rnn.agent import ActorCriticRNN 19 | from bsuite.baselines.tf.actor_critic_rnn.agent import default_agent 20 | from bsuite.baselines.tf.actor_critic_rnn.agent import PolicyValueRNN 21 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/actor_critic_rnn/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.tf import actor_critic_rnn 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = actor_critic_rnn.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/boot_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple implementation of Bootstrapped DQN with prior networks.""" 17 | 18 | from bsuite.baselines.tf.boot_dqn.agent import BootstrappedDqn 19 | from bsuite.baselines.tf.boot_dqn.agent import default_agent 20 | from bsuite.baselines.tf.boot_dqn.agent import make_ensemble 21 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/boot_dqn/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.tf import boot_dqn 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = boot_dqn.default_agent( 34 | env.observation_spec(), env.action_spec(), num_ensemble=2) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple TensorFlow 2-based DQN implementation.""" 17 | 18 | from bsuite.baselines.tf.dqn.agent import default_agent 19 | from bsuite.baselines.tf.dqn.agent import DQN 20 | -------------------------------------------------------------------------------- /bsuite/baselines/tf/dqn/run_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Basic test coverage for agent training.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | from bsuite.baselines import experiment 24 | from bsuite.baselines.tf import dqn 25 | 26 | 27 | class RunTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters(*sweep.TESTING) 30 | def test_run(self, bsuite_id: str): 31 | env = bsuite.load_from_id(bsuite_id) 32 | 33 | agent = dqn.default_agent( 34 | env.observation_spec(), env.action_spec()) 35 | 36 | experiment.run( 37 | agent=agent, 38 | environment=env, 39 | num_episodes=5) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /bsuite/baselines/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/third_party/dopamine_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/third_party/openai_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/third_party/openai_ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/baselines/utils/pool.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Example of generating a full set of bsuite results using multiprocessing.""" 17 | 18 | from concurrent import futures 19 | import multiprocessing 20 | from typing import Callable, Optional, Sequence 21 | 22 | import termcolor 23 | import tqdm 24 | 25 | BsuiteId = str 26 | 27 | 28 | def map_mpi( 29 | run_fn: Callable[[BsuiteId], BsuiteId], 30 | bsuite_ids: Sequence[BsuiteId], 31 | num_processes: Optional[int] = None, 32 | ): 33 | """Maps `run_fn` over `bsuite_ids`, using `num_processes` in parallel.""" 34 | 35 | num_processes = num_processes or multiprocessing.cpu_count() 36 | num_experiments = len(bsuite_ids) 37 | 38 | message = """ 39 | Experiment info 40 | --------------- 41 | Num experiments: {num_experiments} 42 | Num worker processes: {num_processes} 43 | """.format( 44 | num_processes=num_processes, num_experiments=num_experiments) 45 | termcolor.cprint(message, color='blue', attrs=['bold']) 46 | 47 | # Create a pool of processes, dispatch the experiments to them, show progress. 48 | pool = futures.ProcessPoolExecutor(num_processes) 49 | progress_bar = tqdm.tqdm(total=num_experiments) 50 | 51 | for bsuite_id in pool.map(run_fn, bsuite_ids): 52 | description = '[Last finished: {}]'.format(bsuite_id) 53 | progress_bar.set_description(termcolor.colored(description, color='green')) 54 | progress_bar.update() 55 | -------------------------------------------------------------------------------- /bsuite/baselines/utils/replay_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.baselines.replay.""" 17 | 18 | from absl.testing import absltest 19 | 20 | from bsuite.baselines.utils import replay as replay_lib 21 | import numpy as np 22 | 23 | 24 | class BasicReplayTest(absltest.TestCase): 25 | 26 | def test_end_to_end(self): 27 | shapes = (10, 10, 3), () 28 | capacity = 5 29 | 30 | def generate_sample(): 31 | return [np.random.randint(0, 256, size=(10, 10, 3), dtype=np.uint8), 32 | np.random.uniform(size=())] 33 | 34 | replay = replay_lib.Replay(capacity=capacity) 35 | 36 | # Does it crash if we sample when there's barely any data? 37 | sample = generate_sample() 38 | replay.add(sample) 39 | samples = replay.sample(size=2) 40 | for sample, shape in zip(samples, shapes): 41 | self.assertEqual(sample.shape, (2,) + shape) 42 | 43 | # Fill to capacity. 44 | for _ in range(capacity - 1): 45 | replay.add(generate_sample()) 46 | samples = replay.sample(size=3) 47 | for sample, shape in zip(samples, shapes): 48 | self.assertEqual(sample.shape, (3,) + shape) 49 | 50 | replay.add(generate_sample()) 51 | samples = replay.sample(size=capacity) 52 | for sample, shape in zip(samples, shapes): 53 | self.assertEqual(sample.shape, (capacity,) + shape) 54 | 55 | 56 | if __name__ == '__main__': 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /bsuite/environments/README.md: -------------------------------------------------------------------------------- 1 | # Environments 2 | 3 | This folder contains the raw *environments* used in `bsuite` experiments; we 4 | expose them here for debugging and development purposes; 5 | 6 | Recall that in the context of bsuite, an *experiment* consists of three parts: 7 | 1. Environments: a fixed set of environments determined by some parameters. 2. 8 | Interaction: a fixed regime of agent/environment interaction (e.g. 100 9 | episodes). 3. Analysis: a fixed procedure that maps agent behaviour to results 10 | and plots. 11 | 12 | Note: If you load the environment from this folder you will miss out on the 13 | interaction+analysis as specified by bsuite. In general, you should use the 14 | `bsuite_id` to load the environment via `bsuite.load_from_id(bsuite_id)` rather 15 | than the raw environment. 16 | -------------------------------------------------------------------------------- /bsuite/environments/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """bsuite environments package.""" 17 | 18 | from bsuite.environments.base import Environment 19 | -------------------------------------------------------------------------------- /bsuite/environments/bandit_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.bandit.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import bandit 20 | from dm_env import test_utils 21 | import numpy as np 22 | 23 | 24 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 25 | 26 | def make_object_under_test(self): 27 | return bandit.SimpleBandit(5) 28 | 29 | def make_action_sequence(self): 30 | valid_actions = range(11) 31 | rng = np.random.RandomState(42) 32 | 33 | for _ in range(100): 34 | yield rng.choice(valid_actions) 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /bsuite/environments/cartpole_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.cartpole.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import cartpole 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return cartpole.Cartpole(seed=22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | 38 | if __name__ == '__main__': 39 | absltest.main() 40 | -------------------------------------------------------------------------------- /bsuite/environments/catch_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.catch.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import catch 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return catch.Catch(rows=10, columns=5) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/environments/deep_sea_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.deep_sea.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import deep_sea 20 | from dm_env import test_utils 21 | import numpy as np 22 | 23 | 24 | class DeepSeaInterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 25 | 26 | def make_object_under_test(self): 27 | return deep_sea.DeepSea(10) 28 | 29 | def make_action_sequence(self): 30 | valid_actions = [0, 1] 31 | rng = np.random.RandomState(42) 32 | 33 | for _ in range(100): 34 | yield rng.choice(valid_actions) 35 | 36 | 37 | class StochasticDeepSeaInterfaceTest(test_utils.EnvironmentTestMixin, 38 | absltest.TestCase): 39 | 40 | def make_object_under_test(self): 41 | return deep_sea.DeepSea(5, deterministic=False) 42 | 43 | def make_action_sequence(self): 44 | valid_actions = [0, 1] 45 | rng = np.random.RandomState(42) 46 | 47 | for _ in range(100): 48 | yield rng.choice(valid_actions) 49 | 50 | if __name__ == '__main__': 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /bsuite/environments/discounting_chain_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.discounting_chain.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import discounting_chain 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return discounting_chain.DiscountingChain(10) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2, 3, 4] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/environments/memory_chain_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.memory_len.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from bsuite.environments import memory_chain 21 | from dm_env import test_utils 22 | import numpy as np 23 | 24 | 25 | class MemoryLengthInterfaceTest(test_utils.EnvironmentTestMixin, 26 | parameterized.TestCase): 27 | 28 | def make_object_under_test(self): 29 | return memory_chain.MemoryChain(memory_length=10, num_bits=1) 30 | 31 | def make_action_sequence(self): 32 | valid_actions = [0, 1] 33 | rng = np.random.RandomState(42) 34 | 35 | for _ in range(100): 36 | yield rng.choice(valid_actions) 37 | 38 | 39 | class MemorySizeInterfaceTest(test_utils.EnvironmentTestMixin, 40 | parameterized.TestCase): 41 | 42 | def make_object_under_test(self): 43 | return memory_chain.MemoryChain(memory_length=2, num_bits=10) 44 | 45 | def make_action_sequence(self): 46 | valid_actions = [0, 1] 47 | rng = np.random.RandomState(42) 48 | 49 | for _ in range(100): 50 | yield rng.choice(valid_actions) 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /bsuite/environments/mnist_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.mnist.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import mnist 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class CatchInterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return mnist.MNISTBandit(seed=101) 29 | 30 | def make_action_sequence(self): 31 | num_actions = self.environment.action_spec().num_values 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.randint(num_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/environments/mountain_car_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.mountain_car.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import mountain_car 20 | from dm_env import test_utils 21 | import numpy as np 22 | 23 | 24 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 25 | 26 | def make_object_under_test(self): 27 | return mountain_car.MountainCar(2) 28 | 29 | def make_action_sequence(self): 30 | valid_actions = [0, 1, 2] 31 | rng = np.random.RandomState(42) 32 | 33 | for _ in range(100): 34 | yield rng.choice(valid_actions) 35 | 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/environments/umbrella_chain_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.umbrella_distract.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.environments import umbrella_chain 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class UmbrellaDistractInterfaceTest(test_utils.EnvironmentTestMixin, 26 | absltest.TestCase): 27 | 28 | def make_object_under_test(self): 29 | return umbrella_chain.UmbrellaChain(chain_length=20, n_distractor=22) 30 | 31 | def make_action_sequence(self): 32 | valid_actions = [0, 1] 33 | rng = np.random.RandomState(42) 34 | 35 | for _ in range(100): 36 | yield rng.choice(valid_actions) 37 | 38 | 39 | class UmbrellaLengthInterfaceTest(test_utils.EnvironmentTestMixin, 40 | absltest.TestCase): 41 | 42 | def make_object_under_test(self): 43 | return umbrella_chain.UmbrellaChain(chain_length=10, n_distractor=0) 44 | 45 | def make_action_sequence(self): 46 | valid_actions = [0, 1] 47 | rng = np.random.RandomState(42) 48 | 49 | for _ in range(100): 50 | yield rng.choice(valid_actions) 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /bsuite/experiments/README.md: -------------------------------------------------------------------------------- 1 | # `bsuite` Experiments 2 | 3 | This folder contains all of the experiments that constitute `bsuite`. Each 4 | experiment folder contains three files: 5 | 6 | 1. A mechanism for loading a RL environment that adheres to the [dm_env](https://github.com/deepmind/dm_env) 7 | interface (see `environments/` for the precise environment definitions). 8 | 1. `sweep.py`, which contains a list of different configurations of this 9 | environment over which the agent is tested. 10 | 1. `analysis.py`, which specifies how to 'score' the experiment, and 11 | provides utilities for generating relevant plots. 12 | 13 | Detailed descriptions of each experiment can be found in the Jupyter notebook in 14 | `analysis/results.ipynb`. 15 | -------------------------------------------------------------------------------- /bsuite/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit/bandit.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic bandit environment. 17 | 18 | Observation is a single pixel of 0 - this is an independent arm bandit problem! 19 | Rewards are [0, 0.1, .. 1] assigned randomly to 11 arms and deterministic 20 | """ 21 | 22 | from bsuite.environments import bandit 23 | 24 | load = bandit.SimpleBandit 25 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for bandit experiment.""" 17 | 18 | NUM_EPISODES = 10000 19 | 20 | SETTINGS = tuple({'mapping_seed': n} for n in range(20)) 21 | TAGS = ('basic',) 22 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_noise/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_noise/bandit_noise.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic bandit_noise challenge. 17 | 18 | Observation is a single pixel of 0 - this is an independent arm bandit problem! 19 | Rewards are np.linspace(0, 1, 11) with some level of reward noise. 20 | """ 21 | 22 | from bsuite.environments import bandit 23 | from bsuite.experiments.bandit import sweep 24 | from bsuite.utils import wrappers 25 | 26 | 27 | def load(noise_scale, seed, mapping_seed, num_actions=11): 28 | """Load a bandit_noise experiment with the prescribed settings.""" 29 | env = wrappers.RewardNoise( 30 | env=bandit.SimpleBandit(mapping_seed, num_actions=num_actions), 31 | noise_scale=noise_scale, 32 | seed=seed) 33 | env.bsuite_num_episodes = sweep.NUM_EPISODES 34 | return env 35 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_noise/bandit_noise_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.bandit_noise.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.bandit_noise import bandit_noise 20 | from dm_env import test_utils 21 | import numpy as np 22 | 23 | 24 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 25 | 26 | def make_object_under_test(self): 27 | return bandit_noise.load(1., 42, 42) 28 | 29 | def make_action_sequence(self): 30 | valid_actions = range(11) 31 | rng = np.random.RandomState(42) 32 | 33 | for _ in range(100): 34 | yield rng.choice(valid_actions) 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_noise/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for bandit_noise experiment.""" 17 | 18 | from bsuite.experiments.bandit import sweep as bandit_sweep 19 | 20 | NUM_EPISODES = bandit_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.1, 0.3, 1.0, 3., 10.]: 24 | for n in range(4): 25 | _settings.append({'noise_scale': scale, 'seed': None, 'mapping_seed': n}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('noise',) 29 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_scale/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_scale/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for bandit_scale environments.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.bandit import analysis as bandit_analysis 21 | from bsuite.experiments.bandit_noise import analysis as bandit_noise_analysis 22 | from bsuite.experiments.bandit_scale import sweep 23 | import pandas as pd 24 | import plotnine as gg 25 | 26 | 27 | NUM_EPISODES = sweep.NUM_EPISODES 28 | TAGS = sweep.TAGS 29 | 30 | 31 | def score(df: pd.DataFrame) -> float: 32 | return bandit_noise_analysis.score(df, scaling_var='reward_scale') 33 | 34 | 35 | def plot_learning(df: pd.DataFrame, 36 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 37 | return bandit_noise_analysis.plot_learning(df, sweep_vars, 'reward_scale') 38 | 39 | 40 | def plot_average(df: pd.DataFrame, 41 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 42 | return bandit_noise_analysis.plot_average(df, sweep_vars, 'reward_scale') 43 | 44 | 45 | def plot_seeds(df: pd.DataFrame, 46 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 47 | """Plot the performance by individual work unit.""" 48 | return bandit_analysis.plot_seeds( 49 | df_in=df, 50 | sweep_vars=sweep_vars, 51 | colour_var='reward_scale' 52 | ) + gg.ylab('average episodic return (after rescaling)') 53 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_scale/bandit_scale.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic bandit_scale challenge. 17 | 18 | Observation is a single pixel of 0 - this is an indep arm bandit problem! 19 | Rewards are np.linspace(0, 1, 11) with no noise, but rescaled. 20 | """ 21 | 22 | from bsuite.environments import bandit 23 | from bsuite.experiments.bandit import sweep 24 | from bsuite.utils import wrappers 25 | 26 | 27 | def load(reward_scale, seed, mapping_seed): 28 | """Load a bandit_scale experiment with the prescribed settings.""" 29 | env = wrappers.RewardScale( 30 | env=bandit.SimpleBandit(mapping_seed=mapping_seed), 31 | reward_scale=reward_scale, 32 | seed=seed) 33 | env.bsuite_num_episodes = sweep.NUM_EPISODES 34 | return env 35 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_scale/bandit_scale_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.bandit_scale.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.bandit_scale import bandit_scale 20 | from dm_env import test_utils 21 | import numpy as np 22 | 23 | 24 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 25 | 26 | def make_object_under_test(self): 27 | return bandit_scale.load(10, 42, 42) 28 | 29 | def make_action_sequence(self): 30 | valid_actions = range(11) 31 | rng = np.random.RandomState(42) 32 | 33 | for _ in range(100): 34 | yield rng.choice(valid_actions) 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /bsuite/experiments/bandit_scale/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for bandit_scale experiment.""" 17 | 18 | from bsuite.experiments.bandit import sweep as bandit_sweep 19 | 20 | NUM_EPISODES = bandit_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.001, 0.03, 1.0, 30., 1000.]: 24 | for n in range(4): 25 | _settings.append({'reward_scale': scale, 'seed': None, 'mapping_seed': n}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('scale',) 29 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole/cartpole.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """The Cartpole reinforcement learning environment.""" 17 | 18 | from bsuite.environments import cartpole 19 | 20 | load = cartpole.Cartpole 21 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for a balancing experiment in Cartpole.""" 17 | 18 | NUM_EPISODES = 1000 19 | 20 | SETTINGS = tuple({'seed': None} for _ in range(20)) 21 | TAGS = ('basic', 'credit_assignment', 'generalization') 22 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_noise/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_noise/cartpole_noise.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Cartpole environment with noisy rewards.""" 17 | 18 | from bsuite.environments import cartpole 19 | from bsuite.experiments.cartpole_noise import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(noise_scale, seed): 24 | """Load a cartpole experiment with the prescribed settings.""" 25 | env = wrappers.RewardNoise( 26 | env=cartpole.Cartpole(seed=seed), 27 | noise_scale=noise_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_noise/cartpole_noise_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.cartpole_noise.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.cartpole_noise import cartpole_noise 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return cartpole_noise.load(1., 22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_noise/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for cartpole_noise experiment.""" 17 | 18 | from bsuite.experiments.cartpole import sweep as cartpole_sweep 19 | 20 | NUM_EPISODES = cartpole_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.1, 0.3, 1.0, 3., 10.]: 24 | for seed in range(4): 25 | _settings.append({'noise_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('noise', 'generalization') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_scale/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_scale/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for cartpole_scale environments.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.cartpole import analysis as cartpole_analysis 21 | from bsuite.experiments.cartpole_noise import analysis as cartpole_noise_analysis 22 | from bsuite.experiments.cartpole_scale import sweep 23 | import pandas as pd 24 | import plotnine as gg 25 | 26 | NUM_EPISODES = sweep.NUM_EPISODES 27 | TAGS = sweep.TAGS 28 | 29 | 30 | def score(df: pd.DataFrame) -> float: 31 | return cartpole_noise_analysis.score(df, scaling_var='reward_scale') 32 | 33 | 34 | def plot_learning(df: pd.DataFrame, 35 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 36 | return cartpole_noise_analysis.plot_learning(df, sweep_vars, 'reward_scale') 37 | 38 | 39 | def plot_average(df: pd.DataFrame, 40 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 41 | return cartpole_noise_analysis.plot_average(df, sweep_vars, 'reward_scale') 42 | 43 | 44 | def plot_seeds(df: pd.DataFrame, 45 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 46 | """Plot the performance by individual work unit.""" 47 | return cartpole_analysis.plot_seeds( 48 | df_in=df, 49 | sweep_vars=sweep_vars, 50 | colour_var='reward_scale' 51 | ) + gg.ylab('average episodic return (after rescaling)') 52 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_scale/cartpole_scale.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Cartpole environment with scaled rewards.""" 17 | 18 | from bsuite.environments import cartpole 19 | from bsuite.experiments.cartpole_scale import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(reward_scale, seed): 24 | """Load a cartpole experiment with the prescribed settings.""" 25 | env = wrappers.RewardScale( 26 | env=cartpole.Cartpole(seed=seed), 27 | reward_scale=reward_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | 32 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_scale/cartpole_scale_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.cartpole_scale.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.cartpole_scale import cartpole_scale 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return cartpole_scale.load(10., 22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_scale/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for cartpole_scale experiment.""" 17 | 18 | from bsuite.experiments.cartpole import sweep as cartpole_sweep 19 | 20 | NUM_EPISODES = cartpole_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.001, 0.03, 1.0, 30., 1000.]: 24 | for seed in range(4): 25 | _settings.append({'reward_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('scale', 'generalization') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_swingup/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_swingup/cartpole_swingup_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.cartpole_swingup.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.cartpole_swingup import cartpole_swingup 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return cartpole_swingup.CartpoleSwingup(seed=42, height_threshold=0.8) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | 38 | if __name__ == '__main__': 39 | absltest.main() 40 | -------------------------------------------------------------------------------- /bsuite/experiments/cartpole_swingup/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for a swing up experiment in Cartpole.""" 17 | 18 | from bsuite.experiments.cartpole import sweep as cartpole_sweep 19 | 20 | NUM_EPISODES = cartpole_sweep.NUM_EPISODES 21 | 22 | SETTINGS = tuple({'height_threshold': n / 20, 'x_reward_threshold': 1 - n / 20} 23 | for n in range(20)) 24 | TAGS = ('exploration', 'generalization') 25 | -------------------------------------------------------------------------------- /bsuite/experiments/catch/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/catch/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for catch.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.catch import sweep 21 | from bsuite.utils import plotting 22 | import pandas as pd 23 | import plotnine as gg 24 | 25 | NUM_EPISODES = sweep.NUM_EPISODES 26 | BASE_REGRET = 1.6 27 | TAGS = sweep.TAGS 28 | 29 | 30 | def score(df: pd.DataFrame) -> float: 31 | """Output a single score for catch.""" 32 | return plotting.ave_regret_score( 33 | df, baseline_regret=BASE_REGRET, episode=sweep.NUM_EPISODES) 34 | 35 | 36 | def plot_learning(df: pd.DataFrame, 37 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 38 | """Simple learning curves for catch.""" 39 | p = plotting.plot_regret_learning( 40 | df, sweep_vars=sweep_vars, max_episode=sweep.NUM_EPISODES) 41 | p += gg.geom_hline( 42 | gg.aes(yintercept=BASE_REGRET), linetype='dashed', alpha=0.4, size=1.75) 43 | return p 44 | 45 | 46 | def plot_seeds(df_in: pd.DataFrame, 47 | sweep_vars: Optional[Sequence[str]] = None, 48 | colour_var: Optional[str] = None) -> gg.ggplot: 49 | """Plot the returns through time individually by run.""" 50 | df = df_in.copy() 51 | df['average_return'] = 1.0 - (df.total_regret.diff() / df.episode.diff()) 52 | p = plotting.plot_individual_returns( 53 | df_in=df, 54 | max_episode=NUM_EPISODES, 55 | return_column='average_return', 56 | colour_var=colour_var, 57 | yintercept=1., 58 | sweep_vars=sweep_vars, 59 | ) 60 | return p + gg.ylab('average episodic return') 61 | -------------------------------------------------------------------------------- /bsuite/experiments/catch/catch.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Catch reinforcement learning environment.""" 17 | 18 | from bsuite.environments import catch 19 | 20 | load = catch.Catch 21 | -------------------------------------------------------------------------------- /bsuite/experiments/catch/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for catch experiment.""" 17 | 18 | NUM_EPISODES = 10000 19 | 20 | SETTINGS = tuple({'seed': None} for _ in range(20)) 21 | TAGS = ('basic', 'credit_assignment') 22 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_noise/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_noise/catch_noise.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Catch environment with noisy rewards.""" 17 | 18 | from bsuite.environments import catch 19 | from bsuite.experiments.catch_noise import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(noise_scale, seed): 24 | """Load a catch experiment with the prescribed settings.""" 25 | env = wrappers.RewardNoise( 26 | env=catch.Catch(seed=seed), 27 | noise_scale=noise_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_noise/catch_noise_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.catch_noise.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.catch_noise import catch_noise 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return catch_noise.load(1., 22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_noise/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for catch_noise experiment.""" 17 | 18 | from bsuite.experiments.catch import sweep as catch_sweep 19 | 20 | NUM_EPISODES = catch_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.1, 0.3, 1.0, 3., 10.]: 24 | for seed in range(4): 25 | _settings.append({'noise_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('noise', 'credit_assignment') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_scale/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_scale/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for catch scale environments.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.catch import analysis as catch_analysis 21 | from bsuite.experiments.catch_noise import analysis as catch_noise_analysis 22 | from bsuite.experiments.catch_scale import sweep 23 | import pandas as pd 24 | import plotnine as gg 25 | 26 | NUM_EPISODES = sweep.NUM_EPISODES 27 | TAGS = sweep.TAGS 28 | 29 | 30 | def score(df: pd.DataFrame) -> float: 31 | return catch_noise_analysis.score(df, scaling_var='reward_scale') 32 | 33 | 34 | def plot_learning(df: pd.DataFrame, 35 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 36 | return catch_noise_analysis.plot_learning(df, sweep_vars, 'reward_scale') 37 | 38 | 39 | def plot_average(df: pd.DataFrame, 40 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 41 | return catch_noise_analysis.plot_average(df, sweep_vars, 'reward_scale') 42 | 43 | 44 | def plot_seeds(df: pd.DataFrame, 45 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 46 | """Plot the performance by individual work unit.""" 47 | return catch_analysis.plot_seeds( 48 | df_in=df, 49 | sweep_vars=sweep_vars, 50 | colour_var='reward_scale' 51 | ) + gg.ylab('average episodic return (after rescaling)') 52 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_scale/catch_scale.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Catch environment with scaled rewards.""" 17 | 18 | from bsuite.environments import catch 19 | from bsuite.experiments.catch_scale import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(reward_scale, seed): 24 | """Load a catch experiment with the prescribed settings.""" 25 | env = wrappers.RewardScale( 26 | env=catch.Catch(seed=seed), 27 | reward_scale=reward_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | 32 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_scale/catch_scale_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.catch_scale.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.catch_scale import catch_scale 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return catch_scale.load(10., 22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/catch_scale/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for catch_scale experiment.""" 17 | 18 | from bsuite.experiments.catch import sweep as catch_sweep 19 | 20 | NUM_EPISODES = catch_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.001, 0.03, 1.0, 30., 1000.]: 24 | for seed in range(4): 25 | _settings.append({'reward_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('scale', 'credit_assignment') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea/deep_sea.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Python implementation of 'Deep Sea' exploration environment.""" 17 | 18 | from bsuite.environments import deep_sea 19 | 20 | load = deep_sea.DeepSea 21 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for deep_sea experiment.""" 17 | 18 | NUM_EPISODES = 10000 19 | 20 | SETTINGS = tuple({'size': n, 'mapping_seed': 42} for n in range(10, 51, 2)) 21 | TAGS = ('exploration',) 22 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea_stochastic/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea_stochastic/deep_sea_stochastic.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Stochastic Deep Sea environment.""" 17 | 18 | from bsuite.environments import deep_sea 19 | from bsuite.experiments.deep_sea_stochastic import sweep 20 | 21 | 22 | def load(size: int, mapping_seed=0): 23 | """Load a deep sea experiment with the prescribed settings.""" 24 | env = deep_sea.DeepSea( 25 | size=size, 26 | deterministic=False, 27 | mapping_seed=mapping_seed, 28 | ) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | 32 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea_stochastic/deep_sea_stochastic_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.deep_sea_stochastic.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.deep_sea_stochastic import deep_sea_stochastic 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return deep_sea_stochastic.load(22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/deep_sea_stochastic/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for deep_sea_stochastic.""" 17 | 18 | from bsuite.experiments.deep_sea import sweep as deep_sea_sweep 19 | 20 | NUM_EPISODES = deep_sea_sweep.NUM_EPISODES 21 | 22 | SETTINGS = tuple({'size': n, 'mapping_seed': 42} for n in range(10, 51, 2)) 23 | TAGS = ('exploration', 'noise') 24 | -------------------------------------------------------------------------------- /bsuite/experiments/discounting_chain/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/discounting_chain/discounting_chain.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic discounting challenge. 17 | 18 | Observation is two pixels: (context, time_to_live) 19 | 20 | Context will only be -1 in the first step, then equal to the action selected in 21 | the first step. For all future decisions the agent is in a "chain" for that 22 | action. Reward of +1 come at one of: 1, 3, 10, 30, 100 23 | 24 | However, depending on the seed, one of these chains has a 10% bonus. 25 | """ 26 | 27 | from bsuite.environments import discounting_chain 28 | 29 | load = discounting_chain.DiscountingChain 30 | -------------------------------------------------------------------------------- /bsuite/experiments/discounting_chain/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for discounting_chain experiment.""" 17 | 18 | NUM_EPISODES = 1000 19 | 20 | SETTINGS = tuple({'mapping_seed': n} for n in range(20)) 21 | TAGS = ('credit_assignment',) 22 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_len/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_len/memory_len.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic memory challenge. 17 | 18 | Observation is given by n+1 pixels: (context, time_to_live). 19 | 20 | Context will only be nonzero in the first step, when it will be +1 or -1 iid 21 | by component. All actions take no effect until time_to_live=0, then the agent 22 | must repeat the observations that it saw bit-by-bit. 23 | """ 24 | 25 | from typing import Optional 26 | 27 | from bsuite.environments import memory_chain 28 | from bsuite.experiments.memory_len import sweep 29 | 30 | 31 | def load(memory_length: int, seed: Optional[int] = 0): 32 | """Memory Chain environment, with variable delay between cue and decision.""" 33 | env = memory_chain.MemoryChain( 34 | memory_length=memory_length, 35 | num_bits=1, 36 | seed=seed, 37 | ) 38 | env.bsuite_num_episodes = sweep.NUM_EPISODES 39 | return env 40 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_len/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for memory_len experiment.""" 17 | 18 | NUM_EPISODES = 10000 19 | 20 | _log_spaced = [] 21 | _log_spaced.extend(range(1, 11)) 22 | _log_spaced.extend([12, 14, 17, 20, 25]) 23 | _log_spaced.extend(range(30, 105, 10)) 24 | 25 | SETTINGS = tuple({'memory_length': n} for n in _log_spaced) 26 | TAGS = ('memory',) 27 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_size/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_size/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for memory_len experiment.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.memory_len import analysis as memory_len_analysis 21 | from bsuite.experiments.memory_size import sweep 22 | import pandas as pd 23 | import plotnine as gg 24 | 25 | NUM_EPISODES = sweep.NUM_EPISODES 26 | TAGS = sweep.TAGS 27 | 28 | 29 | def score(df: pd.DataFrame) -> float: 30 | return memory_len_analysis.score(df, group_col='num_bits') 31 | 32 | 33 | def plot_learning(df: pd.DataFrame, 34 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 35 | return memory_len_analysis.plot_learning(df, sweep_vars, 'num_bits') 36 | 37 | 38 | def plot_scale(df: pd.DataFrame, 39 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 40 | return memory_len_analysis.plot_scale(df, sweep_vars, 'num_bits') 41 | 42 | 43 | def plot_seeds(df: pd.DataFrame, 44 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 45 | return memory_len_analysis.plot_seeds( 46 | df_in=df[df.episode > 100], 47 | sweep_vars=sweep_vars, 48 | colour_var='num_bits', 49 | ) 50 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_size/memory_size.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic memory challenge. 17 | 18 | Observation is given by n+1 pixels: (context, time_to_live). 19 | 20 | Context will only be nonzero in the first step, when it will be +1 or -1 iid 21 | by component. All actions take no effect until time_to_live=0, then the agent 22 | must repeat the observations that it saw bit-by-bit. 23 | """ 24 | 25 | from typing import Optional 26 | 27 | from bsuite.environments import memory_chain 28 | from bsuite.experiments.memory_size import sweep 29 | 30 | 31 | def load(num_bits: int, seed: Optional[int] = 0): 32 | """Memory Chain environment, with variable number of bits.""" 33 | env = memory_chain.MemoryChain( 34 | memory_length=2, 35 | num_bits=num_bits, 36 | seed=seed, 37 | ) 38 | env.bsuite_num_episodes = sweep.NUM_EPISODES 39 | return env 40 | 41 | -------------------------------------------------------------------------------- /bsuite/experiments/memory_size/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for memory_len experiment.""" 17 | 18 | from bsuite.experiments.memory_len import sweep as memory_len_sweep 19 | 20 | NUM_EPISODES = memory_len_sweep.NUM_EPISODES 21 | 22 | _log_spaced = [] 23 | _log_spaced.extend(range(1, 11)) 24 | _log_spaced.extend([12, 14, 17, 20, 25]) 25 | _log_spaced.extend(range(30, 50, 10)) 26 | 27 | SETTINGS = tuple({'num_bits': n} for n in _log_spaced) 28 | TAGS = ('memory',) 29 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist/mnist.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """MNIST classification as a bandit. 17 | 18 | In this environment, we test the agent's generalization ability, and abstract 19 | away exploration/planning/memory etc -- i.e. a bandit, with no 'state'. 20 | """ 21 | 22 | from bsuite.environments import mnist 23 | 24 | load = mnist.MNISTBandit 25 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for mnist bandit experiment.""" 17 | 18 | NUM_EPISODES = 10000 19 | 20 | SETTINGS = tuple({'seed': None} for n in range(20)) 21 | TAGS = ('basic', 'generalization') 22 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_noise/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_noise/mnist_noise.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """MNIST bandit with noisy rewards.""" 17 | 18 | from bsuite.environments import mnist 19 | from bsuite.experiments.mnist_noise import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(noise_scale, seed): 24 | """Load a mnist_noise experiment with the prescribed settings.""" 25 | env = wrappers.RewardNoise( 26 | env=mnist.MNISTBandit(seed=seed), 27 | noise_scale=noise_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_noise/mnist_noise_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.mnist.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.mnist_noise import mnist_noise 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return mnist_noise.load(noise_scale=2.0, seed=101) 29 | 30 | def make_action_sequence(self): 31 | num_actions = self.environment.action_spec().num_values 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.randint(num_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_noise/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for mnist_noise experiment.""" 17 | 18 | from bsuite.experiments.mnist import sweep as mnist_sweep 19 | 20 | NUM_EPISODES = mnist_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.1, 0.3, 1.0, 3., 10.]: 24 | for seed in range(4): 25 | _settings.append({'noise_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('noise', 'generalization') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_scale/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_scale/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for mnist scale environments.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.mnist import analysis as mnist_analysis 21 | from bsuite.experiments.mnist_noise import analysis as mnist_noise_analysis 22 | from bsuite.experiments.mnist_scale import sweep 23 | import pandas as pd 24 | import plotnine as gg 25 | 26 | NUM_EPISODES = sweep.NUM_EPISODES 27 | TAGS = sweep.TAGS 28 | 29 | 30 | def score(df: pd.DataFrame) -> float: 31 | return mnist_noise_analysis.score(df, scaling_var='reward_scale') 32 | 33 | 34 | def plot_learning(df: pd.DataFrame, 35 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 36 | return mnist_noise_analysis.plot_learning(df, sweep_vars, 'reward_scale') 37 | 38 | 39 | def plot_average(df: pd.DataFrame, 40 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 41 | return mnist_noise_analysis.plot_average(df, sweep_vars, 'reward_scale') 42 | 43 | 44 | def plot_seeds(df: pd.DataFrame, 45 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 46 | """Plot the performance by individual work unit.""" 47 | return mnist_analysis.plot_seeds( 48 | df_in=df, 49 | sweep_vars=sweep_vars, 50 | colour_var='reward_scale' 51 | ) 52 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_scale/mnist_scale.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """MNIST bandit with reward scaling.""" 17 | 18 | from bsuite.environments import mnist 19 | from bsuite.experiments.mnist_scale import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(reward_scale, seed): 24 | """Load a bandit_scale experiment with the prescribed settings.""" 25 | env = wrappers.RewardScale( 26 | env=mnist.MNISTBandit(seed=seed), 27 | reward_scale=reward_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_scale/mnist_scale_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.mnist.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.mnist_scale import mnist_scale 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return mnist_scale.load(reward_scale=2.0, seed=101) 29 | 30 | def make_action_sequence(self): 31 | num_actions = self.environment.action_spec().num_values 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.randint(num_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/mnist_scale/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for mnist_scale experiment.""" 17 | 18 | from bsuite.experiments.mnist import sweep as mnist_sweep 19 | 20 | NUM_EPISODES = mnist_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.001, 0.03, 1.0, 30., 1000.]: 24 | for seed in range(4): 25 | _settings.append({'reward_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('scale', 'generalization') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car/mountain_car.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Python implementation of 'Mountain Car' environment. 17 | 18 | An underpowered car must drive up a hill, to succeed you must go back/forth. 19 | This is a classic environment in RL research, first described by: 20 | A Moore, Efficient Memory-Based Learning for Robot Control, 21 | PhD thesis, University of Cambridge, 1990. 22 | """ 23 | 24 | from bsuite.environments import mountain_car 25 | 26 | load = mountain_car.MountainCar 27 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for mountain_car experiment.""" 17 | 18 | NUM_EPISODES = 1000 19 | 20 | SETTINGS = tuple({'seed': None} for n in range(20)) 21 | TAGS = ('basic', 'generalization') 22 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_noise/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_noise/mountain_car_noise.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Mountain car environment with noisy rewards.""" 17 | 18 | from bsuite.environments import mountain_car 19 | from bsuite.experiments.mountain_car_noise import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(noise_scale, seed): 24 | """Load a mountain_car experiment with the prescribed settings.""" 25 | env = wrappers.RewardNoise( 26 | env=mountain_car.MountainCar(seed=seed), 27 | noise_scale=noise_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_noise/mountain_car_noise_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.mountain_car_noise.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.mountain_car_noise import mountain_car_noise 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return mountain_car_noise.load(1., 22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_noise/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for mountain car noise experiment.""" 17 | 18 | from bsuite.experiments.mountain_car import sweep as mountain_car_sweep 19 | 20 | NUM_EPISODES = mountain_car_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.1, 0.3, 1.0, 3., 10.]: 24 | for seed in range(4): 25 | _settings.append({'noise_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('noise', 'generalization') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_scale/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_scale/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for mountain_car_scale experiment.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.mountain_car import analysis as mc_analysis 21 | from bsuite.experiments.mountain_car_noise import analysis as mc_noise_analysis 22 | from bsuite.experiments.mountain_car_scale import sweep 23 | import pandas as pd 24 | import plotnine as gg 25 | 26 | NUM_EPISODES = sweep.NUM_EPISODES 27 | TAGS = sweep.TAGS 28 | 29 | 30 | def score(df: pd.DataFrame) -> float: 31 | return mc_noise_analysis.score(df, scaling_var='reward_scale') 32 | 33 | 34 | def plot_learning(df: pd.DataFrame, 35 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 36 | return mc_noise_analysis.plot_learning(df, sweep_vars, 'reward_scale') 37 | 38 | 39 | def plot_average(df: pd.DataFrame, 40 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 41 | return mc_noise_analysis.plot_average(df, sweep_vars, 'reward_scale') 42 | 43 | 44 | def plot_seeds(df: pd.DataFrame, 45 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 46 | """Plot the performance by individual work unit.""" 47 | return mc_analysis.plot_seeds( 48 | df_in=df, 49 | sweep_vars=sweep_vars, 50 | colour_var='reward_scale' 51 | ) + gg.ylab('average episodic return (after rescaling)') 52 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_scale/mountain_car_scale.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Mountain car scale reinforcement learning environment.""" 17 | 18 | from bsuite.environments import mountain_car 19 | from bsuite.experiments.mountain_car_noise import sweep 20 | from bsuite.utils import wrappers 21 | 22 | 23 | def load(reward_scale: float, seed: int): 24 | """Load a mountain_car experiment with the prescribed settings.""" 25 | env = wrappers.RewardScale( 26 | env=mountain_car.MountainCar(seed=seed), 27 | reward_scale=reward_scale, 28 | seed=seed) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_scale/mountain_car_scale_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.mountain_car_scale.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments.mountain_car_scale import mountain_car_scale 20 | from dm_env import test_utils 21 | 22 | import numpy as np 23 | 24 | 25 | class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 26 | 27 | def make_object_under_test(self): 28 | return mountain_car_scale.load(10., 22) 29 | 30 | def make_action_sequence(self): 31 | valid_actions = [0, 1, 2] 32 | rng = np.random.RandomState(42) 33 | 34 | for _ in range(100): 35 | yield rng.choice(valid_actions) 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /bsuite/experiments/mountain_car_scale/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for mountain_car_scale experiment.""" 17 | 18 | from bsuite.experiments.mountain_car import sweep as mountain_car_sweep 19 | 20 | NUM_EPISODES = mountain_car_sweep.NUM_EPISODES 21 | 22 | _settings = [] 23 | for scale in [0.001, 0.03, 1.0, 30., 1000.]: 24 | for seed in range(4): 25 | _settings.append({'reward_scale': scale, 'seed': None}) 26 | 27 | SETTINGS = tuple(_settings) 28 | TAGS = ('scale', 'generalization') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/summary_analysis_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.experiments.summary_analysis.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.experiments import summary_analysis 20 | 21 | 22 | class SummaryAnalysisTest(absltest.TestCase): 23 | 24 | def test_constants(self): 25 | self.assertNotEmpty(summary_analysis.BSUITE_INFO) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_distract/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_distract/analysis.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Analysis for umbrella_distract experiment.""" 17 | 18 | from typing import Optional, Sequence 19 | 20 | from bsuite.experiments.umbrella_distract import sweep 21 | from bsuite.experiments.umbrella_length import analysis as umbrella_length_analysis 22 | from bsuite.utils import plotting 23 | import pandas as pd 24 | import plotnine as gg 25 | 26 | NUM_EPISODES = sweep.NUM_EPISODES 27 | TAGS = sweep.TAGS 28 | 29 | 30 | def score(df: pd.DataFrame) -> float: 31 | return umbrella_length_analysis.score_by_group(df, 'n_distractor') 32 | 33 | 34 | def plot_learning(df: pd.DataFrame, 35 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 36 | """Plots the average regret through time.""" 37 | return plotting.plot_regret_group_nosmooth( 38 | df_in=df, 39 | group_col='n_distractor', 40 | sweep_vars=sweep_vars, 41 | max_episode=sweep.NUM_EPISODES, 42 | ) 43 | 44 | 45 | def plot_scale(df: pd.DataFrame, 46 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 47 | """Plots the average return at end of learning investigating scaling.""" 48 | return plotting.plot_regret_ave_scaling( 49 | df_in=df, 50 | group_col='n_distractor', 51 | episode=sweep.NUM_EPISODES, 52 | regret_thresh=0.5, 53 | sweep_vars=sweep_vars, 54 | ) 55 | 56 | 57 | def plot_seeds(df_in: pd.DataFrame, 58 | sweep_vars: Optional[Sequence[str]] = None) -> gg.ggplot: 59 | """Plot the returns through time individually by run.""" 60 | return umbrella_length_analysis.plot_seeds(df_in, sweep_vars, 'n_distractor') 61 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_distract/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for umbrella_distract experiment.""" 17 | 18 | from bsuite.experiments.umbrella_length import sweep as umbrella_length_sweep 19 | 20 | NUM_EPISODES = umbrella_length_sweep.NUM_EPISODES 21 | 22 | _log_spaced = [] 23 | _log_spaced.extend(range(1, 11)) 24 | _log_spaced.extend([12, 14, 17, 20, 25]) 25 | _log_spaced.extend(range(30, 105, 10)) 26 | 27 | SETTINGS = tuple({'n_distractor': n} for n in _log_spaced) 28 | TAGS = ('credit_assignment', 'noise') 29 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_distract/umbrella_distract.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Unbrella chain environment with varying distractor observations.""" 17 | 18 | from bsuite.environments import umbrella_chain 19 | from bsuite.experiments.umbrella_distract import sweep 20 | 21 | 22 | def load(n_distractor: int, seed=0): 23 | """Load a deep sea experiment with the prescribed settings.""" 24 | env = umbrella_chain.UmbrellaChain( 25 | chain_length=20, 26 | n_distractor=n_distractor, 27 | seed=seed, 28 | ) 29 | env.bsuite_num_episodes = sweep.NUM_EPISODES 30 | return env 31 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_length/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_length/sweep.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Sweep definition for umbrella_length experiment.""" 17 | 18 | NUM_EPISODES = 10000 19 | 20 | _log_spaced = [] 21 | _log_spaced.extend(range(1, 11)) 22 | _log_spaced.extend([12, 14, 17, 20, 25]) 23 | _log_spaced.extend(range(30, 105, 10)) 24 | 25 | 26 | SETTINGS = tuple({'chain_length': n, 'n_distractor': 20} for n in _log_spaced) 27 | TAGS = ('credit_assignment', 'noise') 28 | -------------------------------------------------------------------------------- /bsuite/experiments/umbrella_length/umbrella_length.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Simple diagnostic credit assigment challenge. 17 | 18 | Observation is 3 + n_distractor pixels: 19 | (need_umbrella, have_umbrella, time_to_live, n x distractors) 20 | 21 | Only the first action takes any effect (pick up umbrella or not). 22 | All other actions take no effect and the reward is +1, -1 on the final step. 23 | Distractor states are always Bernoulli sampled iid each step. 24 | """ 25 | 26 | from bsuite.environments import umbrella_chain 27 | 28 | load = umbrella_chain.UmbrellaChain 29 | -------------------------------------------------------------------------------- /bsuite/logging/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/logging/base.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """An abstract base class for loggers.""" 17 | 18 | import abc 19 | from typing import Any, Mapping 20 | 21 | 22 | class Logger(abc.ABC): 23 | """A logger has a `write` method.""" 24 | 25 | @abc.abstractmethod 26 | def write(self, data: Mapping[str, Any]): 27 | """Writes `data` to destination (file, terminal, database, etc).""" 28 | -------------------------------------------------------------------------------- /bsuite/logging/csv_load.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Read functionality for local csv-based experiments.""" 17 | 18 | import glob 19 | import os 20 | from typing import List, Tuple 21 | 22 | from bsuite import sweep 23 | from bsuite.logging import csv_logging 24 | from bsuite.logging import logging_utils 25 | import pandas as pd 26 | 27 | 28 | def load_one_result_set(results_dir: str) -> pd.DataFrame: 29 | """Returns a pandas DataFrame of bsuite results stored in results_dir.""" 30 | data = [] 31 | for file_path in glob.glob(os.path.join(results_dir, '*.csv')): 32 | _, name = os.path.split(file_path) 33 | # Rough and ready error-checking for only bsuite csv files. 34 | if not name.startswith(csv_logging.BSUITE_PREFIX): 35 | print('Warning - we recommend you use a fresh folder for bsuite results.') 36 | continue 37 | 38 | # Then we will assume that the file is actually a bsuite file 39 | df = pd.read_csv(file_path) 40 | file_bsuite_id = name.strip('.csv').split(csv_logging.INITIAL_SEPARATOR)[1] 41 | bsuite_id = file_bsuite_id.replace(csv_logging.SAFE_SEPARATOR, 42 | sweep.SEPARATOR) 43 | df['bsuite_id'] = bsuite_id 44 | df['results_dir'] = results_dir 45 | data.append(df) 46 | df = pd.concat(data, sort=False) 47 | return logging_utils.join_metadata(df) 48 | 49 | 50 | def load_bsuite( 51 | results_dirs: logging_utils.PathCollection 52 | ) -> Tuple[pd.DataFrame, List[str]]: 53 | """Returns a pandas DataFrame of bsuite results.""" 54 | return logging_utils.load_multiple_runs( 55 | path_collection=results_dirs, 56 | single_load_fn=load_one_result_set, 57 | ) 58 | -------------------------------------------------------------------------------- /bsuite/logging/csv_load_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.utils.csv_load.""" 17 | 18 | import random 19 | import sys 20 | 21 | from absl import flags 22 | from absl.testing import absltest 23 | from bsuite.logging import csv_load 24 | from bsuite.logging import csv_logging 25 | 26 | FLAGS = flags.FLAGS 27 | _NUM_WRITES = 10 28 | 29 | 30 | def generate_results(bsuite_id, results_dir): 31 | logger = csv_logging.Logger(bsuite_id, results_dir) 32 | steps_per_episode = 7 33 | total_return = 0.0 34 | for i in range(_NUM_WRITES): 35 | episode_return = random.random() 36 | total_return += episode_return 37 | data = dict( 38 | steps=i * steps_per_episode, 39 | episode=i, 40 | total_return=total_return, 41 | episode_len=steps_per_episode, 42 | episode_return=episode_return, 43 | extra=42, 44 | ) 45 | logger.write(data) 46 | 47 | 48 | class CsvLoadTest(absltest.TestCase): 49 | 50 | def test_logger(self): 51 | try: 52 | flags.FLAGS.test_tmpdir 53 | except flags.UnparsedFlagAccessError: 54 | # Need to initialize flags when running `pytest`. 55 | flags.FLAGS(sys.argv) 56 | results_dir = self.create_tempdir().full_path 57 | generate_results(bsuite_id='catch/0', results_dir=results_dir) 58 | generate_results(bsuite_id='catch/1', results_dir=results_dir) 59 | 60 | df = csv_load.load_one_result_set(results_dir=results_dir) 61 | self.assertLen(df, _NUM_WRITES * 2) 62 | 63 | # Check that sweep metadata is joined correctly. 64 | # Catch includes a 'seed' parameter, so we expect to see it here. 65 | self.assertIn('seed', df.columns) 66 | self.assertIn('bsuite_id', df.columns) 67 | 68 | 69 | if __name__ == '__main__': 70 | absltest.main() 71 | -------------------------------------------------------------------------------- /bsuite/logging/sqlite_load.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Read functionality for local SQLite-based experiments.""" 17 | 18 | from typing import List, Optional, Tuple 19 | 20 | from bsuite import sweep 21 | from bsuite.logging import logging_utils 22 | import pandas as pd 23 | import sqlite3 24 | 25 | 26 | def load_one_result_set( 27 | db_path: str, 28 | connection: Optional[sqlite3.Connection] = None) -> pd.DataFrame: 29 | """Returns a pandas DataFrame of bsuite results. 30 | 31 | Args: 32 | db_path: Path to the database file. 33 | connection: Optional connection, for testing purposes. If supplied, 34 | `db_path` will be ignored. 35 | 36 | Returns: 37 | A pandas DataFrame containing bsuite results. 38 | """ 39 | if connection is None: 40 | connection = sqlite3.connect(db_path) 41 | 42 | # Get a list of all table names in this database. 43 | query = 'select name from sqlite_master where type=\'table\';' 44 | with connection: 45 | table_names = connection.execute(query).fetchall() 46 | 47 | dataframes = [] 48 | for table_name in table_names: 49 | dataframe = pd.read_sql_query('select * from ' + table_name[0], connection) 50 | dataframe['bsuite_id'] = [ 51 | table_name[0] + sweep.SEPARATOR + str(setting_index) 52 | for setting_index in dataframe.setting_index] 53 | dataframes.append(dataframe) 54 | 55 | df = pd.concat(dataframes, sort=False) 56 | return logging_utils.join_metadata(df) 57 | 58 | 59 | def load_bsuite( 60 | results_dirs: logging_utils.PathCollection 61 | ) -> Tuple[pd.DataFrame, List[str]]: 62 | """Returns a pandas DataFrame of bsuite results.""" 63 | return logging_utils.load_multiple_runs( 64 | path_collection=results_dirs, 65 | single_load_fn=load_one_result_set, 66 | ) 67 | -------------------------------------------------------------------------------- /bsuite/logging/sqlite_load_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.utils.sqlite_load.""" 17 | 18 | import random 19 | 20 | from absl.testing import absltest 21 | from bsuite.logging import sqlite_load 22 | from bsuite.logging import sqlite_logging 23 | 24 | import sqlite3 25 | 26 | _NUM_WRITES = 10 27 | 28 | 29 | def generate_results(experiment_name, setting_index, connection): 30 | logger = sqlite_logging.Logger(db_path='unused', 31 | experiment_name=experiment_name, 32 | setting_index=setting_index, 33 | connection=connection) 34 | 35 | steps_per_episode = 7 36 | 37 | total_return = 0.0 38 | 39 | for i in range(_NUM_WRITES): 40 | episode_return = random.random() 41 | total_return += episode_return 42 | 43 | data = dict( 44 | steps=i * steps_per_episode, 45 | episode=i, 46 | total_return=total_return, 47 | episode_len=steps_per_episode, 48 | episode_return=episode_return, 49 | extra=42, 50 | ) 51 | logger.write(data) 52 | 53 | 54 | class SqliteLoadTest(absltest.TestCase): 55 | 56 | def test_logger(self): 57 | connection = sqlite3.connect(':memory:') 58 | 59 | generate_results( 60 | experiment_name='catch', setting_index=1, connection=connection) 61 | generate_results( 62 | experiment_name='catch', setting_index=2, connection=connection) 63 | 64 | df = sqlite_load.load_one_result_set(db_path='unused', 65 | connection=connection) 66 | self.assertLen(df, _NUM_WRITES * 2) 67 | 68 | # Check that sweep metadata is joined correctly. 69 | # Catch includes a 'seed' parameter, so we expect to see it here. 70 | self.assertIn('seed', df.columns) 71 | self.assertIn('bsuite_id', df.columns) 72 | 73 | 74 | if __name__ == '__main__': 75 | absltest.main() 76 | -------------------------------------------------------------------------------- /bsuite/logging/terminal_logging.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A simple logger that pretty-prints to terminal.""" 17 | 18 | import logging as std_logging 19 | import numbers 20 | from typing import Any, Mapping 21 | 22 | from absl import logging 23 | from bsuite import environments 24 | from bsuite.logging import base 25 | from bsuite.utils import wrappers 26 | import dm_env 27 | 28 | 29 | def wrap_environment(env: environments.Environment, 30 | pretty_print: bool = True, 31 | log_every: bool = False, 32 | log_by_step: bool = False) -> dm_env.Environment: 33 | """Returns a wrapped environment that logs to terminal.""" 34 | # Set logging up to show up in STDERR. 35 | std_logging.getLogger().addHandler(logging.PythonHandler()) 36 | logger = Logger(pretty_print, absl_logging=True) 37 | return wrappers.Logging( 38 | env, logger, log_by_step=log_by_step, log_every=log_every) 39 | 40 | 41 | class Logger(base.Logger): 42 | """Writes data to terminal.""" 43 | 44 | def __init__(self, pretty_print: bool = True, absl_logging: bool = False): 45 | self._pretty_print = pretty_print 46 | self._print_fn = logging.info if absl_logging else print 47 | 48 | def write(self, data: Mapping[str, Any]): 49 | """Writes to terminal, pretty-printing the results.""" 50 | 51 | if self._pretty_print: 52 | data = pretty_dict(data) 53 | 54 | self._print_fn(data) 55 | 56 | 57 | def pretty_dict(data: Mapping[str, Any]) -> str: 58 | """Prettifies a dictionary into a string as `k1 = v1 | ... | kn = vn`.""" 59 | msg = [] 60 | for key in sorted(data): 61 | value = value_format(data[key]) 62 | msg_pair = f'{key} = {value}' 63 | msg.append(msg_pair) 64 | 65 | return ' | '.join(msg) 66 | 67 | 68 | def value_format(value: Any) -> str: 69 | """Convenience function for string formatting.""" 70 | if isinstance(value, numbers.Integral): 71 | return str(value) 72 | if isinstance(value, numbers.Number): 73 | return f'{value:0.4f}' 74 | return str(value) 75 | -------------------------------------------------------------------------------- /bsuite/sweep_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for sweep.py.""" 17 | 18 | from absl.testing import absltest 19 | 20 | from bsuite import sweep 21 | from bsuite.experiments.bandit import sweep as bandit_sweep 22 | 23 | 24 | class SweepTest(absltest.TestCase): 25 | """Simple tests for sweeps.""" 26 | 27 | def test_sweep_contents(self): 28 | """Checks that all sweeps have sensible contents.""" 29 | 30 | test_bsuite_id = 'bandit/0' 31 | test_bsuite_id_1 = 'bandit/1' 32 | 33 | # Check `test_bsuite_id` is in BANDIT, SWEEP, and TESTING sweeps. 34 | self.assertIn(test_bsuite_id, sweep.BANDIT) 35 | self.assertIn(test_bsuite_id, sweep.SWEEP) 36 | self.assertIn(test_bsuite_id, sweep.TESTING) 37 | 38 | # `test_bsuite_id_1` should *not* be included in the testing sweep. 39 | self.assertNotIn(test_bsuite_id_1, sweep.TESTING) 40 | 41 | # Check all settings present in sweep. 42 | self.assertLen(sweep.BANDIT, len(bandit_sweep.SETTINGS)) 43 | 44 | # Check `test_bsuite_id` is found in the 'basic' TAG section. 45 | self.assertIn(test_bsuite_id, sweep.TAGS['basic']) 46 | 47 | def test_sweep_immutable(self): 48 | """Checks that all exposed sweeps are immutable.""" 49 | 50 | with self.assertRaises(TypeError): 51 | # pytype: disable=attribute-error 52 | # pytype: disable=unsupported-operands 53 | sweep.BANDIT[0] = 'new_bsuite_id' 54 | sweep.SWEEP[0] = 'new_bsuite_id' 55 | sweep.TESTING[0] = 'new_bsuite_id' 56 | sweep.TAGS['new_tag'] = 42 57 | # pytype: enable=unsupported-operands 58 | # pytype: enable=attribute-error 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /bsuite/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/tests/environments_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests that we can load all settings in sweep.py with bsuite.load.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from bsuite import bsuite 22 | from bsuite import sweep 23 | 24 | 25 | def _reduced_names_and_kwargs(): 26 | """Returns a subset of sweep.SETTINGS that covers all environment types.""" 27 | result = [] 28 | 29 | last_name = None 30 | last_keywords = None 31 | 32 | for bsuite_id, kwargs in sweep.SETTINGS.items(): 33 | name = bsuite_id.split(sweep.SEPARATOR)[0] 34 | keywords = set(kwargs) 35 | if name != last_name or keywords != last_keywords: 36 | if 'mnist' not in name: 37 | result.append((name, kwargs)) 38 | last_name = name 39 | last_keywords = keywords 40 | return result 41 | 42 | 43 | class EnvironmentsTest(parameterized.TestCase): 44 | 45 | @parameterized.parameters(*_reduced_names_and_kwargs()) 46 | def test_environment(self, name, settings): 47 | env = bsuite.load(name, settings) 48 | self.assertGreater(env.action_spec().num_values, 0) 49 | self.assertGreater(env.bsuite_num_episodes, 0) 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /bsuite/tests/sweep_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.sweep.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from bsuite import sweep 21 | 22 | 23 | class SweepTest(parameterized.TestCase): 24 | 25 | def test_access_sweep(self): 26 | self.assertNotEmpty(sweep.SETTINGS) 27 | 28 | def test_access_experiment_constants(self): 29 | self.assertNotEmpty(sweep.DEEP_SEA) 30 | 31 | @parameterized.parameters(*sweep.SETTINGS) 32 | def test_sweep_name_format(self, bsuite_id): 33 | self.assertIn(sweep.SEPARATOR, bsuite_id) 34 | split = bsuite_id.split(sweep.SEPARATOR) 35 | self.assertTrue(len(split), 2) 36 | self.assertNotEmpty(split[0]) 37 | self.assertNotEmpty(split[1]) 38 | 39 | if __name__ == '__main__': 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /bsuite/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | -------------------------------------------------------------------------------- /bsuite/utils/gym_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for bsuite.utils.gym_wrapper.""" 17 | 18 | from absl.testing import absltest 19 | from bsuite.utils import gym_wrapper 20 | 21 | from dm_env import specs 22 | import gym 23 | import numpy as np 24 | 25 | 26 | class DMEnvFromGymTest(absltest.TestCase): 27 | 28 | def test_gym_cartpole(self): 29 | env = gym_wrapper.DMEnvFromGym(gym.make('CartPole-v0')) 30 | 31 | # Test converted observation spec. 32 | observation_spec = env.observation_spec() 33 | self.assertEqual(type(observation_spec), specs.BoundedArray) 34 | self.assertEqual(observation_spec.shape, (4,)) 35 | self.assertEqual(observation_spec.minimum.shape, (4,)) 36 | self.assertEqual(observation_spec.maximum.shape, (4,)) 37 | self.assertEqual(observation_spec.dtype, np.dtype('float32')) 38 | 39 | # Test converted action spec. 40 | action_spec = env.action_spec() 41 | self.assertEqual(type(action_spec), specs.DiscreteArray) 42 | self.assertEqual(action_spec.shape, ()) 43 | self.assertEqual(action_spec.minimum, 0) 44 | self.assertEqual(action_spec.maximum, 1) 45 | self.assertEqual(action_spec.num_values, 2) 46 | self.assertEqual(action_spec.dtype, np.dtype('int64')) 47 | 48 | # Test step. 49 | timestep = env.reset() 50 | self.assertTrue(timestep.first()) 51 | timestep = env.step(1) 52 | self.assertEqual(timestep.reward, 1.0) 53 | self.assertEqual(timestep.observation.shape, (4,)) 54 | env.close() 55 | 56 | def test_episode_truncation(self): 57 | # Pendulum has no early termination condition. 58 | gym_env = gym.make('Pendulum-v0') 59 | env = gym_wrapper.DMEnvFromGym(gym_env) 60 | ts = env.reset() 61 | while not ts.last(): 62 | ts = env.step(env.action_spec().generate_value()) 63 | self.assertEqual(ts.discount, 1.0) 64 | env.close() 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /reports/README.md: -------------------------------------------------------------------------------- 1 | # bsuite experiment report 2 | 3 | This is a folder that maintains the latex template for auto-generating a bsuite report appendix, suitable for conference submission. 4 | To do this: 5 | - Fill in bsuite_preamble.tex with relevant links to colab/plots. 6 | - Write description of agents/algorithms in bsuite_appendix.tex together with some commentary on results. 7 | - use \input{} or copy/paste the bsuite_preamble.tex before your \begin{document} and bsuite_appendix.tex inside your document. 8 | - You can find examples of using the bsuite appendix with ICLR, ICML, NeurIPS templates as well as standalone pdf generation in the subfolders here. 9 | 10 | -------------------------------------------------------------------------------- /reports/iclr_2019/iclr2019_conference.bst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/iclr_2019/iclr2019_conference.bst -------------------------------------------------------------------------------- /reports/iclr_2019/images/bar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/iclr_2019/images/bar_plot.png -------------------------------------------------------------------------------- /reports/iclr_2019/images/radar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/iclr_2019/images/radar_plot.png -------------------------------------------------------------------------------- /reports/icml_2019/algorithm.sty: -------------------------------------------------------------------------------- 1 | % ALGORITHM STYLE -- Released 8 April 1996 2 | % for LaTeX-2e 3 | % Copyright -- 1994 Peter Williams 4 | % E-mail Peter.Williams@dsto.defence.gov.au 5 | \NeedsTeXFormat{LaTeX2e} 6 | \ProvidesPackage{algorithm} 7 | \typeout{Document Style `algorithm' - floating environment} 8 | 9 | \RequirePackage{float} 10 | \RequirePackage{ifthen} 11 | \newcommand{\ALG@within}{nothing} 12 | \newboolean{ALG@within} 13 | \setboolean{ALG@within}{false} 14 | \newcommand{\ALG@floatstyle}{ruled} 15 | \newcommand{\ALG@name}{Algorithm} 16 | \newcommand{\listalgorithmname}{List of \ALG@name s} 17 | 18 | % Declare Options 19 | % first appearance 20 | \DeclareOption{plain}{ 21 | \renewcommand{\ALG@floatstyle}{plain} 22 | } 23 | \DeclareOption{ruled}{ 24 | \renewcommand{\ALG@floatstyle}{ruled} 25 | } 26 | \DeclareOption{boxed}{ 27 | \renewcommand{\ALG@floatstyle}{boxed} 28 | } 29 | % then numbering convention 30 | \DeclareOption{part}{ 31 | \renewcommand{\ALG@within}{part} 32 | \setboolean{ALG@within}{true} 33 | } 34 | \DeclareOption{chapter}{ 35 | \renewcommand{\ALG@within}{chapter} 36 | \setboolean{ALG@within}{true} 37 | } 38 | \DeclareOption{section}{ 39 | \renewcommand{\ALG@within}{section} 40 | \setboolean{ALG@within}{true} 41 | } 42 | \DeclareOption{subsection}{ 43 | \renewcommand{\ALG@within}{subsection} 44 | \setboolean{ALG@within}{true} 45 | } 46 | \DeclareOption{subsubsection}{ 47 | \renewcommand{\ALG@within}{subsubsection} 48 | \setboolean{ALG@within}{true} 49 | } 50 | \DeclareOption{nothing}{ 51 | \renewcommand{\ALG@within}{nothing} 52 | \setboolean{ALG@within}{true} 53 | } 54 | \DeclareOption*{\edef\ALG@name{\CurrentOption}} 55 | 56 | % ALGORITHM 57 | % 58 | \ProcessOptions 59 | \floatstyle{\ALG@floatstyle} 60 | \ifthenelse{\boolean{ALG@within}}{ 61 | \ifthenelse{\equal{\ALG@within}{part}} 62 | {\newfloat{algorithm}{htbp}{loa}[part]}{} 63 | \ifthenelse{\equal{\ALG@within}{chapter}} 64 | {\newfloat{algorithm}{htbp}{loa}[chapter]}{} 65 | \ifthenelse{\equal{\ALG@within}{section}} 66 | {\newfloat{algorithm}{htbp}{loa}[section]}{} 67 | \ifthenelse{\equal{\ALG@within}{subsection}} 68 | {\newfloat{algorithm}{htbp}{loa}[subsection]}{} 69 | \ifthenelse{\equal{\ALG@within}{subsubsection}} 70 | {\newfloat{algorithm}{htbp}{loa}[subsubsection]}{} 71 | \ifthenelse{\equal{\ALG@within}{nothing}} 72 | {\newfloat{algorithm}{htbp}{loa}}{} 73 | }{ 74 | \newfloat{algorithm}{htbp}{loa} 75 | } 76 | \floatname{algorithm}{\ALG@name} 77 | 78 | \newcommand{\listofalgorithms}{\listof{algorithm}{\listalgorithmname}} 79 | 80 | -------------------------------------------------------------------------------- /reports/icml_2019/icml2019.bst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/icml_2019/icml2019.bst -------------------------------------------------------------------------------- /reports/icml_2019/images/bar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/icml_2019/images/bar_plot.png -------------------------------------------------------------------------------- /reports/icml_2019/images/radar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/icml_2019/images/radar_plot.png -------------------------------------------------------------------------------- /reports/neurips_2019/images/bar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/neurips_2019/images/bar_plot.png -------------------------------------------------------------------------------- /reports/neurips_2019/images/radar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/neurips_2019/images/radar_plot.png -------------------------------------------------------------------------------- /reports/neurips_2019/references.bib: -------------------------------------------------------------------------------- 1 | %% References for bsuite report 2 | 3 | @article{osband2019bsuite, 4 | title={Core {RL} Behaviour Suite}, 5 | author={Osband, Ian and Doron, Yotam and Hessel, Matteo and Aslanides, John and Van Hasselt, Hado and Sezener, Eren and Saraiva, Andre and Lattimore, Tor and Szepezvari, Csaba and Singh, Satinder and Van Roy, Benjamin and Sutton, Richard and Silver, David and}, 6 | year={2019}, 7 | } 8 | 9 | @inproceedings{osband2016deep, 10 | Author = {Osband, Ian and Blundell, Charles and Pritzel, Alexander and Van Roy, Benjamin}, 11 | Booktitle = {Advances In Neural Information Processing Systems 29}, 12 | Pages = {4026--4034}, 13 | Title = {Deep exploration via bootstrapped {DQN}}, 14 | Year = {2016}} 15 | 16 | @incollection{osband2018rpf, 17 | title = {Randomized Prior Functions for Deep Reinforcement Learning}, 18 | author = {Osband, Ian and Aslanides, John and Cassirer, Albin}, 19 | booktitle = {Advances in Neural Information Processing Systems 31}, 20 | editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett}, 21 | pages = {8617--8629}, 22 | year = {2018}, 23 | publisher = {Curran Associates, Inc.}, 24 | url = {http://papers.nips.cc/paper/8080-randomized-prior-functions-for-deep-reinforcement-learning.pdf} 25 | } 26 | 27 | @article{mnih2015human, 28 | Author = {Mnih, Volodymyr and Kavukcuoglu, Koray and Silver, David and Rusu, Andrei A and Veness, Joel and Bellemare, Marc G and Graves, Alex and Riedmiller, Martin and Fidjeland, Andreas K and Ostrovski, Georg and others}, 29 | Date-Added = {2018-05-18 14:55:54 +0000}, 30 | Date-Modified = {2018-05-18 14:55:54 +0000}, 31 | Journal = {Nature}, 32 | Number = {7540}, 33 | Pages = {529--533}, 34 | Publisher = {Nature Research}, 35 | Title = {Human-level control through deep reinforcement learning}, 36 | Volume = {518}, 37 | Year = {2015}} 38 | 39 | @inproceedings{mnih2016asynchronous, 40 | Author = {Mnih, Volodymyr and Badia, Adria Puigdomenech and Mirza, Mehdi and Graves, Alex and Lillicrap, Timothy and Harley, Tim and Silver, David and Kavukcuoglu, Koray}, 41 | Booktitle = {Proc. of ICML}, 42 | Date-Added = {2018-05-18 14:55:54 +0000}, 43 | Date-Modified = {2018-05-18 14:55:54 +0000}, 44 | Title = {Asynchronous methods for deep reinforcement learning}, 45 | Year = {2016}} 46 | -------------------------------------------------------------------------------- /reports/standalone/images/bar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/standalone/images/bar_plot.png -------------------------------------------------------------------------------- /reports/standalone/images/radar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/bsuite/6d8f64997ca256473c3d10be021431facc5a14d7/reports/standalone/images/radar_plot.png -------------------------------------------------------------------------------- /reports/standalone/references.bib: -------------------------------------------------------------------------------- 1 | %% References for bsuite report 2 | 3 | @article{osband2019bsuite, 4 | title={Core {RL} Behaviour Suite}, 5 | author={Osband, Ian and Doron, Yotam and Hessel, Matteo and Aslanides, John and Van Hasselt, Hado and Sezener, Eren and Saraiva, Andre and Lattimore, Tor and Szepezvari, Csaba and Singh, Satinder and Van Roy, Benjamin and Sutton, Richard and Silver, David and}, 6 | year={2019}, 7 | } 8 | 9 | @inproceedings{osband2016deep, 10 | Author = {Osband, Ian and Blundell, Charles and Pritzel, Alexander and Van Roy, Benjamin}, 11 | Booktitle = {Advances In Neural Information Processing Systems 29}, 12 | Pages = {4026--4034}, 13 | Title = {Deep exploration via bootstrapped {DQN}}, 14 | Year = {2016}} 15 | 16 | @incollection{osband2018rpf, 17 | title = {Randomized Prior Functions for Deep Reinforcement Learning}, 18 | author = {Osband, Ian and Aslanides, John and Cassirer, Albin}, 19 | booktitle = {Advances in Neural Information Processing Systems 31}, 20 | editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett}, 21 | pages = {8617--8629}, 22 | year = {2018}, 23 | publisher = {Curran Associates, Inc.}, 24 | url = {http://papers.nips.cc/paper/8080-randomized-prior-functions-for-deep-reinforcement-learning.pdf} 25 | } 26 | 27 | @article{mnih2015human, 28 | Author = {Mnih, Volodymyr and Kavukcuoglu, Koray and Silver, David and Rusu, Andrei A and Veness, Joel and Bellemare, Marc G and Graves, Alex and Riedmiller, Martin and Fidjeland, Andreas K and Ostrovski, Georg and others}, 29 | Date-Added = {2018-05-18 14:55:54 +0000}, 30 | Date-Modified = {2018-05-18 14:55:54 +0000}, 31 | Journal = {Nature}, 32 | Number = {7540}, 33 | Pages = {529--533}, 34 | Publisher = {Nature Research}, 35 | Title = {Human-level control through deep reinforcement learning}, 36 | Volume = {518}, 37 | Year = {2015}} 38 | 39 | @inproceedings{mnih2016asynchronous, 40 | Author = {Mnih, Volodymyr and Badia, Adria Puigdomenech and Mirza, Mehdi and Graves, Alex and Lillicrap, Timothy and Harley, Tim and Silver, David and Kavukcuoglu, Koray}, 41 | Booktitle = {Proc. of ICML}, 42 | Date-Added = {2018-05-18 14:55:54 +0000}, 43 | Date-Modified = {2018-05-18 14:55:54 +0000}, 44 | Title = {Asynchronous methods for deep reinforcement learning}, 45 | Year = {2016}} 46 | -------------------------------------------------------------------------------- /reports/standalone/standalone.tex: -------------------------------------------------------------------------------- 1 | \documentclass[10pt]{article} 2 | 3 | 4 | \input{../bsuite_preamble} 5 | 6 | 7 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% FORMATTING OPTIONS 8 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 9 | \usepackage{xcolor} 10 | \usepackage[margin=2cm]{geometry} 11 | \usepackage{titlesec} 12 | \definecolor{darkblue}{RGB}{0,0,140} 13 | \usepackage[colorlinks=true, allcolors=darkblue]{hyperref} 14 | 15 | \titlespacing{\section}{0pt}{1ex}{0ex} 16 | \titlespacing*{\subsection}{0pt}{1ex}{0ex} 17 | \setlength{\parskip}{6pt}% 18 | \setlength{\parindent}{0pt}% 19 | 20 | 21 | 22 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% DOCUMENT START 23 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 24 | 25 | 26 | 27 | \begin{document} 28 | 29 | \appendix 30 | \input{../bsuite_appendix} 31 | 32 | 33 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% BIBLIOGRAPHY 34 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 35 | { 36 | \small 37 | \bibliographystyle{plain} 38 | \bibliography{references} 39 | } 40 | 41 | 42 | \end{document} 43 | -------------------------------------------------------------------------------- /run_on_gcp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Follow the instructions in README.md to setup Cloud SDK first. 3 | 4 | # gcloud settings below 5 | export IMAGE_FAMILY="tf-1-13-cpu" 6 | export ZONE="us-west1-b" 7 | export INSTANCE_NAME="bsuite$RANDOM" 8 | export MACHINE_TYPE="n1-highcpu-64" # or n1-highcpu-8 for debugging etc 9 | 10 | # run settings below 11 | export SCRIPT_TO_RUN="~/bsuite/bsuite/baselines/dqn/run.py" 12 | export BSUITE_ENV="SWEEP" 13 | 14 | SECONDS=0 15 | 16 | set -e 17 | 18 | gcloud compute instances create $INSTANCE_NAME \ 19 | --zone=$ZONE \ 20 | --image-family=$IMAGE_FAMILY \ 21 | --image-project=deeplearning-platform-release \ 22 | --machine-type=$MACHINE_TYPE 23 | 24 | until gcloud compute ssh $INSTANCE_NAME --command "git clone \ 25 | https://github.com/deepmind/bsuite.git" --zone $ZONE &> /dev/null 26 | do 27 | echo "Waiting for the instance to be initiated." 28 | sleep 10 29 | done 30 | 31 | gcloud compute ssh $INSTANCE_NAME --command "sudo apt-get install python3-pip" --zone $ZONE 32 | gcloud compute ssh $INSTANCE_NAME --command "sudo pip3 install virtualenv" --zone $ZONE 33 | gcloud compute ssh $INSTANCE_NAME --command "virtualenv -p /usr/bin/python3 bsuite_env" 34 | gcloud compute ssh $INSTANCE_NAME --command "source ~/bsuite_env/bin/activate \ 35 | && pip3 install ~/bsuite/[baselines]" --zone $ZONE 36 | 37 | gcloud compute ssh $INSTANCE_NAME --command "nohup bash -c 'source \ 38 | ~/bsuite_env/bin/activate && python3 $SCRIPT_TO_RUN \ 39 | --bsuite_id=$BSUITE_ENV --logging_mode=sqlite > /dev/null 2>&1 \ 40 | && touch /tmp/bsuite_completed.txt > /dev/null 2>&1' 1>/dev/null \ 41 | 2>/dev/null &" --zone $ZONE 42 | 43 | 44 | until gcloud compute ssh $INSTANCE_NAME --command "cat /tmp/bsuite_completed.txt" \ 45 | --zone $ZONE &> /dev/null 46 | do 47 | echo "Waiting for jobs to be completed." 48 | sleep 60 49 | done 50 | 51 | gcloud compute scp --recurse $INSTANCE_NAME:/tmp/bsuite.db /tmp/bsuite.db --zone $ZONE 52 | 53 | echo "Experiments completed!" 54 | 55 | gcloud compute instances stop $INSTANCE_NAME \ 56 | --zone=$ZONE 57 | 58 | echo "The experiment took $SECONDS seconds." 59 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Fail on any error. 4 | set -e 5 | # Display commands being run. 6 | set -x 7 | 8 | # Set up a new virtual environment. 9 | python3 -m venv bsuite_testing 10 | source bsuite_testing/bin/activate 11 | 12 | # Install all dependencies. 13 | pip install --upgrade pip setuptools 14 | pip install . 15 | pip install .[baselines_jax] 16 | pip install .[baselines] 17 | 18 | # Install test dependencies. 19 | pip install .[testing] 20 | 21 | N_CPU=$(grep -c ^processor /proc/cpuinfo) 22 | 23 | # Run static type-checking. 24 | pytype -j "${N_CPU}" bsuite 25 | 26 | # Run all tests. 27 | pytest -n "${N_CPU}" bsuite 28 | 29 | # Clean-up. 30 | deactivate 31 | rm -rf bsuite_testing/ 32 | --------------------------------------------------------------------------------