├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── .pylintrc ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── acme ├── __init__.py ├── _metadata.py ├── adders │ ├── __init__.py │ ├── base.py │ ├── reverb │ │ ├── __init__.py │ │ ├── base.py │ │ ├── episode.py │ │ ├── episode_test.py │ │ ├── sequence.py │ │ ├── sequence_test.py │ │ ├── structured.py │ │ ├── structured_test.py │ │ ├── test_cases.py │ │ ├── test_utils.py │ │ ├── transition.py │ │ ├── transition_test.py │ │ └── utils.py │ └── wrappers.py ├── agents │ ├── __init__.py │ ├── agent.py │ ├── jax │ │ ├── __init__.py │ │ ├── actor_core.py │ │ ├── actors.py │ │ ├── actors_test.py │ │ ├── ail │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── builder_test.py │ │ │ ├── config.py │ │ │ ├── dac.py │ │ │ ├── gail.py │ │ │ ├── learning.py │ │ │ ├── learning_test.py │ │ │ ├── losses.py │ │ │ ├── losses_test.py │ │ │ ├── networks.py │ │ │ └── rewards.py │ │ ├── ars │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── bc │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── agent_test.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ ├── pretraining.py │ │ │ └── pretraining_test.py │ │ ├── builders.py │ │ ├── bve │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── losses.py │ │ │ └── networks.py │ │ ├── cql │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── agent_test.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── crr │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── agent_test.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── losses.py │ │ │ └── networks.py │ │ ├── d4pg │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── dqn │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── actor.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── learning_lib.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ └── rainbow.py │ │ ├── impala │ │ │ ├── __init__.py │ │ │ ├── acting.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── networks.py │ │ │ └── types.py │ │ ├── lfd │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── lfd_adder.py │ │ │ ├── lfd_adder_test.py │ │ │ ├── sacfd.py │ │ │ └── td3fd.py │ │ ├── mbop │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acting.py │ │ │ ├── agent_test.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── dataset.py │ │ │ ├── dataset_test.py │ │ │ ├── ensemble.py │ │ │ ├── ensemble_test.py │ │ │ ├── learning.py │ │ │ ├── losses.py │ │ │ ├── models.py │ │ │ ├── mppi.py │ │ │ ├── mppi_test.py │ │ │ └── networks.py │ │ ├── mpo │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acting.py │ │ │ ├── builder.py │ │ │ ├── categorical_mpo.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── networks.py │ │ │ ├── rollout_loss.py │ │ │ ├── types.py │ │ │ └── utils.py │ │ ├── multiagent │ │ │ ├── __init__.py │ │ │ └── decentralized │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── actor.py │ │ │ │ ├── builder.py │ │ │ │ ├── config.py │ │ │ │ ├── factories.py │ │ │ │ └── learner_set.py │ │ ├── normalization.py │ │ ├── ppo │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── networks.py │ │ │ └── normalization.py │ │ ├── pwil │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── adder.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ └── rewarder.py │ │ ├── r2d2 │ │ │ ├── __init__.py │ │ │ ├── actor.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── rnd │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── sac │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── sqil │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ └── builder_test.py │ │ ├── td3 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ ├── value_dice │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ └── networks.py │ │ └── wpo │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acting.py │ │ │ ├── builder.py │ │ │ ├── config.py │ │ │ ├── learning.py │ │ │ ├── networks.py │ │ │ ├── types.py │ │ │ └── utils.py │ ├── replay.py │ └── tf │ │ ├── __init__.py │ │ ├── actors.py │ │ ├── actors_test.py │ │ ├── bc │ │ ├── README.md │ │ ├── __init__.py │ │ └── learning.py │ │ ├── bcq │ │ ├── README.md │ │ ├── __init__.py │ │ ├── discrete_learning.py │ │ └── discrete_learning_test.py │ │ ├── crr │ │ ├── __init__.py │ │ └── recurrent_learning.py │ │ ├── d4pg │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ ├── learning.py │ │ └── networks.py │ │ ├── ddpg │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── dmpo │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── dqfd │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_test.py │ │ └── bsuite_demonstrations.py │ │ ├── dqn │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── impala │ │ ├── README.md │ │ ├── __init__.py │ │ ├── acting.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── iqn │ │ ├── README.md │ │ ├── __init__.py │ │ ├── learning.py │ │ └── learning_test.py │ │ ├── mcts │ │ ├── README.md │ │ ├── __init__.py │ │ ├── acting.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_test.py │ │ ├── learning.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── mlp.py │ │ │ ├── simulator.py │ │ │ └── simulator_test.py │ │ ├── search.py │ │ ├── search_test.py │ │ └── types.py │ │ ├── mog_mpo │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent_distributed.py │ │ ├── learning.py │ │ └── networks.py │ │ ├── mompo │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── mpo │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── r2d2 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ └── learning.py │ │ ├── r2d3 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agent.py │ │ └── agent_test.py │ │ └── svg0_prior │ │ ├── README.md │ │ ├── __init__.py │ │ ├── acting.py │ │ ├── agent.py │ │ ├── agent_distributed.py │ │ ├── agent_distributed_test.py │ │ ├── agent_test.py │ │ ├── learning.py │ │ ├── networks.py │ │ └── utils.py ├── core.py ├── core_test.py ├── datasets │ ├── __init__.py │ ├── image_augmentation.py │ ├── numpy_iterator.py │ ├── numpy_iterator_test.py │ ├── reverb.py │ ├── reverb_benchmark.py │ └── tfds.py ├── environment_loop.py ├── environment_loop_test.py ├── environment_loops │ ├── __init__.py │ ├── open_spiel_environment_loop.py │ └── open_spiel_environment_loop_test.py ├── jax │ ├── __init__.py │ ├── experiments │ │ ├── __init__.py │ │ ├── config.py │ │ ├── make_distributed_experiment.py │ │ ├── make_distributed_offline_experiment.py │ │ ├── run_experiment.py │ │ ├── run_experiment_test.py │ │ ├── run_offline_experiment.py │ │ ├── run_offline_experiment_test.py │ │ └── test_utils.py │ ├── imitation_learning_types.py │ ├── inference_server.py │ ├── losses │ │ ├── __init__.py │ │ ├── impala.py │ │ ├── impala_test.py │ │ ├── mpo.py │ │ └── wpo.py │ ├── networks │ │ ├── __init__.py │ │ ├── atari.py │ │ ├── base.py │ │ ├── continuous.py │ │ ├── distributional.py │ │ ├── duelling.py │ │ ├── embedding.py │ │ ├── multiplexers.py │ │ ├── policy_value.py │ │ ├── rescaling.py │ │ └── resnet.py │ ├── observation_stacking.py │ ├── running_statistics.py │ ├── running_statistics_test.py │ ├── savers.py │ ├── savers_test.py │ ├── snapshotter.py │ ├── snapshotter_test.py │ ├── types.py │ ├── utils.py │ ├── utils_test.py │ ├── variable_utils.py │ └── variable_utils_test.py ├── multiagent │ ├── __init__.py │ ├── types.py │ ├── utils.py │ └── utils_test.py ├── specs.py ├── testing │ ├── __init__.py │ ├── fakes.py │ ├── multiagent_fakes.py │ └── test_utils.py ├── tf │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── distributional.py │ │ ├── distributional_test.py │ │ ├── dpg.py │ │ ├── huber.py │ │ ├── mompo.py │ │ ├── mpo.py │ │ ├── quantile.py │ │ └── r2d2.py │ ├── networks │ │ ├── __init__.py │ │ ├── atari.py │ │ ├── base.py │ │ ├── continuous.py │ │ ├── discrete.py │ │ ├── distributional.py │ │ ├── distributional_test.py │ │ ├── distributions.py │ │ ├── distributions_test.py │ │ ├── duelling.py │ │ ├── embedding.py │ │ ├── legal_actions.py │ │ ├── masked_epsilon_greedy.py │ │ ├── multihead.py │ │ ├── multiplexers.py │ │ ├── noise.py │ │ ├── policy_value.py │ │ ├── quantile.py │ │ ├── recurrence.py │ │ ├── recurrence_test.py │ │ ├── rescaling.py │ │ ├── stochastic.py │ │ └── vision.py │ ├── savers.py │ ├── savers_test.py │ ├── utils.py │ ├── utils_test.py │ ├── variable_utils.py │ └── variable_utils_test.py ├── types.py ├── utils │ ├── __init__.py │ ├── async_utils.py │ ├── counting.py │ ├── counting_test.py │ ├── experiment_utils.py │ ├── frozen_learner.py │ ├── frozen_learner_test.py │ ├── iterator_utils.py │ ├── iterator_utils_test.py │ ├── loggers │ │ ├── __init__.py │ │ ├── aggregators.py │ │ ├── asynchronous.py │ │ ├── auto_close.py │ │ ├── base.py │ │ ├── base_test.py │ │ ├── constant.py │ │ ├── csv.py │ │ ├── csv_test.py │ │ ├── dataframe.py │ │ ├── default.py │ │ ├── filters.py │ │ ├── filters_test.py │ │ ├── flatten.py │ │ ├── image.py │ │ ├── image_test.py │ │ ├── terminal.py │ │ ├── terminal_test.py │ │ ├── tf_summary.py │ │ └── timestamp.py │ ├── lp_utils.py │ ├── lp_utils_test.py │ ├── metrics.py │ ├── observers │ │ ├── __init__.py │ │ ├── action_metrics.py │ │ ├── action_metrics_test.py │ │ ├── action_norm.py │ │ ├── action_norm_test.py │ │ ├── base.py │ │ ├── env_info.py │ │ ├── env_info_test.py │ │ ├── measurement_metrics.py │ │ └── measurement_metrics_test.py │ ├── paths.py │ ├── paths_test.py │ ├── reverb_utils.py │ ├── reverb_utils_test.py │ ├── signals.py │ ├── tree_utils.py │ └── tree_utils_test.py └── wrappers │ ├── __init__.py │ ├── action_repeat.py │ ├── atari_wrapper.py │ ├── atari_wrapper_dopamine.py │ ├── atari_wrapper_test.py │ ├── base.py │ ├── base_test.py │ ├── canonical_spec.py │ ├── concatenate_observations.py │ ├── delayed_reward.py │ ├── delayed_reward_test.py │ ├── expand_scalar_observation_shapes.py │ ├── frame_stacking.py │ ├── frame_stacking_test.py │ ├── gym_wrapper.py │ ├── gym_wrapper_test.py │ ├── mujoco.py │ ├── multiagent_dict_key_wrapper.py │ ├── multigrid_wrapper.py │ ├── noop_starts.py │ ├── noop_starts_test.py │ ├── observation_action_reward.py │ ├── open_spiel_wrapper.py │ ├── open_spiel_wrapper_test.py │ ├── single_precision.py │ ├── single_precision_test.py │ ├── step_limit.py │ ├── step_limit_test.py │ └── video.py ├── docs ├── _static │ └── custom.css ├── conf.py ├── faq.md ├── imgs │ ├── acme-notext.png │ ├── acme.png │ └── configure-and-run-experiments.png ├── index.rst ├── requirements.txt └── user │ ├── agents.md │ ├── components.md │ ├── diagrams │ ├── actor_loop.png │ ├── agent_loop.png │ ├── batch_loop.png │ ├── distributed_loop.png │ └── environment_loop.png │ ├── logos │ ├── jax-small.png │ └── tf-small.png │ └── overview.md ├── examples ├── README.md ├── baselines │ ├── README.md │ ├── imitation │ │ ├── README.md │ │ ├── helpers.py │ │ ├── run_bc.py │ │ ├── run_gail.py │ │ ├── run_iqlearn.py │ │ ├── run_pwil.py │ │ └── run_sqil.py │ ├── offline_rl │ │ └── README.md │ ├── rl_continuous │ │ ├── README.md │ │ ├── helpers.py │ │ ├── run_d4pg.py │ │ ├── run_dmpo.py │ │ ├── run_mogmpo.py │ │ ├── run_mpo.py │ │ ├── run_ppo.py │ │ ├── run_sac.py │ │ ├── run_td3.py │ │ └── run_wpo.py │ ├── rl_discrete │ │ ├── README.md │ │ ├── helpers.py │ │ ├── run_dqn.py │ │ ├── run_impala.py │ │ ├── run_mdqn.py │ │ ├── run_muzero.py │ │ ├── run_qr_dqn.py │ │ └── run_r2d2.py │ └── rlfd │ │ └── README.md ├── bsuite │ ├── run_dqn.py │ ├── run_impala.py │ └── run_mcts.py ├── multiagent │ └── multigrid │ │ ├── helpers.py │ │ └── run_multigrid.py ├── offline │ ├── bc_utils.py │ ├── run_bc.py │ ├── run_bc_jax.py │ ├── run_bcq.py │ ├── run_cql_jax.py │ ├── run_crr_jax.py │ ├── run_dqfd.py │ ├── run_mbop_jax.py │ └── run_offline_td3_jax.py ├── open_spiel │ └── run_dqn.py ├── quickstart.ipynb ├── tf │ └── control_suite │ │ ├── helpers.py │ │ ├── lp_d4pg.py │ │ ├── lp_ddpg.py │ │ ├── lp_dmpo.py │ │ ├── lp_dmpo_pixels.py │ │ ├── lp_dmpo_pixels_drqv2.py │ │ └── lp_mpo.py └── tutorial.ipynb ├── setup.py └── test.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: acme-tests 2 | on: [push, pull_request, workflow_dispatch] 3 | 4 | jobs: 5 | run-acme-tests: 6 | runs-on: ubuntu-latest 7 | strategy: 8 | matrix: 9 | docker-image: ["python:3.8", "python:3.9"] 10 | steps: 11 | - name: Checkout acme 12 | uses: actions/checkout@v2 13 | - name: Run tests in docker 14 | run: | 15 | docker run --mount "type=bind,src=$(pwd),dst=/tmp/acme" \ 16 | -w "/tmp/acme" --rm ${{ matrix.docker-image }} /bin/bash test.sh 17 | 18 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | pip install --upgrade pip setuptools twine 19 | - name: Build and publish 20 | env: 21 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 22 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 23 | run: | 24 | python setup.py sdist 25 | twine upload dist/* 26 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration. 2 | version: 2 3 | sphinx: 4 | configuration: docs/conf.py 5 | python: 6 | install: 7 | - requirements: docs/requirements.txt 8 | 9 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /acme/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Acme is a framework for reinforcement learning.""" 16 | 17 | # Internal import. 18 | 19 | # Expose specs and types modules. 20 | from acme import specs 21 | from acme import types 22 | 23 | # Make __version__ accessible. 24 | from acme._metadata import __version__ 25 | 26 | # Expose core interfaces. 27 | from acme.core import Actor 28 | from acme.core import Learner 29 | from acme.core import Saveable 30 | from acme.core import VariableSource 31 | from acme.core import Worker 32 | 33 | # Expose the environment loop. 34 | from acme.environment_loop import EnvironmentLoop 35 | 36 | from acme.specs import make_environment_spec 37 | 38 | -------------------------------------------------------------------------------- /acme/_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Package metadata for acme. 16 | 17 | This is kept in a separate module so that it can be imported from setup.py, at 18 | a time when acme's dependencies may not have been installed yet. 19 | """ 20 | 21 | # We follow Semantic Versioning (https://semver.org/) 22 | _MAJOR_VERSION = '0' 23 | _MINOR_VERSION = '4' 24 | _PATCH_VERSION = '1' 25 | 26 | # Example: '0.4.2' 27 | __version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) 28 | -------------------------------------------------------------------------------- /acme/adders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Adders for sending data from actors to replay buffers.""" 16 | 17 | # pylint: disable=unused-import 18 | 19 | from acme.adders.base import Adder 20 | from acme.adders.wrappers import ForkingAdder 21 | from acme.adders.wrappers import IgnoreExtrasAdder 22 | -------------------------------------------------------------------------------- /acme/adders/reverb/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Adders for Reverb replay buffers.""" 16 | 17 | # pylint: disable=unused-import 18 | 19 | from acme.adders.reverb.base import DEFAULT_PRIORITY_TABLE 20 | from acme.adders.reverb.base import PriorityFn 21 | from acme.adders.reverb.base import PriorityFnInput 22 | from acme.adders.reverb.base import ReverbAdder 23 | from acme.adders.reverb.base import Step 24 | from acme.adders.reverb.base import Trajectory 25 | 26 | from acme.adders.reverb.episode import EpisodeAdder 27 | from acme.adders.reverb.sequence import EndBehavior 28 | from acme.adders.reverb.sequence import SequenceAdder 29 | from acme.adders.reverb.structured import create_n_step_transition_config 30 | from acme.adders.reverb.structured import create_step_spec 31 | from acme.adders.reverb.structured import n_step_from_trajectory 32 | from acme.adders.reverb.structured import StructuredAdder 33 | from acme.adders.reverb.transition import NStepTransitionAdder 34 | -------------------------------------------------------------------------------- /acme/adders/reverb/transition_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for NStepTransition adders.""" 16 | 17 | from acme.adders.reverb import test_cases 18 | from acme.adders.reverb import test_utils 19 | from acme.adders.reverb import transition as adders 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | 25 | class NStepTransitionAdderTest(test_utils.AdderTestMixin, 26 | parameterized.TestCase): 27 | 28 | @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) 29 | def test_adder(self, n_step, additional_discount, first, steps, 30 | expected_transitions): 31 | adder = adders.NStepTransitionAdder(self.client, n_step, 32 | additional_discount) 33 | super().run_test_adder( 34 | adder=adder, 35 | first=first, 36 | steps=steps, 37 | expected_items=expected_transitions, 38 | stack_sequence_fields=False, 39 | signature=adder.signature(*test_utils.get_specs(steps[0]))) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /acme/adders/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A library of useful adder wrappers.""" 16 | 17 | from typing import Iterable 18 | 19 | from acme import types 20 | from acme.adders import base 21 | import dm_env 22 | 23 | 24 | class ForkingAdder(base.Adder): 25 | """An adder that forks data into several other adders.""" 26 | 27 | def __init__(self, adders: Iterable[base.Adder]): 28 | self._adders = adders 29 | 30 | def reset(self): 31 | for adder in self._adders: 32 | adder.reset() 33 | 34 | def add_first(self, timestep: dm_env.TimeStep): 35 | for adder in self._adders: 36 | adder.add_first(timestep) 37 | 38 | def add(self, 39 | action: types.NestedArray, 40 | next_timestep: dm_env.TimeStep, 41 | extras: types.NestedArray = ()): 42 | for adder in self._adders: 43 | adder.add(action, next_timestep, extras) 44 | 45 | 46 | class IgnoreExtrasAdder(base.Adder): 47 | """An adder that ignores extras.""" 48 | 49 | def __init__(self, adder: base.Adder): 50 | self._adder = adder 51 | 52 | def reset(self): 53 | self._adder.reset() 54 | 55 | def add_first(self, timestep: dm_env.TimeStep): 56 | self._adder.add_first(timestep) 57 | 58 | def add(self, 59 | action: types.NestedArray, 60 | next_timestep: dm_env.TimeStep, 61 | extras: types.NestedArray = ()): 62 | self._adder.add(action, next_timestep) 63 | -------------------------------------------------------------------------------- /acme/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Agent implementations.""" 16 | -------------------------------------------------------------------------------- /acme/agents/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """JAX agents.""" 16 | -------------------------------------------------------------------------------- /acme/agents/jax/ail/README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Imitation Learning (AIL) 2 | 3 | This folder contains a modular implementation of an Adversarial 4 | Imitation Learning agent. 5 | The initial algorithm is Generative Adversarial Imitation Learning 6 | (GAIL - [Ho et al., 2016]), but many more tricks and variations are 7 | available. 8 | The corresponding paper ([Orsini et al., 2021]) explains and discusses 9 | the utility of all those tricks. 10 | 11 | AIL requires an off-policy RL algorithm to work, passed in as an 12 | `ActorLearnerBuilder`. 13 | 14 | If you use this code, please cite [Orsini et al., 2021]. 15 | 16 | [Ho et al., 2016]: https://arxiv.org/abs/1606.03476 17 | [Orsini et al., 2021]: https://arxiv.org/abs/2106.00672 18 | -------------------------------------------------------------------------------- /acme/agents/jax/ail/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a AIL agent.""" 16 | 17 | from acme.agents.jax.ail import losses 18 | from acme.agents.jax.ail import rewards 19 | from acme.agents.jax.ail.builder import AILBuilder 20 | from acme.agents.jax.ail.config import AILConfig 21 | from acme.agents.jax.ail.dac import DACBuilder 22 | from acme.agents.jax.ail.dac import DACConfig 23 | from acme.agents.jax.ail.gail import GAILBuilder 24 | from acme.agents.jax.ail.gail import GAILConfig 25 | from acme.agents.jax.ail.learning import AILLearner 26 | from acme.agents.jax.ail.networks import AILNetworks 27 | from acme.agents.jax.ail.networks import AIRLModule 28 | from acme.agents.jax.ail.networks import compute_ail_reward 29 | from acme.agents.jax.ail.networks import DiscriminatorMLP 30 | from acme.agents.jax.ail.networks import DiscriminatorModule 31 | from acme.agents.jax.ail.networks import make_discriminator 32 | -------------------------------------------------------------------------------- /acme/agents/jax/ail/gail.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Builder for GAIL. 16 | 17 | https://arxiv.org/pdf/1606.03476.pdf 18 | """ 19 | 20 | import dataclasses 21 | from typing import Callable, Iterator 22 | 23 | from acme import types 24 | from acme.agents.jax import actor_core as actor_core_lib 25 | from acme.agents.jax import ppo 26 | from acme.agents.jax.ail import builder 27 | from acme.agents.jax.ail import config as ail_config 28 | from acme.agents.jax.ail import losses 29 | 30 | 31 | @dataclasses.dataclass 32 | class GAILConfig: 33 | """Configuration options specific to GAIL.""" 34 | ail_config: ail_config.AILConfig 35 | ppo_config: ppo.PPOConfig 36 | 37 | 38 | class GAILBuilder(builder.AILBuilder[ppo.PPONetworks, 39 | actor_core_lib.FeedForwardPolicyWithExtra] 40 | ): 41 | """GAIL Builder.""" 42 | 43 | def __init__(self, config: GAILConfig, 44 | make_demonstrations: Callable[[int], 45 | Iterator[types.Transition]]): 46 | 47 | ppo_builder = ppo.PPOBuilder(config.ppo_config) 48 | super().__init__( 49 | ppo_builder, 50 | config=config.ail_config, 51 | discriminator_loss=losses.gail_loss(), 52 | make_demonstrations=make_demonstrations) 53 | -------------------------------------------------------------------------------- /acme/agents/jax/ars/README.md: -------------------------------------------------------------------------------- 1 | # Augmented Random Search (ARS) 2 | 3 | This folder contains an implementation of the ARS algorithm 4 | ([Mania et al., 2018]). 5 | 6 | 7 | [Mania et al., 2018]: https://arxiv.org/pdf/1803.07055.pdf 8 | -------------------------------------------------------------------------------- /acme/agents/jax/ars/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ARS agent.""" 16 | 17 | from acme.agents.jax.ars.builder import ARSBuilder 18 | from acme.agents.jax.ars.config import ARSConfig 19 | from acme.agents.jax.ars.networks import make_networks 20 | from acme.agents.jax.ars.networks import make_policy_network 21 | -------------------------------------------------------------------------------- /acme/agents/jax/ars/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ARS config.""" 16 | import dataclasses 17 | 18 | from acme.adders import reverb as adders_reverb 19 | 20 | 21 | @dataclasses.dataclass 22 | class ARSConfig: 23 | """Configuration options for ARS.""" 24 | num_steps: int = 1000000 25 | normalize_observations: bool = True 26 | step_size: float = 0.015 27 | num_directions: int = 60 28 | exploration_noise_std: float = 0.025 29 | top_directions: int = 20 30 | reward_shift: float = 1.0 31 | replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE 32 | -------------------------------------------------------------------------------- /acme/agents/jax/ars/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ARS networks definition.""" 16 | 17 | from typing import Tuple 18 | 19 | from acme import specs 20 | from acme.jax import networks as networks_lib 21 | import jax.numpy as jnp 22 | 23 | 24 | BEHAVIOR_PARAMS_NAME = 'policy' 25 | EVAL_PARAMS_NAME = 'eval' 26 | 27 | 28 | def make_networks( 29 | spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: 30 | """Creates networks used by the agent. 31 | 32 | The model used by the ARS paper is a simple clipped linear model. 33 | 34 | Args: 35 | spec: an environment spec 36 | 37 | Returns: 38 | A FeedForwardNetwork network. 39 | """ 40 | 41 | obs_size = spec.observations.shape[0] 42 | act_size = spec.actions.shape[0] 43 | return networks_lib.FeedForwardNetwork( 44 | init=lambda _: jnp.zeros((obs_size, act_size)), 45 | apply=lambda matrix, obs: jnp.clip(jnp.matmul(obs, matrix), -1, 1)) 46 | 47 | 48 | def make_policy_network( 49 | network: networks_lib.FeedForwardNetwork, 50 | eval_mode: bool = True) -> Tuple[str, networks_lib.FeedForwardNetwork]: 51 | params_name = EVAL_PARAMS_NAME if eval_mode else BEHAVIOR_PARAMS_NAME 52 | return (params_name, network) 53 | -------------------------------------------------------------------------------- /acme/agents/jax/bc/README.md: -------------------------------------------------------------------------------- 1 | # Behavioral Cloning (BC) 2 | 3 | This folder contains an implementation for supervised learning of a policy from 4 | a dataset of observations and target actions. This is an approach of Imitation 5 | Learning known as Behavioral Cloning, introduced by [Pomerleau, 1989]. 6 | 7 | Several losses are implemented: 8 | 9 | * Mean squared error (mse) 10 | * Cross entropy (logp) 11 | * Peer Behavioral Cloning (peerbc), a regularization scheme from 12 | [Wang et al., 2021] 13 | * Reward-regularized Classification for Apprenticeship Learning (rcal), 14 | another regularization scheme from [Piot et al., 2014], defined for discrete 15 | action environments (or discretized action-spaces in case of continuous 16 | control). 17 | 18 | [Pomerleau, 1989]: https://papers.nips.cc/paper/95-alvinn-an-autonomous-land-vehicle-in-a-neural-network.pdf 19 | [Wang et al., 2021]: https://arxiv.org/pdf/2010.01748.pdf 20 | [Piot et al., 2014]: https://www.cristal.univ-lille.fr/~pietquin/pdf/AAMAS_2014_BPMGOP.pdf 21 | -------------------------------------------------------------------------------- /acme/agents/jax/bc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of a behavior cloning (BC) agent.""" 16 | 17 | from acme.agents.jax.bc import pretraining 18 | from acme.agents.jax.bc.builder import BCBuilder 19 | from acme.agents.jax.bc.config import BCConfig 20 | from acme.agents.jax.bc.learning import BCLearner 21 | from acme.agents.jax.bc.losses import BCLoss 22 | from acme.agents.jax.bc.losses import logp 23 | from acme.agents.jax.bc.losses import mse 24 | from acme.agents.jax.bc.losses import peerbc 25 | from acme.agents.jax.bc.losses import rcal 26 | from acme.agents.jax.bc.networks import BCNetworks 27 | from acme.agents.jax.bc.networks import BCPolicyNetwork 28 | from acme.agents.jax.bc.networks import convert_policy_value_to_bc_network 29 | from acme.agents.jax.bc.networks import convert_to_bc_network 30 | -------------------------------------------------------------------------------- /acme/agents/jax/bc/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config classes for BC.""" 16 | import dataclasses 17 | 18 | 19 | @dataclasses.dataclass 20 | class BCConfig: 21 | """Configuration options for BC. 22 | 23 | Attributes: 24 | learning_rate: Learning rate. 25 | num_sgd_steps_per_step: How many gradient updates to perform per step. 26 | """ 27 | learning_rate: float = 1e-4 28 | num_sgd_steps_per_step: int = 1 29 | -------------------------------------------------------------------------------- /acme/agents/jax/bve/README.md: -------------------------------------------------------------------------------- 1 | # Behavior Value Estimation (BVE) 2 | 3 | This folder contains the implementation BVE algorithm [1]. BVE is an offline RL 4 | algorithm that estimates the behavior value of the policy in the offline 5 | dataset during the training. When deployed in an environment BVE does a single 6 | step of policy improvement. It is a value based method. The original paper also 7 | introduced regularizers to have conservative value estimates. 8 | 9 | For simplicity of implementation the `rlax` sarsa loss function is used in 10 | `loss.py`. The network in `networks.py` is the typical DQN architecture. 11 | 12 | [1] Gulcehre et al., Regularized Behavior Value Estimation, 2021, https://arxiv.org/abs/2103.09575. 13 | -------------------------------------------------------------------------------- /acme/agents/jax/bve/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of a behavior value estimation (BVE).""" 16 | 17 | from acme.agents.jax.bve.builder import BVEBuilder 18 | from acme.agents.jax.bve.config import BVEConfig 19 | from acme.agents.jax.bve.losses import BVELoss 20 | from acme.agents.jax.bve.networks import BVENetworks 21 | -------------------------------------------------------------------------------- /acme/agents/jax/bve/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Network definitions for BVE.""" 16 | 17 | import dataclasses 18 | from typing import Optional 19 | 20 | from acme.jax import networks as networks_lib 21 | 22 | 23 | @dataclasses.dataclass 24 | class BVENetworks: 25 | """The network and pure functions for the BVE agent. 26 | 27 | Attributes: 28 | policy_network: The policy network. 29 | sample_fn: A pure function. Samples an action based on the network output. 30 | log_prob: A pure function. Computes log-probability for an action. 31 | """ 32 | policy_network: networks_lib.TypedFeedForwardNetwork 33 | sample_fn: networks_lib.SampleFn 34 | log_prob: Optional[networks_lib.LogProbFn] = None 35 | -------------------------------------------------------------------------------- /acme/agents/jax/cql/README.md: -------------------------------------------------------------------------------- 1 | # Conservative Q-Learning (CQL) 2 | 3 | CQL (1) is an offline RL algorithm. It is based on an offline version of SAC 4 | with an additional regularizing ("conservative") component in the critic loss. 5 | 6 | (1) [Kumar et al., *Conservative Q-Learning for Offline Reinforcement Learning*, 7 | 2020](https://arxiv.org/abs/2006.04779) 8 | -------------------------------------------------------------------------------- /acme/agents/jax/cql/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of the CQL agent.""" 16 | 17 | from acme.agents.jax.cql.builder import CQLBuilder 18 | from acme.agents.jax.cql.config import CQLConfig 19 | from acme.agents.jax.cql.learning import CQLLearner 20 | from acme.agents.jax.cql.networks import CQLNetworks 21 | from acme.agents.jax.cql.networks import make_networks 22 | -------------------------------------------------------------------------------- /acme/agents/jax/crr/README.md: -------------------------------------------------------------------------------- 1 | # Critic Regularized Regression (CRR) 2 | 3 | This folder contains an implementation of the CRR algorithm 4 | ([Wang et al., 2020]). It is an offline RL algorithm to learn policies from data 5 | using a form of critic-regularized regression. 6 | 7 | For the advantage estimate, a sampled mean is used. See policy.py file for 8 | possible weighting coefficients for the policy loss (including exponential 9 | estimated advantage). The policy network assumes a continuous action space. 10 | 11 | [Wang et al., 2020]: https://arxiv.org/abs/2006.15134 12 | -------------------------------------------------------------------------------- /acme/agents/jax/crr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of the Critic Regularized Regression (CRR) agent.""" 16 | 17 | from acme.agents.jax.crr.builder import CRRBuilder 18 | from acme.agents.jax.crr.config import CRRConfig 19 | from acme.agents.jax.crr.learning import CRRLearner 20 | from acme.agents.jax.crr.losses import policy_loss_coeff_advantage_exp 21 | from acme.agents.jax.crr.losses import policy_loss_coeff_advantage_indicator 22 | from acme.agents.jax.crr.losses import policy_loss_coeff_constant 23 | from acme.agents.jax.crr.losses import PolicyLossCoeff 24 | from acme.agents.jax.crr.networks import CRRNetworks 25 | from acme.agents.jax.crr.networks import make_networks 26 | -------------------------------------------------------------------------------- /acme/agents/jax/crr/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config classes for CRR.""" 16 | import dataclasses 17 | 18 | 19 | @dataclasses.dataclass 20 | class CRRConfig: 21 | """Configuration options for CRR. 22 | 23 | Attributes: 24 | learning_rate: Learning rate. 25 | discount: discount to use for TD updates. 26 | target_update_period: period to update target's parameters. 27 | use_sarsa_target: compute on-policy target using iterator's actions rather 28 | than sampled actions. 29 | Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). 30 | """ 31 | learning_rate: float = 3e-4 32 | discount: float = 0.99 33 | target_update_period: int = 100 34 | use_sarsa_target: bool = False 35 | -------------------------------------------------------------------------------- /acme/agents/jax/d4pg/README.md: -------------------------------------------------------------------------------- 1 | # Distributed Distributional Deep Deterministic Policy Gradient (D4PG) 2 | 3 | This folder contains an implementation of the D4PG agent introduced in 4 | ([Barth-Maron et al., 2018]), which extends previous Deterministic Policy 5 | Gradient (DPG) algorithms ([Silver et al., 2014]; [Lillicrap et al., 2015]) by 6 | using a distributional Q-network similar to C51 ([Bellemare et al., 2017]). 7 | 8 | Note that since the synchronous agent is not distributed (i.e. not using 9 | multiple asynchronous actors), it is not precisely speaking D4PG; a more 10 | accurate name would be Distributional DDPG. In this algorithm, the critic 11 | outputs a distribution over state-action values; in this particular case this 12 | discrete distribution is parametrized as in C51. 13 | 14 | Detailed notes: 15 | 16 | - The `vmin|vmax` hyperparameters of the distributional critic may need tuning 17 | depending on your environment's rewards. A good rule of thumb is to set `vmax` 18 | to the discounted sum of the maximum instantaneous rewards for the maximum 19 | episode length; then set `vmin` to `-vmax`. 20 | 21 | [Barth-Maron et al., 2018]: https://arxiv.org/abs/1804.08617 22 | [Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 23 | [Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 24 | [Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 25 | -------------------------------------------------------------------------------- /acme/agents/jax/d4pg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a D4PG agent.""" 16 | 17 | from acme.agents.jax.d4pg.builder import D4PGBuilder 18 | from acme.agents.jax.d4pg.config import D4PGConfig 19 | from acme.agents.jax.d4pg.learning import D4PGLearner 20 | from acme.agents.jax.d4pg.networks import D4PGNetworks 21 | from acme.agents.jax.d4pg.networks import get_default_behavior_policy 22 | from acme.agents.jax.d4pg.networks import get_default_eval_policy 23 | from acme.agents.jax.d4pg.networks import make_networks 24 | 25 | -------------------------------------------------------------------------------- /acme/agents/jax/d4pg/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config classes for D4PG.""" 16 | import dataclasses 17 | from typing import Optional 18 | from acme.adders import reverb as adders_reverb 19 | 20 | 21 | @dataclasses.dataclass 22 | class D4PGConfig: 23 | """Configuration options for D4PG.""" 24 | sigma: float = 0.3 25 | target_update_period: int = 100 26 | samples_per_insert: Optional[float] = 32.0 27 | 28 | # Loss options 29 | n_step: int = 5 30 | discount: float = 0.99 31 | batch_size: int = 256 32 | learning_rate: float = 1e-4 33 | clipping: bool = True 34 | 35 | # Replay options 36 | min_replay_size: int = 1000 37 | max_replay_size: int = 1000000 38 | replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE 39 | prefetch_size: int = 4 40 | # Rate to be used for the SampleToInsertRatio rate limitter tolerance. 41 | # See a formula in make_replay_tables for more details. 42 | samples_per_insert_tolerance_rate: float = 0.1 43 | 44 | # How many gradient updates to perform per step. 45 | num_sgd_steps_per_step: int = 1 46 | -------------------------------------------------------------------------------- /acme/agents/jax/dqn/README.md: -------------------------------------------------------------------------------- 1 | # Deep Q-Networks (DQN) 2 | 3 | This folder contains an implementation of the DQN algorithm 4 | ([Mnih et al., 2013], [Mnih et al., 2015]), with extras bells & whistles, 5 | similar to Rainbow DQN ([Hessel et al., 2017]). 6 | 7 | * Q-learning with neural network function approximation. The loss is given by 8 | the Huber loss applied to the temporal difference error. 9 | * Target Q' network updated periodically ([Mnih et al., 2015]). 10 | * N-step bootstrapping ([Sutton & Barto, 2018]). 11 | * Double Q-learning ([van Hasselt et al., 2015]). 12 | * Prioritized experience replay ([Schaul et al., 2015]). 13 | 14 | This DQN implementation has a configurable loss. In losses.py, you can find 15 | ready-to-use implementations of other methods related to DQN. 16 | 17 | * Vanilla Deep Q-learning [Mnih et al., 2013], with two optimization tweaks 18 | (Adam intead of RMSProp, square instead of Huber, as suggested e.g. by 19 | [Obando-Ceron et al., 2020]). 20 | * Quantile regression DQN (QrDQN) [Dabney et al., 2017] 21 | * Categorical DQN (C51) [Bellemare et al., 2017] 22 | * Munchausen DQN [Vieillard et al., 2020] 23 | * Regularized DQN (DQNReg) [Co-Reyes et al., 2021] 24 | 25 | 26 | [Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 27 | [Mnih et al., 2015]: https://www.nature.com/articles/nature14236 28 | [van Hasselt et al., 2015]: https://arxiv.org/abs/1509.06461 29 | [Schaul et al., 2015]: https://arxiv.org/abs/1511.05952 30 | [Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 31 | [Dabney et al., 2017]: https://arxiv.org/abs/1710.10044 32 | [Hessel et al., 2017]: https://arxiv.org/abs/1710.02298 33 | [Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 34 | [Sutton & Barto, 2018]: http://incompleteideas.net/book/the-book.html 35 | [Obando-Ceron et al., 2020]: https://arxiv.org/abs/2011.14826 36 | [Vieillard et al., 2020]: https://arxiv.org/abs/2007.14430 37 | [Co-Reyes et al., 2021]: https://arxiv.org/abs/2101.03958 38 | -------------------------------------------------------------------------------- /acme/agents/jax/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of a deep Q-networks (DQN) agent.""" 16 | 17 | from acme.agents.jax.dqn.actor import behavior_policy 18 | from acme.agents.jax.dqn.actor import default_behavior_policy 19 | from acme.agents.jax.dqn.actor import DQNPolicy 20 | from acme.agents.jax.dqn.actor import Epsilon 21 | from acme.agents.jax.dqn.actor import EpsilonPolicy 22 | from acme.agents.jax.dqn.builder import DistributionalDQNBuilder 23 | from acme.agents.jax.dqn.builder import DQNBuilder 24 | from acme.agents.jax.dqn.config import DQNConfig 25 | from acme.agents.jax.dqn.learning import DQNLearner 26 | from acme.agents.jax.dqn.learning_lib import LossExtra 27 | from acme.agents.jax.dqn.learning_lib import LossFn 28 | from acme.agents.jax.dqn.learning_lib import ReverbUpdate 29 | from acme.agents.jax.dqn.learning_lib import SGDLearner 30 | from acme.agents.jax.dqn.losses import PrioritizedCategoricalDoubleQLearning 31 | from acme.agents.jax.dqn.losses import PrioritizedDoubleQLearning 32 | from acme.agents.jax.dqn.losses import QLearning 33 | from acme.agents.jax.dqn.losses import QrDqn 34 | from acme.agents.jax.dqn.networks import DQNNetworks 35 | -------------------------------------------------------------------------------- /acme/agents/jax/impala/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Importance-weighted actor-learner architecture (IMPALA) agent.""" 16 | 17 | from acme.agents.jax.impala.builder import IMPALABuilder 18 | from acme.agents.jax.impala.config import IMPALAConfig 19 | from acme.agents.jax.impala.learning import IMPALALearner 20 | from acme.agents.jax.impala.networks import IMPALANetworks 21 | from acme.agents.jax.impala.networks import make_atari_networks 22 | -------------------------------------------------------------------------------- /acme/agents/jax/impala/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """IMPALA networks definition.""" 16 | 17 | from acme import specs 18 | from acme.jax import networks as networks_lib 19 | 20 | 21 | IMPALANetworks = networks_lib.UnrollableNetwork 22 | 23 | 24 | def make_atari_networks(env_spec: specs.EnvironmentSpec) -> IMPALANetworks: 25 | """Builds default IMPALA networks for Atari games.""" 26 | 27 | def make_core_module() -> networks_lib.DeepIMPALAAtariNetwork: 28 | return networks_lib.DeepIMPALAAtariNetwork(env_spec.actions.num_values) 29 | 30 | return networks_lib.make_unrollable_network(env_spec, make_core_module) 31 | -------------------------------------------------------------------------------- /acme/agents/jax/impala/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Some types/assumptions used in the IMPALA agent.""" 16 | from typing import Callable, Tuple 17 | 18 | from acme.agents.jax.actor_core import RecurrentState 19 | from acme.jax import networks 20 | from acme.jax import types as jax_types 21 | import jax.numpy as jnp 22 | 23 | # Only simple observations & discrete action spaces for now. 24 | Observation = jnp.ndarray 25 | Action = int 26 | Outputs = Tuple[Tuple[networks.Logits, networks.Value], RecurrentState] 27 | PolicyValueInitFn = Callable[[networks.PRNGKey, RecurrentState], 28 | networks.Params] 29 | PolicyValueFn = Callable[[networks.Params, Observation, RecurrentState], 30 | Outputs] 31 | RecurrentStateFn = Callable[[jax_types.PRNGKey], RecurrentState] 32 | -------------------------------------------------------------------------------- /acme/agents/jax/lfd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Lfd agents.""" 16 | 17 | from acme.agents.jax.lfd.builder import LfdBuilder 18 | from acme.agents.jax.lfd.builder import LfdStep 19 | from acme.agents.jax.lfd.config import LfdConfig 20 | from acme.agents.jax.lfd.sacfd import SACfDBuilder 21 | from acme.agents.jax.lfd.sacfd import SACfDConfig 22 | from acme.agents.jax.lfd.td3fd import TD3fDBuilder 23 | from acme.agents.jax.lfd.td3fd import TD3fDConfig 24 | -------------------------------------------------------------------------------- /acme/agents/jax/lfd/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """LfD config.""" 16 | 17 | import dataclasses 18 | 19 | 20 | @dataclasses.dataclass 21 | class LfdConfig: 22 | """Configuration options for LfD. 23 | 24 | Attributes: 25 | initial_insert_count: Number of steps of demonstrations to add to the replay 26 | buffer before adding any step of the collected episodes. Note that since 27 | only full episodes can be added, this number of steps is only a target. 28 | demonstration_ratio: Ratio of demonstration steps to add to the replay 29 | buffer. ratio = num_demonstration_steps_added / total_num_steps_added. 30 | The ratio must be in [0, 1). 31 | Note that this ratio is the desired ratio in the steady behavior and does 32 | not account for the initial demonstrations inserts. 33 | Note also that this ratio is only a target ratio since the granularity 34 | is the episode. 35 | """ 36 | initial_insert_count: int = 0 37 | demonstration_ratio: float = 0.01 38 | -------------------------------------------------------------------------------- /acme/agents/jax/lfd/sacfd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SAC agent learning from demonstrations.""" 16 | 17 | import dataclasses 18 | from typing import Callable, Iterator 19 | 20 | from acme.agents.jax import actor_core as actor_core_lib 21 | from acme.agents.jax import sac 22 | from acme.agents.jax.lfd import builder 23 | from acme.agents.jax.lfd import config 24 | import reverb 25 | 26 | 27 | @dataclasses.dataclass 28 | class SACfDConfig: 29 | """Configuration options specific to SAC with demonstrations. 30 | 31 | Attributes: 32 | lfd_config: LfD config. 33 | sac_config: SAC config. 34 | """ 35 | lfd_config: config.LfdConfig 36 | sac_config: sac.SACConfig 37 | 38 | 39 | class SACfDBuilder(builder.LfdBuilder[sac.SACNetworks, 40 | actor_core_lib.FeedForwardPolicy, 41 | reverb.ReplaySample]): 42 | """Builder for SAC agent learning from demonstrations.""" 43 | 44 | def __init__(self, sac_fd_config: SACfDConfig, 45 | lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): 46 | sac_builder = sac.SACBuilder(sac_fd_config.sac_config) 47 | super().__init__(sac_builder, lfd_iterator_fn, sac_fd_config.lfd_config) 48 | -------------------------------------------------------------------------------- /acme/agents/jax/lfd/td3fd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """TD3 agent learning from demonstrations.""" 16 | 17 | import dataclasses 18 | from typing import Callable, Iterator 19 | 20 | from acme.agents.jax import actor_core as actor_core_lib 21 | from acme.agents.jax import td3 22 | from acme.agents.jax.lfd import builder 23 | from acme.agents.jax.lfd import config 24 | import reverb 25 | 26 | 27 | @dataclasses.dataclass 28 | class TD3fDConfig: 29 | """Configuration options specific to TD3 with demonstrations. 30 | 31 | Attributes: 32 | lfd_config: LfD config. 33 | td3_config: TD3 config. 34 | """ 35 | lfd_config: config.LfdConfig 36 | td3_config: td3.TD3Config 37 | 38 | 39 | class TD3fDBuilder(builder.LfdBuilder[td3.TD3Networks, 40 | actor_core_lib.FeedForwardPolicy, 41 | reverb.ReplaySample]): 42 | """Builder for TD3 agent learning from demonstrations.""" 43 | 44 | def __init__(self, td3_fd_config: TD3fDConfig, 45 | lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): 46 | td3_builder = td3.TD3Builder(td3_fd_config.td3_config) 47 | super().__init__(td3_builder, lfd_iterator_fn, td3_fd_config.lfd_config) 48 | -------------------------------------------------------------------------------- /acme/agents/jax/mbop/README.md: -------------------------------------------------------------------------------- 1 | # Model-Based Offline Planning (MBOP) 2 | 3 | This folder contains an implementation of the MBOP algorithm ([Argenson and 4 | Dulac-Arnold, 2021]). It is an offline RL algorithm that generates a model that 5 | can be used to control the system directly through planning. The learning 6 | components, i.e. the world model, policy prior and the n-step return, are simple 7 | supervised ensemble learners. It uses the Model-Predictive Path Integral control 8 | planner. 9 | 10 | The networks assume continuous and flattened observation and action spaces. The 11 | dataset, i.e. demonstrations, should be in timestep-batched format (i.e. triple 12 | transitions of the previous, current and next timesteps) and normalized. See 13 | dataset.py file for helper functions for loading RLDS datasets and 14 | normalization. 15 | 16 | [Argenson and Dulac-Arnold, 2021]: https://arxiv.org/abs/2008.05556 17 | -------------------------------------------------------------------------------- /acme/agents/jax/mbop/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MBOP config.""" 16 | 17 | import dataclasses 18 | 19 | from acme.agents.jax.mbop import mppi 20 | 21 | 22 | @dataclasses.dataclass(frozen=True) 23 | class MBOPConfig: 24 | """Configuration options for the MBOP agent. 25 | 26 | Attributes: 27 | mppi_config: Planner hyperparameters. 28 | learning_rate: Learning rate. 29 | num_networks: Number of networks in the ensembles. 30 | num_sgd_steps_per_step: How many gradient updates to perform per learner 31 | step. 32 | """ 33 | mppi_config: mppi.MPPIConfig = dataclasses.field( 34 | default_factory=mppi.MPPIConfig 35 | ) 36 | learning_rate: float = 3e-4 37 | num_networks: int = 5 38 | num_sgd_steps_per_step: int = 1 39 | -------------------------------------------------------------------------------- /acme/agents/jax/mpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MPO agent module.""" 16 | 17 | from acme.agents.jax.mpo.acting import ActorState 18 | from acme.agents.jax.mpo.acting import make_actor_core 19 | from acme.agents.jax.mpo.builder import MPOBuilder 20 | from acme.agents.jax.mpo.config import MPOConfig 21 | from acme.agents.jax.mpo.learning import MPOLearner 22 | from acme.agents.jax.mpo.networks import make_control_networks 23 | from acme.agents.jax.mpo.networks import MPONetworks 24 | from acme.agents.jax.mpo.types import CategoricalPolicyLossConfig 25 | from acme.agents.jax.mpo.types import CriticType 26 | from acme.agents.jax.mpo.types import GaussianPolicyLossConfig 27 | from acme.agents.jax.mpo.types import PolicyLossConfig 28 | -------------------------------------------------------------------------------- /acme/agents/jax/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Multiagent implementations.""" 16 | -------------------------------------------------------------------------------- /acme/agents/jax/multiagent/decentralized/README.md: -------------------------------------------------------------------------------- 1 | # Decentralized Multiagent Learning 2 | 3 | This folder contains an implementation of decentralized multiagent learning. 4 | The current implementation supports homogeneous sub-agents (i.e., all agents 5 | running identical sub-algorithms). 6 | 7 | The underlying multiagent environment should produce observations and rewards 8 | that are each a dict, with keys corresponding to string IDs for the agents that 9 | map to their respective local observation and rewards. Rewards can be 10 | heterogeneous (e.g., for non-cooperative environments). 11 | 12 | The environment step() should consume dict-style actions, with key:value pairs 13 | corresponding to agent:action, as above. 14 | 15 | Discounts are assumed shared between agents (i.e., should be a single scalar). 16 | -------------------------------------------------------------------------------- /acme/agents/jax/multiagent/decentralized/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Decentralized multiagent configuration.""" 16 | 17 | from acme.agents.jax.multiagent.decentralized.builder import DecentralizedMultiAgentBuilder 18 | from acme.agents.jax.multiagent.decentralized.config import DecentralizedMultiagentConfig 19 | from acme.agents.jax.multiagent.decentralized.factories import builder_factory 20 | from acme.agents.jax.multiagent.decentralized.factories import default_config_factory 21 | from acme.agents.jax.multiagent.decentralized.factories import DefaultSupportedAgent 22 | from acme.agents.jax.multiagent.decentralized.factories import network_factory 23 | from acme.agents.jax.multiagent.decentralized.factories import policy_network_factory 24 | -------------------------------------------------------------------------------- /acme/agents/jax/multiagent/decentralized/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Decentralized multiagent config.""" 16 | 17 | import dataclasses 18 | from typing import Dict 19 | 20 | from acme.multiagent import types 21 | 22 | 23 | @dataclasses.dataclass 24 | class DecentralizedMultiagentConfig: 25 | """Configuration options for decentralized multiagent.""" 26 | sub_agent_configs: Dict[types.AgentID, types.AgentConfig] 27 | batch_size: int = 256 28 | prefetch_size: int = 2 29 | -------------------------------------------------------------------------------- /acme/agents/jax/ppo/README.md: -------------------------------------------------------------------------------- 1 | # Proximal Policy Optimization (PPO) 2 | 3 | This folder contains an implementation of the PPO algorithm 4 | ([Schulman et al., 2017]) with clipped surrogate objective. 5 | 6 | Implementation notes: 7 | - PPO is not a strictly on-policy algorithm. In each call to the learner's 8 | step function, a batch of transitions are taken from the Reverb replay 9 | buffer, and N epochs of updates are performed on the data in the batch. 10 | Using larger values for num_epochs and num_minibatches makes the algorithm 11 | "more off-policy". 12 | 13 | [Schulman et al., 2017]: https://arxiv.org/abs/1707.06347 14 | -------------------------------------------------------------------------------- /acme/agents/jax/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PPO agent.""" 16 | 17 | from acme.agents.jax.ppo.builder import PPOBuilder 18 | from acme.agents.jax.ppo.config import PPOConfig 19 | from acme.agents.jax.ppo.learning import PPOLearner 20 | from acme.agents.jax.ppo.networks import EntropyFn 21 | from acme.agents.jax.ppo.networks import make_categorical_ppo_networks 22 | from acme.agents.jax.ppo.networks import make_continuous_networks 23 | from acme.agents.jax.ppo.networks import make_discrete_networks 24 | from acme.agents.jax.ppo.networks import make_inference_fn 25 | from acme.agents.jax.ppo.networks import make_mvn_diag_ppo_networks 26 | from acme.agents.jax.ppo.networks import make_networks 27 | from acme.agents.jax.ppo.networks import make_ppo_networks 28 | from acme.agents.jax.ppo.networks import make_tanh_normal_ppo_networks 29 | from acme.agents.jax.ppo.networks import PPONetworks 30 | from acme.agents.jax.ppo.normalization import build_ema_mean_std_normalizer 31 | from acme.agents.jax.ppo.normalization import build_mean_std_normalizer 32 | from acme.agents.jax.ppo.normalization import NormalizationFns 33 | from acme.agents.jax.ppo.normalization import NormalizedGenericActor 34 | -------------------------------------------------------------------------------- /acme/agents/jax/pwil/README.md: -------------------------------------------------------------------------------- 1 | # PWIL 2 | 3 | This folder contains an implementation of the PWIL algorithm 4 | ([R.Dadashi et al., 2020]). 5 | 6 | The description of PWIL in ([R.Dadashi et al., 2020]) leaves the behavior 7 | unspecified when the episode lengths are not fixed in advance. Here, we assign 8 | zero reward when a trajectory exceeds the desired length, and keep the partial 9 | return unaffected when a trajectory is shorter than the desired length. 10 | 11 | We prefill the replay buffer in a concurrent thread of the learner, to avoid 12 | potential Reverb deadlocks. 13 | 14 | [R.Dadashi et al., 2020]: https://arxiv.org/abs/2006.04678 15 | -------------------------------------------------------------------------------- /acme/agents/jax/pwil/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PWIL agent.""" 16 | 17 | from acme.agents.jax.pwil.builder import PWILBuilder 18 | from acme.agents.jax.pwil.config import PWILConfig 19 | from acme.agents.jax.pwil.config import PWILDemonstrations 20 | -------------------------------------------------------------------------------- /acme/agents/jax/pwil/adder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Reward-substituting adder wrapper.""" 16 | 17 | from acme import adders 18 | from acme import types 19 | from acme.agents.jax.pwil import rewarder 20 | import dm_env 21 | 22 | 23 | class PWILAdder(adders.Adder): 24 | """Adder wrapper substituting PWIL rewards.""" 25 | 26 | def __init__(self, direct_rl_adder: adders.Adder, 27 | pwil_rewarder: rewarder.WassersteinDistanceRewarder): 28 | self._adder = direct_rl_adder 29 | self._rewarder = pwil_rewarder 30 | self._latest_observation = None 31 | 32 | def add_first(self, timestep: dm_env.TimeStep): 33 | self._rewarder.reset() 34 | self._latest_observation = timestep.observation 35 | self._adder.add_first(timestep) 36 | 37 | def add(self, 38 | action: types.NestedArray, 39 | next_timestep: dm_env.TimeStep, 40 | extras: types.NestedArray = ()): 41 | updated_timestep = next_timestep._replace( 42 | reward=self._rewarder.append_and_compute_reward( 43 | observation=self._latest_observation, action=action)) 44 | self._latest_observation = next_timestep.observation 45 | self._adder.add(action, updated_timestep, extras) 46 | 47 | def reset(self): 48 | self._latest_observation = None 49 | self._adder.reset() 50 | -------------------------------------------------------------------------------- /acme/agents/jax/pwil/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PWIL config.""" 16 | import dataclasses 17 | from typing import Iterator 18 | 19 | from acme import types 20 | 21 | 22 | @dataclasses.dataclass 23 | class PWILConfig: 24 | """Configuration options for PWIL. 25 | 26 | The default values correspond to the experiment setup from the PWIL 27 | publication http://arxiv.org/abs/2006.04678. 28 | """ 29 | 30 | # Number of transitions to fill the replay buffer with for pretraining. 31 | num_transitions_rb: int = 50000 32 | 33 | # If False, uses only observations for computing the distance; if True, also 34 | # uses the actions. 35 | use_actions_for_distance: bool = True 36 | 37 | # Scaling for the reward function, see equation (6) in 38 | # http://arxiv.org/abs/2006.04678. 39 | alpha: float = 5. 40 | 41 | # Controls the kernel size of the reward function, see equation (6) 42 | # in http://arxiv.org/abs/2006.04678. 43 | beta: float = 5. 44 | 45 | # When False, uses the reward signal from the dataset during prefilling. 46 | prefill_constant_reward: bool = True 47 | 48 | num_sgd_steps_per_step: int = 1 49 | 50 | 51 | @dataclasses.dataclass 52 | class PWILDemonstrations: 53 | """Unbatched, unshuffled transitions with approximate episode length.""" 54 | demonstrations: Iterator[types.Transition] 55 | episode_length: int 56 | -------------------------------------------------------------------------------- /acme/agents/jax/r2d2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of an R2D2 agent.""" 16 | 17 | from acme.agents.jax.r2d2.actor import EpsilonRecurrentPolicy 18 | from acme.agents.jax.r2d2.actor import make_behavior_policy 19 | from acme.agents.jax.r2d2.builder import R2D2Builder 20 | from acme.agents.jax.r2d2.config import R2D2Config 21 | from acme.agents.jax.r2d2.learning import R2D2Learner 22 | from acme.agents.jax.r2d2.learning import R2D2ReplaySample 23 | from acme.agents.jax.r2d2.networks import make_atari_networks 24 | from acme.agents.jax.r2d2.networks import R2D2Networks 25 | -------------------------------------------------------------------------------- /acme/agents/jax/r2d2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PPO config.""" 16 | import dataclasses 17 | 18 | from acme.adders import reverb as adders_reverb 19 | import rlax 20 | 21 | 22 | @dataclasses.dataclass 23 | class R2D2Config: 24 | """Configuration options for R2D2 agent.""" 25 | discount: float = 0.997 26 | target_update_period: int = 2500 27 | evaluation_epsilon: float = 0. 28 | num_epsilons: int = 256 29 | variable_update_period: int = 400 30 | 31 | # Learner options 32 | burn_in_length: int = 40 33 | trace_length: int = 80 34 | sequence_period: int = 40 35 | learning_rate: float = 1e-3 36 | bootstrap_n: int = 5 37 | clip_rewards: bool = False 38 | tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR 39 | 40 | # Replay options 41 | samples_per_insert_tolerance_rate: float = 0.1 42 | samples_per_insert: float = 4.0 43 | min_replay_size: int = 50_000 44 | max_replay_size: int = 100_000 45 | batch_size: int = 64 46 | prefetch_size: int = 2 47 | num_parallel_calls: int = 16 48 | replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE 49 | 50 | # Priority options 51 | importance_sampling_exponent: float = 0.6 52 | priority_exponent: float = 0.9 53 | max_priority_weight: float = 0.9 54 | -------------------------------------------------------------------------------- /acme/agents/jax/r2d2/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """R2D2 Networks.""" 16 | 17 | from acme import specs 18 | from acme.jax import networks as networks_lib 19 | 20 | 21 | R2D2Networks = networks_lib.UnrollableNetwork 22 | 23 | 24 | def make_atari_networks(env_spec: specs.EnvironmentSpec) -> R2D2Networks: 25 | """Builds default R2D2 networks for Atari games.""" 26 | 27 | def make_core_module() -> networks_lib.R2D2AtariNetwork: 28 | return networks_lib.R2D2AtariNetwork(env_spec.actions.num_values) 29 | 30 | return networks_lib.make_unrollable_network(env_spec, make_core_module) 31 | -------------------------------------------------------------------------------- /acme/agents/jax/rnd/README.md: -------------------------------------------------------------------------------- 1 | # Random Network Distillation (RND) 2 | 3 | This folder contains an implementation of the RND algorithm 4 | ([Burda et al., 2018]) 5 | 6 | RND requires a RL algorithm to work, passed in as an `ActorLearnerBuilder`. 7 | 8 | By default this implementation ignores the original reward: the agent is trained 9 | only on the intrinsic exploration reward. To also use extrinsic reward, 10 | intrinsic and extrinsic reward weights can be passed into make_networks. 11 | 12 | [Burda et al., 2018]: https://arxiv.org/abs/1810.12894 13 | -------------------------------------------------------------------------------- /acme/agents/jax/rnd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """RND agent.""" 16 | 17 | from acme.agents.jax.rnd.builder import RNDBuilder 18 | from acme.agents.jax.rnd.config import RNDConfig 19 | from acme.agents.jax.rnd.learning import rnd_loss 20 | from acme.agents.jax.rnd.learning import rnd_update_step 21 | from acme.agents.jax.rnd.learning import RNDLearner 22 | from acme.agents.jax.rnd.learning import RNDTrainingState 23 | from acme.agents.jax.rnd.networks import compute_rnd_reward 24 | from acme.agents.jax.rnd.networks import make_networks 25 | from acme.agents.jax.rnd.networks import rnd_reward_fn 26 | from acme.agents.jax.rnd.networks import RNDNetworks 27 | -------------------------------------------------------------------------------- /acme/agents/jax/rnd/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """RND config.""" 16 | import dataclasses 17 | 18 | 19 | @dataclasses.dataclass 20 | class RNDConfig: 21 | """Configuration options for RND.""" 22 | 23 | # Learning rate for the predictor. 24 | predictor_learning_rate: float = 1e-4 25 | 26 | # If True, the direct rl algorithm is using the SequenceAdder data format. 27 | is_sequence_based: bool = False 28 | 29 | # How many gradient updates to perform per step. 30 | num_sgd_steps_per_step: int = 1 31 | -------------------------------------------------------------------------------- /acme/agents/jax/sac/README.md: -------------------------------------------------------------------------------- 1 | # Soft Actor-Critic (SAC) 2 | 3 | This folder contains an implementation of the SAC algorithm 4 | ([Haarnoja et al., 2018]) with automatic tuning of the temperature 5 | ([Haarnoja et al., 2019]). 6 | 7 | This is an actor-critic method with: 8 | 9 | - a stochastic policy optimization (as opposed to, e.g., DPG) with a maximum entropy regularization; and 10 | - two critics to mitigate the over-estimation bias in policy evaluation ([Fujimoto et al., 2018]). 11 | 12 | For the maximum entropy regularization, we provide a commonly used heuristic for specifying entropy target (`target_entropy_from_env_spec`). 13 | The heuristic returns `-num_actions` by default or `num_actions * target_entropy_per_dimension` 14 | if `target_entropy_per_dimension` is specified. 15 | 16 | 17 | [Haarnoja et al., 2018]: https://arxiv.org/abs/1801.01290 18 | [Haarnoja et al., 2019]: https://arxiv.org/abs/1812.05905 19 | [Fujimoto et al., 2018]: https://arxiv.org/abs/1802.09477 20 | -------------------------------------------------------------------------------- /acme/agents/jax/sac/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SAC agent.""" 16 | 17 | from acme.agents.jax.sac.builder import SACBuilder 18 | from acme.agents.jax.sac.config import SACConfig 19 | from acme.agents.jax.sac.config import target_entropy_from_env_spec 20 | from acme.agents.jax.sac.learning import SACLearner 21 | from acme.agents.jax.sac.networks import apply_policy_and_sample 22 | from acme.agents.jax.sac.networks import default_models_to_snapshot 23 | from acme.agents.jax.sac.networks import make_networks 24 | from acme.agents.jax.sac.networks import SACNetworks 25 | -------------------------------------------------------------------------------- /acme/agents/jax/sqil/README.md: -------------------------------------------------------------------------------- 1 | # Soft Q imitation learning (SQIL) 2 | 3 | This folder contains an implementation of the SQIL algorithm 4 | ([Reddy et al., 2019]) 5 | 6 | SQIL requires an off-policy RL algorithm to work, passed in as an 7 | `ActorLearnerBuilder`. 8 | 9 | [Reddy et al., 2019]: https://arxiv.org/abs/1905.11108 10 | -------------------------------------------------------------------------------- /acme/agents/jax/sqil/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SQIL agent.""" 16 | 17 | from acme.agents.jax.sqil.builder import SQILBuilder 18 | -------------------------------------------------------------------------------- /acme/agents/jax/sqil/builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for the SQIL iterator.""" 16 | 17 | from acme import types 18 | from acme.agents.jax.sqil import builder 19 | import numpy as np 20 | import reverb 21 | 22 | from absl.testing import absltest 23 | 24 | 25 | class BuilderTest(absltest.TestCase): 26 | 27 | def test_sqil_iterator(self): 28 | demonstrations = [ 29 | types.Transition(np.array([[1], [2], [3]]), (), (), (), ()) 30 | ] 31 | replay = [ 32 | reverb.ReplaySample( 33 | info=(), 34 | data=types.Transition(np.array([[4], [5], [6]]), (), (), (), ())) 35 | ] 36 | sqil_it = builder._generate_sqil_samples(iter(demonstrations), iter(replay)) 37 | np.testing.assert_array_equal( 38 | next(sqil_it).data.observation, np.array([[1], [3], [5]])) 39 | np.testing.assert_array_equal( 40 | next(sqil_it).data.observation, np.array([[2], [4], [6]])) 41 | self.assertRaises(StopIteration, lambda: next(sqil_it)) 42 | 43 | if __name__ == '__main__': 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /acme/agents/jax/td3/README.md: -------------------------------------------------------------------------------- 1 | # Twin Delayed Deep Deterministic policy gradient algorithm (TD3) 2 | 3 | This folder contains an implementation of the TD3 algorithm, 4 | [Fujimoto, 2018]. 5 | 6 | 7 | Note the following differences with the original author's implementation: 8 | 9 | * the default network architecture is a LayerNorm MLP, 10 | * there is no initial exploration phase with a random policy, 11 | * the target critic and twin critic updates are not delayed. 12 | 13 | [Fujimoto, 2018]: https://arxiv.org/pdf/1802.09477.pdf 14 | -------------------------------------------------------------------------------- /acme/agents/jax/td3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """TD3 agent.""" 16 | 17 | from acme.agents.jax.td3.builder import TD3Builder 18 | from acme.agents.jax.td3.config import TD3Config 19 | from acme.agents.jax.td3.learning import TD3Learner 20 | from acme.agents.jax.td3.networks import get_default_behavior_policy 21 | from acme.agents.jax.td3.networks import make_networks 22 | from acme.agents.jax.td3.networks import TD3Networks 23 | -------------------------------------------------------------------------------- /acme/agents/jax/value_dice/README.md: -------------------------------------------------------------------------------- 1 | # Value Dice 2 | 3 | This folder contains an implementation of the ValueDice algorithm 4 | ([Kostrikov et al., 2019]). 5 | 6 | The implementation supports both: 7 | - offline training (demonstrations only) 8 | - mixed mode 9 | 10 | Offline training is achieved by setting 'nu_reg_scale' and 'alpha' to 0. 11 | 12 | [Kostrikov et al., 2019]: https://arxiv.org/abs/1912.05032 13 | -------------------------------------------------------------------------------- /acme/agents/jax/value_dice/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ValueDice agent.""" 16 | 17 | from acme.agents.jax.value_dice.builder import ValueDiceBuilder 18 | from acme.agents.jax.value_dice.config import ValueDiceConfig 19 | from acme.agents.jax.value_dice.learning import ValueDiceLearner 20 | from acme.agents.jax.value_dice.networks import apply_policy_and_sample 21 | from acme.agents.jax.value_dice.networks import make_networks 22 | from acme.agents.jax.value_dice.networks import ValueDiceNetworks 23 | -------------------------------------------------------------------------------- /acme/agents/jax/value_dice/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ValueDice config.""" 16 | 17 | import dataclasses 18 | 19 | from acme.adders import reverb as adders_reverb 20 | 21 | 22 | @dataclasses.dataclass 23 | class ValueDiceConfig: 24 | """Configuration options for ValueDice.""" 25 | 26 | policy_learning_rate: float = 1e-5 27 | nu_learning_rate: float = 1e-3 28 | discount: float = .99 29 | batch_size: int = 256 30 | alpha: float = 0.05 31 | policy_reg_scale: float = 1e-4 32 | nu_reg_scale: float = 10.0 33 | 34 | # Replay options 35 | replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE 36 | samples_per_insert: float = 256 * 4 37 | # Rate to be used for the SampleToInsertRatio rate limitter tolerance. 38 | # See a formula in make_replay_tables for more details. 39 | samples_per_insert_tolerance_rate: float = 0.1 40 | min_replay_size: int = 1000 41 | max_replay_size: int = 1000000 42 | prefetch_size: int = 4 43 | 44 | # How many gradient updates to perform per step. 45 | num_sgd_steps_per_step: int = 1 46 | -------------------------------------------------------------------------------- /acme/agents/jax/wpo/README.md: -------------------------------------------------------------------------------- 1 | # Wasserstein Policy Optimization (WPO) 2 | 3 | This folder contains an implementation of Wasserstein Policy 4 | Optimization (WPO) introduced in ([Pfau et al., 2025]). This implementation is 5 | forked from the implementation of Maximum a Posteriori Policy Optimization (MPO) 6 | ([Abdolmaleki et al., 2018a], [2018b]) written by Bobak Shahriari, with options 7 | such as distributional critics removed. 8 | 9 | [Pfau et al., 2025]: https://arxiv.org/pdf/2505.00663.pdf 10 | [Abdolmaleki et al., 2018a]: https://arxiv.org/pdf/1806.06920.pdf 11 | [2018b]: https://arxiv.org/pdf/1812.02256.pdf 12 | 13 | -------------------------------------------------------------------------------- /acme/agents/jax/wpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """WPO agent module.""" 16 | 17 | from acme.agents.jax.wpo.acting import ActorState 18 | from acme.agents.jax.wpo.acting import make_actor_core 19 | from acme.agents.jax.wpo.builder import WPOBuilder 20 | from acme.agents.jax.wpo.config import WPOConfig 21 | from acme.agents.jax.wpo.learning import WPOLearner 22 | from acme.agents.jax.wpo.networks import make_control_networks 23 | from acme.agents.jax.wpo.networks import WPONetworks 24 | from acme.agents.jax.wpo.types import GaussianPolicyLossConfig 25 | from acme.agents.jax.wpo.types import PolicyLossConfig 26 | -------------------------------------------------------------------------------- /acme/agents/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /acme/agents/tf/bc/README.md: -------------------------------------------------------------------------------- 1 | # Behavioral Cloning (BC) 2 | 3 | This folder contains an implementation for supervised learning of a policy from 4 | a dataset of observations and target actions. This is an approach known as 5 | Behavioral Cloning, introduced by [Pomerleau, 1989]. There is an example which 6 | generates data for bsuite environment `Deep Sea` using an optimal policy. 7 | 8 | [Pomerleau, 1989]: https://papers.nips.cc/paper/95-alvinn-an-autonomous-land-vehicle-in-a-neural-network.pdf 9 | -------------------------------------------------------------------------------- /acme/agents/tf/bc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of a behavior cloning (BC) agent.""" 16 | 17 | from acme.agents.tf.bc.learning import BCLearner 18 | -------------------------------------------------------------------------------- /acme/agents/tf/bcq/README.md: -------------------------------------------------------------------------------- 1 | # Discrete Batch-Constrained Deep Q-learning (BCQ) 2 | 3 | This folder contains an implementation of the discrete BCQ algorithm introduced 4 | in ([Fujimoto et al., 2019]), which is a variant of the BCQ algorithm 5 | ([Fujimoto et al., 2018]). 6 | 7 | [Fujimoto et al., 2018]: https://arxiv.org/pdf/1812.02900.pdf 8 | [Fujimoto et al., 2019]: https://arxiv.org/pdf/1910.01708.pdf 9 | -------------------------------------------------------------------------------- /acme/agents/tf/bcq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Batch-Constrained Deep Q-learning (BCQ).""" 16 | 17 | from acme.agents.tf.bcq.discrete_learning import DiscreteBCQLearner 18 | -------------------------------------------------------------------------------- /acme/agents/tf/crr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a CRR agent.""" 16 | 17 | from acme.agents.tf.crr.recurrent_learning import RCRRLearner 18 | -------------------------------------------------------------------------------- /acme/agents/tf/d4pg/README.md: -------------------------------------------------------------------------------- 1 | # Distributed Distributional Deep Deterministic Policy Gradient (D4PG) 2 | 3 | This folder contains an implementation of the D4PG agent introduced in 4 | ([Barth-Maron et al., 2018]), which extends previous Deterministic Policy 5 | Gradient (DPG) algorithms ([Silver et al., 2014]; [Lillicrap et al., 2015]) by 6 | using a distributional Q-network similar to C51 ([Bellemare et al., 2017]). 7 | 8 | Note that since the synchronous agent is not distributed (i.e. not using 9 | multiple asynchronous actors), it is not precisely speaking D4PG; a more 10 | accurate name would be Distributional DDPG. In this algorithm, the critic 11 | outputs a distribution over state-action values; in this particular case this 12 | discrete distribution is parametrized as in C51. 13 | 14 | Detailed notes: 15 | 16 | - The `vmin|vmax` hyperparameters of the distributional critic may need tuning 17 | depending on your environment's rewards. A good rule of thumb is to set 18 | `vmax` to the discounted sum of the maximum instantaneous rewards for the 19 | maximum episode length; then set `vmin` to `-vmax`. 20 | 21 | [Barth-Maron et al., 2018]: https://arxiv.org/abs/1804.08617 22 | [Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 23 | [Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 24 | [Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 25 | -------------------------------------------------------------------------------- /acme/agents/tf/d4pg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a D4PG agent.""" 16 | 17 | from acme.agents.tf.d4pg.agent import D4PG 18 | from acme.agents.tf.d4pg.agent_distributed import DistributedD4PG 19 | from acme.agents.tf.d4pg.learning import D4PGLearner 20 | from acme.agents.tf.d4pg.networks import make_default_networks 21 | -------------------------------------------------------------------------------- /acme/agents/tf/ddpg/README.md: -------------------------------------------------------------------------------- 1 | # Deep Deterministic Policy Gradient (DDPG) 2 | 3 | This folder contains an implementation of the DDPG agent introduced in ( 4 | [Lillicrap et al., 2015]), which extends the Deterministic Policy Gradient (DPG) 5 | algorithm (introduced in [Silver et al., 2014]) to the realm of deep learning. 6 | 7 | DDPG is an off-policy [actor-critic algorithm]. In this algorithm, critic is a 8 | network that takes an observation and an action and outputs a value estimate 9 | based on the current policy. It is trained to minimize the square 10 | temporal-difference (TD) error. The actor is the policy network that takes 11 | observations as input and outputs actions. For each observation, it is trained 12 | to maximize the critic's value estimate. 13 | 14 | [Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 15 | [Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 16 | [actor-critic algorithm]: http://incompleteideas.net/book/RLbook2018.pdf#page=353 17 | -------------------------------------------------------------------------------- /acme/agents/tf/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a DDPG agent.""" 16 | 17 | from acme.agents.tf.ddpg.agent import DDPG 18 | from acme.agents.tf.ddpg.agent_distributed import DistributedDDPG 19 | from acme.agents.tf.ddpg.learning import DDPGLearner 20 | -------------------------------------------------------------------------------- /acme/agents/tf/dmpo/README.md: -------------------------------------------------------------------------------- 1 | # Distributional Maximum a posteriori Policy Optimization (DMPO) 2 | 3 | This folder contains an implementation of a novel agent (DMPO) introduced in the 4 | original Acme release. This work extends the MPO algorithm 5 | ([Abdolmaleki et al., 2018a], [2018b]) by using a distributional Q-network 6 | similar to C51 ([Bellemare et al., 2017]). Therefore, as in the case of the D4PG 7 | agent, this algorithm's critic outputs a distribution over state-action values. 8 | 9 | As in our MPO agent, this is a more general algorithm, the current 10 | implementation targets the continuous control setting and is most readily 11 | applied to the DeepMind control suite or similar control tasks. This 12 | implementation also includes the options of: 13 | 14 | * per-dimension KL constraint satisfaction, and 15 | * action penalization via the multi-objective MPO work of 16 | [Abdolmaleki et al., 2020]. 17 | 18 | Detailed notes: 19 | 20 | * The `vmin|vmax` hyperparameters of the distributional critic may need tuning 21 | depending on your environment's rewards. A good rule of thumb is to set 22 | `vmax` to the discounted sum of the maximum instantaneous rewards for the 23 | maximum episode length; then set `vmin` to `-vmax`. 24 | * When using per-dimension KL constraint satisfaction, you may need to tune 25 | the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule 26 | of thumb would be to divide it by the number of dimensions in the action 27 | space. 28 | 29 | [Abdolmaleki et al., 2018a]: https://arxiv.org/pdf/1806.06920.pdf 30 | [2018b]: https://arxiv.org/pdf/1812.02256.pdf 31 | [Abdolmaleki et al., 2020]: https://arxiv.org/pdf/2005.07513.pdf 32 | [Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 33 | -------------------------------------------------------------------------------- /acme/agents/tf/dmpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a distributional MPO agent.""" 16 | 17 | from acme.agents.tf.dmpo.agent import DistributionalMPO 18 | from acme.agents.tf.dmpo.agent_distributed import DistributedDistributionalMPO 19 | from acme.agents.tf.dmpo.learning import DistributionalMPOLearner 20 | -------------------------------------------------------------------------------- /acme/agents/tf/dqfd/README.md: -------------------------------------------------------------------------------- 1 | # Deep Q-learning from Demonstrations (DQfD) 2 | 3 | This folder contains an implementation of the DQfD algorithm 4 | ([Hester et al., 2017]). This agent extends DQN by mixing expert demonstrations 5 | with the agent's experience in each mini-batch. 6 | 7 | [Hester et al., 2017]: https://arxiv.org/abs/1704.03732 8 | -------------------------------------------------------------------------------- /acme/agents/tf/dqfd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for DQfD.""" 16 | 17 | from acme.agents.tf.dqfd.agent import DQfD 18 | from acme.agents.tf.dqfd.bsuite_demonstrations import DemonstrationRecorder 19 | -------------------------------------------------------------------------------- /acme/agents/tf/dqn/README.md: -------------------------------------------------------------------------------- 1 | # Deep Q-Networks (DQN) 2 | 3 | This folder contains an implementation of the DQN algorithm 4 | ([Mnih et al., 2013], [Mnih et al., 2015]), with extras bells & whistles, 5 | similar to Rainbow DQN ([Hessel et al., 2017]). 6 | 7 | * Q-learning with neural network function approximation. The loss is given by 8 | the Huber loss applied to the temporal difference error. 9 | * Target Q' network updated periodically ([Mnih et al., 2015]). 10 | * N-step bootstrapping ([Sutton & Barto, 2018]). 11 | * Double Q-learning ([van Hasselt et al., 2015]). 12 | * Prioritized experience replay ([Schaul et al., 2015]). 13 | 14 | [Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 15 | [Mnih et al., 2015]: https://www.nature.com/articles/nature14236 16 | [van Hasselt et al., 2015]: https://arxiv.org/abs/1509.06461 17 | [Schaul et al., 2015]: https://arxiv.org/abs/1511.05952 18 | [Hessel et al., 2017]: https://arxiv.org/abs/1710.02298 19 | [Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 20 | [Sutton & Barto, 2018]: http://incompleteideas.net/book/the-book.html 21 | -------------------------------------------------------------------------------- /acme/agents/tf/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of a deep Q-networks (DQN) agent.""" 16 | 17 | from acme.agents.tf.dqn.agent import DQN 18 | from acme.agents.tf.dqn.agent_distributed import DistributedDQN 19 | from acme.agents.tf.dqn.learning import DQNLearner 20 | -------------------------------------------------------------------------------- /acme/agents/tf/dqn/agent_distributed_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Integration test for the distributed agent.""" 16 | 17 | import acme 18 | from acme.agents.tf import dqn 19 | from acme.testing import fakes 20 | from acme.tf import networks 21 | import launchpad as lp 22 | 23 | from absl.testing import absltest 24 | 25 | 26 | class DistributedAgentTest(absltest.TestCase): 27 | """Simple integration/smoke test for the distributed agent.""" 28 | 29 | def test_atari(self): 30 | """Tests that the agent can run for some steps without crashing.""" 31 | env_factory = lambda x: fakes.fake_atari_wrapped() 32 | net_factory = lambda spec: networks.DQNAtariNetwork(spec.num_values) 33 | 34 | agent = dqn.DistributedDQN( 35 | environment_factory=env_factory, 36 | network_factory=net_factory, 37 | num_actors=2, 38 | batch_size=32, 39 | min_replay_size=32, 40 | max_replay_size=1000, 41 | ) 42 | program = agent.build() 43 | 44 | (learner_node,) = program.groups['learner'] 45 | learner_node.disable_run() 46 | 47 | lp.launch(program, launch_type='test_mt') 48 | 49 | learner: acme.Learner = learner_node.create_handle().dereference() 50 | 51 | for _ in range(5): 52 | learner.step() 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /acme/agents/tf/dqn/agent_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for DQN agent.""" 16 | 17 | import acme 18 | from acme import specs 19 | from acme.agents.tf import dqn 20 | from acme.testing import fakes 21 | import numpy as np 22 | import sonnet as snt 23 | 24 | from absl.testing import absltest 25 | 26 | 27 | def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: 28 | return snt.Sequential([ 29 | snt.Flatten(), 30 | snt.nets.MLP([50, 50, action_spec.num_values]), 31 | ]) 32 | 33 | 34 | class DQNTest(absltest.TestCase): 35 | 36 | def test_dqn(self): 37 | # Create a fake environment to test with. 38 | environment = fakes.DiscreteEnvironment( 39 | num_actions=5, 40 | num_observations=10, 41 | obs_dtype=np.float32, 42 | episode_length=10) 43 | spec = specs.make_environment_spec(environment) 44 | 45 | # Construct the agent. 46 | agent = dqn.DQN( 47 | environment_spec=spec, 48 | network=_make_network(spec.actions), 49 | batch_size=10, 50 | samples_per_insert=2, 51 | min_replay_size=10) 52 | 53 | # Try running the environment loop. We have no assertions here because all 54 | # we care about is that the agent runs without raising any errors. 55 | loop = acme.EnvironmentLoop(environment, agent) 56 | loop.run(num_episodes=2) 57 | 58 | 59 | if __name__ == '__main__': 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /acme/agents/tf/impala/README.md: -------------------------------------------------------------------------------- 1 | # Importance-weighted actor-learner architecture (IMPALA) 2 | 3 | This agent is an implementation of the algorithm described in *IMPALA: Scalable 4 | Distributed Deep-RL with Importance Weighted Actor-Learner Architectures* 5 | ([Espeholt et al., 2018]). 6 | 7 | [Espeholt et al., 2018]: https://arxiv.org/abs/1802.01561 8 | -------------------------------------------------------------------------------- /acme/agents/tf/impala/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Importance-weighted actor-learner architecture (IMPALA) agent.""" 16 | 17 | from acme.agents.tf.impala.acting import IMPALAActor 18 | from acme.agents.tf.impala.agent import IMPALA 19 | from acme.agents.tf.impala.agent_distributed import DistributedIMPALA 20 | from acme.agents.tf.impala.learning import IMPALALearner 21 | -------------------------------------------------------------------------------- /acme/agents/tf/impala/agent_distributed_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Integration test for the distributed agent.""" 16 | 17 | import acme 18 | from acme.agents.tf import impala 19 | from acme.testing import fakes 20 | from acme.tf import networks 21 | import launchpad as lp 22 | 23 | from absl.testing import absltest 24 | 25 | 26 | class DistributedAgentTest(absltest.TestCase): 27 | """Simple integration/smoke test for the distributed agent.""" 28 | 29 | def test_atari(self): 30 | """Tests that the agent can run for some steps without crashing.""" 31 | env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) 32 | net_factory = lambda spec: networks.IMPALAAtariNetwork(spec.num_values) 33 | 34 | agent = impala.DistributedIMPALA( 35 | environment_factory=env_factory, 36 | network_factory=net_factory, 37 | num_actors=2, 38 | batch_size=32, 39 | sequence_length=5, 40 | sequence_period=1, 41 | ) 42 | program = agent.build() 43 | 44 | (learner_node,) = program.groups['learner'] 45 | learner_node.disable_run() 46 | 47 | lp.launch(program, launch_type='test_mt') 48 | 49 | learner: acme.Learner = learner_node.create_handle().dereference() 50 | 51 | for _ in range(5): 52 | learner.step() 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /acme/agents/tf/iqn/README.md: -------------------------------------------------------------------------------- 1 | # Implicit Quantile Networks for Distributional RL (IQN) 2 | 3 | This folder contains an implementation of the IQN algorithm introduced in 4 | ([Dabney et al., 2018]). 5 | 6 | [Dabney et al., 2018]: https://arxiv.org/abs/1806.06923 7 | -------------------------------------------------------------------------------- /acme/agents/tf/iqn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of an IQN agent.""" 16 | 17 | from acme.agents.tf.iqn.learning import IQNLearner 18 | -------------------------------------------------------------------------------- /acme/agents/tf/mcts/README.md: -------------------------------------------------------------------------------- 1 | # Monte-Carlo Tree Search (MCTS) 2 | 3 | This agent implements planning with a simulator (learned or otherwise), with 4 | search guided by policy and value networks. This can be thought of as a 5 | scaled-down and simplified version of the AlphaZero algorithm 6 | ([Silver et al., 2018]). 7 | 8 | The algorithm is agnostic to the choice of environment model -- this can be an 9 | exact simulator (as in AlphaZero), or a learned transition model; we provide 10 | examples of both cases. 11 | 12 | [Silver et al., 2018]: https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go 13 | -------------------------------------------------------------------------------- /acme/agents/tf/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Monte-Carlo tree search (MCTS) agent.""" 16 | 17 | from acme.agents.tf.mcts.agent import MCTS 18 | from acme.agents.tf.mcts.agent_distributed import DistributedMCTS 19 | -------------------------------------------------------------------------------- /acme/agents/tf/mcts/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Models for planning via MCTS.""" 16 | 17 | # pylint: disable=unused-import 18 | 19 | from acme.agents.tf.mcts.models.base import Model 20 | -------------------------------------------------------------------------------- /acme/agents/tf/mcts/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base model class, specifying the interface..""" 16 | 17 | import abc 18 | from typing import Optional 19 | 20 | from acme.agents.tf.mcts import types 21 | 22 | import dm_env 23 | 24 | 25 | class Model(dm_env.Environment, abc.ABC): 26 | """Base (abstract) class for models used for planning via MCTS.""" 27 | 28 | @abc.abstractmethod 29 | def load_checkpoint(self): 30 | """Loads a saved model state, if it exists.""" 31 | 32 | @abc.abstractmethod 33 | def save_checkpoint(self): 34 | """Saves the model state so that we can reset it after a rollout.""" 35 | 36 | @abc.abstractmethod 37 | def update( 38 | self, 39 | timestep: dm_env.TimeStep, 40 | action: types.Action, 41 | next_timestep: dm_env.TimeStep, 42 | ) -> dm_env.TimeStep: 43 | """Updates the model given an observation, action, reward, and discount.""" 44 | 45 | @abc.abstractmethod 46 | def reset(self, initial_state: Optional[types.Observation] = None): 47 | """Resets the model, optionally to an initial state.""" 48 | 49 | @property 50 | @abc.abstractmethod 51 | def needs_reset(self) -> bool: 52 | """Returns whether or not the model needs to be reset.""" 53 | -------------------------------------------------------------------------------- /acme/agents/tf/mcts/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Type aliases and assumptions that are specific to the MCTS agent.""" 16 | 17 | from typing import Callable, Tuple, Union 18 | import numpy as np 19 | 20 | # pylint: disable=invalid-name 21 | 22 | # Assumption: actions are scalar and discrete (integral). 23 | Action = Union[int, np.int32, np.int64] 24 | 25 | # Assumption: observations are array-like. 26 | Observation = np.ndarray 27 | 28 | # Assumption: rewards and discounts are scalar. 29 | Reward = Union[float, np.float32, np.float64] 30 | Discount = Union[float, np.float32, np.float64] 31 | 32 | # Notation: policy logits/probabilities are simply a vector of floats. 33 | Probs = np.ndarray 34 | 35 | # Notation: the value function is scalar-valued. 36 | Value = float 37 | 38 | # Notation: the 'evaluation function' maps observations -> (probs, value). 39 | EvaluationFn = Callable[[Observation], Tuple[Probs, Value]] 40 | -------------------------------------------------------------------------------- /acme/agents/tf/mog_mpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a (MoG) distributional MPO agent.""" 16 | 17 | from acme.agents.tf.mog_mpo.agent_distributed import DistributedMoGMPO 18 | from acme.agents.tf.mog_mpo.learning import MoGMPOLearner 19 | from acme.agents.tf.mog_mpo.learning import PolicyEvaluationConfig 20 | from acme.agents.tf.mog_mpo.networks import make_default_networks 21 | -------------------------------------------------------------------------------- /acme/agents/tf/mompo/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Objective Maximum a posteriori Policy Optimization (MO-MPO) 2 | 3 | This folder contains an implementation of Multi-Objective Maximum a posteriori 4 | Policy Optimization (MO-MPO), introduced in ([Abdolmaleki, Huang et al., 2020]). 5 | This trains a policy that optimizes for multiple objectives, with the desired 6 | preference across objectives encoded by the hyperparameters `epsilon`. 7 | 8 | As with our MPO agent, while this is a more general algorithm, the current 9 | implementation targets the continuous control setting and is most readily 10 | applied to the DeepMind control suite or similar control tasks. This 11 | implementation also includes the options of: 12 | 13 | * per-dimension KL constraint satisfaction, and 14 | * distributional (per-objective) critics, as used by the DMPO agent 15 | 16 | Detailed notes: 17 | 18 | * When using per-dimension KL constraint satisfaction, you may need to tune 19 | the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule 20 | of thumb would be to divide it by the number of dimensions in the action 21 | space. 22 | * If using a distributional critic, the `vmin|vmax` hyperparameters of the 23 | distributional critic may need tuning depending on your environment's 24 | rewards. A good rule of thumb is to set `vmax` to the discounted sum of the 25 | maximum instantaneous rewards for the maximum episode length; then set 26 | `vmin` to `-vmax`. 27 | 28 | [Abdolmaleki, Huang et al., 2020]: https://arxiv.org/abs/2005.07513 29 | -------------------------------------------------------------------------------- /acme/agents/tf/mompo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a distributional MPO agent.""" 16 | 17 | from acme.agents.tf.mompo.agent import MultiObjectiveMPO 18 | from acme.agents.tf.mompo.agent_distributed import DistributedMultiObjectiveMPO 19 | from acme.agents.tf.mompo.learning import MultiObjectiveMPOLearner 20 | from acme.agents.tf.mompo.learning import QValueObjective 21 | from acme.agents.tf.mompo.learning import RewardObjective 22 | -------------------------------------------------------------------------------- /acme/agents/tf/mpo/README.md: -------------------------------------------------------------------------------- 1 | # Maximum a posteriori Policy Optimization (MPO) 2 | 3 | This folder contains an implementation of Maximum a posteriori Policy 4 | Optimization (MPO) introduced in ([Abdolmaleki et al., 2018a], [2018b]). While 5 | this is a more general algorithm, the current implementation targets the 6 | continuous control setting and is most readily applied to the DeepMind control 7 | suite or similar control tasks. 8 | 9 | This implementation includes a few important options such as: 10 | 11 | * per-dimension KL constraint satisfaction, and 12 | * action penalization via the multi-objective MPO work of 13 | [Abdolmaleki, Huang et al., 2020]. 14 | 15 | See the DMPO agent directory for a similar agent that uses a distributional 16 | critic. See the MO-MPO agent directory for an agent that optimizes for multiple 17 | objectives. 18 | 19 | Detailed notes: 20 | 21 | * When using per-dimension KL constraint satisfaction, you may need to tune 22 | the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule 23 | of thumb would be to divide it by the number of dimensions in the action 24 | space. 25 | 26 | [Abdolmaleki et al., 2018a]: https://arxiv.org/pdf/1806.06920.pdf 27 | [2018b]: https://arxiv.org/pdf/1812.02256.pdf 28 | [Abdolmaleki, Huang et al., 2020]: https://arxiv.org/pdf/2005.07513.pdf 29 | -------------------------------------------------------------------------------- /acme/agents/tf/mpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a MPO agent.""" 16 | 17 | from acme.agents.tf.mpo.agent import MPO 18 | from acme.agents.tf.mpo.agent_distributed import DistributedMPO 19 | from acme.agents.tf.mpo.learning import MPOLearner 20 | -------------------------------------------------------------------------------- /acme/agents/tf/r2d2/README.md: -------------------------------------------------------------------------------- 1 | # R2D2 - Recurrent Experience Replay in Distributed Reinforcement Learning 2 | 3 | This folder contains an implementation of the R2D2 agent introduced in 4 | ([Kapturowski et al., 2019]). This work builds upon the DQN algorithm 5 | ([Mnih et al., 2013], [Mnih et al., 2015]) and Ape-X framework ([Horgan et al., 6 | 2018]), extending distributed Q-Learning to use recurrent neural networks. This 7 | version is a synchronous version of the agent, and is therefore not distributed. 8 | 9 | [Kapturowski et al., 2019]: https://openreview.net/forum?id=r1lyTjAqYX 10 | [Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 11 | [Mnih et al., 2015]: https://www.nature.com/articles/nature14236 12 | [Horgan et al. 2018]: https://arxiv.org/pdf/1803.00933 13 | -------------------------------------------------------------------------------- /acme/agents/tf/r2d2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for Recurrent DQN (R2D2).""" 16 | 17 | from acme.agents.tf.r2d2.agent import R2D2 18 | from acme.agents.tf.r2d2.agent_distributed import DistributedR2D2 19 | -------------------------------------------------------------------------------- /acme/agents/tf/r2d2/agent_distributed_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Integration test for the distributed agent.""" 16 | 17 | import acme 18 | from acme.agents.tf import r2d2 19 | from acme.testing import fakes 20 | from acme.tf import networks 21 | import launchpad as lp 22 | 23 | from absl.testing import absltest 24 | 25 | 26 | class DistributedAgentTest(absltest.TestCase): 27 | """Simple integration/smoke test for the distributed agent.""" 28 | 29 | def test_agent(self): 30 | env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) 31 | net_factory = lambda spec: networks.R2D2AtariNetwork(spec.num_values) 32 | 33 | agent = r2d2.DistributedR2D2( 34 | environment_factory=env_factory, 35 | network_factory=net_factory, 36 | num_actors=2, 37 | batch_size=32, 38 | min_replay_size=32, 39 | max_replay_size=1000, 40 | replay_period=1, 41 | burn_in_length=1, 42 | trace_length=10, 43 | ) 44 | program = agent.build() 45 | 46 | (learner_node,) = program.groups['learner'] 47 | learner_node.disable_run() 48 | 49 | lp.launch(program, launch_type='test_mt') 50 | 51 | learner: acme.Learner = learner_node.create_handle().dereference() 52 | 53 | for _ in range(5): 54 | learner.step() 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /acme/agents/tf/r2d3/README.md: -------------------------------------------------------------------------------- 1 | # R2D3 - R2D2 from Demonstrations 2 | 3 | This folder contains an implementation of the R2D3 agent introduced in 4 | ([Paine et al., 2019]). This work builds upon the R2D2 algorithm 5 | ([Kapturowski et al., 2019]). 6 | 7 | In this case a learner similar to the one used in R2D2 receives batches with a 8 | fixed proportion of replay buffer and demonstration data. 9 | 10 | [Paine et al., 2019]: https://arxiv.org/abs/1909.01387 11 | [Kapturowski et al., 2019]: https://openreview.net/forum?id=r1lyTjAqYX 12 | -------------------------------------------------------------------------------- /acme/agents/tf/r2d3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for Recurrent DQfD (R2D3).""" 16 | 17 | from acme.agents.tf.r2d3.agent import R2D3 18 | -------------------------------------------------------------------------------- /acme/agents/tf/svg0_prior/README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Value Gradients (SVG) with Behavior Prior. 2 | 3 | This folder contains a version of the SVG-0 agent introduced in 4 | ([Heess et al., 2015]) that has been extended with an entropy bonus, RETRACE 5 | ([Munos et al., 2016]) for off-policy correction and code to learn behavior 6 | priors ([Tirumala et al., 2019], [Galashov et al., 2019]). 7 | 8 | The base SVG-0 algorithm is similar to DPG and DDPG ([Silver et al., 2015], 9 | [Lillicrap et al., 2015]) but uses the reparameterization trick to learn 10 | stochastic and not deterministic policies. In addition, the RETRACE algorithm is 11 | used to learn value functions using multiple timesteps of data with importance 12 | sampling for off policy correction. 13 | 14 | In addition an optional Behavior Prior can be learnt using this setup with an 15 | information asymmetry that has shown to boost performance in some domains. 16 | Example code to run with and without behavior priors on the DeepMind Control 17 | Suite and Locomotion tasks are provided in the `examples` folder. 18 | 19 | [Heess et al., 2015]: https://arxiv.org/abs/1510.09142 20 | [Munos et al., 2016]: https://arxiv.org/abs/1606.02647 21 | [Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 22 | [Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 23 | [Tirumala et al., 2020]: https://arxiv.org/abs/2010.14274 24 | [Galashov et al., 2019]: https://arxiv.org/abs/1905.01240 25 | -------------------------------------------------------------------------------- /acme/agents/tf/svg0_prior/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementations of a SVG0 agent with prior.""" 16 | 17 | from acme.agents.tf.svg0_prior.agent import SVG0 18 | from acme.agents.tf.svg0_prior.agent_distributed import DistributedSVG0 19 | from acme.agents.tf.svg0_prior.learning import SVG0Learner 20 | from acme.agents.tf.svg0_prior.networks import make_default_networks 21 | from acme.agents.tf.svg0_prior.networks import make_network_with_prior 22 | -------------------------------------------------------------------------------- /acme/core_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for core.py.""" 16 | 17 | from typing import List 18 | 19 | from acme import core 20 | from acme import types 21 | 22 | from absl.testing import absltest 23 | 24 | 25 | class StepCountingLearner(core.Learner): 26 | """A learner which counts `num_steps` and then raises `StopIteration`.""" 27 | 28 | def __init__(self, num_steps: int): 29 | self.step_count = 0 30 | self.num_steps = num_steps 31 | 32 | def step(self): 33 | self.step_count += 1 34 | if self.step_count >= self.num_steps: 35 | raise StopIteration() 36 | 37 | def get_variables(self, unused: List[str]) -> List[types.NestedArray]: 38 | del unused 39 | return [] 40 | 41 | 42 | class CoreTest(absltest.TestCase): 43 | 44 | def test_learner_run_with_limit(self): 45 | learner = StepCountingLearner(100) 46 | learner.run(7) 47 | self.assertEqual(learner.step_count, 7) 48 | 49 | def test_learner_run_no_limit(self): 50 | learner = StepCountingLearner(100) 51 | with self.assertRaises(StopIteration): 52 | learner.run() 53 | self.assertEqual(learner.step_count, 100) 54 | 55 | 56 | if __name__ == '__main__': 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /acme/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset interfaces.""" 16 | 17 | from acme.datasets.numpy_iterator import NumpyIterator 18 | from acme.datasets.reverb import make_reverb_dataset 19 | # from acme.datasets.reverb import make_reverb_dataset_trajectory 20 | -------------------------------------------------------------------------------- /acme/datasets/numpy_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A iterator that does zero-copy conversion of `tf.Tensor`s into `np.ndarray`s.""" 16 | 17 | from typing import Iterator 18 | 19 | from acme import types 20 | import numpy as np 21 | import tree 22 | 23 | 24 | class NumpyIterator(Iterator[types.NestedArray]): 25 | """Iterator over a dataset with elements converted to numpy. 26 | 27 | Note: This iterator returns read-only numpy arrays. 28 | 29 | This iterator (compared to `tf.data.Dataset.as_numpy_iterator()`) does not 30 | copy the data when comverting `tf.Tensor`s to `np.ndarray`s. 31 | 32 | TODO(b/178684359): Remove this when it is upstreamed into `tf.data`. 33 | """ 34 | 35 | __slots__ = ['_iterator'] 36 | 37 | def __init__(self, dataset): 38 | self._iterator: Iterator[types.NestedTensor] = iter(dataset) 39 | 40 | def __iter__(self) -> 'NumpyIterator': 41 | return self 42 | 43 | def __next__(self) -> types.NestedArray: 44 | return tree.map_structure(lambda t: np.asarray(memoryview(t)), 45 | next(self._iterator)) 46 | 47 | def next(self): 48 | return self.__next__() 49 | -------------------------------------------------------------------------------- /acme/datasets/numpy_iterator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for acme.datasets.numpy_iterator.""" 16 | 17 | import collections 18 | 19 | from acme.datasets import numpy_iterator 20 | import tensorflow as tf 21 | 22 | from absl.testing import absltest 23 | 24 | 25 | class NumpyIteratorTest(absltest.TestCase): 26 | 27 | def testBasic(self): 28 | ds = tf.data.Dataset.range(3) 29 | self.assertEqual([0, 1, 2], list(numpy_iterator.NumpyIterator(ds))) 30 | 31 | def testNestedStructure(self): 32 | point = collections.namedtuple('Point', ['x', 'y']) 33 | ds = tf.data.Dataset.from_tensor_slices({ 34 | 'a': ([1, 2], [3, 4]), 35 | 'b': [5, 6], 36 | 'c': point([7, 8], [9, 10]) 37 | }) 38 | self.assertEqual([{ 39 | 'a': (1, 3), 40 | 'b': 5, 41 | 'c': point(7, 9) 42 | }, { 43 | 'a': (2, 4), 44 | 'b': 6, 45 | 'c': point(8, 10) 46 | }], list(numpy_iterator.NumpyIterator(ds))) 47 | 48 | if __name__ == '__main__': 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /acme/environment_loops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Specialized environment loops.""" 16 | 17 | try: 18 | # pylint: disable=g-import-not-at-top 19 | from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop 20 | except ImportError: 21 | pass 22 | -------------------------------------------------------------------------------- /acme/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /acme/jax/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """JAX experiment utils.""" 16 | 17 | from acme.jax.experiments.config import CheckpointingConfig 18 | from acme.jax.experiments.config import default_evaluator_factory 19 | from acme.jax.experiments.config import DeprecatedPolicyFactory 20 | from acme.jax.experiments.config import EvaluatorFactory 21 | from acme.jax.experiments.config import ExperimentConfig 22 | from acme.jax.experiments.config import make_policy 23 | from acme.jax.experiments.config import MakeActorFn 24 | from acme.jax.experiments.config import NetworkFactory 25 | from acme.jax.experiments.config import OfflineExperimentConfig 26 | from acme.jax.experiments.config import PolicyFactory 27 | from acme.jax.experiments.config import SnapshotModelFactory 28 | from acme.jax.experiments.make_distributed_experiment import make_distributed_experiment 29 | from acme.jax.experiments.make_distributed_offline_experiment import make_distributed_offline_experiment 30 | from acme.jax.experiments.run_experiment import run_experiment 31 | from acme.jax.experiments.run_offline_experiment import run_offline_experiment 32 | -------------------------------------------------------------------------------- /acme/jax/experiments/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for testing of acme.jax.experiments functions.""" 16 | 17 | from acme.jax import experiments 18 | from acme.tf import savers 19 | from acme.utils import counting 20 | 21 | 22 | def restore_counter( 23 | checkpointing_config: experiments.CheckpointingConfig) -> counting.Counter: 24 | """Restores a counter from the latest checkpoint saved with this config.""" 25 | counter = counting.Counter() 26 | savers.Checkpointer( 27 | objects_to_save={'counter': counter}, 28 | directory=checkpointing_config.directory, 29 | add_uid=checkpointing_config.add_uid, 30 | max_to_keep=checkpointing_config.max_to_keep) 31 | return counter 32 | -------------------------------------------------------------------------------- /acme/jax/imitation_learning_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """JAX type definitions for imitation and apprenticeship learning algorithms.""" 16 | 17 | from typing import TypeVar 18 | 19 | # Common TypeVars that correspond to various aspects of the direct RL algorithm. 20 | DirectPolicyNetwork = TypeVar('DirectPolicyNetwork') 21 | DirectRLNetworks = TypeVar('DirectRLNetworks') 22 | DirectRLTrainingState = TypeVar('DirectRLTrainingState') 23 | -------------------------------------------------------------------------------- /acme/jax/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common loss functions.""" 16 | 17 | from acme.jax.losses.impala import impala_loss 18 | from acme.jax.losses.mpo import MPO 19 | from acme.jax.losses.mpo import MPOParams 20 | from acme.jax.losses.mpo import MPOStats 21 | from acme.jax.losses.wpo import WPO 22 | from acme.jax.losses.wpo import WPOParams 23 | from acme.jax.losses.wpo import WPOStats 24 | -------------------------------------------------------------------------------- /acme/jax/networks/policy_value.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Policy-value network head for actor-critic algorithms.""" 16 | 17 | from typing import Tuple 18 | 19 | import haiku as hk 20 | import jax.numpy as jnp 21 | 22 | 23 | class PolicyValueHead(hk.Module): 24 | """A network with two linear layers, for policy and value respectively.""" 25 | 26 | def __init__(self, num_actions: int): 27 | super().__init__(name='policy_value_network') 28 | self._policy_layer = hk.Linear(num_actions) 29 | self._value_layer = hk.Linear(1) 30 | 31 | def __call__(self, inputs: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 32 | """Returns a (Logits, Value) tuple.""" 33 | logits = self._policy_layer(inputs) # [B, A] 34 | value = jnp.squeeze(self._value_layer(inputs), axis=-1) # [B] 35 | 36 | return logits, value 37 | -------------------------------------------------------------------------------- /acme/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Multiagent helpers.""" 16 | -------------------------------------------------------------------------------- /acme/multiagent/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Types for multiagent setups.""" 16 | 17 | from typing import Any, Callable, Dict, Tuple 18 | 19 | from acme import specs 20 | from acme.agents.jax import builders as jax_builders 21 | from acme.utils.loggers import base 22 | import reverb 23 | 24 | 25 | # Sub-agent types 26 | AgentID = str 27 | EvalMode = bool 28 | GenericAgent = Any 29 | AgentConfig = Any 30 | Networks = Any 31 | PolicyNetwork = Any 32 | LoggerFn = Callable[[], base.Logger] 33 | InitNetworkFn = Callable[[GenericAgent, specs.EnvironmentSpec], Networks] 34 | InitPolicyNetworkFn = Callable[ 35 | [GenericAgent, Networks, specs.EnvironmentSpec, AgentConfig, bool], 36 | Networks] 37 | InitBuilderFn = Callable[[GenericAgent, AgentConfig], 38 | jax_builders.GenericActorLearnerBuilder] 39 | 40 | # Multiagent types 41 | MultiAgentLoggerFn = Dict[AgentID, LoggerFn] 42 | MultiAgentNetworks = Dict[AgentID, Networks] 43 | MultiAgentPolicyNetworks = Dict[AgentID, PolicyNetwork] 44 | MultiAgentSample = Tuple[reverb.ReplaySample, ...] 45 | NetworkFactory = Callable[[specs.EnvironmentSpec], MultiAgentNetworks] 46 | PolicyFactory = Callable[[MultiAgentNetworks, EvalMode], 47 | MultiAgentPolicyNetworks] 48 | BuilderFactory = Callable[[ 49 | Dict[AgentID, GenericAgent], 50 | Dict[AgentID, AgentConfig], 51 | ], Dict[AgentID, jax_builders.GenericActorLearnerBuilder]] 52 | -------------------------------------------------------------------------------- /acme/multiagent/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Multiagent utilities.""" 16 | 17 | from acme import specs 18 | from acme.multiagent import types 19 | import dm_env 20 | 21 | 22 | def get_agent_spec(env_spec: specs.EnvironmentSpec, 23 | agent_id: types.AgentID) -> specs.EnvironmentSpec: 24 | """Returns a single agent spec from environment spec. 25 | 26 | Args: 27 | env_spec: environment spec, wherein observation, action, and reward specs 28 | are simply lists (with each entry specifying the respective spec for the 29 | given agent index). Discounts are scalars shared amongst agents. 30 | agent_id: agent index. 31 | """ 32 | return specs.EnvironmentSpec( 33 | actions=env_spec.actions[agent_id], 34 | discounts=env_spec.discounts, 35 | observations=env_spec.observations[agent_id], 36 | rewards=env_spec.rewards[agent_id]) 37 | 38 | 39 | def get_agent_timestep(timestep: dm_env.TimeStep, 40 | agent_id: types.AgentID) -> dm_env.TimeStep: 41 | """Returns the extracted timestep for a particular agent.""" 42 | # Discounts are assumed to be shared amongst agents 43 | reward = None if timestep.reward is None else timestep.reward[agent_id] 44 | return dm_env.TimeStep( 45 | observation=timestep.observation[agent_id], 46 | reward=reward, 47 | discount=timestep.discount, 48 | step_type=timestep.step_type) 49 | -------------------------------------------------------------------------------- /acme/specs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Objects which specify the input/output spaces of an environment. 16 | 17 | This module exposes the same spec classes as `dm_env` as well as providing an 18 | additional `EnvironmentSpec` class which collects all of the specs for a given 19 | environment. An `EnvironmentSpec` instance can be created directly or by using 20 | the `make_environment_spec` helper given a `dm_env.Environment` instance. 21 | """ 22 | 23 | from typing import Any, NamedTuple 24 | 25 | import dm_env 26 | from dm_env import specs 27 | 28 | Array = specs.Array 29 | BoundedArray = specs.BoundedArray 30 | DiscreteArray = specs.DiscreteArray 31 | 32 | 33 | class EnvironmentSpec(NamedTuple): 34 | """Full specification of the domains used by a given environment.""" 35 | # TODO(b/144758674): Use NestedSpec type here. 36 | observations: Any 37 | actions: Any 38 | rewards: Any 39 | discounts: Any 40 | 41 | 42 | def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec: 43 | """Returns an `EnvironmentSpec` describing values used by an environment.""" 44 | return EnvironmentSpec( 45 | observations=environment.observation_spec(), 46 | actions=environment.action_spec(), 47 | rewards=environment.reward_spec(), 48 | discounts=environment.discount_spec()) 49 | -------------------------------------------------------------------------------- /acme/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Testing helpers.""" 16 | -------------------------------------------------------------------------------- /acme/testing/multiagent_fakes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Fake (mock) components for multiagent testing.""" 16 | 17 | from typing import Dict, List 18 | 19 | from acme import specs 20 | import numpy as np 21 | 22 | 23 | def _make_multiagent_spec(agent_indices: List[str]) -> Dict[str, specs.Array]: 24 | """Returns dummy multiagent sub-spec (e.g., observation or action spec). 25 | 26 | Args: 27 | agent_indices: a list of agent indices. 28 | """ 29 | return { 30 | agent_id: specs.BoundedArray((1,), np.float32, 0, 1) 31 | for agent_id in agent_indices 32 | } 33 | 34 | 35 | def make_multiagent_environment_spec( 36 | agent_indices: List[str]) -> specs.EnvironmentSpec: 37 | """Returns dummy multiagent environment spec. 38 | 39 | Args: 40 | agent_indices: a list of agent indices. 41 | """ 42 | action_spec = _make_multiagent_spec(agent_indices) 43 | discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) 44 | observation_spec = _make_multiagent_spec(agent_indices) 45 | reward_spec = _make_multiagent_spec(agent_indices) 46 | return specs.EnvironmentSpec( 47 | actions=action_spec, 48 | discounts=discount_spec, 49 | observations=observation_spec, 50 | rewards=reward_spec) 51 | -------------------------------------------------------------------------------- /acme/testing/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Testing utilities.""" 16 | 17 | import sys 18 | from typing import Optional 19 | 20 | from absl import flags 21 | from absl.testing import parameterized 22 | 23 | 24 | class TestCase(parameterized.TestCase): 25 | """A custom TestCase which handles FLAG parsing for pytest compatibility.""" 26 | 27 | def get_tempdir(self, name: Optional[str] = None) -> str: 28 | try: 29 | flags.FLAGS.test_tmpdir 30 | except flags.UnparsedFlagAccessError: 31 | # Need to initialize flags when running `pytest`. 32 | flags.FLAGS(sys.argv, known_only=True) 33 | return self.create_tempdir(name).full_path 34 | -------------------------------------------------------------------------------- /acme/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /acme/tf/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Various losses for training agent components (policies, critics, etc).""" 16 | 17 | from acme.tf.losses.distributional import categorical 18 | from acme.tf.losses.distributional import multiaxis_categorical 19 | from acme.tf.losses.dpg import dpg 20 | from acme.tf.losses.huber import huber 21 | from acme.tf.losses.mompo import KLConstraint 22 | from acme.tf.losses.mompo import MultiObjectiveMPO 23 | from acme.tf.losses.mpo import MPO 24 | from acme.tf.losses.r2d2 import transformed_n_step_loss 25 | 26 | # Internal imports. 27 | # pylint: disable=g-bad-import-order,g-import-not-at-top 28 | from acme.tf.losses.quantile import NonUniformQuantileRegression 29 | from acme.tf.losses.quantile import QuantileDistribution 30 | -------------------------------------------------------------------------------- /acme/tf/networks/discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Networks used in discrete-action agents.""" 16 | 17 | import sonnet as snt 18 | import tensorflow as tf 19 | 20 | 21 | class DiscreteFilteredQNetwork(snt.Module): 22 | """Discrete filtered Q-network. 23 | 24 | This produces filtered Q values according to the method used in the discrete 25 | BCQ algorithm (https://arxiv.org/pdf/1910.01708.pdf - section 4). 26 | """ 27 | 28 | def __init__(self, 29 | g_network: snt.Module, 30 | q_network: snt.Module, 31 | threshold: float): 32 | super().__init__(name='discrete_filtered_qnet') 33 | assert threshold >= 0 and threshold <= 1 34 | self.g_network = g_network 35 | self.q_network = q_network 36 | self._threshold = threshold 37 | 38 | def __call__(self, o_t: tf.Tensor) -> tf.Tensor: 39 | q_t = self.q_network(o_t) 40 | g_t = tf.nn.softmax(self.g_network(o_t)) 41 | normalized_g_t = g_t / tf.reduce_max(g_t, axis=-1, keepdims=True) 42 | 43 | # Filter actions based on g_network outputs. 44 | min_q = tf.reduce_min(q_t, axis=-1, keepdims=True) 45 | return tf.where(normalized_g_t >= self._threshold, q_t, min_q) 46 | -------------------------------------------------------------------------------- /acme/tf/networks/duelling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A duelling network architecture, as described in [0]. 16 | 17 | [0] https://arxiv.org/abs/1511.06581 18 | """ 19 | 20 | from typing import Sequence 21 | 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | 26 | class DuellingMLP(snt.Module): 27 | """A Duelling MLP Q-network.""" 28 | 29 | def __init__( 30 | self, 31 | num_actions: int, 32 | hidden_sizes: Sequence[int], 33 | ): 34 | super().__init__(name='duelling_q_network') 35 | 36 | self._value_mlp = snt.nets.MLP([*hidden_sizes, 1]) 37 | self._advantage_mlp = snt.nets.MLP([*hidden_sizes, num_actions]) 38 | 39 | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: 40 | """Forward pass of the duelling network. 41 | 42 | Args: 43 | inputs: 2-D tensor of shape [batch_size, embedding_size]. 44 | 45 | Returns: 46 | q_values: 2-D tensor of action values of shape [batch_size, num_actions] 47 | """ 48 | 49 | # Compute value & advantage for duelling. 50 | value = self._value_mlp(inputs) # [B, 1] 51 | advantages = self._advantage_mlp(inputs) # [B, A] 52 | 53 | # Advantages have zero mean. 54 | advantages -= tf.reduce_mean(advantages, axis=-1, keepdims=True) # [B, A] 55 | 56 | q_values = value + advantages # [B, A] 57 | 58 | return q_values 59 | -------------------------------------------------------------------------------- /acme/tf/networks/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Modules for computing custom embeddings.""" 16 | 17 | from acme.tf.networks import base 18 | from acme.wrappers import observation_action_reward 19 | 20 | import sonnet as snt 21 | import tensorflow as tf 22 | 23 | 24 | class OAREmbedding(snt.Module): 25 | """Module for embedding (observation, action, reward) inputs together.""" 26 | 27 | def __init__(self, torso: base.Module, num_actions: int): 28 | super().__init__(name='oar_embedding') 29 | self._num_actions = num_actions 30 | self._torso = torso 31 | 32 | def __call__(self, inputs: observation_action_reward.OAR) -> tf.Tensor: 33 | """Embed each of the (observation, action, reward) inputs & concatenate.""" 34 | 35 | # Add dummy trailing dimension to rewards if necessary. 36 | if len(inputs.reward.shape.dims) == 1: 37 | inputs = inputs._replace(reward=tf.expand_dims(inputs.reward, axis=-1)) 38 | 39 | features = self._torso(inputs.observation) # [T?, B, D] 40 | action = tf.one_hot(inputs.action, depth=self._num_actions) # [T?, B, A] 41 | reward = tf.nn.tanh(inputs.reward) # [T?, B, 1] 42 | 43 | embedding = tf.concat([features, action, reward], axis=-1) # [T?, B, D+A+1] 44 | 45 | return embedding 46 | -------------------------------------------------------------------------------- /acme/tf/networks/noise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Noise layers (for exploration).""" 16 | 17 | from acme import types 18 | import sonnet as snt 19 | import tensorflow as tf 20 | import tensorflow_probability as tfp 21 | import tree 22 | 23 | tfd = tfp.distributions 24 | 25 | 26 | class ClippedGaussian(snt.Module): 27 | """Sonnet module for adding clipped Gaussian noise to each output.""" 28 | 29 | def __init__(self, stddev: float, name: str = 'clipped_gaussian'): 30 | super().__init__(name=name) 31 | self._noise = tfd.Normal(loc=0., scale=stddev) 32 | 33 | def __call__(self, inputs: types.NestedTensor) -> types.NestedTensor: 34 | def add_noise(tensor: tf.Tensor): 35 | output = tensor + tf.cast(self._noise.sample(tensor.shape), 36 | dtype=tensor.dtype) 37 | output = tf.clip_by_value(output, -1.0, 1.0) 38 | return output 39 | 40 | return tree.map_structure(add_noise, inputs) 41 | -------------------------------------------------------------------------------- /acme/tf/networks/policy_value.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Policy-value network head for actor-critic algorithms.""" 16 | 17 | from typing import Tuple 18 | 19 | import sonnet as snt 20 | import tensorflow as tf 21 | 22 | 23 | class PolicyValueHead(snt.Module): 24 | """A network with two linear layers, for policy and value respectively.""" 25 | 26 | def __init__(self, num_actions: int): 27 | super().__init__(name='policy_value_network') 28 | self._policy_layer = snt.Linear(num_actions) 29 | self._value_layer = snt.Linear(1) 30 | 31 | def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: 32 | """Returns a (Logits, Value) tuple.""" 33 | logits = self._policy_layer(inputs) # [B, A] 34 | value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] 35 | 36 | return logits, value 37 | -------------------------------------------------------------------------------- /acme/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Acme utility functions.""" 16 | -------------------------------------------------------------------------------- /acme/utils/experiment_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility definitions for Acme experiments.""" 16 | 17 | from typing import Optional 18 | 19 | from acme.utils import loggers 20 | 21 | 22 | def make_experiment_logger(label: str, 23 | steps_key: Optional[str] = None, 24 | task_instance: int = 0) -> loggers.Logger: 25 | del task_instance 26 | if steps_key is None: 27 | steps_key = f'{label}_steps' 28 | return loggers.make_default_logger(label=label, steps_key=steps_key) 29 | 30 | 31 | def create_experiment_logger_factory() -> loggers.LoggerFactory: 32 | return make_experiment_logger 33 | -------------------------------------------------------------------------------- /acme/utils/frozen_learner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Frozen learner.""" 16 | 17 | from typing import Callable, List, Optional, Sequence 18 | 19 | import acme 20 | 21 | 22 | class FrozenLearner(acme.Learner): 23 | """Wraps a learner ignoring the step calls, i.e. freezing it.""" 24 | 25 | def __init__(self, 26 | learner: acme.Learner, 27 | step_fn: Optional[Callable[[], None]] = None): 28 | """Initializes the frozen learner. 29 | 30 | Args: 31 | learner: Learner to be wrapped. 32 | step_fn: Function to call instead of the step() method of the learner. 33 | This can be used, e.g. to drop samples from an iterator that would 34 | normally be consumed by the learner. 35 | """ 36 | self._learner = learner 37 | self._step_fn = step_fn 38 | 39 | def step(self): 40 | """See base class.""" 41 | if self._step_fn: 42 | self._step_fn() 43 | 44 | def run(self, num_steps: Optional[int] = None): 45 | """See base class.""" 46 | self._learner.run(num_steps) 47 | 48 | def save(self): 49 | """See base class.""" 50 | return self._learner.save() 51 | 52 | def restore(self, state): 53 | """See base class.""" 54 | self._learner.restore(state) 55 | 56 | def get_variables(self, names: Sequence[str]) -> List[acme.types.NestedArray]: 57 | """See base class.""" 58 | return self._learner.get_variables(names) 59 | -------------------------------------------------------------------------------- /acme/utils/iterator_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Iterator utilities.""" 16 | import itertools 17 | import operator 18 | from typing import Any, Iterator, List, Sequence 19 | 20 | 21 | def unzip_iterators(zipped_iterators: Iterator[Sequence[Any]], 22 | num_sub_iterators: int) -> List[Iterator[Any]]: 23 | """Returns unzipped iterators. 24 | 25 | Note that simply returning: 26 | [(x[i] for x in iter_tuple[i]) for i in range(num_sub_iterators)] 27 | seems to cause all iterators to point to the final value of i, thus causing 28 | all sub_learners to consume data from this final iterator. 29 | 30 | Args: 31 | zipped_iterators: zipped iterators (e.g., from zip_iterators()). 32 | num_sub_iterators: the number of sub-iterators in the zipped iterator. 33 | """ 34 | iter_tuple = itertools.tee(zipped_iterators, num_sub_iterators) 35 | return [ 36 | map(operator.itemgetter(i), iter_tuple[i]) 37 | for i in range(num_sub_iterators) 38 | ] 39 | -------------------------------------------------------------------------------- /acme/utils/iterator_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for iterator_utils.""" 16 | 17 | from acme.utils import iterator_utils 18 | import numpy as np 19 | 20 | from absl.testing import absltest 21 | 22 | 23 | class IteratorUtilsTest(absltest.TestCase): 24 | 25 | def test_iterator_zipping(self): 26 | 27 | def get_iters(): 28 | x = iter(range(0, 10)) 29 | y = iter(range(20, 30)) 30 | return [x, y] 31 | 32 | zipped = zip(*get_iters()) 33 | unzipped = iterator_utils.unzip_iterators(zipped, num_sub_iterators=2) 34 | expected_x, expected_y = get_iters() 35 | np.testing.assert_equal(list(unzipped[0]), list(expected_x)) 36 | np.testing.assert_equal(list(unzipped[1]), list(expected_y)) 37 | 38 | 39 | if __name__ == '__main__': 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /acme/utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Acme loggers.""" 16 | 17 | from acme.utils.loggers.aggregators import Dispatcher 18 | from acme.utils.loggers.asynchronous import AsyncLogger 19 | from acme.utils.loggers.auto_close import AutoCloseLogger 20 | from acme.utils.loggers.base import Logger 21 | from acme.utils.loggers.base import LoggerFactory 22 | from acme.utils.loggers.base import LoggerLabel 23 | from acme.utils.loggers.base import LoggerStepsKey 24 | from acme.utils.loggers.base import LoggingData 25 | from acme.utils.loggers.base import NoOpLogger 26 | from acme.utils.loggers.base import TaskInstance 27 | from acme.utils.loggers.base import to_numpy 28 | from acme.utils.loggers.constant import ConstantLogger 29 | from acme.utils.loggers.csv import CSVLogger 30 | from acme.utils.loggers.dataframe import InMemoryLogger 31 | from acme.utils.loggers.filters import GatedFilter 32 | from acme.utils.loggers.filters import KeyFilter 33 | from acme.utils.loggers.filters import NoneFilter 34 | from acme.utils.loggers.filters import TimeFilter 35 | from acme.utils.loggers.flatten import FlattenDictLogger 36 | from acme.utils.loggers.default import make_default_logger # pylint: disable=g-bad-import-order 37 | from acme.utils.loggers.terminal import TerminalLogger 38 | from acme.utils.loggers.timestamp import TimestampLogger 39 | 40 | # Internal imports. 41 | -------------------------------------------------------------------------------- /acme/utils/loggers/aggregators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for aggregating to other loggers.""" 16 | 17 | from typing import Callable, Optional, Sequence 18 | from acme.utils.loggers import base 19 | 20 | 21 | class Dispatcher(base.Logger): 22 | """Writes data to multiple `Logger` objects.""" 23 | 24 | def __init__( 25 | self, 26 | to: Sequence[base.Logger], 27 | serialize_fn: Optional[Callable[[base.LoggingData], str]] = None, 28 | ): 29 | """Initialize `Dispatcher` connected to several `Logger` objects.""" 30 | self._to = to 31 | self._serialize_fn = serialize_fn 32 | 33 | def write(self, values: base.LoggingData): 34 | """Writes `values` to the underlying `Logger` objects.""" 35 | if self._serialize_fn: 36 | values = self._serialize_fn(values) 37 | for logger in self._to: 38 | logger.write(values) 39 | 40 | def close(self): 41 | for logger in self._to: 42 | logger.close() 43 | -------------------------------------------------------------------------------- /acme/utils/loggers/asynchronous.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logger which makes another logger asynchronous.""" 16 | 17 | from typing import Any, Mapping 18 | 19 | from acme.utils import async_utils 20 | from acme.utils.loggers import base 21 | 22 | 23 | class AsyncLogger(base.Logger): 24 | """Logger which makes the logging to another logger asyncronous.""" 25 | 26 | def __init__(self, to: base.Logger): 27 | """Initializes the logger. 28 | 29 | Args: 30 | to: A `Logger` object to which the current object will forward its results 31 | when `write` is called. 32 | """ 33 | self._to = to 34 | self._async_worker = async_utils.AsyncExecutor(self._to.write, queue_size=5) 35 | 36 | def write(self, values: Mapping[str, Any]): 37 | self._async_worker.put(values) 38 | 39 | def close(self): 40 | """Closes the logger, closing is synchronous.""" 41 | self._async_worker.close() 42 | self._to.close() 43 | -------------------------------------------------------------------------------- /acme/utils/loggers/auto_close.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logger which self closes on exit if not closed yet.""" 16 | 17 | import weakref 18 | 19 | from acme.utils.loggers import base 20 | 21 | 22 | class AutoCloseLogger(base.Logger): 23 | """Logger which auto closes itself on exit if not already closed.""" 24 | 25 | def __init__(self, logger: base.Logger): 26 | self._logger = logger 27 | # The finalizer "logger.close" is invoked in one of the following scenario: 28 | # 1) the current logger is GC 29 | # 2) from the python doc, when the program exits, each remaining live 30 | # finalizer is called. 31 | # Note that in the normal flow, where "close" is explicitly called, 32 | # the finalizer is marked as dead using the detach function so that 33 | # the underlying logger is not closed twice (once explicitly and once 34 | # implicitly when the object is GC or when the program exits). 35 | self._finalizer = weakref.finalize(self, logger.close) 36 | 37 | def write(self, values: base.LoggingData): 38 | if self._logger is None: 39 | raise ValueError('init not called') 40 | self._logger.write(values) 41 | 42 | def close(self): 43 | if self._finalizer.detach(): 44 | self._logger.close() 45 | self._logger = None 46 | -------------------------------------------------------------------------------- /acme/utils/loggers/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for acme.utils.loggers.base.""" 16 | 17 | from acme.utils.loggers import base 18 | import jax.numpy as jnp 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from absl.testing import absltest 23 | 24 | 25 | class BaseTest(absltest.TestCase): 26 | 27 | def test_tensor_serialisation(self): 28 | data = {'x': tf.zeros(shape=(32,))} 29 | output = base.to_numpy(data) 30 | expected = {'x': np.zeros(shape=(32,))} 31 | np.testing.assert_array_equal(output['x'], expected['x']) 32 | 33 | def test_device_array_serialisation(self): 34 | data = {'x': jnp.zeros(shape=(32,))} 35 | output = base.to_numpy(data) 36 | expected = {'x': np.zeros(shape=(32,))} 37 | np.testing.assert_array_equal(output['x'], expected['x']) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /acme/utils/loggers/constant.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logger for values that remain constant.""" 16 | 17 | from acme.utils.loggers import base 18 | 19 | 20 | class ConstantLogger(base.Logger): 21 | """Logger for values that remain constant throughout the experiment. 22 | 23 | This logger is used to log additional values e.g. level_name or 24 | hyperparameters that do not change in an experiment. Having these values 25 | allows to group or facet plots when analysing data post-experiment. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | constant_data: base.LoggingData, 31 | to: base.Logger, 32 | ): 33 | """Initialise the extra info logger. 34 | 35 | Args: 36 | constant_data: Key-value pairs containing the constant info to be logged. 37 | to: The logger to add these extra info to. 38 | """ 39 | self._constant_data = constant_data 40 | self._to = to 41 | 42 | def write(self, data: base.LoggingData): 43 | self._to.write({**self._constant_data, **data}) 44 | 45 | def close(self): 46 | self._to.close() 47 | -------------------------------------------------------------------------------- /acme/utils/loggers/dataframe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logger for writing to an in-memory list. 16 | 17 | This is convenient for e.g. interactive usage via Google Colab. 18 | 19 | For example, for usage with pandas: 20 | 21 | ```python 22 | from acme.utils import loggers 23 | import pandas as pd 24 | 25 | logger = InMemoryLogger() 26 | # ... 27 | logger.write({'foo': 1.337, 'bar': 420}) 28 | 29 | results = pd.DataFrame(logger.data) 30 | ``` 31 | """ 32 | 33 | from typing import Sequence 34 | 35 | from acme.utils.loggers import base 36 | 37 | 38 | class InMemoryLogger(base.Logger): 39 | """A simple logger that keeps all data in memory.""" 40 | 41 | def __init__(self): 42 | self._data = [] 43 | 44 | def write(self, data: base.LoggingData): 45 | self._data.append(data) 46 | 47 | def close(self): 48 | pass 49 | 50 | @property 51 | def data(self) -> Sequence[base.LoggingData]: 52 | return self._data 53 | -------------------------------------------------------------------------------- /acme/utils/loggers/terminal_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for terminal logger.""" 16 | 17 | from acme.utils.loggers import terminal 18 | 19 | from absl.testing import absltest 20 | 21 | 22 | class LoggingTest(absltest.TestCase): 23 | 24 | def test_logging_output_format(self): 25 | inputs = { 26 | 'c': 'foo', 27 | 'a': 1337, 28 | 'b': 42.0001, 29 | } 30 | expected_outputs = 'A = 1337 | B = 42.000 | C = foo' 31 | test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) 32 | 33 | logger = terminal.TerminalLogger(print_fn=test_fn) 34 | logger.write(inputs) 35 | 36 | def test_label(self): 37 | inputs = {'foo': 'bar', 'baz': 123} 38 | expected_outputs = '[Test] Baz = 123 | Foo = bar' 39 | test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) 40 | 41 | logger = terminal.TerminalLogger(print_fn=test_fn, label='test') 42 | logger.write(inputs) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /acme/utils/loggers/timestamp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Timestamp logger.""" 16 | 17 | import time 18 | 19 | from acme.utils.loggers import base 20 | 21 | 22 | class TimestampLogger(base.Logger): 23 | """Logger which populates the timestamp key with the current timestamp.""" 24 | 25 | def __init__(self, logger: base.Logger, timestamp_key: str): 26 | self._logger = logger 27 | self._timestamp_key = timestamp_key 28 | 29 | def write(self, values: base.LoggingData): 30 | values = dict(values) 31 | values[self._timestamp_key] = time.time() 32 | self._logger.write(values) 33 | 34 | def close(self): 35 | self._logger.close() 36 | -------------------------------------------------------------------------------- /acme/utils/lp_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for acme launchpad utilities.""" 16 | 17 | from acme.utils import lp_utils 18 | 19 | from absl.testing import absltest 20 | 21 | 22 | class LpUtilsTest(absltest.TestCase): 23 | 24 | def test_partial_kwargs(self): 25 | 26 | def foo(a, b, c=2): 27 | return a, b, c 28 | 29 | def bar(a, b): 30 | return a, b 31 | 32 | # Override the default values. The last two should be no-ops. 33 | foo1 = lp_utils.partial_kwargs(foo, c=1) 34 | foo2 = lp_utils.partial_kwargs(foo) 35 | bar1 = lp_utils.partial_kwargs(bar) 36 | 37 | # Check that we raise errors on overriding kwargs with no default values 38 | with self.assertRaises(ValueError): 39 | lp_utils.partial_kwargs(foo, a=2) 40 | 41 | # CHeck the we raise if we try to override a kwarg that doesn't exist. 42 | with self.assertRaises(ValueError): 43 | lp_utils.partial_kwargs(foo, d=2) 44 | 45 | # Make sure we get back the correct values. 46 | self.assertEqual(foo1(1, 2), (1, 2, 1)) 47 | self.assertEqual(foo2(1, 2), (1, 2, 2)) 48 | self.assertEqual(bar1(1, 2), (1, 2)) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /acme/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """This module does nothing and exists solely for the sake of OS compatibility.""" 16 | 17 | from typing import Type, TypeVar 18 | 19 | T = TypeVar('T') 20 | 21 | 22 | def record_class_usage(cls: Type[T]) -> Type[T]: 23 | return cls 24 | -------------------------------------------------------------------------------- /acme/utils/observers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Acme observers.""" 16 | 17 | from acme.utils.observers.action_metrics import ContinuousActionObserver 18 | from acme.utils.observers.action_norm import ActionNormObserver 19 | from acme.utils.observers.base import EnvLoopObserver 20 | from acme.utils.observers.base import Number 21 | from acme.utils.observers.env_info import EnvInfoObserver 22 | from acme.utils.observers.measurement_metrics import MeasurementObserver 23 | -------------------------------------------------------------------------------- /acme/utils/observers/action_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """An observer that collects action norm stats. 16 | """ 17 | from typing import Dict 18 | 19 | from acme.utils.observers import base 20 | import dm_env 21 | import numpy as np 22 | 23 | 24 | class ActionNormObserver(base.EnvLoopObserver): 25 | """An observer that collects action norm stats.""" 26 | 27 | def __init__(self): 28 | self._action_norms = None 29 | 30 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep 31 | ) -> None: 32 | """Observes the initial state.""" 33 | self._action_norms = [] 34 | 35 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, 36 | action: np.ndarray) -> None: 37 | """Records one environment step.""" 38 | self._action_norms.append(np.linalg.norm(action)) 39 | 40 | def get_metrics(self) -> Dict[str, base.Number]: 41 | """Returns metrics collected for the current episode.""" 42 | return {'action_norm_avg': np.mean(self._action_norms), 43 | 'action_norm_min': np.min(self._action_norms), 44 | 'action_norm_max': np.max(self._action_norms)} 45 | -------------------------------------------------------------------------------- /acme/utils/observers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Metrics observers.""" 16 | 17 | import abc 18 | from typing import Dict, Union 19 | 20 | import dm_env 21 | import numpy as np 22 | 23 | 24 | Number = Union[int, float] 25 | 26 | 27 | class EnvLoopObserver(abc.ABC): 28 | """An interface for collecting metrics/counters in EnvironmentLoop.""" 29 | 30 | @abc.abstractmethod 31 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep 32 | ) -> None: 33 | """Observes the initial state.""" 34 | 35 | @abc.abstractmethod 36 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, 37 | action: np.ndarray) -> None: 38 | """Records one environment step.""" 39 | 40 | @abc.abstractmethod 41 | def get_metrics(self) -> Dict[str, Number]: 42 | """Returns metrics collected for the current episode.""" 43 | -------------------------------------------------------------------------------- /acme/utils/observers/env_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """An observer that returns env's info. 16 | """ 17 | from typing import Dict 18 | 19 | from acme.utils.observers import base 20 | import dm_env 21 | import numpy as np 22 | 23 | 24 | class EnvInfoObserver(base.EnvLoopObserver): 25 | """An observer that collects and accumulates scalars from env's info.""" 26 | 27 | def __init__(self): 28 | self._metrics = None 29 | 30 | def _accumulate_metrics(self, env: dm_env.Environment) -> None: 31 | if not hasattr(env, 'get_info'): 32 | return 33 | info = getattr(env, 'get_info')() 34 | if not info: 35 | return 36 | for k, v in info.items(): 37 | if np.isscalar(v): 38 | self._metrics[k] = self._metrics.get(k, 0) + v 39 | 40 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep 41 | ) -> None: 42 | """Observes the initial state.""" 43 | self._metrics = {} 44 | self._accumulate_metrics(env) 45 | 46 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, 47 | action: np.ndarray) -> None: 48 | """Records one environment step.""" 49 | self._accumulate_metrics(env) 50 | 51 | def get_metrics(self) -> Dict[str, base.Number]: 52 | """Returns metrics collected for the current episode.""" 53 | return self._metrics 54 | -------------------------------------------------------------------------------- /acme/utils/paths_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for paths.""" 16 | 17 | from unittest import mock 18 | 19 | from acme.testing import test_utils 20 | import acme.utils.paths as paths # pylint: disable=consider-using-from-import 21 | 22 | from absl.testing import absltest 23 | 24 | 25 | class PathTest(test_utils.TestCase): 26 | 27 | def test_process_path(self): 28 | root_directory = self.get_tempdir() 29 | with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: 30 | mock_unique_id.return_value = ('test',) 31 | path = paths.process_path(root_directory, 'foo', 'bar') 32 | self.assertEqual(path, f'{root_directory}/test/foo/bar') 33 | 34 | def test_unique_id_with_flag(self): 35 | with mock.patch.object(paths, 'ACME_ID') as mock_acme_id: 36 | mock_acme_id.value = 'test_flag' 37 | self.assertEqual(paths.get_unique_id(), ('test_flag',)) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /acme/utils/signals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper methods for handling signals.""" 16 | 17 | import contextlib 18 | import ctypes 19 | import threading 20 | from typing import Any, Callable, Optional 21 | 22 | import launchpad 23 | 24 | _Handler = Callable[[], Any] 25 | 26 | 27 | @contextlib.contextmanager 28 | def runtime_terminator(callback: Optional[_Handler] = None): 29 | """Runtime terminator used for stopping computation upon agent termination. 30 | 31 | Runtime terminator optionally executed a provided `callback` and then raises 32 | `SystemExit` exception in the thread performing the computation. 33 | 34 | Args: 35 | callback: callback to execute before raising exception. 36 | 37 | Yields: 38 | None. 39 | """ 40 | worker_id = threading.get_ident() 41 | def signal_handler(): 42 | if callback: 43 | callback() 44 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc( 45 | ctypes.c_long(worker_id), ctypes.py_object(SystemExit)) 46 | assert res < 2, 'Stopping worker failed' 47 | launchpad.register_stop_handler(signal_handler) 48 | yield 49 | launchpad.unregister_stop_handler(signal_handler) 50 | -------------------------------------------------------------------------------- /acme/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common environment wrapper classes.""" 16 | 17 | from acme.wrappers.action_repeat import ActionRepeatWrapper 18 | from acme.wrappers.atari_wrapper import AtariWrapper 19 | from acme.wrappers.base import EnvironmentWrapper 20 | from acme.wrappers.base import wrap_all 21 | from acme.wrappers.canonical_spec import CanonicalSpecWrapper 22 | from acme.wrappers.concatenate_observations import ConcatObservationWrapper 23 | from acme.wrappers.delayed_reward import DelayedRewardWrapper 24 | from acme.wrappers.expand_scalar_observation_shapes import ExpandScalarObservationShapesWrapper 25 | from acme.wrappers.frame_stacking import FrameStacker 26 | from acme.wrappers.frame_stacking import FrameStackingWrapper 27 | from acme.wrappers.gym_wrapper import GymAtariAdapter 28 | from acme.wrappers.gym_wrapper import GymWrapper 29 | from acme.wrappers.noop_starts import NoopStartsWrapper 30 | from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper 31 | from acme.wrappers.single_precision import SinglePrecisionWrapper 32 | from acme.wrappers.step_limit import StepLimitWrapper 33 | from acme.wrappers.video import MujocoVideoWrapper 34 | from acme.wrappers.video import VideoWrapper 35 | 36 | try: 37 | # pylint: disable=g-import-not-at-top 38 | from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper 39 | except ImportError: 40 | pass 41 | -------------------------------------------------------------------------------- /acme/wrappers/action_repeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Wrapper that implements action repeats.""" 16 | 17 | from acme import types 18 | from acme.wrappers import base 19 | import dm_env 20 | 21 | 22 | class ActionRepeatWrapper(base.EnvironmentWrapper): 23 | """Action repeat wrapper.""" 24 | 25 | def __init__(self, environment: dm_env.Environment, num_repeats: int = 1): 26 | super().__init__(environment) 27 | self._num_repeats = num_repeats 28 | 29 | def step(self, action: types.NestedArray) -> dm_env.TimeStep: 30 | # Initialize accumulated reward and discount. 31 | reward = 0. 32 | discount = 1. 33 | 34 | # Step the environment by repeating action. 35 | for _ in range(self._num_repeats): 36 | timestep = self._environment.step(action) 37 | 38 | # Accumulate reward and discount. 39 | reward += timestep.reward * discount 40 | discount *= timestep.discount 41 | 42 | # Don't go over episode boundaries. 43 | if timestep.last(): 44 | break 45 | 46 | # Replace the final timestep's reward and discount. 47 | return timestep._replace(reward=reward, discount=discount) 48 | -------------------------------------------------------------------------------- /acme/wrappers/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for base.""" 16 | 17 | import copy 18 | import pickle 19 | 20 | from acme.testing import fakes 21 | from acme.wrappers import base 22 | 23 | from absl.testing import absltest 24 | 25 | 26 | class BaseTest(absltest.TestCase): 27 | 28 | def test_pickle_unpickle(self): 29 | test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) 30 | 31 | test_env_pickled = pickle.dumps(test_env) 32 | test_env_restored = pickle.loads(test_env_pickled) 33 | self.assertEqual( 34 | test_env.observation_spec(), 35 | test_env_restored.observation_spec(), 36 | ) 37 | 38 | def test_deepcopy(self): 39 | test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) 40 | copied_env = copy.deepcopy(test_env) 41 | del copied_env 42 | 43 | if __name__ == '__main__': 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | div.version { 2 | color:#404040 !important; 3 | } 4 | 5 | .wy-side-nav-search { 6 | background-color: #fff; 7 | } 8 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sphinx configuration. 16 | """ 17 | 18 | project = 'Acme' 19 | author = 'DeepMind Technologies Limited' 20 | copyright = '2018, DeepMind Technologies Limited' # pylint: disable=redefined-builtin 21 | version = '' 22 | release = '' 23 | master_doc = 'index' 24 | 25 | extensions = [ 26 | 'myst_parser' 27 | ] 28 | 29 | html_theme = 'sphinx_rtd_theme' 30 | html_logo = 'imgs/acme.png' 31 | html_theme_options = { 32 | 'logo_only': True, 33 | } 34 | html_css_files = [ 35 | 'custom.css', 36 | ] 37 | 38 | templates_path = [] 39 | html_static_path = ['_static'] 40 | exclude_patterns = ['_build', 'requirements.txt'] 41 | 42 | -------------------------------------------------------------------------------- /docs/imgs/acme-notext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/imgs/acme-notext.png -------------------------------------------------------------------------------- /docs/imgs/acme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/imgs/acme.png -------------------------------------------------------------------------------- /docs/imgs/configure-and-run-experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/imgs/configure-and-run-experiments.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Acme 2 | --------------- 3 | 4 | Acme is a library of reinforcement learning (RL) building blocks that strives to 5 | expose simple, efficient, and readable agents. These agents first and foremost 6 | serve both as reference implementations as well as providing strong baselines 7 | for algorithm performance. However, the baseline agents exposed by Acme should 8 | also provide enough flexibility and simplicity that they can be used as a 9 | starting block for novel research. Finally, the building blocks of Acme are 10 | designed in such a way that the agents can be written at multiple scales (e.g. 11 | single-stream vs. distributed agents). 12 | 13 | .. toctree:: 14 | :hidden: 15 | :titlesonly: 16 | 17 | self 18 | 19 | 20 | .. toctree:: :caption: Getting started 21 | :titlesonly: 22 | 23 | user/overview 24 | user/agents 25 | user/components 26 | faq 27 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | myst-parser 2 | markdown-callouts 3 | 4 | -------------------------------------------------------------------------------- /docs/user/diagrams/actor_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/diagrams/actor_loop.png -------------------------------------------------------------------------------- /docs/user/diagrams/agent_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/diagrams/agent_loop.png -------------------------------------------------------------------------------- /docs/user/diagrams/batch_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/diagrams/batch_loop.png -------------------------------------------------------------------------------- /docs/user/diagrams/distributed_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/diagrams/distributed_loop.png -------------------------------------------------------------------------------- /docs/user/diagrams/environment_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/diagrams/environment_loop.png -------------------------------------------------------------------------------- /docs/user/logos/jax-small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/logos/jax-small.png -------------------------------------------------------------------------------- /docs/user/logos/tf-small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/acme/0be0d4f9e27f06de0eab482e3280be54f792ae15/docs/user/logos/tf-small.png -------------------------------------------------------------------------------- /examples/baselines/README.md: -------------------------------------------------------------------------------- 1 | # Acme baseline examples 2 | 3 | This directory contains launcher scripts for the baselines referenced in [the paper](https://arxiv.org/abs/2006.00979). 4 | These scripts reproduce the plots given in the paper. 5 | 6 | ## How to run 7 | 8 | A command line for running the SAC baseline example in distributed mode on the Hopper environment limited to 100k environment steps: 9 | ``` 10 | cd examples/baselines/rl_continuous 11 | python run_sac.py --run_distributed=True --env_name=gym:Hopper-v2 --num_steps=100_000 12 | ``` 13 | -------------------------------------------------------------------------------- /examples/baselines/imitation/README.md: -------------------------------------------------------------------------------- 1 | # imitation 2 | 3 | This folder contains run scripts for imitation learning algorithms. This setup 4 | consists in imitating expert traces that are provided in a dataset, with no 5 | reward function. 6 | -------------------------------------------------------------------------------- /examples/baselines/offline_rl/README.md: -------------------------------------------------------------------------------- 1 | # offline_rl 2 | 3 | Coming soon... 4 | 5 | This folder contains run script for offline RL algorithms. This setup 6 | consists in maximizing the performance of an algorithm on an environment, 7 | without actually interacting with it. The only used information is the one of a 8 | dataset of logged interactions that contain the reward function information. 9 | -------------------------------------------------------------------------------- /examples/baselines/rl_continuous/README.md: -------------------------------------------------------------------------------- 1 | # Baselines for Continuous-action RL Algorithms 2 | 3 | This folder contains run scripts which reproduce published baselines for 4 | various RL algorithms focusing on continuous action spaces. 5 | 6 | For consistency, the architecture of any network used within these baselines 7 | (e.g. the policy or critic) should be comparable in size and complexity to 8 | an MLP([256] * 3). 9 | -------------------------------------------------------------------------------- /examples/baselines/rl_discrete/README.md: -------------------------------------------------------------------------------- 1 | # Baselines for Discrete Action RL Algorithms 2 | 3 | This folder contains run scripts which reproduce published baselines for 4 | various discrete RL algorithms. 5 | 6 | For consistency, the architecture of any network used within these baselines 7 | (e.g. the policy or critic) should be comparable in size and complexity. This 8 | is equivalent to a small resnet, but for more details see `helpers.py`. 9 | -------------------------------------------------------------------------------- /examples/baselines/rlfd/README.md: -------------------------------------------------------------------------------- 1 | # rlfd 2 | 3 | Coming soon... 4 | 5 | This folder contains run script for RL from demonstrations algorithms. This 6 | setup consists in maximizing the performance of an algorithm in an environment, 7 | making the most of both environment interactions and demonstrations provided as 8 | an external dataset, containing the reward information. 9 | --------------------------------------------------------------------------------