├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── artifacts └── montezuma_base │ ├── 20240425-115850 │ └── logs │ │ └── evaluator │ │ └── logs.csv │ ├── 20240425-115853 │ └── logs │ │ └── learner │ │ └── logs.csv │ ├── 20240506-133311 │ └── logs │ │ └── evaluator │ │ └── logs.csv │ ├── 20240506-133314 │ └── logs │ │ └── learner │ │ └── logs.csv │ ├── 20240506-135934 │ └── logs │ │ └── evaluator │ │ └── logs.csv │ ├── 20240506-135937 │ └── logs │ │ └── learner │ │ └── logs.csv │ └── config.txt ├── compose.yaml ├── docker-configurations ├── local-docker-images.md ├── python3.10 │ └── Dockerfile └── python3.7 │ ├── Dockerfile │ ├── compose.yaml │ ├── post-install.sh │ └── requirements.txt ├── docs ├── DRLearner_notes.md ├── atari_pong.md ├── aws-setup.md ├── debug_and_monitor.md ├── docker.md ├── img │ ├── lunar_lander.png │ ├── notebook-instance.png │ ├── tensorboard.png │ └── wandb.png ├── unity.md └── vertexai.md ├── drlearner ├── __init__.py ├── configs │ ├── config_atari.py │ ├── config_discomaze.py │ ├── config_lunar_lander.py │ └── resources │ │ ├── __init__.py │ │ ├── atari.py │ │ ├── local_resources.py │ │ └── toy_env.py ├── core │ ├── __init__.py │ ├── distributed_layout.py │ ├── environment_loop.py │ ├── local_layout.py │ ├── loggers │ │ ├── __init__.py │ │ └── image.py │ └── observers │ │ ├── __init__.py │ │ ├── action_dist.py │ │ ├── actions.py │ │ ├── discomaze_unique_states.py │ │ ├── distillation_coef.py │ │ ├── intrinsic_reward.py │ │ ├── lazy_dict.py │ │ ├── meta_controller.py │ │ └── video.py ├── drlearner │ ├── __init__.py │ ├── actor.py │ ├── actor_core.py │ ├── agent.py │ ├── builder.py │ ├── config.py │ ├── distributed_agent.py │ ├── drlearner_types.py │ ├── learning.py │ ├── lifelong_curiosity.py │ ├── networks │ │ ├── __init__.py │ │ ├── distillation_network.py │ │ ├── embedding_network.py │ │ ├── networks.py │ │ ├── networks_zoo │ │ │ ├── __init__.py │ │ │ ├── atari.py │ │ │ ├── discomaze.py │ │ │ └── lunar_lander.py │ │ ├── policy_networks.py │ │ ├── uvfa_network.py │ │ └── uvfa_torso.py │ └── utils.py ├── environments │ ├── __init__.py │ ├── atari.py │ ├── disco_maze.py │ └── lunar_lander.py └── utils │ ├── __init__.py │ ├── stats.py │ └── utils.py ├── examples ├── distrun_atari.py ├── distrun_discomaze.py ├── distrun_lunar_lander.py ├── play_atari.py ├── run_atari.py ├── run_discomaze.py └── run_lunar_lander.py ├── external ├── vertex.py └── xm_docker.py ├── my_process_entry.py ├── requirements.txt └── scripts └── update_tb.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | **/__pycache__/ 3 | 4 | venv/ 5 | 6 | checkpoints/ 7 | scratch/ 8 | experiments/ 9 | 10 | *.json 11 | roms/ 12 | wandb/ 13 | .env 14 | .ipynb_checkpoints -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | ## Basic dependencies. 3 | ADD . /app 4 | 5 | WORKDIR /app 6 | 7 | ### Installing dependencies. 8 | RUN apt-get update \ 9 | && apt-get install -y --no-install-recommends \ 10 | build-essential \ 11 | curl \ 12 | wget \ 13 | xvfb \ 14 | ffmpeg \ 15 | xorg-dev \ 16 | libsdl2-dev \ 17 | swig \ 18 | cmake \ 19 | unar \ 20 | libpython3.10 \ 21 | tmux 22 | 23 | # Conda environment 24 | 25 | ENV CONDA_DIR /opt/conda 26 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 27 | /bin/bash ~/miniconda.sh -b -p /opt/conda 28 | 29 | # # Put conda in path so we can use conda activate 30 | ENV PATH=$CONDA_DIR/bin:$PATH 31 | 32 | # SHELL ["/bin/bash", "-c"] 33 | RUN conda create --name drlearner python=3.10 -y 34 | RUN python --version 35 | RUN echo "source activate drlearner" > ~/.bashrc 36 | ENV PATH /opt/conda/envs/drlearner/bin:$PATH 37 | 38 | # RUN conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib 39 | # RUN conda env config vars set PYTHONPATH=$PYTHONPATH:$(pwd) 40 | 41 | # Install dependencies (some of them are old + maybe there is need to check support of cuda) 42 | RUN python3.10 -m pip install --upgrade pip 43 | RUN python3.10 -m pip install --no-cache-dir -r requirements.txt 44 | RUN python3.10 -m pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable 45 | RUN conda install conda-forge::ffmpeg 46 | 47 | # RUN pip install git+https://github.com/google-deepmind/acme.git@4c6351ef8ff3f4045a9a24bee6a994667d89c69c 48 | 49 | 50 | # RUN conda install -c conda-forge cudatoolkit=11.2.2 cudnn=8.1.0 51 | 52 | # Get binaries for Atari games 53 | RUN wget http://www.atarimania.com/roms/Roms.rar 54 | RUN unar Roms.rar 55 | RUN mv Roms roms 56 | RUN ale-import-roms roms/ 57 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib 58 | ENV PYTHONPATH=$PYTHONPATH:$(pwd) 59 | ENV XLA_PYTHON_CLIENT_PREALLOCATE='0' 60 | 61 | # ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib 62 | RUN chmod -R 777 ./ 63 | 64 | 65 | # CMD ["python3" ,"examples/run_lunar_lander.py"] 66 | CMD ["/bin/bash"] 67 | 68 | # CMD ["/bin/bash python3", "examples/run_atari.py --level PongNoFrameskip-v4 --num_episodes 1000 --exp_path experiments/test_pong/"] 69 | 70 | # sudo docker build -t rdlearner:latest . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DRLearner 3 | Open Source Deep Reinforcement Learning (DRL) library, based on Agent 57 (Badia et al, 2020). 4 | We recommend reading this documentation [page](docs/DRLearner_notes.md) to get the essence of DRLearner. 5 | 6 | # Table of contents 7 | - [DRLearner](#drlearner) 8 | - [Table of content](#table-of-content) 9 | - [System Requirements](#system-requirements) 10 | - [Installation](#installation) 11 | - [Running DRLearner Agent](#running-drlearner-agent) 12 | - [Documentation](#documentation) 13 | - [Ongoing Support](#ongoing-support) 14 | 15 | 16 | ## System Requirements 17 | 18 | Hardware and cloud infrastructure used for DRLearner testing are listed below. For more information on specific configurations for running experiments, see GCP Hardware Specs and Running Experiments at the bottom of this document. 19 | 20 | | Google Cloud Configuration | Local Configuration | 21 | | --- | --- | 22 | | (GCP) | (Local) | 23 | | Tested on Ubuntu 20.4 with Python3.7 | Tested on Ubuntu 22.04 with python3.10 | 24 | | Hardware: NVIDIA Tesla, 500 Gb drive | Hardware: 8-core i7 | 25 | 26 | Depending on exact OS and hardware, packages such as git, Python3.7, Anaconda/Miniconda or gcc. 27 | 28 | ## Installation 29 | 30 | We recommend [Docker-based](docs/docker.md) installation, however for installation from scratch follow the instructions: 31 | 32 | 33 | Clone the repo 34 | ``` 35 | git clone https://github.com/PatternsandPredictions/DRLearner_beta.git 36 | cd DRLearner_beta/ 37 | ``` 38 | 39 | Install xvfb for virtual display 40 | ``` 41 | sudo apt-get update 42 | sudo apt-get install xvfb 43 | ``` 44 | 45 | ### Creating environment 46 | 47 | #### Conda 48 | 49 | Restarting enviroment after creating and activating it is recommended to make sure that enviromental variables got updated. 50 | ``` 51 | sudo apt-get update 52 | sudo apt-get install libpython3.10 ffmpeg swig 53 | conda create --name drlearner python=3.10 54 | conda activate drlearner 55 | 56 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib 57 | export PYTHONPATH=$PYTHONPATH:$(pwd) 58 | conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib 59 | conda env config vars set PYTHONPATH=$PYTHONPATH:$(pwd) 60 | ``` 61 | 62 | Install packages 63 | ``` 64 | pip install --no-cache-dir -r requirements.txt 65 | pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable 66 | ``` 67 | 68 | #### Venv 69 | ``` 70 | sudo apt-get update 71 | sudo apt-get install libpython3.10 swig ffmpeg -y 72 | python3.10 -m venv venv 73 | source venv/bin/activate 74 | 75 | export PYTHONPATH=$PYTHONPATH:$(pwd) 76 | ``` 77 | 78 | Install packages 79 | ``` 80 | pip install --no-cache-dir -r requirements.txt 81 | pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable 82 | ``` 83 | 84 | ### Binary files for Atari games 85 | ``` 86 | sudo apt-get install unrar 87 | wget http://www.atarimania.com/roms/Roms.rar 88 | unrar e Roms.rar roms/ 89 | ale-import-roms roms/ 90 | 91 | ``` 92 | 93 | ## Running DRLearner Agent 94 | 95 | DRLearner comes with the following available environments: 96 | - Lunar Lander: 97 | - [Config](drlearner/configs/config_lunar_lander.py) 98 | - [Synchronous Agent](examples/run_lunar_lander.py) 99 | - [Asynchronos Agent](examples/distrun_lunar_lander.py) 100 | - Atari: 101 | - [Config](drlearner/configs/config_atari.py) 102 | - [Synchronous Agent](examples/run_atari.py) 103 | - [Asynchronos Agent](examples/distrun_atari.py) 104 | - [Example](docs/atari_pong.md) 105 | - Disco Maze 106 | - [Config](drlearner/configs/config_discomaze.py) 107 | - [Synchronous Agent](examples/run_discomaze.py) 108 | - [Asynchronos Agent](examples/distrun_discomaze.py) 109 | 110 | ### Lunar Lander example 111 | 112 | #### Training 113 | ``` 114 | python ./examples/run_lunar_lander.py --num_episodes 1000 --exp_path experiments/test_pong/ --exp_name my_first_experiment 115 | ``` 116 | Correct terminal output like this means that the training has been launched successfully: 117 | 118 | `[Enviroment] Mean Distillation Alpha = 1.000 | Action Mean Time = 0.027 | Env Step Mean Time = 0.000 | Episode Length = 63 | Episode Return = -453.10748291015625 | Episodes = 1 | Intrinsic Rewards Mean = 2.422 | Intrinsic Rewards Sum = 155.000 | Observe Mean Time = 0.014 | Steps = 63 | Steps Per Second = 15.544 119 | [Actor] Idm Accuracy = 0.12812499701976776 | Idm Loss = 1.4282478094100952 | Rnd Loss = 0.07360860705375671 | Extrinsic Uvfa Loss = 36.87723159790039 | Intrinsic Uvfa Loss = 19.602252960205078 | Steps = 1 | Time Elapsed = 65.282 120 | ` 121 | 122 | To specify which directory to save changes in please specify exp_path. If model already exists in exp_path it will be loaded and training will resume. 123 | To name experiment in W&B please specify exp_name flag. 124 | 125 | #### Observing Lunar Lander in action 126 | To visualize any enviroment all you have to do is pass an instance of StorageVideoObserver to the enviroment. You pass and instance of DRLearnerConfig to the observer. In the config you can define 127 | 128 | ``` 129 | observers = [IntrinsicRewardObserver(), DistillationCoefObserver(),StorageVideoObserver(config)] 130 | loop = EnvironmentLoop(env, agent, logger=logger_env, observers=observers) 131 | loop.run(FLAGS.num_episodes) 132 | ``` 133 | ![Alt text](docs/img/lunar_lander.png) 134 | 135 | 136 | #### Training with checkpoints (Montezuma) 137 | 138 | Model will pick up from the moment it stopped in the previous training. Montezuma is the most difficult game so make sure you have enough computational power. Total number of actors is defined as number_of_actors_per_mixture*num_mixtures. If you will try to run too many actors your setup might break. If you have 16 cores of CPU we advice aroud 12 actors total. 139 | 140 | ``` 141 | python ./examples/distrun_atari.py --exp_path artifacts/montezuma_base --exp_name montezuma_training 142 | ``` 143 | 144 | More examples of synchronous and distributed agents training within the environments can be found in `examples/` . 145 | 146 | ## Documentation 147 | - [Debugging and monitoring](docs/debug_and_monitor.md) 148 | - [Docker installation](docs/docker.md) 149 | - [Apptainer on Unity cluster](docs/unity.md) 150 | - [Running on Vertex AI](docs/vertexai.md) 151 | - [Running on AWS](docs/aws-setup.md) 152 | 153 | ## Ongoing Support 154 | 155 | Join the [DRLearner Developers List](https://groups.google.com/g/drlearner?pli=10). 156 | 157 | -------------------------------------------------------------------------------- /artifacts/montezuma_base/20240506-133311/logs/evaluator/logs.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/artifacts/montezuma_base/20240506-133311/logs/evaluator/logs.csv -------------------------------------------------------------------------------- /artifacts/montezuma_base/20240506-133314/logs/learner/logs.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/artifacts/montezuma_base/20240506-133314/logs/learner/logs.csv -------------------------------------------------------------------------------- /artifacts/montezuma_base/config.txt: -------------------------------------------------------------------------------- 1 | DRLearnerConfig(gamma_min=0.99, gamma_max=0.997, num_mixtures=32, target_update_period=400, evaluation_epsilon=0.01, epsilon=0.01, actor_epsilon=0.4, target_epsilon=0.01, variable_update_period=800, retrace_lambda=0.95, burn_in_length=0, trace_length=80, sequence_period=40, num_sgd_steps_per_step=1, uvfa_learning_rate=0.0001, idm_learning_rate=0.0005, distillation_learning_rate=0.0005, idm_weight_decay=1e-05, distillation_weight_decay=1e-05, idm_clip_steps=5, distillation_clip_steps=5, clip_rewards=True, max_absolute_reward=1.0, tx_pair=TxPair(apply=, apply_inv=), distillation_moving_average_coef=0.001, beta_min=0.0, beta_max=0.3, observation_embed_dim=32, episodic_memory_num_neighbors=10, episodic_memory_max_size=1500, episodic_memory_max_similarity=8.0, episodic_memory_cluster_distance=0.008, episodic_memory_pseudo_counts=0.001, episodic_memory_epsilon=0.0001, distillation_embed_dim=128, max_lifelong_modulation=5.0, samples_per_insert_tolerance_rate=0.5, samples_per_insert=2.0, min_replay_size=6250, max_replay_size=100000, batch_size=64, prefetch_size=1, num_parallel_calls=16, replay_table_name='priority_table', importance_sampling_exponent=0.6, priority_exponent=0.9, max_priority_weight=0.9, window=160, actor_window=160, evaluation_window=3600, n_arms=32, mc_epsilon=0.5, actor_mc_epsilon=0.3, evaluation_mc_epsilon=0.01, mc_beta=1.0, env_library='gym', video_log_period=50, actions_log_period=1, logs_dir='experiments/ll/', num_episodes=50) -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | drlearner: 3 | build: . 4 | volumes: 5 | - .:/app 6 | stdin_open: true # docker run -i 7 | tty: true # docker run -t 8 | env_file: 9 | - .env 10 | -------------------------------------------------------------------------------- /docker-configurations/local-docker-images.md: -------------------------------------------------------------------------------- 1 | # Docker image with tested GPU support. 2 | 3 | Present docker image contains dockerfile, updated docker-compose and post-installation script for docker-based setup. 4 | 5 | ## Prerequisites 6 | *****This setup was tested on Linix host machine only.** - most of the cloud setups will have variations of linux host. 7 | 8 | In order to operate successfully with docker image, user should have local NVIDIA drivers installed as well as NVIDIA CUDA toolkit. 9 | Both installations from ubuntu repo and according to NVIDIA instructions. 10 | 11 | ### Useful references: 12 | 1. [Download Nvidia drivers](https://www.nvidia.com/download/index.aspx) 13 | 2. [Installation guide for CUDA on linux host](https://www.cherryservers.com/blog/install-cuda-ubuntu) 14 | 3. [Official CUDA toolkit installation guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) - tested both installations on Ubuntu, AWS Linux AMI and installation with conda. 15 | 16 | ### Usage 17 | 1. Once Nvidia drivers, CUDA toolkit and nvcc are installed, installation can be validatied running the next commands. 18 | ```nvcc --version``` 19 | ```nvidia-smi``` 20 | 21 | 2. Copy Dockerfile, compose.yaml and post-install from current directory to root project directory. 22 | 3. Run ```docker-compose build --no-cache && docker-compose up -d``` 23 | 4. Once system is fully built, enter the container, and check state of installed nvidia libs, using commands above. 24 | 5. Change permissions to post-install.sh files to be executable if needed. 25 | 6. Inside of container, run `./post-install.sh` to install the post-installation requirements. 26 | 7. Run any of examples. 27 | 28 | -------------------------------------------------------------------------------- /docker-configurations/python3.10/Dockerfile: -------------------------------------------------------------------------------- 1 | # Start from a CUDA development image 2 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 3 | 4 | # Make non-interactive environment. 5 | ENV DEBIAN_FRONTEND noninteractive 6 | 7 | ## Installing dependencies. 8 | RUN apt-get update -y \ 9 | && apt-get install -y --no-install-recommends \ 10 | build-essential \ 11 | python3.10 \ 12 | python3.10-dev \ 13 | python3-pip \ 14 | curl \ 15 | wget \ 16 | xvfb \ 17 | ffmpeg \ 18 | xorg-dev \ 19 | libsdl2-dev \ 20 | swig \ 21 | cmake \ 22 | git \ 23 | unar \ 24 | libpython3.10 \ 25 | zlib1g-dev \ 26 | tmux \ 27 | && rm -rf /var/lib/apt/lists/* 28 | 29 | ## Workdir 30 | ADD . /app 31 | WORKDIR /app 32 | # Library paths. 33 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib/:$(pwd) 34 | ENV PYTHONPATH=$PYTHONPATH:$(pwd) 35 | 36 | # Update pip to the latest version & install packages. 37 | RUN python3.10 -m pip install --upgrade pip 38 | RUN python3.10 -m pip install jax==0.4.3 39 | RUN python3.10 -m pip install jaxlib==0.4.3+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 40 | RUN python3.10 -m pip install --no-cache-dir -r requirements.txt 41 | RUN python3.10 -m pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable 42 | 43 | # Atari games. 44 | RUN wget http://www.atarimania.com/roms/Roms.rar 45 | RUN unar Roms.rar 46 | RUN mv Roms roms 47 | RUN ale-import-roms roms/ 48 | 49 | RUN chmod +x ./ 50 | 51 | CMD ["/bin/bash"] 52 | -------------------------------------------------------------------------------- /docker-configurations/python3.7/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.1.1-base-ubuntu20.04 2 | # FROM python:3.7 3 | ## Basic dependencies. 4 | ADD . /app 5 | 6 | WORKDIR /app 7 | 8 | RUN ln -snf /usr/share/zoneinfo/$CONTAINER_TIMEZONE /etc/localtime && echo $CONTAINER_TIMEZONE > /etc/timezone 9 | 10 | ### Installing dependencies. 11 | RUN \ 12 | --mount=type=cache,target=/var/cache/apt \ 13 | apt-get update \ 14 | && apt-get install -y --no-install-recommends \ 15 | build-essential \ 16 | curl \ 17 | wget \ 18 | xvfb \ 19 | ffmpeg \ 20 | xorg-dev \ 21 | libsdl2-dev \ 22 | swig \ 23 | cmake \ 24 | git \ 25 | unar \ 26 | libpython3.7 27 | 28 | # Conda environment 29 | ENV CONDA_DIR /opt/conda 30 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 31 | /bin/bash ~/miniconda.sh -b -p /opt/conda 32 | 33 | # # Put conda in path so we can use conda activate 34 | ENV PATH=$CONDA_DIR/bin:$PATH 35 | 36 | # SHELL ["/bin/bash", "-c"] 37 | RUN conda create --name drlearner python=3.7 -y 38 | RUN python --version 39 | RUN echo "source activate drlearner" > ~/.bashrc 40 | ENV PATH /opt/conda/envs/env/bin:$PATH 41 | 42 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib 43 | 44 | ENV PYTHONPATH=$PYTHONPATH:$(pwd) 45 | RUN conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib 46 | RUN conda env config vars set PYTHONPATH=$PYTHONPATH:$(pwd) 47 | RUN conda install nvidia/label/cuda-11.3.1::cuda-nvcc -y 48 | RUN conda install -c conda-forge cudatoolkit=11.3.1 cudnn=8.2 -y 49 | RUN conda install -c anaconda git 50 | 51 | RUN chmod +x ./ 52 | 53 | CMD ["/bin/bash"] 54 | 55 | 56 | -------------------------------------------------------------------------------- /docker-configurations/python3.7/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | drlearner: 3 | build: . 4 | volumes: 5 | - .:/app 6 | stdin_open: true # docker run -i 7 | tty: true # docker run -t 8 | deploy: 9 | resources: 10 | reservations: 11 | devices: 12 | - driver: nvidia 13 | count: 1 14 | capabilities: [gpu] 15 | -------------------------------------------------------------------------------- /docker-configurations/python3.7/post-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "This is post-installation script" 4 | 5 | pip install pip==21.3 6 | pip install jax==0.3.7 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 7 | pip install jaxlib==0.3.7+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 8 | 9 | pip install setuptools==65.5.0 10 | pip install wheel==0.38.0 11 | pip install git+https://github.com/horus95/lazydict.git 12 | pip install --no-cache-dir -r requirements.txt 13 | pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable 14 | 15 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib 16 | echo $LD_LIBRARY_PATH 17 | 18 | wget http://www.atarimania.com/roms/Roms.rar 19 | unar Roms.rar 20 | mv Roms roms 21 | ale-import-roms roms/ 22 | -------------------------------------------------------------------------------- /docker-configurations/python3.7/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | ale_py==0.7.5 3 | numpy==1.21.5 4 | cloudpickle==2.0.0 5 | six==1.16.0 6 | dm-acme==0.4.0 7 | libpython==0.2 8 | dm-acme[tf] 9 | chex==0.1.3 10 | Cython==0.29.28 11 | flax==0.4.1 12 | optax==0.1.2 13 | rlax==0.1.2 14 | pyglet==1.5.24 15 | #jax==0.3.15 # 0.4.10 # 0.4.13 # 0.3.7 16 | #jaxlib==0.3.15 # 0.4.10 # 0.4.13 # 0.3.7 pip install --upgrade jaxlib==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html 17 | dm-haiku==0.0.5 18 | dm-acme[reverb] 19 | #gym[accept-rom-license, atari, Box2D] # ==0.21.0 20 | xmanager==0.1.5 21 | pyvirtualdisplay==3.0 22 | #lazydict==1.0.0b2 # pip install git+https://github.com/horus95/lazydict 23 | sk-video==1.1.10 24 | ffmpeg-python==0.2.0 25 | wandb==0.16.2 26 | ##pip install --upgrade ml_dtypes==0.2.0 27 | -------------------------------------------------------------------------------- /docs/atari_pong.md: -------------------------------------------------------------------------------- 1 | # Playing Pong on Atari: Your best Pong score ever 2 | 3 | ``` 4 | python ./examples/run_atari.py --level PongNoFrameskip-v4 --num_episodes 1000 --exp_path experiments/test_pong/ --exp_name test_pong 5 | ``` 6 | Correct terminal output like this means that the training has been launched successfully: 7 | 8 | `[Learner] Action Mean Time = 0.015 | Env Step Mean Time = 0.005 | Episode Length = 825 | Episode Return = -21.0 | Episodes = 1 | Observe Mean Time = 0.016 | Steps = 825 | Steps Per Second = 24.269` 9 | 10 | Training the model may take up to several hours to run, depending on configuration. 11 | -------------------------------------------------------------------------------- /docs/aws-setup.md: -------------------------------------------------------------------------------- 1 | # AWS installation 2 | 3 | Current reposiory contains pre-built docker images to run in any cloud / on-premise platform. 4 | This is the recommended way to run in in destineed containers, as they are compatible and tested in GPU and CPU setups, 5 | and they are a basis for containerized distributed scheme. 6 | 7 | In the given file you will find installation instructions to run in [Amazon SageMaker](https://aws.amazon.com/sagemaker/), but they are applicable to according EC2 instances. 8 | 9 | ## Pre-requisites 10 | 11 | 1. Familiriality with AWS cloud is assumed. 12 | 2. Root or IAM account is configured. 13 | 3. *****Disclaimer: AWS is a paid service, and any computations imply costs.** 14 | 4. Navigate to the [console](https://us-east-1.console.aws.amazon.com/console/home?region=us-east-1#) for your selected region. 15 | 5. Create or run your [SageMaker](https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/notebook-instances) instance 16 | 6. Open jupyter lab 17 | 7. Upload your files. 18 | 8. Click Terminal among available options. Validate `nvidia-smi` to make sure that drivers are successfully installed. 19 | 9. Run docker compose build && docker compose up -d according to instructions. 20 | 21 | 22 | Alternatively one may try to setup the appropriate image to EC2 together with drivers, and install application as per guide. 23 | 24 | # CUDA ON EC2 FROM SCRATCH 25 | This instruction helps to set up Pytorch with CUDA on an EC2 instance with plain, Ubuntu AMI. 26 | 27 | ## Pre-installation actions 28 | 1) Verify the instance has the CUDA-capable GPU 29 | ``` 30 | lspci | grep -i nvidia 31 | ``` 32 | 33 | 2) Install kernel headers and development packages 34 | ``` 35 | sudo apt-get install linux-headers-$(uname -r) 36 | ``` 37 | 38 | ## NVIDIA drivers installation 39 | 1) Download a CUDA keyring for your distribution $distro and architecture $arch 40 | ``` 41 | wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.1-1_all.deb 42 | ``` 43 | i.e. for Ubuntu 22.04 with x86_64 the command would look as follows: 44 | ``` 45 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb 46 | ``` 47 | 2) Add the downloaded keyring package 48 | ``` 49 | sudo dpkg -i cuda-keyring_1.1-1_all.deb 50 | ``` 51 | 3) Update the APT repository cache 52 | ``` 53 | sudo apt-get update 54 | ``` 55 | 4) Install the drivers 56 | ``` 57 | sudo apt-get -y install cuda-drivers 58 | ``` 59 | 5) Reboot the instance 60 | ``` 61 | sudo reboot 62 | ``` 63 | 6) Verify the installation 64 | ``` 65 | nvidia-smi 66 | ``` 67 | It is important to keep in mind CUDA Version is displayed in the upper-right corner, as PyTorch needs to be compatible with it. 68 | 69 | **NOTE:** At this stage NVIDIA recommends following [Post-installation actions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#environment-setup). I didn't and it worked but some unexpected errors might occur. 70 | ## PyTorch installation 71 | 72 | ### Install package manager 73 | I used conda but pip+venv *should* also work 74 | 1) Install conda 75 | ``` 76 | mkdir -p ~/miniconda3 77 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 78 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 79 | rm -rf ~/miniconda3/miniconda.sh 80 | ``` 81 | 2) Initialize conda 82 | ``` 83 | ~/miniconda3/bin/conda init bash 84 | ``` 85 | 3) Reload bash 86 | ``` 87 | source ~/.bashrc 88 | ``` 89 | 4) Create a new conda environment 90 | ``` 91 | conda create -n env 92 | ``` 93 | 5) Activate the newly created environment 94 | ``` 95 | conda activate env 96 | ``` 97 | -------------------------------------------------------------------------------- /docs/debug_and_monitor.md: -------------------------------------------------------------------------------- 1 | # Debuging and monitoring 2 | Our framwork supports several loggers to help monitor and debug trained agents. All checkpoints will be saved in the workdir directory passed to the agent. 3 | Currently we log 18 different parameters. 7 for the actor and 11 for the enviroment but those can be easily extended by the user. 4 | Actor's parameters: 5 | 6 | ## Terminal logger 7 | Terminal logger logs progress to standard output. To use it please pass an instance of TerminalLogger to agent and enviroment. Logs aren't saved anywhere. 8 | 9 | ## Tensorboard logger 10 | This framework supports standard tensorboard logger. To use it please pass an instance of TFSummaryLogger to agent and enviroment. Logs will be saved in the workdir directory passed to the TFSummaryLogger logger. To visualize the logs please run code snipet below. 11 | ``` 12 | tensorboard --logdir 13 | 14 | ``` 15 | ![Alt text](img/tensorboard.png) 16 | 17 | ## CSV Logger 18 | This is standar csv logger. To use it please pass an instance of CSVLogger (or CloudCSVLogger when running on vertex) to agent and enviroment. Logs will be saved in the workdir directory passed to the logger. 19 | 20 | ## Weights and biases 21 | Our framework support [W&B](https://wandb.ai/) logger. It is installed along with the python requirements.txt. To use it please use code snipet below. 22 | ``` 23 | from drlearner.utils.utils import make_wandb_logger 24 | ``` 25 | or directly 26 | ``` 27 | from drlearner.utils.utils import WandbLogger 28 | ``` 29 | Set WANDB_API_KEY enviromental variable to your personal api key. If you are using doker/compose you can set your enviromental variable in the ".env" file. You can find api key on your W&B profile > user settings > Danger zone > reveal api key. 30 | Logs will be saved locally to /wandb directory and on your W&B account in the cloud. 31 | ![Alt text](img/wandb.png) 32 | 33 | ### Combining loggers 34 | To use more than one logger please see code snipet below 35 | ``` 36 | wandb_logger=WandbLogger(logdir=tb_workdir, label=label,hyperparams=hyperparams) 37 | tensorboard_logger=TFSummaryLogger(logdir=tb_workdir, label=label) 38 | terminal_logger=loggers.terminal.TerminalLogger(label=label, print_fn=print_fn) 39 | 40 | all_loggers = [wandb_logger,tensorboard_logger,terminal_logger] 41 | 42 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn) 43 | logger = loggers.filters.NoneFilter(logger) 44 | ``` 45 | -------------------------------------------------------------------------------- /docs/docker.md: -------------------------------------------------------------------------------- 1 | # Running in Docker 2 | 3 | Clone the repo 4 | ``` 5 | git clone https://github.com/PatternsandPredictions/DRLearner_beta.git 6 | cd DRLearner_beta/ 7 | ``` 8 | 9 | Install Docker (if not already installed) and Docker Compose (optional) 10 | ``` 11 | https://docs.docker.com/desktop/install/linux-install/ 12 | https://docs.docker.com/compose/install/linux/ 13 | ``` 14 | 15 | 1. Use Dockerfile directly 16 | ``` 17 | docker build -t drlearner:latest . 18 | docker run -it --name drlearner -d drlearner:latest 19 | ``` 20 | 2. Use Docker compose 21 | ``` 22 | docker compose up 23 | ``` 24 | 25 | Now you can attach yourself to the docker container to play with it. 26 | ``` 27 | docker exec -it drlearner bash 28 | ``` 29 | ## Dockerfile 30 | Using python image, setting "/app" as workdir and running essential linux dependecies 31 | ``` 32 | FROM python:3.10 33 | ADD . /app 34 | WORKDIR /app 35 | 36 | RUN apt-get update \ 37 | && apt-get install -y --no-install-recommends \ 38 | build-essential \ 39 | curl \ 40 | wget \ 41 | xvfb \ 42 | ffmpeg \ 43 | xorg-dev \ 44 | libsdl2-dev \ 45 | swig \ 46 | cmake \ 47 | unar \ 48 | libpython3.10 \ 49 | tmux 50 | ``` 51 | Downloading conda and creating enviroment. 52 | ``` 53 | ENV CONDA_DIR /opt/conda 54 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 55 | /bin/bash ~/miniconda.sh -b -p /opt/conda 56 | 57 | RUN conda create --name drlearner python=3.10 -y 58 | RUN python --version 59 | RUN echo "source activate drlearner" > ~/.bashrc 60 | ENV PATH /opt/conda/envs/drlearner/bin:$PATH 61 | ``` 62 | 63 | Installing requirements for python and downloading game roms. 64 | ``` 65 | RUN pip install --upgrade pip 66 | RUN pip install --no-cache-dir -r requirements.txt 67 | RUN pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable 68 | 69 | # Get binaries for Atari games 70 | RUN wget http://www.atarimania.com/roms/Roms.rar 71 | RUN unar Roms.rar 72 | RUN mv Roms roms 73 | RUN ale-import-roms roms/ 74 | ``` 75 | 76 | Setting up enviromental variables and changing acces mode for all files. 77 | ``` 78 | ENV PYTHONPATH=$PYTHONPATH:$(pwd) 79 | RUN chmod -R 777 ./ 80 | ``` 81 | 82 | Default command for running container. Here you can modify it or simply run with "CMD ["/bin/bash"]" and attach yourself to container to run commands directly. 83 | ``` 84 | CMD ["python3" ,"examples/run_atari.py", "--level","PongNoFrameskip-v4", "--num_episodes", "1000", "--exp_path", "experiments/test_pong/", "--exp_name", "my_first_experiment"] 85 | or 86 | CMD ["/bin/bash"] 87 | ``` 88 | 89 | ## Docker compose 90 | Compose run one service called drlearner that is built using Dockerfile present in the main directory. Thanks to setting volumes as .:/app we don't have to rebuilt container each time we change codebase. Setting flags stdin_open and tty allows interactive mode of docker container. Thanks to that option user can attach themselves to the container and use it interactivly. 91 | 92 | ``` 93 | services: 94 | drlearner: 95 | build: . 96 | volumes: 97 | - .:/app 98 | stdin_open: true # docker run -i 99 | tty: true # docker run -t 100 | env_file: 101 | - .env 102 | ``` 103 | All the enviromental variable will be read from .env file. 104 | -------------------------------------------------------------------------------- /docs/img/lunar_lander.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/lunar_lander.png -------------------------------------------------------------------------------- /docs/img/notebook-instance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/notebook-instance.png -------------------------------------------------------------------------------- /docs/img/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/tensorboard.png -------------------------------------------------------------------------------- /docs/img/wandb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/wandb.png -------------------------------------------------------------------------------- /docs/unity.md: -------------------------------------------------------------------------------- 1 | # Running an Apptainer container on Unity 2 | 3 | Install `spython` utility to make an Apptainer definition file from Dockerfile: 4 | ``` 5 | pip install spython 6 | spython recipe Dockerfile1 > Apptainer1.def 7 | ``` 8 | Modify the definition file with the following environment settings: 9 | ``` 10 | # Make non-interactive environment. 11 | export TZ='America/New_York' 12 | export DEBIAN_FRONTEND=noninteractive 13 | ``` 14 | Build the Apptainer image (sif): 15 | ``` 16 | module load apptainer/latest 17 | unset APPTAINER_BINDPATH 18 | apptainer build --fakeroot sifs/drlearner1.sif Apptainer1.def 19 | ``` 20 | Allocate a computational node and load the required modules: 21 | ``` 22 | salloc -N 1 -n 1 -p gpu-preempt -G 1 -t 2:00:00 --constraint=a100 23 | module load cuda/11.8.0 24 | module load cudnn/8.7.0.84-11.8 25 | ``` 26 | Run the container: 27 | ``` 28 | apptainer exec --nv sifs/drlearner1.sif bash 29 | ``` 30 | Export a user's WANDB key for logging the job (for illustrative purposes only!): 31 | ``` 32 | export WANDB_API_KEY=c5180d032d5325b08df49b65f9574c8cd59af6b1 33 | ``` 34 | Run the Atari example: 35 | ``` 36 | python3.10 examples/distrun_atari.py --exp_path experiments/apptainer_test_distrun_atari --exp_name apptainer_test_distrun_atari 37 | ``` -------------------------------------------------------------------------------- /docs/vertexai.md: -------------------------------------------------------------------------------- 1 | # Running on Vertex AI 2 | 3 | ## Installation and set-up 4 | 5 | 1. (Local) Install `gcloud`. 6 | ``` 7 | sudo apt-get install apt-transport-https ca-certificates gnupg curl 8 | 9 | echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 10 | 11 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 12 | sudo apt-get update && sudo apt-get install google-cloud-sdk 13 | ``` 14 | 15 | 2. (Local) Set up GCP project. 16 | ``` 17 | gcloud init # choose the existing project or create a new one 18 | export GCP_PROJECT= 19 | echo $GCP_PROJECT # make sure it's the DRLearner project 20 | conda env config vars set GCP_PROJECT= # optional 21 | ``` 22 | 3. (Local) Authorise the use of GCP services by DRLearner. 23 | ``` 24 | gcloud auth application-default login # get credentials to allow DRLearner code calls to GC APIs 25 | export GOOGLE_APPLICATION_CREDENTIALS=/home//.config/gcloud/application_default_credentials.json 26 | conda env config vars set GOOGLE_APPLICATION_CREDENTIALS=/home//.config/gcloud/application_default_credentials.json # optional 27 | ``` 28 | 4. (Local) Install and configure Docker. 29 | ``` 30 | sudo apt-get remove docker docker-engine docker.io containerd runc 31 | sudo apt-get update && sudo apt-get install lsb-release 32 | sudo mkdir -p /etc/apt/keyrings 33 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg 34 | echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null 35 | sudo apt-get update 36 | sudo apt-get install docker-ce docker-ce-cli containerd.io 37 | 38 | sudo groupadd docker 39 | sudo usermod -aG docker 40 | 41 | gcloud auth configure-docker 42 | ``` 43 | 44 | 5. (GCP console) Enable IAM, Enable Vertex AI, Enable Container Registry in ``. 45 | 46 | 47 | 6. (GCP console) Set up a xmanager service account. 48 | - Create xmanager service account in `IAM & Admin/Service accounts` . 49 | - Add 'Storage Admin', 'Vertex AI Administrator', 'Vertex AI User' , 'Service Account User' roles. 50 | 51 | 7. Set up a Cloud storage bucket. 52 | - (GCP console) Create a Cloud storage bucket in Cloud Storage in `us-central1` region. 53 | - (Local) `export GOOGLE_CLOUD_BUCKET_NAME=` 54 | - (Local, optional) `conda env config vars set GOOGLE_CLOUD_BUCKET_NAME=` 55 | 56 | 8. (Local) Replace `envs/drlearner/lib/python3.10/site-packages/launchpad/nodes/python/xm_docker.py` with `./external/xm_docker.py` (to get the correct Docker instructions)* 57 | 58 | *Can't rebuild launchpad package with those changes because the of complicated build process (requires Bazel...) 59 | 60 | 61 | 9. (Local) Replace `envs/drlearner/lib/python3.10/site-packages/xmanager/cloud/vertex.py` with `./external/vertex.py` (to add new machine types, allow web access to nodes from GCP console). 62 | 63 | 64 | 10. (Local) Tensorboard instructions: 65 | - Use scripts/update_tb.py to download current tfevents file which is saved in `` 66 | ``` 67 | python update_tb.py / 68 | ``` 69 | ! We recommend syncing tf files regularly and keeping older versions as well, 70 | since Vertex AI silently restarts the workers which are down, 71 | and they start writing logs in tf file from scratch ! 72 | 73 | ## GCP Hardware Specs 74 | The hardware requirements for running DRLearner on Vertex AI are specified in `drlearner/configs/resources/` - there are two setups: for easy environment (i.e. Atari Boxing) and a more complex one (i.e. Atari Montezuma Revenge). See the table below. 75 | 76 | 77 | | | Simple env | Complex env | 78 | |---------------|:------------------------------------------:|---------------------------------------------:| 79 | | Actor | e2-standard-4 (4 CPU, 16 RAM) | e2-standard-4 (4 CPU, 16 RAM) | 80 | | Learner | n1-standard-4 (4 CPU, 16 RAM + TESLA P100) | n1-highmem-16 (16 CPU, 104 RAM + TESLA P100) | 81 | | Replay Buffer | e2-highmem-8 (8 CPU, 64 RAM) | e2-highmem-16 (16 CPU, 128 RAM) | 82 | 83 | New configurations can be added using the same xm_docker.DockerConfig and xm.JobRequirements classes. Available for use on Vertex AI machine types are listed here https://cloud.google.com/vertex-ai/pricing. 84 | But it might require adding the new machine names to `external/vertex.py` i.e. `'n2-standard-64': (64, 256 * xm.GiB),`. 85 | 86 | ## GCP Troubleshooting 87 | In case of any 'Permission denied' issues, go to `IAM & Admin/` in GCP console and try adding 'Service Account User' role to your User, and 88 | 'Compute Storage Admin' role to 'AI Platform Custom Code Service Agent' Service Account. 89 | 90 | ## Running experiments 91 | ``` 92 | python ./examples/distrun_atari.py --run_on_vertex --exp_path /gcs/$GOOGLE_CLOUD_BUCKET_NAME/test_pong/ --level PongNoFrameskip-v4 --num_actors_per_mixture 3 93 | ``` 94 | - add `--noxm_build_image_locally` to build Docker images with Cloud Build, otherwise it will be built locally. 95 | - number of nodes running Actor code is `--num_actors_per_mixture` x `num_mixtures` - default number of mixtures for Atari is 32 - so be careful and don't launch the full-scale experiment before testing that everything works correctly. 96 | -------------------------------------------------------------------------------- /drlearner/__init__.py: -------------------------------------------------------------------------------- 1 | from .drlearner import * 2 | from .core import * 3 | from .configs import * 4 | -------------------------------------------------------------------------------- /drlearner/configs/config_atari.py: -------------------------------------------------------------------------------- 1 | from drlearner.drlearner.config import DRLearnerConfig 2 | import rlax 3 | from acme.adders import reverb as adders_reverb 4 | 5 | AtariDRLearnerConfig = DRLearnerConfig( 6 | gamma_min=0.99, 7 | gamma_max=0.997, 8 | num_mixtures=32, 9 | target_update_period=400, 10 | evaluation_epsilon=0.01, 11 | actor_epsilon=0.4, 12 | target_epsilon=0.01, 13 | variable_update_period=800, 14 | 15 | # Learner options 16 | retrace_lambda=0.95, 17 | burn_in_length=0, 18 | trace_length=80, 19 | sequence_period=40, 20 | num_sgd_steps_per_step=1, 21 | uvfa_learning_rate=1e-4, 22 | idm_learning_rate=5e-4, 23 | distillation_learning_rate=5e-4, 24 | idm_weight_decay=1e-5, 25 | distillation_weight_decay=1e-5, 26 | idm_clip_steps=5, 27 | distillation_clip_steps=5, 28 | clip_rewards=True, 29 | max_absolute_reward=1.0, 30 | tx_pair=rlax.SIGNED_HYPERBOLIC_PAIR, 31 | distillation_moving_average_coef=1e-3, 32 | 33 | # Intrinsic reward multipliers 34 | beta_min=0., 35 | beta_max=0.3, 36 | 37 | # Embedding network options 38 | observation_embed_dim=32, 39 | episodic_memory_num_neighbors=10, 40 | episodic_memory_max_size=1500, 41 | episodic_memory_max_similarity=8., 42 | episodic_memory_cluster_distance=8e-3, 43 | episodic_memory_pseudo_counts=1e-3, 44 | episodic_memory_epsilon=1e-4, 45 | 46 | # Distillation network 47 | distillation_embed_dim=128, 48 | max_lifelong_modulation=5.0, 49 | 50 | # Replay options 51 | samples_per_insert_tolerance_rate=0.5, 52 | samples_per_insert=2., 53 | min_replay_size=6250, 54 | max_replay_size=100_000, 55 | batch_size=64, 56 | prefetch_size=1, 57 | num_parallel_calls=16, 58 | replay_table_name=adders_reverb.DEFAULT_PRIORITY_TABLE, 59 | 60 | # Priority options 61 | importance_sampling_exponent=0.6, 62 | priority_exponent=0.9, 63 | max_priority_weight=0.9, 64 | 65 | # Meta Controller options 66 | actor_window=160, 67 | evaluation_window=3600, 68 | n_arms=32, 69 | actor_mc_epsilon=0.3, 70 | evaluation_mc_epsilon=0.01, 71 | mc_beta=1., 72 | 73 | # Agent video logging options 74 | env_library='gym', 75 | video_log_period=50, 76 | actions_log_period=1, 77 | logs_dir='experiments/videos/', 78 | num_episodes=50, 79 | 80 | ) -------------------------------------------------------------------------------- /drlearner/configs/config_discomaze.py: -------------------------------------------------------------------------------- 1 | import rlax 2 | from acme.adders import reverb as adders_reverb 3 | 4 | from drlearner.drlearner.config import DRLearnerConfig 5 | 6 | DiscomazeDRLearnerConfig = DRLearnerConfig( 7 | gamma_min=0.99, 8 | gamma_max=0.99, 9 | num_mixtures=3, 10 | target_update_period=100, 11 | evaluation_epsilon=0., 12 | actor_epsilon=0.05, 13 | target_epsilon=0.0, 14 | variable_update_period=1000, 15 | 16 | # Learner options 17 | retrace_lambda=0.97, 18 | burn_in_length=0, 19 | trace_length=30, 20 | sequence_period=30, 21 | num_sgd_steps_per_step=1, 22 | uvfa_learning_rate=1e-3, 23 | idm_learning_rate=1e-3, 24 | distillation_learning_rate=1e-3, 25 | idm_weight_decay=1e-5, 26 | distillation_weight_decay=1e-5, 27 | idm_clip_steps=5, 28 | distillation_clip_steps=5, 29 | clip_rewards=False, 30 | max_absolute_reward=1.0, 31 | tx_pair=rlax.SIGNED_HYPERBOLIC_PAIR, 32 | 33 | # Intrinsic reward multipliers 34 | beta_min=0., 35 | beta_max=0.5, 36 | 37 | # Embedding network options 38 | observation_embed_dim=16, 39 | episodic_memory_num_neighbors=10, 40 | episodic_memory_max_size=5_000, 41 | episodic_memory_max_similarity=8., 42 | episodic_memory_cluster_distance=8e-3, 43 | episodic_memory_pseudo_counts=1e-3, 44 | episodic_memory_epsilon=1e-2, 45 | 46 | # Distillation network 47 | distillation_embed_dim=32, 48 | max_lifelong_modulation=5.0, 49 | 50 | # Replay options 51 | samples_per_insert_tolerance_rate=1.0, 52 | samples_per_insert=0.0, 53 | min_replay_size=1, 54 | max_replay_size=100_000, 55 | batch_size=64, 56 | prefetch_size=1, 57 | num_parallel_calls=16, 58 | replay_table_name=adders_reverb.DEFAULT_PRIORITY_TABLE, 59 | 60 | # Priority options 61 | importance_sampling_exponent=0.6, 62 | priority_exponent=0.9, 63 | max_priority_weight=0.9, 64 | 65 | # Meta Controller options 66 | actor_window=160, 67 | evaluation_window=1000, 68 | n_arms=3, 69 | actor_mc_epsilon=0.3, 70 | evaluation_mc_epsilon=0.01, 71 | mc_beta=1., 72 | 73 | # Agent video logging options 74 | env_library='discomaze', 75 | video_log_period=10, 76 | actions_log_period=1, 77 | logs_dir='experiments/emb_size_4_nn1_less_learning_steps', 78 | num_episodes=50, 79 | ) 80 | -------------------------------------------------------------------------------- /drlearner/configs/config_lunar_lander.py: -------------------------------------------------------------------------------- 1 | import rlax 2 | from acme.adders import reverb as adders_reverb 3 | 4 | from drlearner.drlearner.config import DRLearnerConfig 5 | 6 | LunarLanderDRLearnerConfig = DRLearnerConfig( 7 | gamma_min=0.99, 8 | gamma_max=0.99, 9 | num_mixtures=3, 10 | target_update_period=50, 11 | evaluation_epsilon=0.01, 12 | actor_epsilon=0.05, 13 | target_epsilon=0.01, 14 | variable_update_period=100, 15 | 16 | # Learner options 17 | retrace_lambda=0.95, 18 | burn_in_length=0, 19 | trace_length=30, 20 | sequence_period=30, 21 | num_sgd_steps_per_step=1, 22 | uvfa_learning_rate=5e-4, 23 | idm_learning_rate=5e-4, 24 | distillation_learning_rate=5e-4, 25 | idm_weight_decay=1e-5, 26 | distillation_weight_decay=1e-5, 27 | idm_clip_steps=5, 28 | distillation_clip_steps=5, 29 | clip_rewards=True, 30 | max_absolute_reward=1.0, 31 | tx_pair=rlax.SIGNED_HYPERBOLIC_PAIR, 32 | 33 | # Intrinsic reward multipliers 34 | beta_min=0., 35 | beta_max=0., 36 | 37 | # Embedding network options 38 | observation_embed_dim=1, 39 | episodic_memory_num_neighbors=1, 40 | episodic_memory_max_size=1, 41 | episodic_memory_max_similarity=8., 42 | episodic_memory_cluster_distance=8e-3, 43 | episodic_memory_pseudo_counts=1e-3, 44 | episodic_memory_epsilon=1e-2, 45 | 46 | # Distillation network 47 | distillation_embed_dim=1, 48 | max_lifelong_modulation=1, 49 | 50 | # Replay options 51 | samples_per_insert_tolerance_rate=1.0, 52 | samples_per_insert=0.0, 53 | min_replay_size=1, 54 | max_replay_size=100_000, 55 | batch_size=64, 56 | prefetch_size=1, 57 | num_parallel_calls=16, 58 | replay_table_name=adders_reverb.DEFAULT_PRIORITY_TABLE, 59 | 60 | # Priority options 61 | importance_sampling_exponent=0.6, 62 | priority_exponent=0.9, 63 | max_priority_weight=0.9, 64 | 65 | # Meta Controller options 66 | actor_window=160, 67 | evaluation_window=1000, 68 | n_arms=3, 69 | actor_mc_epsilon=0.3, 70 | evaluation_mc_epsilon=0.01, 71 | mc_beta=1., 72 | 73 | # Agent video logging options 74 | env_library='gym', 75 | video_log_period=50, 76 | actions_log_period=1, 77 | logs_dir='experiments/ll/', 78 | num_episodes=50, 79 | ) 80 | -------------------------------------------------------------------------------- /drlearner/configs/resources/__init__.py: -------------------------------------------------------------------------------- 1 | from .atari import get_vertex_resources as get_atari_vertex_resources 2 | from .toy_env import get_vertex_resources as get_toy_env_vertex_resources 3 | from .local_resources import get_local_resources 4 | -------------------------------------------------------------------------------- /drlearner/configs/resources/atari.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from launchpad.nodes.python import xm_docker 4 | from xmanager import xm 5 | import xmanager.cloud.build_image 6 | 7 | 8 | def get_vertex_resources(): 9 | resources = dict() 10 | 11 | resources['learner'] = xm_docker.DockerConfig( 12 | os.getcwd() + '/', 13 | os.getcwd() + '/requirements.txt', 14 | xm.JobRequirements(cpu=16, memory=104 * xm.GiB, P100=1) 15 | ) 16 | 17 | resources['counter'] = xm_docker.DockerConfig( 18 | os.getcwd() + '/', 19 | os.getcwd() + '/requirements.txt', 20 | xm.JobRequirements(cpu=2, memory=16 * xm.GiB) 21 | ) 22 | 23 | for node in ['actor', 'evaluator']: 24 | resources[node] = xm_docker.DockerConfig( 25 | os.getcwd() + '/', 26 | os.getcwd() + '/requirements.txt', 27 | xm.JobRequirements(cpu=4, memory=16 * xm.GiB) 28 | ) 29 | 30 | resources['replay'] = xm_docker.DockerConfig( 31 | os.getcwd() + '/', 32 | os.getcwd() + '/requirements.txt', 33 | xm.JobRequirements(cpu=16, memory=128 * xm.GiB) 34 | ) 35 | 36 | return resources 37 | -------------------------------------------------------------------------------- /drlearner/configs/resources/local_resources.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from launchpad.nodes.python import xm_docker 4 | from xmanager import xm 5 | import xmanager.cloud.build_image 6 | from launchpad.nodes.python import local_multi_processing 7 | 8 | 9 | def get_local_resources(): 10 | local_resources = dict( 11 | actor=local_multi_processing.PythonProcess( 12 | env=dict(CUDA_VISIBLE_DEVICES='-1') 13 | ), 14 | counter=local_multi_processing.PythonProcess( 15 | env=dict(CUDA_VISIBLE_DEVICES='-1') 16 | ), 17 | evaluator=local_multi_processing.PythonProcess( 18 | env=dict(CUDA_VISIBLE_DEVICES='-1') 19 | ), 20 | replay=local_multi_processing.PythonProcess( 21 | env=dict(CUDA_VISIBLE_DEVICES='-1') 22 | ), 23 | learner=local_multi_processing.PythonProcess( 24 | env=dict( 25 | # XLA_PYTHON_CLIENT_MEM_FRACTION='0.1', 26 | CUDA_VISIBLE_DEVICES='0', 27 | XLA_PYTHON_CLIENT_PREALLOCATE='0', 28 | LD_LIBRARY_PATH=os.environ.get('LD_LIBRARY_PATH', '') + ':/usr/local/cuda/lib64')) 29 | ) 30 | return local_resources -------------------------------------------------------------------------------- /drlearner/configs/resources/toy_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from launchpad.nodes.python import xm_docker 4 | from xmanager import xm 5 | import xmanager.cloud.build_image 6 | 7 | 8 | def get_vertex_resources(): 9 | resources = dict() 10 | 11 | resources['learner'] = xm_docker.DockerConfig( 12 | os.getcwd() + '/', 13 | os.getcwd() + '/requirements.txt', 14 | xm.JobRequirements(cpu=4, memory=15 * xm.GiB, P100=1) 15 | ) 16 | 17 | resources['counter'] = xm_docker.DockerConfig( 18 | os.getcwd() + '/', 19 | os.getcwd() + '/requirements.txt', 20 | xm.JobRequirements(cpu=2, memory=16 * xm.GiB) 21 | ) 22 | 23 | for node in ['actor', 'evaluator']: 24 | resources[node] = xm_docker.DockerConfig( 25 | os.getcwd() + '/', 26 | os.getcwd() + '/requirements.txt', 27 | xm.JobRequirements(cpu=4, memory=16 * xm.GiB) 28 | ) 29 | 30 | resources['replay'] = xm_docker.DockerConfig( 31 | os.getcwd() + '/', 32 | os.getcwd() + '/requirements.txt', 33 | xm.JobRequirements(cpu=8, memory=64 * xm.GiB) 34 | ) 35 | 36 | return resources 37 | -------------------------------------------------------------------------------- /drlearner/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/drlearner/core/__init__.py -------------------------------------------------------------------------------- /drlearner/core/environment_loop.py: -------------------------------------------------------------------------------- 1 | """A simple agent-environment training loop.""" 2 | 3 | import operator 4 | import time 5 | from typing import Optional, Sequence 6 | import platform 7 | 8 | import numpy as np 9 | from pyvirtualdisplay import Display 10 | import tree 11 | 12 | from acme import core 13 | from acme.utils import counting 14 | from acme.utils import loggers 15 | from acme.utils import observers as observers_lib 16 | from acme.utils import signals 17 | 18 | import dm_env 19 | from dm_env import specs 20 | 21 | from drlearner.core.loggers import disable_view_window 22 | from drlearner.core.observers import VideoObserver 23 | 24 | 25 | class EnvironmentLoop(core.Worker): 26 | """A simple RL environment loop. 27 | This takes `Environment` and `Actor` instances and coordinates their 28 | interaction. Agent is updated if `should_update=True`. This can be used as: 29 | loop = EnvironmentLoop(environment, actor) 30 | loop.run(num_episodes) 31 | A `Counter` instance can optionally be given in order to maintain counts 32 | between different Acme components. If not given a local Counter will be 33 | created to maintain counts between calls to the `run` method. 34 | A `Logger` instance can also be passed in order to control the output of the 35 | loop. If not given a platform-specific default logger will be used as defined 36 | by utils.loggers.make_default_logger. A string `label` can be passed to easily 37 | change the label associated with the default logger; this is ignored if a 38 | `Logger` instance is given. 39 | A list of 'Observer' instances can be specified to generate additional metrics 40 | to be logged by the logger. They have access to the 'Environment' instance, 41 | the current timestep datastruct and the current action. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | environment: dm_env.Environment, 47 | actor: core.Actor, 48 | counter: Optional[counting.Counter] = None, 49 | logger: Optional[loggers.Logger] = None, 50 | should_update: bool = True, 51 | label: str = 'environment_loop', 52 | observers: Sequence[observers_lib.EnvLoopObserver] = (), 53 | ): 54 | # Internalize agent and environment. 55 | self._environment = environment 56 | self._actor = actor 57 | self._counter = counter or counting.Counter() 58 | self._logger = logger or loggers.make_default_logger(label) 59 | self._should_update = should_update 60 | self._observers = observers 61 | 62 | self.platform = platform.system().lower() 63 | 64 | if any([isinstance(o, VideoObserver) for o in observers]): 65 | if 'linux' in self.platform: 66 | display = Display(visible=0, size=(1400, 900)) 67 | display.start() 68 | else: 69 | disable_view_window() 70 | 71 | def run_episode(self, episode_count: int) -> loggers.LoggingData: 72 | """Run one episode. 73 | Each episode is a loop which interacts first with the environment to get an 74 | observation and then give that observation to the agent in order to retrieve 75 | an action. 76 | Returns: 77 | An instance of `loggers.LoggingData`. 78 | """ 79 | # Reset any counts and start the environment. 80 | start_time = time.time() 81 | action_time, env_step_time, observe_time = 0., 0., 0. 82 | episode_steps = 0 83 | 84 | # For evaluation, this keeps track of the total undiscounted reward 85 | # accumulated during the episode. 86 | episode_return = tree.map_structure(_generate_zeros_from_spec, 87 | self._environment.reward_spec()) 88 | timestep = self._environment.reset() 89 | # Make the first observation. 90 | self._actor.observe_first(timestep) 91 | actor_extras = self._actor.get_extras() 92 | 93 | for observer in self._observers: 94 | # Initialize the observer with the current state of the env after reset 95 | # and the initial timestep. 96 | if hasattr(observer, 'observe_first'): 97 | observer.observe_first( 98 | self._environment, 99 | timestep, 100 | actor_extras, 101 | episode=episode_count, 102 | step=episode_steps, 103 | ) 104 | 105 | # Run an episode. 106 | while not timestep.last(): 107 | # Generate an action from the agent's policy and step the environment. 108 | t = time.perf_counter() 109 | action = self._actor.select_action(timestep.observation) 110 | action_time += time.perf_counter() - t 111 | 112 | t = time.perf_counter() 113 | timestep = self._environment.step(action) 114 | env_step_time += time.perf_counter() - t 115 | 116 | # Have the agent observe the timestep and let the actor update itself. 117 | t = time.perf_counter() 118 | self._actor.observe(action, next_timestep=timestep) 119 | observe_time += time.perf_counter() - t 120 | 121 | actor_extras = self._actor.get_extras() 122 | for observer in self._observers: 123 | # One environment step was completed. Observe the current state of the 124 | # environment, the current timestep and the action. 125 | observer.observe( 126 | self._environment, 127 | timestep, 128 | action, 129 | actor_extras, 130 | episode=episode_count, 131 | step=episode_steps, 132 | ) 133 | if self._should_update: 134 | self._actor.update() 135 | 136 | # Book-keeping. 137 | episode_steps += 1 138 | 139 | # Equivalent to: episode_return += timestep.reward 140 | # We capture the return value because if timestep.reward is a JAX 141 | # DeviceArray, episode_return will not be mutated in-place. (In all other 142 | # cases, the returned episode_return will be the same object as the 143 | # argument episode_return.) 144 | episode_return = tree.map_structure(operator.iadd, 145 | episode_return, 146 | timestep.reward) 147 | 148 | # Record counts. 149 | counts = self._counter.increment(episodes=1, steps=episode_steps) 150 | 151 | # Collect the results and combine with counts. 152 | steps_per_second = episode_steps / (time.time() - start_time) 153 | result = { 154 | 'episode_length': episode_steps, 155 | 'episode_return': episode_return, 156 | 'steps_per_second': steps_per_second, 157 | 'action_mean_time': action_time / episode_steps, 158 | 'env_step_mean_time': env_step_time / episode_steps, 159 | 'observe_mean_time': observe_time / episode_steps, 160 | 161 | } 162 | result.update(counts) 163 | for observer in self._observers: 164 | if hasattr(observer, 'get_metrics'): 165 | result.update( 166 | observer.get_metrics(timestep=timestep, episode=episode_count), 167 | ) 168 | 169 | return result 170 | 171 | def run(self, 172 | num_episodes: Optional[int] = None, 173 | num_steps: Optional[int] = None): 174 | """Perform the run loop. 175 | Run the environment loop either for `num_episodes` episodes or for at 176 | least `num_steps` steps (the last episode is always run until completion, 177 | so the total number of steps may be slightly more than `num_steps`). 178 | At least one of these two arguments has to be None. 179 | Upon termination of an episode a new episode will be started. If the number 180 | of episodes and the number of steps are not given then this will interact 181 | with the environment infinitely. 182 | Args: 183 | num_episodes: number of episodes to run the loop for. 184 | num_steps: minimal number of steps to run the loop for. 185 | Raises: 186 | ValueError: If both 'num_episodes' and 'num_steps' are not None. 187 | """ 188 | 189 | if not (num_episodes is None or num_steps is None): 190 | raise ValueError('Either "num_episodes" or "num_steps" should be None.') 191 | 192 | def should_terminate(episode_count: int, step_count: int) -> bool: 193 | return ((num_episodes is not None and episode_count >= num_episodes) or 194 | (num_steps is not None and step_count >= num_steps)) 195 | 196 | episode_count, step_count = 0, 0 197 | with signals.runtime_terminator(): 198 | while not should_terminate(episode_count, step_count): 199 | result = self.run_episode(episode_count) 200 | episode_count += 1 201 | step_count += result['episode_length'] 202 | # Log the given episode results. 203 | self._logger.write(result) 204 | 205 | 206 | # Placeholder for an EnvironmentLoop alias 207 | 208 | 209 | def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: 210 | return np.zeros(spec.shape, spec.dtype) 211 | -------------------------------------------------------------------------------- /drlearner/core/local_layout.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Local agent based on builders.""" 17 | 18 | from typing import Any, Optional 19 | 20 | from acme import specs 21 | from acme.agents import agent 22 | from acme.agents.jax import builders 23 | from acme.jax import utils 24 | from acme.tf import savers 25 | from acme.utils import counting 26 | import jax 27 | import reverb 28 | 29 | 30 | class LocalLayout(agent.Agent): 31 | """An Agent that runs an algorithm defined by 'builder' on a single machine. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | seed: int, 37 | environment_spec: specs.EnvironmentSpec, 38 | builder: builders.GenericActorLearnerBuilder, 39 | networks: Any, 40 | policy_network: Any, 41 | workdir: Optional[str] = '~/acme', 42 | min_replay_size: int = 1000, 43 | samples_per_insert: float = 256.0, 44 | batch_size: int = 256, 45 | num_sgd_steps_per_step: int = 1, 46 | prefetch_size: int = 1, 47 | device_prefetch: bool = True, 48 | counter: Optional[counting.Counter] = None, 49 | checkpoint: bool = True, 50 | ): 51 | """Initialize the agent. 52 | 53 | Args: 54 | seed: A random seed to use for this layout instance. 55 | environment_spec: description of the actions, observations, etc. 56 | builder: builder defining an RL algorithm to train. 57 | networks: network objects to be passed to the learner. 58 | policy_network: function that given an observation returns actions. 59 | workdir: if provided saves the state of the learner and the counter 60 | (if the counter is not None) into workdir. 61 | min_replay_size: minimum replay size before updating. 62 | samples_per_insert: number of samples to take from replay for every insert 63 | that is made. 64 | batch_size: batch size for updates. 65 | num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call. 66 | For performance reasons (especially to reduce TPU host-device transfer 67 | times) it is performance-beneficial to do multiple sgd updates at once, 68 | provided that it does not hurt the training, which needs to be verified 69 | empirically for each environment. 70 | prefetch_size: whether to prefetch iterator. 71 | device_prefetch: whether prefetching should happen to a device. 72 | counter: counter object used to keep track of steps. 73 | checkpoint: boolean indicating whether to checkpoint the learner 74 | and the counter (if the counter is not None). 75 | """ 76 | if prefetch_size < 0: 77 | raise ValueError(f'Prefetch size={prefetch_size} should be non negative') 78 | 79 | key = jax.random.PRNGKey(seed) 80 | 81 | # Create the replay server and grab its address. 82 | replay_tables = builder.make_replay_tables(environment_spec) 83 | replay_server = reverb.Server(replay_tables, port=None) 84 | replay_client = reverb.Client(f'localhost:{replay_server.port}') 85 | 86 | # Create actor, dataset, and learner for generating, storing, and consuming 87 | # data respectively. 88 | adder = builder.make_adder(replay_client) 89 | 90 | def _is_reverb_queue(reverb_table: reverb.Table, 91 | reverb_client: reverb.Client) -> bool: 92 | """Returns True iff the Reverb Table is actually a queue.""" 93 | # TODO(sinopalnikov): make it more generic and check for a table that 94 | # needs special handling on update. 95 | info = reverb_client.server_info() 96 | table_info = info[reverb_table.name] 97 | is_queue = ( 98 | table_info.max_times_sampled == 1 and 99 | table_info.sampler_options.fifo and 100 | table_info.remover_options.fifo) 101 | return is_queue 102 | 103 | is_reverb_queue = any(_is_reverb_queue(table, replay_client) 104 | for table in replay_tables) 105 | 106 | dataset = builder.make_dataset_iterator(replay_client) 107 | if prefetch_size > 1: 108 | device = jax.devices()[0] if device_prefetch else None 109 | dataset = utils.prefetch(dataset, buffer_size=prefetch_size, 110 | device=device) 111 | learner_key, key = jax.random.split(key) 112 | learner = builder.make_learner( 113 | random_key=learner_key, 114 | networks=networks, 115 | dataset=dataset, 116 | replay_client=replay_client, 117 | counter=counter) 118 | if not checkpoint or workdir is None: 119 | self._checkpointer = None 120 | else: 121 | objects_to_save = {'learner': learner} 122 | if counter is not None: 123 | objects_to_save.update({'counter': counter}) 124 | self._checkpointer = savers.Checkpointer( 125 | objects_to_save, 126 | time_delta_minutes=30, 127 | subdirectory='learner', 128 | directory=workdir, 129 | add_uid=(workdir == '~/acme')) 130 | 131 | actor_key, key = jax.random.split(key) 132 | # use actor_id as 0 for local layout 133 | actor = builder.make_actor( 134 | actor_key, policy_network, adder, variable_source=learner) 135 | self._custom_update_fn = None 136 | if is_reverb_queue: 137 | # Reverb queue requires special handling on update: custom logic to 138 | # decide when it is safe to make a learner step. This is only needed for 139 | # the local agent, where the actor and the learner are running 140 | # synchronously and the learner will deadlock if it makes a step with 141 | # no data available. 142 | def custom_update(): 143 | should_update_actor = False 144 | # Run a number of learner steps (usually gradient steps). 145 | # TODO(raveman): This is wrong. When running multi-level learners, 146 | # different levels might have different batch sizes. Find a solution. 147 | while all(table.can_sample(batch_size) for table in replay_tables): 148 | learner.step() 149 | should_update_actor = True 150 | 151 | if should_update_actor: 152 | # "wait=True" to make it more onpolicy 153 | actor.update(wait=True) 154 | 155 | self._custom_update_fn = custom_update 156 | 157 | effective_batch_size = batch_size * num_sgd_steps_per_step 158 | super().__init__( 159 | actor=actor, 160 | learner=learner, 161 | min_observations=max(effective_batch_size, min_replay_size), 162 | observations_per_step=float(effective_batch_size) / samples_per_insert) 163 | 164 | # Save the replay so we don't garbage collect it. 165 | self._replay_server = replay_server 166 | 167 | def update(self): 168 | if self._custom_update_fn: 169 | self._custom_update_fn() 170 | else: 171 | super().update() 172 | if self._checkpointer: 173 | self._checkpointer.save() 174 | 175 | def save(self): 176 | """Checkpoint the state of the agent.""" 177 | if self._checkpointer: 178 | self._checkpointer.save(force=True) 179 | -------------------------------------------------------------------------------- /drlearner/core/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from drlearner.core.loggers.image import ImageLogger, disable_view_window 2 | -------------------------------------------------------------------------------- /drlearner/core/loggers/image.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from acme.utils.loggers.tf_summary import TFSummaryLogger 3 | 4 | 5 | def disable_view_window(): 6 | """ 7 | Disables gym view window 8 | """ 9 | from gym.envs.classic_control import rendering 10 | org_constructor = rendering.Viewer.__init__ 11 | 12 | def constructor(self, *args, **kwargs): 13 | org_constructor(self, *args, **kwargs) 14 | self.window.set_visible(visible=False) 15 | 16 | rendering.Viewer.__init__ = constructor 17 | 18 | 19 | class ImageLogger(TFSummaryLogger): 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | 23 | def write_image(self, name, image, step=0): 24 | with self.summary.as_default(): 25 | tf.summary.image(name, [image], step=step) -------------------------------------------------------------------------------- /drlearner/core/observers/__init__.py: -------------------------------------------------------------------------------- 1 | from .action_dist import ActionProbObserver 2 | from .discomaze_unique_states import UniqueStatesDiscoMazeObserver 3 | from .intrinsic_reward import IntrinsicRewardObserver 4 | from .meta_controller import MetaControllerObserver 5 | from .distillation_coef import DistillationCoefObserver 6 | 7 | from .video import VideoObserver, StorageVideoObserver 8 | from .actions import ActionsObserver 9 | -------------------------------------------------------------------------------- /drlearner/core/observers/action_dist.py: -------------------------------------------------------------------------------- 1 | import dm_env 2 | import numpy as np 3 | 4 | 5 | class ActionProbObserver: 6 | def __init__(self, num_actions): 7 | self._num_actions = num_actions 8 | self._action_counter = None 9 | 10 | def observe_first(self, *args, **kwargs) -> None: 11 | # todo: defaultdict 12 | self._action_counter = {i: 0 for i in range(self._num_actions)} 13 | 14 | def observe(self, *args, **kwargs) -> None: 15 | env, timestamp, action, actor_extras = args 16 | self._action_counter[int(action)] += 1 17 | 18 | def get_metrics(self, **kwargs): 19 | total_actions = sum(self._action_counter.values()) 20 | return {f'Action: {i}': self._action_counter[i] / total_actions for i in range(self._num_actions)} 21 | -------------------------------------------------------------------------------- /drlearner/core/observers/actions.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from io import BytesIO 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from matplotlib import pyplot as plt 9 | from PIL import Image 10 | import dm_env 11 | 12 | from drlearner.core.loggers import ImageLogger 13 | 14 | 15 | class ActionsObserver: 16 | def __init__(self, config): 17 | self.config = config 18 | self.image_logger = ImageLogger(config.logs_dir) 19 | 20 | self.unique_actions = set() 21 | self.ratios = list() 22 | self.episode_actions = list() 23 | 24 | def _log_episode_action(self, timestamp: dm_env.TimeStep, action: np.array): 25 | if not timestamp.last(): 26 | self.episode_actions.append(action) 27 | else: 28 | episode_ratios = self.calculate_actions_ratio() 29 | self.ratios.append(episode_ratios) 30 | 31 | self.episode_actions = list() 32 | 33 | def observe( 34 | self, 35 | *args, 36 | **kwargs, 37 | ) -> None: 38 | env, timestamp, action, actor_extras = args 39 | 40 | action = np.asscalar(action) 41 | self.unique_actions.add(action) 42 | 43 | episode_count = kwargs['episode_count'] 44 | log_action = episode_count % self.config.actions_log_period == 0 45 | if log_action: 46 | self._log_episode_action(timestamp, action) 47 | 48 | def calculate_actions_ratio(self) -> Dict: 49 | """ 50 | Calculates actions ratio per episode 51 | 52 | Returns 53 | ratios: list of action ratios per action type 54 | unique_actions: set of possible actions 55 | """ 56 | counter = Counter(self.episode_actions) 57 | episode_ratios = dict() 58 | 59 | for action in self.unique_actions: 60 | count = counter.get(action, 0) 61 | ratio = count / len(self.episode_actions) 62 | episode_ratios[str(action)] = ratio 63 | 64 | return episode_ratios 65 | 66 | def _plot_actions(self): 67 | """ 68 | Creates actions plot and logs it into tensorboard 69 | """ 70 | for action in self.unique_actions: 71 | values = [ratio.get(str(action), 0) for ratio in self.ratios] 72 | 73 | n = len(self.ratios) 74 | steps = list(range(n)) 75 | 76 | plt.plot(steps, values, c=np.random.rand(3), label=action) 77 | plt.legend() 78 | 79 | plt.title('action ratios per episode') 80 | plt.xlabel('episode') 81 | plt.ylabel('action ratio') 82 | 83 | buffer = BytesIO() 84 | plt.savefig(buffer, format='png') 85 | img = Image.open(buffer) 86 | self.image_logger.write_image('actions_ratio', tf.convert_to_tensor(np.array(img))) 87 | 88 | def get_metrics(self, **kwargs): 89 | episode_count = kwargs['episode_count'] 90 | last_episode = episode_count == self.config.num_episodes - 1 91 | 92 | if last_episode: 93 | self._plot_actions() 94 | 95 | return dict() 96 | -------------------------------------------------------------------------------- /drlearner/core/observers/discomaze_unique_states.py: -------------------------------------------------------------------------------- 1 | import dm_env 2 | import numpy as np 3 | 4 | MAZE_PATH_COLOR = (0., 0., 0.) 5 | AGENT_COLOR = (1., 1., 1.) 6 | 7 | 8 | def mask_color_on_rgb(image, color) -> np.ndarray: 9 | """ 10 | Given `image` of shape (H, W, C=3) and `color` pf shape (3,) return 11 | mask of shape (H, W) where pixel on the image have the same color as `color` 12 | """ 13 | return np.isclose(image[..., 0], color[0]) & \ 14 | np.isclose(image[..., 1], color[1]) & \ 15 | np.isclose(image[..., 2], color[2]) 16 | 17 | 18 | class UniqueStatesVisitsCounter: 19 | def __init__(self, total): 20 | self.__total = total 21 | self.__visited = set() 22 | self.__reward_first_visit = [] 23 | self.__reward_repeated_visit = [] 24 | 25 | def add(self, state, intrinsic_reward): 26 | coords = self.get_xy_from_state(state) 27 | 28 | if coords in self.__visited: 29 | self.__reward_repeated_visit.append(float(intrinsic_reward)) 30 | else: 31 | self.__visited.add(coords) 32 | self.__reward_first_visit.append(float(intrinsic_reward)) 33 | 34 | @staticmethod 35 | def get_xy_from_state(state): 36 | mask = mask_color_on_rgb(state, AGENT_COLOR) 37 | coords = np.where(mask) 38 | x, y = coords[1][0], coords[0][0] 39 | return x, y 40 | 41 | def get_number_of_visited(self): 42 | return len(self.__visited) 43 | 44 | def get_fraction_of_visited(self): 45 | return len(self.__visited) / self.__total 46 | 47 | def get_mean_first_visit_reward(self): 48 | return np.mean(self.__reward_first_visit) 49 | 50 | def get_mean_repeated_visit_reward(self): 51 | return np.mean(self.__reward_repeated_visit) 52 | 53 | 54 | class UniqueStatesDiscoMazeObserver: 55 | def __init__(self): 56 | self._state_visit_counter: UniqueStatesVisitsCounter = None 57 | 58 | def reset(self, states_total): 59 | self._state_visit_counter = UniqueStatesVisitsCounter(states_total) 60 | 61 | def observe_first(self, *args, **kwargs) -> None: 62 | env, timestamp, actor_extras = args 63 | 64 | states_total = mask_color_on_rgb( 65 | timestamp.observation.observation, 66 | color=MAZE_PATH_COLOR 67 | ).sum() + 1 # +1 for current position of agent 68 | # TODO: account for states with targets if needed; currently support only 0 targets case 69 | self.reset(states_total) 70 | 71 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, 72 | action: np.ndarray, actor_extras, **kwargs) -> None: 73 | self._state_visit_counter.add( 74 | timestep.observation.observation, 75 | actor_extras['intrinsic_reward'] 76 | ) 77 | 78 | def get_metrics(self, **kwargs): 79 | metrics = { 80 | "unique_fraction": self._state_visit_counter.get_fraction_of_visited(), 81 | "first_visit_mean_reward": self._state_visit_counter.get_mean_first_visit_reward(), 82 | "repeated_visit_mean_reward": self._state_visit_counter.get_mean_repeated_visit_reward() 83 | } 84 | return metrics 85 | -------------------------------------------------------------------------------- /drlearner/core/observers/distillation_coef.py: -------------------------------------------------------------------------------- 1 | import dm_env 2 | import numpy as np 3 | 4 | 5 | class DistillationCoefObserver: 6 | def __init__(self): 7 | self._alphas = None 8 | 9 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep, actor_extras, **kwargs) -> None: 10 | self._alphas = [] 11 | 12 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, 13 | action: np.ndarray, actor_extras, **kwargs) -> None: 14 | self._alphas.append(float(actor_extras['alpha'])) 15 | 16 | 17 | def get_metrics(self, **kwargs): 18 | return {'Mean Distillation Alpha': np.mean(self._alphas)} -------------------------------------------------------------------------------- /drlearner/core/observers/intrinsic_reward.py: -------------------------------------------------------------------------------- 1 | import dm_env 2 | import numpy as np 3 | 4 | 5 | class IntrinsicRewardObserver: 6 | def __init__(self): 7 | self._intrinsic_rewards = None 8 | 9 | def observe_first(self, *args, **kwargs) -> None: 10 | env, timestep, actor_extras = args 11 | 12 | self._intrinsic_rewards = [] 13 | self._intrinsic_rewards.append(float(actor_extras['intrinsic_reward'])) 14 | 15 | def observe(self, *args, **kwargs) -> None: 16 | env, timestep, action, actor_extras = args 17 | self._intrinsic_rewards.append(float(actor_extras['intrinsic_reward'])) 18 | 19 | def get_metrics(self, **kwargs): 20 | return { 21 | "intrinsic_rewards_sum": np.sum(self._intrinsic_rewards), 22 | "intrinsic_rewards_mean": np.mean(self._intrinsic_rewards) 23 | } 24 | -------------------------------------------------------------------------------- /drlearner/core/observers/lazy_dict.py: -------------------------------------------------------------------------------- 1 | from collections.abc import MutableMapping 2 | from threading import RLock 3 | from inspect import getfullargspec 4 | from copy import copy 5 | 6 | class LazyDictionaryError(Exception): 7 | pass 8 | 9 | class CircularReferenceError(LazyDictionaryError): 10 | pass 11 | 12 | class ConstantRedefinitionError(LazyDictionaryError): 13 | pass 14 | 15 | class LazyDictionary(MutableMapping): 16 | def __init__(self, values={ }): 17 | self.lock = RLock() 18 | self.values = copy(values) 19 | self.states = {} 20 | for key in self.values: 21 | self.states[key] = 'defined' 22 | 23 | def __len__(self): 24 | return len(self.values) 25 | 26 | def __iter__(self): 27 | return iter(self.values) 28 | 29 | def __getitem__(self, key): 30 | with self.lock: 31 | if key in self.states: 32 | if self.states[key] == 'evaluating': 33 | raise CircularReferenceError('value of "%s" depends on itself' % key) 34 | elif self.states[key] == 'error': 35 | raise self.values[key] 36 | elif self.states[key] == 'defined': 37 | value = self.values[key] 38 | if callable(value): 39 | args= getfullargspec(value).args 40 | if len(args) == 0: 41 | self.states[key] = 'evaluating' 42 | try: 43 | self.values[key] = value() 44 | except Exception as ex: 45 | self.values[key] = ex 46 | self.states[key] = 'error' 47 | raise ex 48 | elif len(args) == 1: 49 | self.states[key] = 'evaluating' 50 | try: 51 | self.values[key] = value(self) 52 | except Exception as ex: 53 | self.values[key] = ex 54 | self.states[key] = 'error' 55 | raise ex 56 | self.states[key] = 'evaluated' 57 | return self.values[key] 58 | 59 | def __contains__(self, key): 60 | return key in self.values 61 | 62 | def __setitem__(self, key, value): 63 | with self.lock: 64 | if key in self.states and self.states[key][0:4] == 'eval': 65 | raise ConstantRedefinitionError('"%s" is immutable' % key) 66 | self.values[key] = value 67 | self.states[key] = 'defined' 68 | 69 | def __delitem__(self, key): 70 | with self.lock: 71 | if key in self.states and self.states[key][0:4] == 'eval': 72 | raise ConstantRedefinitionError('"%s" is immutable' % key) 73 | del self.values[key] 74 | del self.states[key] 75 | 76 | def __str__(self): 77 | return str(self.values) 78 | 79 | def __repr__(self): 80 | return "LazyDictionary({0})".format(repr(self.values)) 81 | -------------------------------------------------------------------------------- /drlearner/core/observers/meta_controller.py: -------------------------------------------------------------------------------- 1 | import dm_env 2 | import numpy as np 3 | 4 | class MetaControllerObserver: 5 | def __init__(self): 6 | self._mixture_indices = None 7 | self._is_eval = None 8 | 9 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep, actor_extras, **kwargs) -> None: 10 | self._mixture_indices = int(actor_extras['mixture_idx']) 11 | self._is_eval = int(actor_extras['is_eval']) 12 | 13 | 14 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, 15 | action: np.ndarray, actor_extras, **kwargs) -> None: 16 | pass 17 | 18 | def get_metrics(self, **kwargs): 19 | return { 20 | 'mixture_idx': self._mixture_indices, 21 | 'is_eval': self._is_eval 22 | } -------------------------------------------------------------------------------- /drlearner/core/observers/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | import platform 4 | 5 | from skvideo import io 6 | import numpy as np 7 | import tensorflow as tf 8 | import dm_env 9 | 10 | from drlearner.core.loggers import ImageLogger 11 | from drlearner.core.observers.lazy_dict import LazyDictionary 12 | 13 | 14 | class VideoObserver(ABC): 15 | def __init__(self, config): 16 | self.config = config 17 | self.env_library = config.env_library 18 | self.log_period = config.video_log_period 19 | 20 | self.platform = platform.system().lower() 21 | 22 | def render(self, env): 23 | """ 24 | Renders current frame 25 | """ 26 | render_funcs = LazyDictionary( 27 | { 28 | 'dm_control': lambda: env.physics.render(camera_id=0), 29 | 'gym': lambda: env.environment.render(mode='rgb_array'), 30 | 'discomaze': lambda: env.render(mode='state_pixels'), 31 | }, 32 | ) 33 | env_lib = self.env_library 34 | 35 | if env_lib in render_funcs.keys(): 36 | return render_funcs[env_lib] 37 | else: 38 | raise ValueError( 39 | f"Unknown environment library: {env_lib}; choose among {list(render_funcs.keys())}", 40 | ) 41 | 42 | def _log_video(self, episode_count): 43 | return True if (episode_count + 1) % self.log_period == 0 else False 44 | 45 | @abstractmethod 46 | def observe(self, env: dm_env.Environment, *args, **kwargs): 47 | pass 48 | 49 | 50 | class StorageVideoObserver(VideoObserver): 51 | def __init__(self, config): 52 | print(f'INIT: {self.__class__.__name__}') 53 | super().__init__(config) 54 | 55 | self.frames = list() 56 | self.videos_dir = self._create_videos_dir() 57 | 58 | def observe(self, env: dm_env.Environment, *args, **kwargs): 59 | frame = self.render(env) 60 | self.frames.append(frame.astype('uint8')) 61 | 62 | def get_metrics(self, **kwargs): 63 | episode = kwargs['episode'] 64 | 65 | if self._log_video(episode): 66 | video_dir = os.path.join(self.videos_dir, f'episode_{episode + 1}.mp4') 67 | io.vwrite(video_dir, np.array(self.frames)) 68 | 69 | self.frames = list() 70 | 71 | return dict() 72 | 73 | def _create_videos_dir(self): 74 | video_dir = os.path.join(self.config.logs_dir, 'episodes') 75 | 76 | if not os.path.exists(video_dir): 77 | os.makedirs(video_dir, exist_ok=True) 78 | 79 | return video_dir 80 | 81 | observe_first = observe 82 | 83 | 84 | class TBVideoObserver(VideoObserver): 85 | def __init__(self, config): 86 | super().__init__(config) 87 | self.image_logger = ImageLogger(config.logs_dir) 88 | 89 | def log_frame(self, env: dm_env.Environment, episode=None, step=None): 90 | frame = self.render(env) 91 | 92 | self.image_logger.write_image( 93 | f'video_{episode + 1}', 94 | tf.convert_to_tensor(np.array(frame)), 95 | step=step, 96 | ) 97 | 98 | def observe(self, env: dm_env.Environment, *args, **kwargs) -> None: 99 | episode = kwargs['episode'] 100 | step = kwargs['step'] 101 | 102 | if self._log_video(episode): 103 | self.log_frame(env, episode=episode, step=step) 104 | 105 | observe_first = observe -------------------------------------------------------------------------------- /drlearner/drlearner/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import DRLearner 2 | from .builder import DRLearnerBuilder 3 | from .config import DRLearnerConfig 4 | from .distributed_agent import DistributedDRLearnerFromConfig 5 | from .learning import DRLearnerLearner 6 | from .networks import DRLearnerNetworks 7 | from .networks import make_policy_networks 8 | from .networks import networks_zoo 9 | -------------------------------------------------------------------------------- /drlearner/drlearner/actor.py: -------------------------------------------------------------------------------- 1 | """DRLearner JAX actors.""" 2 | from typing import Optional 3 | 4 | import dm_env 5 | import jax 6 | import jax.numpy as jnp 7 | from acme import adders 8 | from acme import types 9 | from acme.agents.jax import actors 10 | from acme.jax import networks as network_lib 11 | from acme.jax import utils 12 | from acme.jax import variable_utils 13 | 14 | from .actor_core import DRLearnerActorCore 15 | 16 | 17 | class DRLearnerActor(actors.GenericActor): 18 | """A generic actor implemented on top of ActorCore. 19 | 20 | An actor based on a policy which takes observations and outputs actions. It 21 | also adds experiences to replay and updates the actor weights from the policy 22 | on the learner. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | actor: DRLearnerActorCore, 28 | mixture_idx: int, 29 | random_key: network_lib.PRNGKey, 30 | variable_client: Optional[variable_utils.VariableClient], 31 | adder: Optional[adders.Adder] = None, 32 | jit: bool = True, 33 | backend: Optional[str] = 'cpu', 34 | per_episode_update: bool = False 35 | ): 36 | """Initializes a feed forward actor. 37 | 38 | Args: 39 | actor: actor core. 40 | random_key: Random key. 41 | variable_client: The variable client to get policy parameters from. 42 | adder: An adder to add experiences to. 43 | jit: Whether to jit the passed ActorCore's pure functions. 44 | backend: Which backend to use when jitting the policy. 45 | per_episode_update: if True, updates variable client params once at the 46 | beginning of each episode 47 | """ 48 | super(DRLearnerActor, self).__init__(actor, random_key, variable_client, adder, jit, backend, per_episode_update) 49 | if jit: 50 | self._observe = jax.jit(actor.observe) 51 | # self._observe_first = jax.jit(actor.observe_first) 52 | else: 53 | self._observe = actor.observe 54 | self._observe_first = actor.observe_first 55 | 56 | self._mixture_idx = jnp.array(mixture_idx, dtype=jnp.int32) 57 | 58 | def select_action(self, 59 | observation: network_lib.Observation) -> types.NestedArray: 60 | action, self._state = self._policy(self._params, observation, self._state) 61 | return utils.to_numpy(action) 62 | 63 | def observe_first(self, timestep: dm_env.TimeStep): 64 | self._random_key, key = jax.random.split(self._random_key) 65 | self._state = self._init(key, self._mixture_idx, self._state) 66 | self._state = self._observe_first(self._params, timestep, self._state) 67 | if self._adder: 68 | self._adder.add_first(timestep) 69 | if self._variable_client and self._per_episode_update: 70 | self._variable_client.update_and_wait() 71 | 72 | def observe(self, action: network_lib.Action, next_timestep: dm_env.TimeStep): 73 | self._state = self._observe(self._params, action, next_timestep, self._state) 74 | super(DRLearnerActor, self).observe(action, next_timestep) 75 | 76 | def get_extras(self): 77 | return self._get_extras(self._state) -------------------------------------------------------------------------------- /drlearner/drlearner/agent.py: -------------------------------------------------------------------------------- 1 | """Defines local DRLearner agent, using JAX.""" 2 | 3 | from typing import Optional 4 | 5 | from acme import specs 6 | from acme.utils import counting 7 | 8 | from ..core import local_layout 9 | from .builder import DRLearnerBuilder 10 | from .config import DRLearnerConfig 11 | from .networks import make_policy_networks, DRLearnerNetworks 12 | 13 | 14 | class DRLearner(local_layout.LocalLayout): 15 | """Local agent for DRLearner. 16 | 17 | This implements a single-process DRLearner agent. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | spec: specs.EnvironmentSpec, 23 | networks: DRLearnerNetworks, 24 | config: DRLearnerConfig, 25 | seed: int, 26 | workdir: Optional[str] = '~/acme', 27 | counter: Optional[counting.Counter] = None, 28 | logger=None 29 | ): 30 | ngu_builder = DRLearnerBuilder(networks, config, num_actors_per_mixture=1,logger=logger) 31 | super().__init__( 32 | seed=seed, 33 | environment_spec=spec, 34 | builder=ngu_builder, 35 | networks=networks, 36 | policy_network=make_policy_networks(networks, config), 37 | workdir=workdir, 38 | min_replay_size=config.min_replay_size, 39 | samples_per_insert=config.samples_per_insert if config.samples_per_insert \ 40 | else 10 / (config.burn_in_length + config.trace_length), 41 | batch_size=config.batch_size, 42 | num_sgd_steps_per_step=config.num_sgd_steps_per_step, 43 | counter=counter, 44 | ) 45 | 46 | def get_extras(self): 47 | return self._actor.get_extras() 48 | -------------------------------------------------------------------------------- /drlearner/drlearner/builder.py: -------------------------------------------------------------------------------- 1 | """DRLearner Builder.""" 2 | from typing import Callable, Iterator, List, Optional 3 | from copy import deepcopy 4 | import functools 5 | 6 | import acme 7 | import jax 8 | import jax.numpy as jnp 9 | import optax 10 | import reverb 11 | import tensorflow as tf 12 | from acme import adders 13 | from acme import core 14 | from acme import specs 15 | from acme.adders import reverb as adders_reverb 16 | from acme.agents.jax import builders 17 | from acme.datasets import reverb as datasets 18 | from acme.jax import networks as networks_lib 19 | from acme.jax import utils 20 | from acme.jax import variable_utils 21 | from acme.utils import counting 22 | from acme.utils import loggers 23 | 24 | from .config import DRLearnerConfig 25 | from .actor import DRLearnerActor 26 | from .actor_core import get_actor_core 27 | from .learning import DRLearnerLearner 28 | from .networks import DRLearnerNetworks 29 | 30 | # run CPU-only tensorflow for data loading 31 | tf.config.set_visible_devices([], "GPU") 32 | 33 | 34 | class DRLearnerBuilder(builders.ActorLearnerBuilder): 35 | """DRLearner Builder. 36 | 37 | """ 38 | 39 | def __init__(self, 40 | networks: DRLearnerNetworks, 41 | config: DRLearnerConfig, 42 | num_actors_per_mixture: int, 43 | logger: Callable[[], loggers.Logger] = lambda: None, ): 44 | """Creates DRLearner learner, a behavior policy and an eval actor. 45 | 46 | Args: 47 | networks: DRLearner networks, used to build core state spec. 48 | config: a config with DRLearner hps 49 | logger: a logger for the learner 50 | """ 51 | self._networks = networks 52 | self._config = config 53 | self._num_actors_per_mixture = num_actors_per_mixture 54 | self._logger_fn = logger 55 | 56 | # Sequence length for dataset iterator. 57 | self._sequence_length = ( 58 | self._config.burn_in_length + self._config.trace_length + 1) 59 | 60 | # Construct the core state spec. 61 | dummy_key = jax.random.PRNGKey(0) 62 | intrinsic_initial_state_params = networks.uvfa_net.initial_state.init(dummy_key, 1) 63 | intrinsic_initial_state = networks.uvfa_net.initial_state.apply(intrinsic_initial_state_params, 64 | dummy_key, 1) 65 | extrinsic_initial_state_params = networks.uvfa_net.initial_state.init(dummy_key, 1) 66 | extrinsic_initial_state = networks.uvfa_net.initial_state.apply(extrinsic_initial_state_params, 67 | dummy_key, 1) 68 | intrinsic_core_state_spec = utils.squeeze_batch_dim(intrinsic_initial_state) 69 | extrinsic_core_state_spec = utils.squeeze_batch_dim(extrinsic_initial_state) 70 | self._extra_spec = { 71 | 'intrinsic_core_state': intrinsic_core_state_spec, 72 | 'extrinsic_core_state': extrinsic_core_state_spec 73 | } 74 | 75 | def evaluate_logger(self): 76 | if isinstance(self._logger_fn, functools.partial): 77 | self._logger_fn=self._logger_fn() 78 | 79 | def make_learner( 80 | self, 81 | random_key: networks_lib.PRNGKey, 82 | networks: DRLearnerNetworks, 83 | dataset: Iterator[reverb.ReplaySample], 84 | replay_client: Optional[reverb.Client] = None, 85 | counter: Optional[counting.Counter] = None, 86 | ) -> core.Learner: 87 | # The learner updates the parameters (and initializes them). 88 | self.evaluate_logger() 89 | return DRLearnerLearner( 90 | uvfa_unroll=networks.uvfa_net.unroll, 91 | uvfa_initial_state=networks.uvfa_net.initial_state, 92 | idm_action_pred=networks.embedding_net.predict_action, 93 | distillation_embed=networks.distillation_net.embed_sequence, 94 | batch_size=self._config.batch_size, 95 | random_key=random_key, 96 | burn_in_length=self._config.burn_in_length, 97 | beta_min=self._config.beta_min, 98 | beta_max=self._config.beta_max, 99 | gamma_min=self._config.gamma_min, 100 | gamma_max=self._config.gamma_max, 101 | num_mixtures=self._config.num_mixtures, 102 | target_epsilon=self._config.target_epsilon, 103 | importance_sampling_exponent=( 104 | self._config.importance_sampling_exponent), 105 | max_priority_weight=self._config.max_priority_weight, 106 | target_update_period=self._config.target_update_period, 107 | iterator=dataset, 108 | uvfa_optimizer=optax.adam(self._config.uvfa_learning_rate), 109 | idm_optimizer=optax.adamw(self._config.idm_learning_rate, 110 | weight_decay=self._config.idm_weight_decay), 111 | distillation_optimizer=optax.adamw(self._config.distillation_learning_rate, 112 | weight_decay=self._config.distillation_weight_decay), 113 | idm_clip_steps=self._config.idm_clip_steps, 114 | distillation_clip_steps=self._config.distillation_clip_steps, 115 | retrace_lambda=self._config.retrace_lambda, 116 | tx_pair=self._config.tx_pair, 117 | clip_rewards=self._config.clip_rewards, 118 | max_abs_reward=self._config.max_absolute_reward, 119 | replay_client=replay_client, 120 | counter=counter, 121 | logger=self._logger_fn) 122 | 123 | def make_replay_tables( 124 | self, 125 | environment_spec: specs.EnvironmentSpec, 126 | ) -> List[reverb.Table]: 127 | """Create tables to insert data into.""" 128 | if self._config.samples_per_insert: 129 | samples_per_insert_tolerance = ( 130 | self._config.samples_per_insert_tolerance_rate * 131 | self._config.samples_per_insert) 132 | error_buffer = self._config.min_replay_size * samples_per_insert_tolerance 133 | limiter = reverb.rate_limiters.SampleToInsertRatio( 134 | min_size_to_sample=self._config.min_replay_size, 135 | samples_per_insert=self._config.samples_per_insert, 136 | error_buffer=error_buffer) 137 | else: 138 | limiter = reverb.rate_limiters.MinSize(1) 139 | 140 | # add intrinsic rewards and mixture_idx (intrinsic reward beta) to extra_specs 141 | self._extra_spec['intrinsic_reward'] = specs.Array( 142 | shape=environment_spec.rewards.shape, 143 | dtype=jnp.float32, 144 | name='intrinsic_reward' 145 | ) 146 | self._extra_spec['mixture_idx'] = specs.Array( 147 | shape=environment_spec.rewards.shape, 148 | dtype=jnp.int32, 149 | name='mixture_idx' 150 | ) 151 | # add probability of action under behavior policy 152 | self._extra_spec['behavior_action_prob'] = specs.Array( 153 | shape=environment_spec.rewards.shape, 154 | dtype=jnp.float32, 155 | name='behavior_action_prob' 156 | ) 157 | 158 | # add the mode of evaluator 159 | self._extra_spec['is_eval'] = specs.Array( 160 | shape=environment_spec.rewards.shape, 161 | dtype=jnp.int32, 162 | name='is_eval' 163 | ) 164 | 165 | self._extra_spec['alpha'] = specs.Array( 166 | shape=environment_spec.rewards.shape, 167 | dtype=jnp.float32, 168 | name='alpha' 169 | ) 170 | 171 | 172 | return [ 173 | reverb.Table( 174 | name=self._config.replay_table_name, 175 | sampler=reverb.selectors.Prioritized( 176 | self._config.priority_exponent), 177 | remover=reverb.selectors.Fifo(), 178 | max_size=self._config.max_replay_size, 179 | rate_limiter=limiter, 180 | signature=adders_reverb.SequenceAdder.signature( 181 | environment_spec, self._extra_spec)) 182 | ] 183 | 184 | def make_dataset_iterator( 185 | self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: 186 | """Create a dataset iterator to use for learning/updating the agent.""" 187 | dataset = datasets.make_reverb_dataset( 188 | table=self._config.replay_table_name, 189 | server_address=replay_client.server_address, 190 | batch_size=self._config.batch_size, 191 | prefetch_size=self._config.prefetch_size, 192 | num_parallel_calls=self._config.num_parallel_calls) 193 | return dataset.as_numpy_iterator() 194 | 195 | def make_adder(self, 196 | replay_client: reverb.Client) -> Optional[adders.Adder]: 197 | """Create an adder which records data generated by the actor/environment.""" 198 | return adders_reverb.SequenceAdder( 199 | client=replay_client, 200 | period=self._config.sequence_period, 201 | sequence_length=self._sequence_length, 202 | delta_encoded=True) 203 | 204 | def make_actor( 205 | self, 206 | random_key: networks_lib.PRNGKey, 207 | policy_networks, 208 | adder: Optional[adders.Adder] = None, 209 | variable_source: Optional[core.VariableSource] = None, 210 | actor_id: int = 0, 211 | is_evaluator: bool = False, 212 | ) -> acme.Actor: 213 | 214 | # Create variable client. 215 | variable_client = variable_utils.VariableClient( 216 | variable_source, 217 | key='actor_variables', 218 | update_period=self._config.variable_update_period) 219 | variable_client.update_and_wait() 220 | 221 | intrinsic_initial_state_key1, intrinsic_initial_state_key2, \ 222 | extrinsic_initial_state_key1, extrinsic_initial_state_key2, random_key = jax.random.split(random_key, 5) 223 | intrinsic_actor_initial_state_params = self._networks.uvfa_net.initial_state.init( 224 | intrinsic_initial_state_key1, 1) 225 | intrinsic_actor_initial_state = self._networks.uvfa_net.initial_state.apply( 226 | intrinsic_actor_initial_state_params, intrinsic_initial_state_key2, 1) 227 | extrinsic_actor_initial_state_params = self._networks.uvfa_net.initial_state.init( 228 | extrinsic_initial_state_key1, 1) 229 | extrinsic_actor_initial_state = self._networks.uvfa_net.initial_state.apply( 230 | extrinsic_actor_initial_state_params, extrinsic_initial_state_key2, 1) 231 | 232 | config = deepcopy(self._config) 233 | if is_evaluator: 234 | config.window = self._config.evaluation_window 235 | config.epsilon = self._config.evaluation_epsilon 236 | config.mc_epsilon = self._config.evaluation_mc_epsilon 237 | else: 238 | config.window = self._config.actor_window 239 | config.epsilon = self._config.actor_epsilon 240 | config.mc_epsilon = self._config.actor_mc_epsilon 241 | 242 | 243 | actor_core = get_actor_core(policy_networks, 244 | intrinsic_actor_initial_state, 245 | extrinsic_actor_initial_state, 246 | actor_id, 247 | self._num_actors_per_mixture, 248 | config, 249 | jit=True) 250 | 251 | mixture_idx = actor_id // self._num_actors_per_mixture 252 | 253 | 254 | return DRLearnerActor( 255 | actor_core, mixture_idx, random_key, variable_client, adder, backend='cpu', jit=True) 256 | -------------------------------------------------------------------------------- /drlearner/drlearner/config.py: -------------------------------------------------------------------------------- 1 | """DRLearner config.""" 2 | import dataclasses 3 | 4 | import rlax 5 | from acme.adders import reverb as adders_reverb 6 | 7 | 8 | @dataclasses.dataclass 9 | class DRLearnerConfig: 10 | """Configuration options for DRLearner agent.""" 11 | gamma_min: float = 0.99 12 | gamma_max: float = 0.997 13 | num_mixtures: int = 32 14 | target_update_period: int = 400 15 | evaluation_epsilon: float = 0.01 16 | epsilon: float = 0.01 17 | actor_epsilon: float = 0.01 18 | target_epsilon: float = 0.01 19 | variable_update_period: int = 400 20 | 21 | # Learner options 22 | retrace_lambda: float = 0.95 23 | burn_in_length: int = 40 24 | trace_length: int = 80 25 | sequence_period: int = 40 26 | num_sgd_steps_per_step: int = 1 27 | uvfa_learning_rate: float = 1e-4 28 | idm_learning_rate: float = 5e-4 29 | distillation_learning_rate: float = 5e-4 30 | idm_weight_decay: float = 1e-5 31 | distillation_weight_decay: float = 1e-5 32 | idm_clip_steps: int = 5 33 | distillation_clip_steps: int = 5 34 | clip_rewards: bool = False 35 | max_absolute_reward: float = 1.0 36 | tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR 37 | distillation_moving_average_coef: float = 1e-3 38 | 39 | # Intrinsic reward multipliers 40 | beta_min: float = 0. 41 | beta_max: float = 0.3 42 | 43 | # Embedding network options 44 | observation_embed_dim: int = 128 45 | episodic_memory_num_neighbors: int = 10 46 | episodic_memory_max_size: int = 30_000 47 | episodic_memory_max_similarity: float = 8. 48 | episodic_memory_cluster_distance: float = 8e-3 49 | episodic_memory_pseudo_counts: float = 1e-3 50 | episodic_memory_epsilon: float = 1e-4 51 | 52 | # Distillation network 53 | distillation_embed_dim: int = 128 54 | max_lifelong_modulation: float = 5.0 55 | 56 | # Replay options 57 | samples_per_insert_tolerance_rate: float = 0.1 58 | samples_per_insert: float = 4.0 59 | min_replay_size: int = 50_000 60 | max_replay_size: int = 100_000 61 | batch_size: int = 64 62 | prefetch_size: int = 2 63 | num_parallel_calls: int = 16 64 | replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE 65 | 66 | # Priority options 67 | importance_sampling_exponent: float = 0.6 68 | priority_exponent: float = 0.9 69 | max_priority_weight: float = 0.9 70 | 71 | # Meta Controller options 72 | window: int = 160 73 | actor_window: int = 160 74 | evaluation_window: int = 3600 75 | n_arms: int = 32 76 | mc_epsilon: float = 0.5 # Value is set from actor_mc_espilon or evaluation_mc_epsilon depending on whether the actor acts as evaluator 77 | actor_mc_epsilon: float = 0.5 78 | evaluation_mc_epsilon: float = 0.01 79 | mc_beta: float = 1. 80 | 81 | # Agent's video logging options 82 | env_library: str = None 83 | video_log_period: int = 10 84 | actions_log_period: int = 1 85 | logs_dir: str = 'experiments/default' 86 | num_episodes: int = 50 87 | 88 | -------------------------------------------------------------------------------- /drlearner/drlearner/distributed_agent.py: -------------------------------------------------------------------------------- 1 | """Defines distributed DRLearner agent, using JAX.""" 2 | 3 | import functools 4 | from typing import Callable, Optional, Sequence 5 | 6 | import dm_env 7 | from acme import specs 8 | from acme.jax import utils 9 | from acme.utils import loggers 10 | 11 | from ..core import distributed_layout 12 | from .config import DRLearnerConfig 13 | from .builder import DRLearnerBuilder 14 | from .networks import DRLearnerNetworks, make_policy_networks 15 | 16 | NetworkFactory = Callable[[specs.EnvironmentSpec], DRLearnerNetworks] 17 | EnvironmentFactory = Callable[[int], dm_env.Environment] 18 | 19 | 20 | class DistributedDRLearnerFromConfig(distributed_layout.DistributedLayout): 21 | """Distributed DRLearner agents from config.""" 22 | 23 | def __init__( 24 | self, 25 | environment_factory: EnvironmentFactory, 26 | environment_spec: specs.EnvironmentSpec, 27 | network_factory: NetworkFactory, 28 | config: DRLearnerConfig, 29 | seed: int, 30 | num_actors_per_mixture: int, 31 | workdir: str = '~/acme', 32 | device_prefetch: bool = False, 33 | log_to_bigtable: bool = True, 34 | log_every: float = 10.0, 35 | # TODO: Refactor: `max_episodes` and `max_steps`` sould be defined on the experiment level, 36 | # not on the agent level, similarly to other experiment related abstractions 37 | max_episodes: Optional[int] = None, 38 | max_steps: Optional[int] = None, 39 | evaluator_factories: Optional[Sequence[ 40 | distributed_layout.EvaluatorFactory]] = None, 41 | actor_observers=(), 42 | evaluator_observers=(), 43 | learner_logger_fn: Optional[Callable[[], loggers.Logger]] = None, 44 | multithreading_colocate_learner_and_reverb: bool = False 45 | ): 46 | learner_logger_fn = learner_logger_fn or functools.partial(loggers.make_default_logger, 47 | 'learner', log_to_bigtable, 48 | time_delta=log_every, asynchronous=True, 49 | serialize_fn=utils.fetch_devicearray, 50 | steps_key='learner_steps') 51 | drlearner_builder = DRLearnerBuilder( 52 | networks=network_factory(environment_spec), 53 | config=config, 54 | num_actors_per_mixture=num_actors_per_mixture, 55 | logger=learner_logger_fn) 56 | policy_network_factory = ( 57 | lambda networks: make_policy_networks(networks, config)) 58 | if evaluator_factories is None: 59 | evaluator_policy_network_factory = ( 60 | lambda networks: make_policy_networks(networks, config, evaluation=True)) 61 | evaluator_factories = [ 62 | distributed_layout.default_evaluator_factory( 63 | environment_factory=environment_factory, 64 | network_factory=network_factory, 65 | policy_factory=evaluator_policy_network_factory, 66 | log_to_bigtable=log_to_bigtable, 67 | observers=evaluator_observers 68 | ) 69 | ] 70 | super().__init__( 71 | seed=seed, 72 | environment_factory=environment_factory, 73 | network_factory=network_factory, 74 | builder=drlearner_builder, 75 | policy_network=policy_network_factory, 76 | evaluator_factories=evaluator_factories, 77 | num_actors=num_actors_per_mixture * config.num_mixtures, 78 | environment_spec=environment_spec, 79 | device_prefetch=device_prefetch, 80 | log_to_bigtable=log_to_bigtable, 81 | max_episodes=max_episodes, 82 | max_steps=max_steps, 83 | actor_logger_fn=distributed_layout.get_default_logger_fn( 84 | log_to_bigtable, log_every), 85 | prefetch_size=config.prefetch_size, 86 | checkpointing_config=distributed_layout.CheckpointingConfig( 87 | directory=workdir, add_uid=(workdir == '~/acme')), 88 | observers=actor_observers, 89 | multithreading_colocate_learner_and_reverb=multithreading_colocate_learner_and_reverb 90 | ) 91 | -------------------------------------------------------------------------------- /drlearner/drlearner/drlearner_types.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Optional 2 | 3 | import chex 4 | import jax.numpy as jnp 5 | import optax 6 | from acme.agents.jax import actor_core as actor_core_lib 7 | from acme.jax import networks as networks_lib 8 | from rlax._src.exploration import IntrinsicRewardState 9 | 10 | from .lifelong_curiosity import LifelongCuriosityModulationState 11 | 12 | 13 | @chex.dataclass(frozen=True, mappable_dataclass=False) 14 | class DRLearnerNetworksParams: 15 | """Collection of all parameters of Neural Networks used by DRLearner Agent""" 16 | intrinsic_uvfa_params: networks_lib.Params # intrinsic Universal Value-Function Approximator 17 | extrinsic_uvfa_params: networks_lib.Params # extrinsic Universal Value-Function Approximator 18 | intrinsic_uvfa_target_params: networks_lib.Params # intrinsic UVFA target network 19 | extrinsic_uvfa_target_params: networks_lib.Params # extrinsic UVFA target network 20 | idm_params: networks_lib.Params # Inverse Dynamics Model 21 | distillation_params: networks_lib.Params # Distillation Network 22 | distillation_random_params: networks_lib.Params # Random Distillation Network 23 | 24 | 25 | @chex.dataclass(frozen=True, mappable_dataclass=False) 26 | class DRLearnerNetworksOptStates: 27 | """Collection of optimizer states for all networks trained by the Learner""" 28 | intrinsic_uvfa_opt_state: optax.OptState 29 | extrinsic_uvfa_opt_state: optax.OptState 30 | idm_opt_state: optax.OptState 31 | distillation_opt_state: optax.OptState 32 | 33 | 34 | class TrainingState(NamedTuple): 35 | """DRLearner Learner training state""" 36 | params: DRLearnerNetworksParams 37 | opt_state: DRLearnerNetworksOptStates 38 | steps: jnp.ndarray 39 | random_key: networks_lib.PRNGKey 40 | 41 | 42 | @chex.dataclass(frozen=True, mappable_dataclass=False) 43 | class MetaControllerState: 44 | episode_returns_history: jnp.ndarray 45 | episode_count: jnp.ndarray 46 | current_episode_return: jnp.ndarray 47 | mixture_idx_history: jnp.ndarray 48 | beta: jnp.ndarray 49 | gamma: jnp.ndarray 50 | is_eval: bool 51 | num_eval_episodes: jnp.ndarray 52 | 53 | 54 | @chex.dataclass(frozen=True, mappable_dataclass=False) 55 | class DRLearnerActorState: 56 | rng: networks_lib.PRNGKey 57 | epsilon: jnp.ndarray 58 | mixture_idx: jnp.ndarray 59 | intrinsic_recurrent_state: actor_core_lib.RecurrentState 60 | extrinsic_recurrent_state: actor_core_lib.RecurrentState 61 | prev_intrinsic_reward: jnp.ndarray 62 | prev_action_prob: jnp.ndarray 63 | prev_alpha: jnp.ndarray 64 | meta_controller_state: MetaControllerState 65 | lifelong_modulation_state: Optional[LifelongCuriosityModulationState] = None 66 | intrinsic_reward_state: Optional[IntrinsicRewardState] = None 67 | 68 | -------------------------------------------------------------------------------- /drlearner/drlearner/lifelong_curiosity.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import chex 4 | import jax.numpy as jnp 5 | 6 | 7 | @chex.dataclass 8 | class LifelongCuriosityModulationState: 9 | distance_mean: chex.Scalar = 0. 10 | distance_var: chex.Scalar = 1. 11 | 12 | 13 | def lifelong_curiosity_modulation( 14 | learnt_embeddings: chex.Array, 15 | random_embeddings: chex.Array, 16 | max_modulation: float = 5.0, 17 | lifelong_modulation_state: Optional[LifelongCuriosityModulationState] = None, 18 | ma_coef: float = 0.0001): 19 | if not lifelong_modulation_state: 20 | lifelong_modulation_state = LifelongCuriosityModulationState() 21 | 22 | error = jnp.sum((learnt_embeddings - random_embeddings) ** 2, axis=-1) 23 | 24 | distance_mean = lifelong_modulation_state.distance_mean 25 | distance_var = lifelong_modulation_state.distance_var 26 | # exponentially weighted moving average and std 27 | distance_var = (1 - ma_coef) * (distance_var + ma_coef * jnp.mean(error - distance_mean) ** 2) 28 | distance_mean = ma_coef * jnp.mean(error) + (1 - ma_coef) * distance_mean 29 | 30 | alpha = 1 + (error - distance_mean) / jnp.sqrt(distance_var) 31 | alpha = jnp.clip(alpha, 1., max_modulation) 32 | 33 | lifelong_modulation_state = LifelongCuriosityModulationState( 34 | distance_mean=distance_mean, 35 | distance_var=distance_var 36 | ) 37 | 38 | return alpha, lifelong_modulation_state 39 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import DRLearnerNetworks 2 | from .policy_networks import DRLearnerPolicyNetworks 3 | from .policy_networks import make_policy_networks 4 | from .uvfa_network import UVFANetworkInput 5 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/distillation_network.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import haiku as hk 4 | import jax 5 | from acme.jax import networks as networks_lib 6 | from acme.jax import utils 7 | from acme.wrappers.observation_action_reward import OAR 8 | 9 | 10 | @dataclasses.dataclass 11 | class DistillationNetwork: 12 | """Pure functions for DRLearner distillation network""" 13 | embed: networks_lib.FeedForwardNetwork 14 | embed_sequence: networks_lib.FeedForwardNetwork 15 | 16 | 17 | def make_distillation_net( 18 | make_distillation_modules, 19 | env_spec) -> DistillationNetwork: 20 | def embed_fn(observation: OAR) -> networks_lib.NetworkOutput: 21 | """ 22 | Embed batch of observations 23 | Args: 24 | observation: jnp.array representing a batch of observations [B, ...] 25 | 26 | Returns: 27 | embedding vectors [B, D] 28 | """ 29 | embedding_torso = make_distillation_modules() 30 | return embedding_torso(observation.observation) 31 | 32 | # transform functions 33 | embed_hk = hk.transform(embed_fn) 34 | 35 | # create dummy batches for networks initialization 36 | observation_batch = utils.add_batch_dim( 37 | utils.zeros_like(env_spec.observations) 38 | ) # [B=1, ...] 39 | 40 | def embed_init(rng): 41 | return embed_hk.init(rng, observation_batch) 42 | 43 | embed = networks_lib.FeedForwardNetwork( 44 | init=embed_init, apply=embed_hk.apply 45 | ) 46 | embed_sequence = networks_lib.FeedForwardNetwork( 47 | init=embed_init, 48 | # vmap over 1-st parameter: apply(params, random_key, data, ...) 49 | apply=jax.vmap(embed_hk.apply, in_axes=(None, None, 0), out_axes=0) 50 | ) 51 | 52 | return DistillationNetwork(embed, embed_sequence) 53 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/embedding_network.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import haiku as hk 4 | import jax.numpy as jnp 5 | from acme.jax import networks as networks_lib 6 | from acme.jax import utils 7 | from acme.wrappers.observation_action_reward import OAR 8 | 9 | 10 | @dataclasses.dataclass 11 | class EmbeddingNetwork: 12 | """Pure functions for DRLearner embedding network""" 13 | predict_action: networks_lib.FeedForwardNetwork 14 | embed: networks_lib.FeedForwardNetwork 15 | 16 | 17 | def make_embedding_net( 18 | make_embedding_modules, 19 | env_spec) -> EmbeddingNetwork: 20 | def embed_fn(observation: OAR) -> networks_lib.NetworkOutput: 21 | """ 22 | Embed batch of observations 23 | Args: 24 | observation: jnp.array representing a batch of observations [B, ...] 25 | 26 | Returns: 27 | embedding vectors [B, D] 28 | """ 29 | embedding_torso, _ = make_embedding_modules() 30 | return embedding_torso(observation.observation) 31 | 32 | def predict_action_fn(observation_tm1: OAR, observation_t: OAR) -> networks_lib.NetworkOutput: 33 | """ 34 | Embed batch of sequences two consecutive observations x_{t_1} and x{t} and predict batch of actions a_{t-1} 35 | Args: 36 | observation_tm1: observation x_{t-1} [T, B, ...] 37 | observation_t: observation x_{t} [T, B, ...] 38 | 39 | Returns: 40 | prediction logits for discrete action a_{t_1} [T, B, N] 41 | """ 42 | embedding_torso, pred_head = make_embedding_modules() 43 | emb_tm1 = hk.BatchApply(embedding_torso)(observation_tm1.observation) 44 | emb_t = hk.BatchApply(embedding_torso)(observation_t.observation) 45 | return hk.BatchApply(pred_head)(jnp.concatenate([emb_tm1, emb_t], axis=-1)) 46 | 47 | # transform functions 48 | embed_hk = hk.transform(embed_fn) 49 | predict_action_hk = hk.transform(predict_action_fn) 50 | 51 | # create dummy batches for networks initialization 52 | observation_sequences = utils.add_batch_dim( 53 | utils.add_batch_dim( 54 | utils.zeros_like(env_spec.observations) 55 | ) 56 | ) # [T=1, B=1, ...] 57 | 58 | def predict_action_init(rng): 59 | return predict_action_hk.init(rng, observation_sequences, observation_sequences) 60 | 61 | # create FeedForwardNetworks corresponding to embed and action prediction functions 62 | predict_action = networks_lib.FeedForwardNetwork( 63 | init=predict_action_init, apply=predict_action_hk.apply 64 | ) 65 | embed = networks_lib.FeedForwardNetwork( 66 | init=embed_hk.init, apply=embed_hk.apply 67 | ) 68 | 69 | return EmbeddingNetwork(predict_action, embed) 70 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/networks.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from .distillation_network import DistillationNetwork 4 | from .embedding_network import EmbeddingNetwork 5 | from .uvfa_network import UVFANetwork 6 | 7 | 8 | @dataclasses.dataclass 9 | class DRLearnerNetworks: 10 | """Wrapper for all DRLearner learnable networks""" 11 | uvfa_net: UVFANetwork 12 | embedding_net: EmbeddingNetwork 13 | distillation_net: DistillationNetwork 14 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/networks_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .atari import make_atari_nets 2 | from .discomaze import make_discomaze_nets 3 | from .lunar_lander import make_lunar_lander_nets 4 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/networks_zoo/atari.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.nn 3 | from acme import specs 4 | from acme.jax import networks as networks_lib 5 | 6 | from ..distillation_network import make_distillation_net 7 | from ..embedding_network import make_embedding_net 8 | from ..networks import DRLearnerNetworks 9 | from ..uvfa_network import make_uvfa_net 10 | from ..uvfa_torso import UVFATorso 11 | from ...config import DRLearnerConfig 12 | 13 | 14 | def make_atari_nets(config: DRLearnerConfig, env_spec: specs.EnvironmentSpec) -> DRLearnerNetworks: 15 | uvfa_net = make_atari_uvfa_net(env_spec, num_mixtures=config.num_mixtures, batch_size=config.batch_size) 16 | embedding_network = make_atari_embedding_net(env_spec, config.observation_embed_dim) 17 | distillation_network = make_atari_distillation_net(env_spec, config.distillation_embed_dim) 18 | return DRLearnerNetworks(uvfa_net, embedding_network, distillation_network) 19 | 20 | 21 | def make_atari_embedding_net(env_spec, embedding_dim): 22 | def make_atari_embedding_modules(): 23 | embedding_torso = hk.Sequential([ 24 | networks_lib.AtariTorso(), 25 | hk.Linear(output_size=embedding_dim), 26 | jax.nn.relu, 27 | ], name='atari_embedding_torso') 28 | pred_head = hk.Sequential([ 29 | hk.Linear(128), 30 | jax.nn.relu, 31 | hk.Linear(env_spec.actions.num_values) 32 | ], name='atari_action_pred_head') 33 | return embedding_torso, pred_head 34 | 35 | return make_embedding_net( 36 | make_embedding_modules=make_atari_embedding_modules, 37 | env_spec=env_spec 38 | ) 39 | 40 | 41 | def make_atari_distillation_net(env_spec, embedding_dim): 42 | def make_atari_distillation_modules(): 43 | embedding_torso = hk.Sequential([ 44 | networks_lib.AtariTorso(), 45 | hk.Linear(output_size=embedding_dim), 46 | ], name='atari_distillation_torso') 47 | return embedding_torso 48 | 49 | return make_distillation_net( 50 | make_distillation_modules=make_atari_distillation_modules, 51 | env_spec=env_spec 52 | ) 53 | 54 | 55 | def make_atari_uvfa_net(env_spec, num_mixtures: int, batch_size: int): 56 | def make_atari_uvfa_modules(): 57 | embedding_torso = make_uvfa_atari_torso(env_spec.actions.num_values, num_mixtures) 58 | recurrent_core = hk.LSTM(512) 59 | head = networks_lib.DuellingMLP( 60 | num_actions=env_spec.actions.num_values, 61 | hidden_sizes=[512] 62 | ) 63 | return embedding_torso, recurrent_core, head 64 | 65 | return make_uvfa_net( 66 | make_uvfa_modules=make_atari_uvfa_modules, 67 | batch_size=batch_size, 68 | env_spec=env_spec 69 | ) 70 | 71 | 72 | def make_uvfa_atari_torso(num_actions: int, num_mixtures: int): 73 | observation_embedding_torso = hk.Sequential([ 74 | networks_lib.AtariTorso(), 75 | hk.Linear(512), 76 | jax.nn.relu 77 | ]) 78 | return UVFATorso( 79 | observation_embedding_torso, 80 | num_actions, num_mixtures, 81 | name='atari_uvfa_torso' 82 | ) 83 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/networks_zoo/discomaze.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.nn 3 | from acme import specs 4 | from acme.jax import networks as networks_lib 5 | 6 | from ..distillation_network import make_distillation_net 7 | from ..embedding_network import make_embedding_net 8 | from ..networks import DRLearnerNetworks 9 | from ..uvfa_network import make_uvfa_net 10 | from ..uvfa_torso import UVFATorso 11 | from ...config import DRLearnerConfig 12 | 13 | 14 | def make_discomaze_nets(config: DRLearnerConfig, env_spec: specs.EnvironmentSpec) -> DRLearnerNetworks: 15 | uvfa_net = make_discomaze_uvfa_net(env_spec, num_mixtures=config.num_mixtures, batch_size=config.batch_size) 16 | embedding_network = make_discomaze_embedding_net(env_spec, config.observation_embed_dim) 17 | distillation_network = make_discomaze_distillation_net(env_spec, config.distillation_embed_dim) 18 | return DRLearnerNetworks(uvfa_net, embedding_network, distillation_network) 19 | 20 | 21 | def make_discomaze_embedding_net(env_spec, embedding_dim): 22 | def make_discomaze_embedding_modules(): 23 | embedding_torso = hk.Sequential([ 24 | hk.Conv2D(16, kernel_shape=3, stride=1), 25 | jax.nn.relu, 26 | hk.Conv2D(32, kernel_shape=3, stride=1), 27 | jax.nn.relu, 28 | hk.Flatten(preserve_dims=-3), 29 | hk.Linear(embedding_dim), 30 | jax.nn.relu 31 | ]) 32 | pred_head = hk.Sequential([ 33 | hk.Linear(32), 34 | jax.nn.relu, 35 | hk.Linear(env_spec.actions.num_values) 36 | ], name='action_pred_head') 37 | return embedding_torso, pred_head 38 | 39 | return make_embedding_net( 40 | make_embedding_modules=make_discomaze_embedding_modules, 41 | env_spec=env_spec 42 | ) 43 | 44 | 45 | def make_discomaze_distillation_net(env_spec, embedding_dim): 46 | def make_discomaze_distillation_modules(): 47 | embedding_torso = hk.Sequential([ 48 | hk.Conv2D(16, kernel_shape=3, stride=1), 49 | jax.nn.relu, 50 | hk.Conv2D(32, kernel_shape=3, stride=1), 51 | jax.nn.relu, 52 | hk.Flatten(preserve_dims=-3), 53 | hk.Linear(embedding_dim), 54 | jax.nn.relu 55 | ]) 56 | return embedding_torso 57 | 58 | return make_distillation_net( 59 | make_distillation_modules=make_discomaze_distillation_modules, 60 | env_spec=env_spec 61 | ) 62 | 63 | 64 | def make_discomaze_uvfa_net(env_spec, num_mixtures: int, batch_size: int): 65 | def make_discomaze_uvfa_modules(): 66 | embedding_torso = make_uvfa_discomaze_torso(env_spec.actions.num_values, num_mixtures) 67 | recurrent_core = hk.LSTM(256) 68 | head = networks_lib.DuellingMLP( 69 | num_actions=env_spec.actions.num_values, 70 | hidden_sizes=[256] 71 | ) 72 | return embedding_torso, recurrent_core, head 73 | 74 | return make_uvfa_net( 75 | make_uvfa_modules=make_discomaze_uvfa_modules, 76 | batch_size=batch_size, 77 | env_spec=env_spec 78 | ) 79 | 80 | 81 | def make_uvfa_discomaze_torso(num_actions: int, num_mixtures: int): 82 | observation_embedding_torso = hk.Sequential([ 83 | hk.Conv2D(16, kernel_shape=3, stride=1), 84 | jax.nn.relu, 85 | hk.Conv2D(32, kernel_shape=3, stride=1), 86 | jax.nn.relu, 87 | hk.Flatten(preserve_dims=-3), 88 | hk.Linear(256), 89 | jax.nn.relu 90 | ]) 91 | return UVFATorso( 92 | observation_embedding_torso, 93 | num_actions, num_mixtures, 94 | name='discomaze_uvfa_torso' 95 | ) 96 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/networks_zoo/lunar_lander.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | from acme import specs 3 | from acme.jax import networks as networks_lib 4 | 5 | from ..distillation_network import make_distillation_net 6 | from ..embedding_network import make_embedding_net 7 | from ..networks import DRLearnerNetworks 8 | from ..uvfa_network import make_uvfa_net 9 | from ..uvfa_torso import UVFATorso 10 | from ...config import DRLearnerConfig 11 | 12 | 13 | def make_lunar_lander_nets(config: DRLearnerConfig, env_spec: specs.EnvironmentSpec) -> DRLearnerNetworks: 14 | uvfa_net = make_lunar_lander_uvfa_net(env_spec, num_mixtures=config.num_mixtures, batch_size=config.batch_size) 15 | embedding_network = make_lunar_lander_embedding_net(env_spec, config.observation_embed_dim) 16 | distillation_network = make_lunar_lander_distillation_net(env_spec, config.distillation_embed_dim) 17 | return DRLearnerNetworks(uvfa_net, embedding_network, distillation_network) 18 | 19 | 20 | def make_lunar_lander_embedding_net(env_spec, embedding_dim): 21 | def make_mlp_embedding_modules(): 22 | embedding_torso = hk.nets.MLP([16, 32, embedding_dim], name='mlp_embedding_torso') 23 | pred_head = hk.Linear(env_spec.actions.num_values, name='action_pred_head') 24 | return embedding_torso, pred_head 25 | 26 | return make_embedding_net( 27 | make_embedding_modules=make_mlp_embedding_modules, 28 | env_spec=env_spec 29 | ) 30 | 31 | 32 | def make_lunar_lander_distillation_net(env_spec, embedding_dim): 33 | def make_mlp_distillation_modules(): 34 | embedding_torso = hk.nets.MLP([16, 32, embedding_dim], name='mlp_embedding_torso') 35 | return embedding_torso 36 | 37 | return make_distillation_net( 38 | make_distillation_modules=make_mlp_distillation_modules, 39 | env_spec=env_spec 40 | ) 41 | 42 | 43 | def make_lunar_lander_uvfa_net(env_spec, num_mixtures: int, batch_size: int): 44 | def make_mlp_uvfa_modules(): 45 | embedding_torso = make_uvfa_lunar_lander_torso(env_spec.actions.num_values, num_mixtures) 46 | recurrent_core = hk.LSTM(32) 47 | head = networks_lib.DuellingMLP( 48 | num_actions=env_spec.actions.num_values, 49 | hidden_sizes=[32] 50 | ) 51 | return embedding_torso, recurrent_core, head 52 | 53 | return make_uvfa_net( 54 | make_uvfa_modules=make_mlp_uvfa_modules, 55 | batch_size=batch_size, 56 | env_spec=env_spec 57 | ) 58 | 59 | 60 | def make_uvfa_lunar_lander_torso(num_actions: int, num_mixtures: int): 61 | observation_embedding_torso = hk.nets.MLP([16, 32, 16]) 62 | return UVFATorso( 63 | observation_embedding_torso, 64 | num_actions, num_mixtures, 65 | name='mlp_uvfa_torso' 66 | ) 67 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/policy_networks.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Callable 3 | 4 | import jax.nn 5 | import jax.numpy as jnp 6 | import jax.random 7 | import rlax 8 | from acme import types 9 | from acme.jax import networks as networks_lib 10 | 11 | from .networks import DRLearnerNetworks 12 | from ..config import DRLearnerConfig 13 | 14 | 15 | @dataclasses.dataclass 16 | class DRLearnerPolicyNetworks: 17 | """Pure functions used by DRLearner actors""" 18 | select_action: Callable 19 | embed_observation: Callable 20 | distillation_embed_observation: Callable 21 | 22 | 23 | def make_policy_networks( 24 | networks: DRLearnerNetworks, 25 | config: DRLearnerConfig, 26 | evaluation: bool = False): 27 | def select_action(intrinsic_params: networks_lib.Params, 28 | extrinsic_params: networks_lib.Params, 29 | key: networks_lib.PRNGKey, 30 | observation: types.NestedArray, 31 | intrinsic_core_state: types.NestedArray, 32 | extrinsic_core_state: types.NestedArray, 33 | epsilon, beta): 34 | intrinsic_key_qnet, extrinsic_key_qnet, key_sample = jax.random.split(key, 3) 35 | intrinsic_q_values, intrinsic_core_state = networks.uvfa_net.forward.apply( 36 | intrinsic_params, intrinsic_key_qnet, observation, intrinsic_core_state) 37 | extrinsic_q_values, extrinsic_core_state = networks.uvfa_net.forward.apply( 38 | extrinsic_params, extrinsic_key_qnet, observation, extrinsic_core_state) 39 | 40 | q_values = config.tx_pair.apply(beta * config.tx_pair.apply_inv(intrinsic_q_values) + 41 | config.tx_pair.apply_inv(extrinsic_q_values)) 42 | epsilon = config.evaluation_epsilon if evaluation else epsilon 43 | action_dist = rlax.epsilon_greedy(epsilon) 44 | action = action_dist.sample(key_sample, q_values) 45 | action_prob = action_dist.probs( 46 | jax.nn.one_hot(jnp.argmax(q_values, axis=-1), num_classes=q_values.shape[-1]) 47 | ) 48 | action_prob = jnp.squeeze(action_prob[:, action], axis=-1) 49 | return action, action_prob, intrinsic_core_state, extrinsic_core_state 50 | 51 | return DRLearnerPolicyNetworks( 52 | select_action=select_action, 53 | embed_observation=networks.embedding_net.embed.apply, 54 | distillation_embed_observation=networks.distillation_net.embed.apply 55 | ) 56 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/uvfa_network.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Tuple, NamedTuple 3 | 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | from acme.jax import networks as networks_lib 7 | from acme.jax import utils 8 | from acme.wrappers.observation_action_reward import OAR 9 | 10 | 11 | @dataclasses.dataclass 12 | class UVFANetwork: 13 | """Pure functions for DRLearner Universal Value Function Approximator""" 14 | initial_state: networks_lib.FeedForwardNetwork 15 | forward: networks_lib.FeedForwardNetwork 16 | unroll: networks_lib.FeedForwardNetwork 17 | 18 | 19 | class UVFANetworkInput(NamedTuple): 20 | """Wrap input specific to DRLearner Recurrent Q-network""" 21 | oar: OAR # observation_t, action_tm1, reward_tm1 22 | intrinsic_reward: jnp.ndarray # ri_tm1 23 | mixture_idx: jnp.ndarray # beta_idx_tm1 24 | 25 | 26 | def make_uvfa_net( 27 | make_uvfa_modules, 28 | batch_size: int, 29 | env_spec) -> UVFANetwork: 30 | def initial_state(batch_size: int): 31 | _, recurrent_core, _ = make_uvfa_modules() 32 | return recurrent_core.initial_state(batch_size) 33 | 34 | def forward(input: UVFANetworkInput, 35 | state: hk.LSTMState) -> Tuple[networks_lib.NetworkOutput, hk.LSTMState]: 36 | """ 37 | Estimate action values for batch of inputs 38 | Args: 39 | input: batch of observations, actions, rewards, intrinsic rewards 40 | and mixture indices (beta param labels) 41 | state: recurrent state 42 | Returns: 43 | q_values: predicted action values 44 | new_state: new recurrent state after prediction 45 | """ 46 | embedding_torso, recurrent_core, head = make_uvfa_modules() 47 | 48 | embeddings = embedding_torso(input) 49 | embeddings, new_state = recurrent_core(embeddings, state) 50 | q_values = head(embeddings) 51 | return q_values, new_state 52 | 53 | def unroll(input: UVFANetworkInput, 54 | state: hk.LSTMState) -> Tuple[networks_lib.NetworkOutput, hk.LSTMState]: 55 | """ 56 | Estimate action values for batch of input sequences 57 | Args: 58 | input: batch of observations, actions, rewards, intrinsic rewards 59 | and mixture indices (beta param labels) sequences 60 | state: recurrent state 61 | Returns: 62 | q_values: predicted action values 63 | new_state: new recurrent state after prediction 64 | """ 65 | embedding_torso, recurrent_core, head = make_uvfa_modules() 66 | 67 | embeddings = hk.BatchApply(embedding_torso)(input) 68 | embeddings, new_states = hk.static_unroll(recurrent_core, embeddings, state) 69 | q_values = hk.BatchApply(head)(embeddings) 70 | return q_values, new_states 71 | 72 | # transform functions 73 | initial_state_hk = hk.transform(initial_state) 74 | forward_hk = hk.transform(forward) 75 | unroll_hk = hk.transform(unroll) 76 | 77 | # create dummy batches for networks initialization 78 | observation = utils.zeros_like(env_spec.observations) 79 | intrinsic_reward = utils.zeros_like(env_spec.rewards) 80 | mixture_idxs = utils.zeros_like(env_spec.rewards, dtype=jnp.int32) 81 | uvfa_input_sequences = utils.add_batch_dim( 82 | utils.tile_nested( 83 | UVFANetworkInput(observation, intrinsic_reward, mixture_idxs), 84 | batch_size 85 | ) 86 | ) 87 | 88 | def initial_state_init(rng, batch_size: int): 89 | return initial_state_hk.init(rng, batch_size) 90 | 91 | def unroll_init(rng, initial_state): 92 | return unroll_hk.init(rng, uvfa_input_sequences, initial_state) 93 | 94 | # create FeedForwardNetworks corresponding to UVFA pure functions 95 | initial_state = networks_lib.FeedForwardNetwork( 96 | init=initial_state_init, apply=initial_state_hk.apply 97 | ) 98 | forward = networks_lib.FeedForwardNetwork( 99 | init=forward_hk.init, apply=forward_hk.apply 100 | ) 101 | unroll = networks_lib.FeedForwardNetwork( 102 | init=unroll_init, apply=unroll_hk.apply 103 | ) 104 | 105 | return UVFANetwork(initial_state, forward, unroll) 106 | -------------------------------------------------------------------------------- /drlearner/drlearner/networks/uvfa_torso.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | class UVFATorso(hk.Module): 7 | def __init__(self, 8 | observation_embedding_torso: hk.Module, 9 | num_actions: int, 10 | num_mixtures: int, 11 | name: str): 12 | super().__init__(name=name) 13 | self._embed = observation_embedding_torso 14 | 15 | self._num_actions = num_actions 16 | self._num_mixtures = num_mixtures 17 | 18 | def __call__(self, input): 19 | oar_t, intrinsic_reward_tm1, mixture_idx_tm1 = input.oar, input.intrinsic_reward, input.mixture_idx, 20 | observation_t, action_tm1, reward_tm1 = oar_t.observation, oar_t.action, oar_t.reward 21 | 22 | features_t = self._embed(observation_t) # [T?, B, D] 23 | action_tm1 = jax.nn.one_hot( 24 | action_tm1, 25 | num_classes=self._num_actions 26 | ) # [T?, B, A] 27 | mixture_idx_tm1 = jax.nn.one_hot( 28 | mixture_idx_tm1, 29 | num_classes=self._num_mixtures 30 | ) # [T?, B, M] 31 | 32 | reward_tm1 = jnp.tanh(reward_tm1) 33 | intrinsic_reward_tm1 = jnp.tanh(intrinsic_reward_tm1) 34 | # Add dummy trailing dimensions to the rewards if necessary. 35 | while reward_tm1.ndim < action_tm1.ndim: 36 | reward_tm1 = jnp.expand_dims(reward_tm1, axis=-1) 37 | 38 | while intrinsic_reward_tm1.ndim < action_tm1.ndim: 39 | intrinsic_reward_tm1 = jnp.expand_dims(intrinsic_reward_tm1, axis=-1) 40 | 41 | embedding = jnp.concatenate( 42 | [features_t, action_tm1, reward_tm1, intrinsic_reward_tm1, mixture_idx_tm1], 43 | axis=-1 44 | ) # [T?, B, D+A+M+2] 45 | return embedding 46 | -------------------------------------------------------------------------------- /drlearner/drlearner/utils.py: -------------------------------------------------------------------------------- 1 | import jax.nn 2 | import jax.numpy as jnp 3 | 4 | 5 | def epsilon_greedy_prob(q_values, epsilon): 6 | """Get probability of actions under epsilon-greedy policy provided estimated q_values""" 7 | num_actions = q_values.shape[0] 8 | max_action = jnp.argmax(q_values) 9 | probs = jnp.full_like(q_values, fill_value=epsilon / num_actions) 10 | probs = probs.at[max_action].set(1 - epsilon * (num_actions - 1) / num_actions) 11 | return probs 12 | 13 | 14 | def get_beta(mixture_idx: jnp.ndarray, beta_min: float, beta_max: float, num_mixtures: int): 15 | beta = jnp.linspace(beta_min, beta_max, num_mixtures)[mixture_idx] 16 | return beta 17 | 18 | 19 | def get_gamma(mixture_idx: jnp.ndarray, gamma_min: float, gamma_max: float, num_mixtures: int): 20 | gamma = jnp.linspace(gamma_min, gamma_max, num_mixtures)[mixture_idx] 21 | return gamma 22 | 23 | 24 | def get_epsilon(actor_id: int, epsilon_base: float, num_actors: int, alpha: float = 8.0): 25 | """Get epsilon parameter for given actor""" 26 | epsilon = epsilon_base ** (1 + alpha * actor_id / ((num_actors - 1) + 0.0001)) 27 | return epsilon 28 | 29 | 30 | def get_beta_ngu(mixture_idx: jnp.ndarray, beta_min: float, beta_max: float, num_mixtures: int): 31 | """Get beta parameter for given number of mixtures and mixture_idx""" 32 | beta = jnp.where( 33 | mixture_idx == num_mixtures - 1, 34 | beta_max, 35 | beta_min + beta_max * jax.nn.sigmoid(10 * (2 * mixture_idx - (num_mixtures - 2)) / (num_mixtures - 2)) 36 | ) 37 | return beta 38 | 39 | 40 | def get_gamma_ngu(mixture_idx: jnp.ndarray, gamma_min: float, gamma_max: float, num_mixtures: int): 41 | """Get gamma parameters for given number of mixtures in descending order""" 42 | gamma = 1 - jnp.exp( 43 | ((num_mixtures - 1 - mixture_idx) * jnp.log(1 - gamma_max) + 44 | mixture_idx * jnp.log(1 - gamma_min)) / (num_mixtures - 1)) 45 | return gamma 46 | 47 | -------------------------------------------------------------------------------- /drlearner/environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/drlearner/environments/__init__.py -------------------------------------------------------------------------------- /drlearner/environments/atari.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Shared helpers for different experiment flavours.""" 17 | 18 | import functools 19 | 20 | import dm_env 21 | import gym 22 | from acme import wrappers 23 | 24 | 25 | def make_environment(level: str = 'PongNoFrameskip-v4', 26 | oar_wrapper: bool = False) -> dm_env.Environment: 27 | """Loads the Atari environment.""" 28 | env = gym.make(level, full_action_space=True) 29 | 30 | # Always use episodes of 108k steps as this is standard, matching the paper. 31 | max_episode_len = 108_000 32 | wrapper_list = [ 33 | wrappers.GymAtariAdapter, 34 | functools.partial( 35 | wrappers.AtariWrapper, 36 | action_repeats=4, 37 | pooled_frames=4, 38 | zero_discount_on_life_loss=False, 39 | expose_lives_observation=False, 40 | num_stacked_frames=1, 41 | max_episode_len=max_episode_len, 42 | to_float=True, 43 | grayscaling=True 44 | ), 45 | ] 46 | if oar_wrapper: 47 | # E.g. IMPALA and R2D2 use this particular variant. 48 | wrapper_list.append(wrappers.ObservationActionRewardWrapper) 49 | wrapper_list.append(wrappers.SinglePrecisionWrapper) 50 | 51 | return wrappers.wrap_all(env, wrapper_list) 52 | -------------------------------------------------------------------------------- /drlearner/environments/disco_maze.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union, Optional 2 | 3 | import dm_env 4 | import gym_discomaze 5 | import numpy as np 6 | from acme import wrappers 7 | from acme.wrappers import base 8 | from dm_env import specs 9 | 10 | 11 | class DiscoMazeWrapper(base.EnvironmentWrapper): 12 | def __init__(self, environment: dm_env.Environment, *, to_float: bool = False, 13 | max_episode_len: Optional[int] = None): 14 | """ 15 | The wrapper performs the following actions: 16 | 1. Converts observations to float (if applied) 17 | 2. Truncates episodes to maximum number of steps (if applied). 18 | 3. Remove action that allows no movement. 19 | """ 20 | super(DiscoMazeWrapper, self).__init__(environment) 21 | self._to_float = to_float 22 | 23 | if not max_episode_len: 24 | max_episode_len = np.inf 25 | self._episode_len = 0 26 | self._max_episode_len = max_episode_len 27 | 28 | self._observation_spec = self._init_observation_spec() 29 | self._action_spec = self._init_action_spec() 30 | 31 | def _init_observation_spec(self): 32 | observation_spec = self.environment.observation_spec() 33 | if self._to_float: 34 | observation_shape = observation_spec.shape 35 | dtype = 'float64' 36 | observation_spec = observation_spec.replace( 37 | dtype=dtype, 38 | minimum=(observation_spec.minimum.astype(dtype) / 255.), 39 | maximum=(observation_spec.maximum.astype(dtype) / 255.) 40 | ) 41 | return observation_spec 42 | 43 | def _init_action_spec(self): 44 | action_spec = self.environment.action_spec() 45 | 46 | action_spec = action_spec.replace(num_values=action_spec.num_values - 1) 47 | return action_spec 48 | 49 | def step(self, action) -> dm_env.TimeStep: 50 | action = action + 1 51 | timestep = self.environment.step(action) 52 | 53 | if self._to_float: 54 | observation = timestep.observation.astype(float) / 255. 55 | timestep = timestep._replace(observation=observation) 56 | 57 | self._episode_len += 1 58 | if self._episode_len == self._max_episode_len: 59 | timestep = timestep._replace(step_type=dm_env.StepType.LAST) 60 | 61 | return timestep 62 | 63 | def reset(self) -> dm_env.TimeStep: 64 | timestep = self.environment.reset() 65 | 66 | if self._to_float: 67 | observation = timestep.observation.astype(float) / 255. 68 | timestep = timestep._replace(observation=observation) 69 | 70 | self._episode_len = 0 71 | return timestep 72 | 73 | def observation_spec(self) -> Union[specs.Array, Sequence[specs.Array]]: 74 | return self._observation_spec 75 | 76 | def action_spec(self) -> Union[specs.Array, Sequence[specs.Array]]: 77 | return self._action_spec 78 | 79 | 80 | def make_discomaze_environment(seed: int) -> dm_env.Environment: 81 | """Create 21x21 disco maze environment with 5 random colors and no target""" 82 | env = gym_discomaze.RandomDiscoMaze(n_row=10, n_col=10, n_colors=5, n_targets=0, generator=seed) 83 | env = wrappers.GymWrapper(env) 84 | env = DiscoMazeWrapper(env, to_float=True, max_episode_len=5000) 85 | env = wrappers.SinglePrecisionWrapper(env) 86 | env = wrappers.ObservationActionRewardWrapper(env) 87 | return env 88 | 89 | 90 | if __name__ == '__main__': 91 | env = make_discomaze_environment(0) 92 | print(env.action_spec().replace(num_values=4)) 93 | -------------------------------------------------------------------------------- /drlearner/environments/lunar_lander.py: -------------------------------------------------------------------------------- 1 | import dm_env 2 | import gym 3 | from acme import wrappers 4 | 5 | 6 | def make_ll_environment(seed: int) -> dm_env.Environment: 7 | env_name = "LunarLander-v2" 8 | 9 | env = gym.make(env_name) 10 | env = wrappers.GymWrapper(env) 11 | env = wrappers.SinglePrecisionWrapper(env) 12 | env = wrappers.ObservationActionRewardWrapper(env) 13 | 14 | return env 15 | -------------------------------------------------------------------------------- /drlearner/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/drlearner/utils/__init__.py -------------------------------------------------------------------------------- /drlearner/utils/stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from statistics import mean, median 4 | 5 | 6 | def read_config(path): 7 | with open(path, 'r') as file: 8 | config = json.load(file) 9 | 10 | return config 11 | 12 | 13 | class StatsCheckpointer: 14 | def __init__(self): 15 | self.values = defaultdict(list) 16 | self.statistics = defaultdict(dict) 17 | self._config = read_config('./configs/config.json') 18 | 19 | self._target_statistics = { 20 | 'min': min, 21 | 'max': max, 22 | 'mean': mean, 23 | 'median': median, 24 | } 25 | self._target_metrics = ['episode_return'] 26 | 27 | def update(self, result, log=False): 28 | for metric in self._target_metrics: 29 | value = float(result[metric]) 30 | self.values[metric].append(value) 31 | 32 | self.evaluate() 33 | self.save() 34 | 35 | if log: 36 | self.log() 37 | 38 | def evaluate(self): 39 | for metric in self._target_metrics: 40 | values = self.values[metric] 41 | 42 | for statistic, function in self._target_statistics.items(): 43 | self.statistics[metric][statistic] = round(function(values), 5) 44 | 45 | def save(self): 46 | path = self._config['statistics_path'] 47 | 48 | with open(path, 'w') as file: 49 | json.dump(self.statistics, file) 50 | 51 | def log(self): 52 | print('=' * 30) 53 | for metric in self._target_metrics: 54 | print(f"{metric.replace('_', ' ').upper()}") 55 | 56 | for statistic, value in self.statistics[metric].items(): 57 | print(f'{statistic}: {value}') 58 | 59 | print('=' * 30) 60 | 61 | def __repr__(self): 62 | return str(self.statistics) 63 | -------------------------------------------------------------------------------- /drlearner/utils/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import time 4 | import os 5 | from dataclasses import asdict 6 | import jax 7 | 8 | from acme import specs 9 | from acme import core 10 | from acme.utils import counting 11 | from acme.utils import observers as observers_lib 12 | from acme.jax import networks as networks_lib 13 | from acme.jax import utils 14 | from acme.utils import loggers 15 | from acme.utils.loggers.tf_summary import TFSummaryLogger 16 | from acme.utils.loggers import base 17 | from drlearner.drlearner.config import DRLearnerConfig 18 | import wandb 19 | 20 | from ..core import distributed_layout 21 | from ..core.environment_loop import EnvironmentLoop 22 | 23 | from typing import Optional, Callable, Any, Mapping, Sequence, TextIO, Union 24 | 25 | 26 | def from_dict_to_dataclass(cls, data): 27 | return cls( 28 | **{ 29 | key: (data[key] if val.default == 30 | val.empty else data.get(key, val.default)) 31 | for key, val in inspect.signature(cls).parameters.items() 32 | } 33 | ) 34 | 35 | 36 | def _format_key(key: str) -> str: 37 | """Internal function for formatting keys in Tensorboard format.""" 38 | return key.title().replace("_", "") 39 | 40 | 41 | class WandbLogger(base.Logger): 42 | """Logs to wandb instance. 43 | 44 | If multiple WandbLoggers are created with the same credentials, results will be 45 | categorized by labels. 46 | """ 47 | 48 | def __init__(self, logdir: str, label: str = "Logs", hyperparams: Optional[Union[None, DRLearnerConfig]] = None, exp_name: Optional[Union[None, str]] = None): 49 | """Initializes the logger. 50 | 51 | Args: 52 | logdir: name of the wandb project 53 | label: label string to use when logging. Default to 'Logs'. 54 | hyperparams: hyperparams config to be saved by wandb 55 | """ 56 | self._time = time.time() 57 | self.label = label 58 | self._iter = 0 59 | 60 | if wandb.run is None: 61 | wandb.init(project=logdir, group="DDP", name=exp_name) 62 | 63 | if hyperparams is not None: 64 | self.save_hyperparams(hyperparams=hyperparams) 65 | 66 | def save_hyperparams(self, hyperparams: Optional[Union[None, DRLearnerConfig]]): 67 | wandb.config = asdict(hyperparams) 68 | 69 | def write(self, values: base.LoggingData): 70 | for key in values.keys(): 71 | wandb.log({f"{self.label}/{_format_key(key)}": values[key]}) 72 | 73 | def close(self): 74 | wandb.finish() 75 | 76 | 77 | class CloudCSVLogger: 78 | def __init__( 79 | self, 80 | directory_or_file: Union[str, TextIO] = "~/acme", 81 | label: str = "", 82 | time_delta: float = 0.0, 83 | add_uid: bool = True, 84 | flush_every: int = 30, 85 | ): 86 | """Instantiates the logger. 87 | 88 | Args: 89 | directory_or_file: Either a directory path as a string, or a file TextIO 90 | object. 91 | label: Extra label to add to logger. This is added as a suffix to the 92 | directory. 93 | time_delta: Interval in seconds between which writes are dropped to 94 | throttle throughput. 95 | add_uid: Whether to add a UID to the file path. See `paths.process_path` 96 | for details. 97 | flush_every: Interval (in writes) between flushes. 98 | """ 99 | 100 | if flush_every <= 0: 101 | raise ValueError( 102 | f"`flush_every` must be a positive integer (got {flush_every})." 103 | ) 104 | 105 | self._last_log_time = time.time() - time_delta 106 | self._time_delta = time_delta 107 | self._flush_every = flush_every 108 | self._add_uid = add_uid 109 | self._writes = 0 110 | self.file_path = os.path.join(directory_or_file, f"{label}_logs.csv") 111 | self._keys = [] 112 | logging.info("Logging to %s", self.file_path) 113 | 114 | def write(self, data: base.LoggingData): 115 | """Writes a `data` into a row of comma-separated values.""" 116 | # Only log if `time_delta` seconds have passed since last logging event. 117 | now = time.time() 118 | 119 | # TODO(b/192227744): Remove this in favour of filters.TimeFilter. 120 | elapsed = now - self._last_log_time 121 | if elapsed < self._time_delta: 122 | logging.debug( 123 | "Not due to log for another %.2f seconds, dropping data.", 124 | self._time_delta - elapsed, 125 | ) 126 | return 127 | self._last_log_time = now 128 | 129 | # Append row to CSV. 130 | data = base.to_numpy(data) 131 | if self._writes == 0: 132 | self._keys = data.keys() 133 | with open(self.file_path, "w") as f: 134 | f.write(",".join(self._keys)) 135 | f.write("\n") 136 | f.write( 137 | ",".join(list(map(str, [data[k] for k in self._keys])))) 138 | f.write("\n") 139 | else: 140 | with open(self.file_path, "a") as f: 141 | f.write( 142 | ",".join(list(map(str, [data[k] for k in self._keys])))) 143 | f.write("\n") 144 | self._writes += 1 145 | 146 | 147 | def make_tf_logger( 148 | workdir: str = "~/acme/", 149 | label: str = "learner", 150 | save_data: bool = True, 151 | time_delta: float = 0.0, 152 | asynchronous: bool = False, 153 | print_fn: Optional[Callable[[str], None]] = print, 154 | serialize_fn: Optional[Callable[[Mapping[str, Any]], 155 | str]] = loggers.base.to_numpy, 156 | steps_key: str = "steps", 157 | ) -> loggers.base.Logger: 158 | del steps_key 159 | if not print_fn: 160 | print_fn = logging.info 161 | 162 | terminal_logger = loggers.terminal.TerminalLogger( 163 | label=label, print_fn=print_fn) 164 | 165 | all_loggers = [terminal_logger] 166 | 167 | if save_data: 168 | if "/gcs/" in workdir: 169 | all_loggers.append( 170 | CloudCSVLogger( 171 | directory_or_file=workdir, label=label, time_delta=time_delta 172 | ) 173 | ) 174 | else: 175 | all_loggers.append( 176 | loggers.csv.CSVLogger( 177 | directory_or_file=workdir, label=label, time_delta=time_delta 178 | ) 179 | ) 180 | 181 | tb_workdir = workdir 182 | if "/gcs/" in tb_workdir: 183 | tb_workdir = tb_workdir.replace("/gcs/", "gs://") 184 | all_loggers.append(TFSummaryLogger(logdir=tb_workdir, label=label)) 185 | 186 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn) 187 | logger = loggers.filters.NoneFilter(logger) 188 | 189 | logger = loggers.filters.TimeFilter(logger, time_delta) 190 | return logger 191 | 192 | 193 | def make_wandb_logger( 194 | workdir: str = "~/acme/", 195 | label: str = "learner", 196 | save_data: bool = True, 197 | time_delta: float = 0.0, 198 | asynchronous: bool = False, 199 | print_fn: Optional[Callable[[str], None]] = print, 200 | serialize_fn: Optional[Callable[[Mapping[str, Any]], 201 | str]] = loggers.base.to_numpy, 202 | steps_key: str = "steps", 203 | hyperparams: Optional[Union[None, DRLearnerConfig]] = None, 204 | exp_name: str = None 205 | ) -> loggers.base.Logger: 206 | del steps_key 207 | if not print_fn: 208 | print_fn = logging.info 209 | 210 | terminal_logger = loggers.terminal.TerminalLogger( 211 | label=label, print_fn=print_fn) 212 | 213 | all_loggers = [terminal_logger] 214 | 215 | if save_data: 216 | if "/gcs/" in workdir: 217 | all_loggers.append( 218 | CloudCSVLogger( 219 | directory_or_file=workdir, label=label, time_delta=time_delta 220 | ) 221 | ) 222 | else: 223 | all_loggers.append( 224 | loggers.csv.CSVLogger( 225 | directory_or_file=workdir, label=label, time_delta=time_delta 226 | ) 227 | ) 228 | 229 | tb_workdir = workdir 230 | if "/gcs/" in tb_workdir: 231 | tb_workdir = tb_workdir.replace("/gcs/", "gs://") 232 | all_loggers.append(WandbLogger(logdir=tb_workdir.split( 233 | "/")[-2], label=label, hyperparams=hyperparams, exp_name=exp_name)) 234 | 235 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn) 236 | logger = loggers.filters.NoneFilter(logger) 237 | 238 | logger = loggers.filters.TimeFilter(logger, time_delta) 239 | return logger 240 | 241 | 242 | def all_loggers( 243 | workdir: str = "~/acme/", 244 | label: str = "learner", 245 | save_data: bool = True, 246 | time_delta: float = 0.0, 247 | asynchronous: bool = False, 248 | print_fn: Optional[Callable[[str], None]] = print, 249 | serialize_fn: Optional[Callable[[Mapping[str, Any]], 250 | str]] = loggers.base.to_numpy, 251 | steps_key: str = "steps", 252 | hyperparams: Optional[Union[None, DRLearnerConfig]] = None 253 | ) -> loggers.base.Logger: 254 | del steps_key 255 | if not print_fn: 256 | print_fn = logging.info 257 | 258 | terminal_logger = loggers.terminal.TerminalLogger( 259 | label=label, print_fn=print_fn) 260 | 261 | all_loggers = [terminal_logger] 262 | 263 | if save_data: 264 | if "/gcs/" in workdir: 265 | all_loggers.append( 266 | CloudCSVLogger( 267 | directory_or_file=workdir, label=label, time_delta=time_delta 268 | ) 269 | ) 270 | else: 271 | all_loggers.append( 272 | loggers.csv.CSVLogger( 273 | directory_or_file=workdir, label=label, time_delta=time_delta 274 | ) 275 | ) 276 | 277 | tb_workdir = workdir 278 | if "/gcs/" in tb_workdir: 279 | tb_workdir = tb_workdir.replace("/gcs/", "gs://") 280 | all_loggers.append(WandbLogger(logdir=tb_workdir.split( 281 | "/")[-2], label=label, hyperparams=hyperparams)) 282 | all_loggers.append(TFSummaryLogger(logdir=tb_workdir, label=label)) 283 | 284 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn) 285 | logger = loggers.filters.NoneFilter(logger) 286 | 287 | logger = loggers.filters.TimeFilter(logger, time_delta) 288 | return logger 289 | 290 | 291 | def evaluator_factory_logger_choice( 292 | environment_factory: distributed_layout.EnvironmentFactory, 293 | network_factory: distributed_layout.NetworkFactory, 294 | policy_factory: distributed_layout.PolicyFactory, 295 | logger_fn: Callable, 296 | observers: Sequence[observers_lib.EnvLoopObserver] = (), 297 | actor_id: int = 0, 298 | ) -> distributed_layout.EvaluatorFactory: 299 | """Returns an evaluator process with customizable log function.""" 300 | 301 | def evaluator( 302 | random_key: networks_lib.PRNGKey, 303 | variable_source: core.VariableSource, 304 | counter: counting.Counter, 305 | make_actor: distributed_layout.MakeActorFn, 306 | ): 307 | """The evaluation process.""" 308 | 309 | # Create environment and evaluator networks 310 | environment_key, actor_key = jax.random.split(random_key) 311 | environment = environment_factory(utils.sample_uint32(environment_key)) 312 | networks = network_factory(specs.make_environment_spec(environment)) 313 | 314 | actor = make_actor( 315 | random_key, policy_factory(networks), variable_source=variable_source 316 | ) # ToDo: fix actor id for R2D2 317 | 318 | # Create logger and counter. 319 | counter = counting.Counter(counter, "evaluator") 320 | 321 | logger = logger_fn() 322 | 323 | # Create the run loop and return it. 324 | return EnvironmentLoop(environment, actor, counter, logger, observers=observers) 325 | 326 | return evaluator 327 | -------------------------------------------------------------------------------- /examples/distrun_atari.py: -------------------------------------------------------------------------------- 1 | """Example running Distributed Layout DRLearner, on Atari.""" 2 | 3 | import functools 4 | import logging 5 | import os 6 | 7 | import acme 8 | import launchpad as lp 9 | from absl import app 10 | from absl import flags 11 | from acme import specs 12 | from acme.jax import utils 13 | 14 | from drlearner.drlearner import DistributedDRLearnerFromConfig, networks_zoo 15 | from drlearner.configs.config_atari import AtariDRLearnerConfig 16 | from drlearner.configs.resources import get_atari_vertex_resources, get_local_resources 17 | from drlearner.core.observers import IntrinsicRewardObserver, MetaControllerObserver, DistillationCoefObserver, StorageVideoObserver 18 | from drlearner.environments.atari import make_environment 19 | from drlearner.drlearner.networks import make_policy_networks 20 | from drlearner.utils.utils import evaluator_factory_logger_choice, make_wandb_logger 21 | 22 | flags.DEFINE_string('level', 'ALE/MontezumaRevenge-v5', 'Which game to play.') 23 | flags.DEFINE_integer('num_steps', 100000, 'Number of steps to train for.') 24 | flags.DEFINE_integer('num_episodes', 1000, 25 | 'Number of episodes to train for.') 26 | flags.DEFINE_string('exp_path', 'experiments/default', 27 | 'Experiment data storage.') 28 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.') 29 | 30 | flags.DEFINE_integer('seed', 0, 'Random seed.') 31 | flags.DEFINE_integer('num_actors_per_mixture', 2, 32 | 'Number of parallel actors per mixture.') 33 | flags.DEFINE_bool('run_on_vertex', False, 34 | 'Whether to run training in multiple processes or on Vertex AI.') 35 | flags.DEFINE_bool('colocate_learner_and_reverb', True, 36 | 'Flag indicating whether to colocate learner and reverb.') 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | def make_program(): 42 | config = AtariDRLearnerConfig 43 | print(config) 44 | 45 | config_dir = os.path.join( 46 | 'experiments/', FLAGS.exp_path.strip('/').split('/')[-1]) 47 | if not os.path.exists(config_dir): 48 | os.makedirs(config_dir) 49 | with open(os.path.join(config_dir, 'config.txt'), 'w') as f: 50 | f.write(str(config)) 51 | 52 | env = make_environment(FLAGS.level, oar_wrapper=True) 53 | env_spec = acme.make_environment_spec(env) 54 | 55 | def net_factory(env_spec: specs.EnvironmentSpec): 56 | return networks_zoo.make_atari_nets(config, env_spec) 57 | 58 | level = str(FLAGS.level) 59 | 60 | def env_factory(seed: int): 61 | return make_environment(level, oar_wrapper=True) 62 | 63 | observers = [ 64 | IntrinsicRewardObserver(), 65 | MetaControllerObserver(), 66 | DistillationCoefObserver(), 67 | StorageVideoObserver(config) 68 | ] 69 | 70 | evaluator_logger_fn = functools.partial(make_wandb_logger, FLAGS.exp_path, 71 | 'evaluator', save_data=True, 72 | time_delta=1, asynchronous=True, 73 | serialize_fn=utils.fetch_devicearray, 74 | print_fn=logging.info, 75 | steps_key='evaluator_steps', 76 | hyperparams=config, 77 | exp_name=FLAGS.exp_name) 78 | 79 | learner_logger_function = functools.partial(make_wandb_logger, FLAGS.exp_path, 80 | 'learner', save_data=True, 81 | time_delta=1, asynchronous=True, 82 | serialize_fn=utils.fetch_devicearray, 83 | print_fn=logging.info, 84 | steps_key='learner_steps', 85 | hyperparams=config, 86 | exp_name=FLAGS.exp_name) 87 | 88 | program = DistributedDRLearnerFromConfig( 89 | seed=FLAGS.seed, 90 | environment_factory=env_factory, 91 | network_factory=net_factory, 92 | config=config, 93 | workdir=FLAGS.exp_path, 94 | num_actors_per_mixture=FLAGS.num_actors_per_mixture, 95 | max_episodes=FLAGS.num_episodes, 96 | max_steps=FLAGS.num_steps, 97 | environment_spec=env_spec, 98 | actor_observers=observers, 99 | evaluator_observers=observers, 100 | learner_logger_fn=learner_logger_function, 101 | evaluator_factories=[ 102 | evaluator_factory_logger_choice( 103 | environment_factory=env_factory, 104 | network_factory=net_factory, 105 | policy_factory=lambda networks: make_policy_networks( 106 | networks, config, evaluation=True), 107 | logger_fn=evaluator_logger_fn, 108 | observers=observers 109 | ) 110 | ], 111 | 112 | multithreading_colocate_learner_and_reverb=FLAGS.colocate_learner_and_reverb 113 | 114 | ).build(name=FLAGS.exp_path.strip('/').split('/')[-1]) 115 | 116 | return program 117 | 118 | 119 | def main(_): 120 | program = make_program() 121 | 122 | if FLAGS.run_on_vertex: 123 | resources = get_atari_vertex_resources() 124 | lp.launch( 125 | program, 126 | launch_type=lp.LaunchType.VERTEX_AI, 127 | xm_resources=resources, 128 | terminal='current_terminal') 129 | 130 | else: 131 | resources = get_local_resources() 132 | lp.launch( 133 | program, 134 | lp.LaunchType.LOCAL_MULTI_PROCESSING, 135 | local_resources=resources, 136 | terminal='current_terminal' 137 | ) 138 | 139 | 140 | if __name__ == '__main__': 141 | app.run(main) 142 | -------------------------------------------------------------------------------- /examples/distrun_discomaze.py: -------------------------------------------------------------------------------- 1 | """Example running distributed layout DRLearner Agent, on Discomaze environment.""" 2 | import functools 3 | import logging 4 | import os 5 | 6 | import acme 7 | from acme.jax import utils 8 | import launchpad as lp 9 | from absl import app 10 | from absl import flags 11 | from acme import specs 12 | 13 | from drlearner.drlearner import DistributedDRLearnerFromConfig, networks_zoo 14 | from drlearner.configs.config_discomaze import DiscomazeDRLearnerConfig 15 | from drlearner.core.observers import UniqueStatesDiscoMazeObserver, IntrinsicRewardObserver, ActionProbObserver, DistillationCoefObserver 16 | from drlearner.environments.disco_maze import make_discomaze_environment 17 | from drlearner.drlearner.networks import make_policy_networks 18 | from drlearner.configs.resources import get_toy_env_vertex_resources, get_local_resources 19 | from drlearner.utils.utils import evaluator_factory_logger_choice, make_tf_logger 20 | 21 | flags.DEFINE_string('level', 'DiscoMaze', 'Which game to play.') 22 | flags.DEFINE_integer('num_episodes', 10000000, 23 | 'Number of episodes to train for.') 24 | flags.DEFINE_string('exp_path', 'experiments/default', 25 | 'Experiment data storage.') 26 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.') 27 | flags.DEFINE_integer('seed', 0, 'Random seed.') 28 | flags.DEFINE_integer('num_actors_per_mixture', 1, 29 | 'Number of parallel actors per mixture.') 30 | flags.DEFINE_bool('run_on_vertex', False, 31 | 'Whether to run training in multiple processes or on Vertex AI.') 32 | flags.DEFINE_bool('colocate_learner_and_reverb', False, 33 | 'Flag indicating whether to colocate learner and reverb.') 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | def make_program(): 39 | config = DiscomazeDRLearnerConfig 40 | print(config) 41 | 42 | config_dir = os.path.join( 43 | 'experiments/', FLAGS.exp_path.strip('/').split('/')[-1]) 44 | if not os.path.exists(config_dir): 45 | os.makedirs(config_dir) 46 | with open(os.path.join(config_dir, 'config.txt'), 'w') as f: 47 | f.write(str(config)) 48 | 49 | env = make_discomaze_environment(FLAGS.seed) 50 | env_spec = acme.make_environment_spec(env) 51 | 52 | def net_factory(env_spec: specs.EnvironmentSpec): 53 | return networks_zoo.make_discomaze_nets(config, env_spec) 54 | 55 | observers = [ 56 | UniqueStatesDiscoMazeObserver(), 57 | IntrinsicRewardObserver(), 58 | DistillationCoefObserver(), 59 | ActionProbObserver(num_actions=env_spec.actions.num_values), 60 | ] 61 | 62 | evaluator_logger_fn = functools.partial(make_tf_logger, FLAGS.exp_path, 63 | 'evaluator', save_data=True, 64 | time_delta=1, asynchronous=True, 65 | serialize_fn=utils.fetch_devicearray, 66 | print_fn=logging.info, 67 | steps_key='evaluator_steps') 68 | 69 | learner_logger_function = functools.partial(make_tf_logger, FLAGS.exp_path, 70 | 'learner', save_data=False, 71 | time_delta=1, asynchronous=True, 72 | serialize_fn=utils.fetch_devicearray, 73 | print_fn=logging.info, 74 | steps_key='learner_steps') 75 | 76 | program = DistributedDRLearnerFromConfig( 77 | seed=FLAGS.seed, 78 | environment_factory=make_discomaze_environment, 79 | network_factory=net_factory, 80 | config=config, 81 | num_actors_per_mixture=FLAGS.num_actors_per_mixture, 82 | environment_spec=env_spec, 83 | actor_observers=observers, 84 | learner_logger_fn=learner_logger_function, 85 | evaluator_observers=observers, 86 | evaluator_factories=[ 87 | evaluator_factory_logger_choice( 88 | environment_factory=make_discomaze_environment, 89 | network_factory=net_factory, 90 | policy_factory=lambda networks: make_policy_networks( 91 | networks, config, evaluation=True), 92 | logger_fn=evaluator_logger_fn, 93 | observers=observers 94 | ) 95 | ], 96 | multithreading_colocate_learner_and_reverb=FLAGS.colocate_learner_and_reverb 97 | ).build(name=FLAGS.exp_path.strip('/').split('/')[-1]) 98 | 99 | return program 100 | 101 | 102 | def main(_): 103 | program = make_program() 104 | 105 | if FLAGS.run_on_vertex: 106 | resources = get_toy_env_vertex_resources() 107 | lp.launch( 108 | program, 109 | launch_type=lp.LaunchType.VERTEX_AI, 110 | xm_resources=resources) 111 | else: 112 | resources = get_local_resources() 113 | lp.launch( 114 | program, 115 | lp.LaunchType.LOCAL_MULTI_PROCESSING, 116 | local_resources=resources 117 | ) 118 | 119 | 120 | if __name__ == '__main__': 121 | app.run(main) 122 | -------------------------------------------------------------------------------- /examples/distrun_lunar_lander.py: -------------------------------------------------------------------------------- 1 | """Example running distributed DRLearner agent, on Lunar Lander.""" 2 | 3 | import functools 4 | import logging 5 | import os 6 | 7 | import acme 8 | from acme.jax import utils 9 | import launchpad as lp 10 | from absl import app 11 | from absl import flags 12 | from acme import specs 13 | 14 | from drlearner.drlearner import DistributedDRLearnerFromConfig, networks_zoo 15 | from drlearner.core.observers import IntrinsicRewardObserver, MetaControllerObserver, DistillationCoefObserver 16 | from drlearner.configs.config_lunar_lander import LunarLanderDRLearnerConfig 17 | from drlearner.configs.resources import get_toy_env_vertex_resources, get_local_resources 18 | from drlearner.environments.lunar_lander import make_ll_environment 19 | from drlearner.drlearner.networks import make_policy_networks 20 | from drlearner.utils.utils import evaluator_factory_logger_choice, make_wandb_logger 21 | 22 | flags.DEFINE_string('level', 'LunarLander-v2', 'Which game to play.') 23 | flags.DEFINE_integer('num_episodes', 200, 'Number of episodes to train for.') 24 | flags.DEFINE_integer('num_steps', 10000, 'Number of steps to train for.') 25 | flags.DEFINE_string('exp_path', 'experiments/default', 26 | 'Experiment data storage.') 27 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.') 28 | 29 | flags.DEFINE_integer('num_actors_per_mixture', 1, 30 | 'Number of parallel actors per mixture.') 31 | 32 | flags.DEFINE_integer('seed', 42, 'Random seed.') 33 | flags.DEFINE_bool('run_on_vertex', False, 34 | 'Whether to run training in multiple processes or on Vertex AI.') 35 | flags.DEFINE_bool('colocate_learner_and_reverb', False, 36 | 'Flag indicating whether to colocate learner and reverb.') 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | def make_program(): 42 | config = LunarLanderDRLearnerConfig 43 | 44 | config_dir = os.path.join( 45 | 'experiments/', FLAGS.exp_path.strip('/').split('/')[-1]) 46 | if not os.path.exists(config_dir): 47 | os.makedirs(config_dir) 48 | with open(os.path.join(config_dir, 'config.txt'), 'w') as f: 49 | f.write(str(config)) 50 | 51 | print(config) 52 | 53 | env = make_ll_environment(0) 54 | env_spec = acme.make_environment_spec(env) 55 | 56 | def net_factory(env_spec: specs.EnvironmentSpec): 57 | return networks_zoo.make_lunar_lander_nets(config, env_spec) 58 | 59 | def env_factory(seed: int): 60 | return make_ll_environment(seed) 61 | 62 | observers = [ 63 | IntrinsicRewardObserver(), 64 | MetaControllerObserver(), 65 | DistillationCoefObserver(), 66 | ] 67 | 68 | evaluator_logger_fn = functools.partial(make_wandb_logger, FLAGS.exp_path, 69 | 'evaluator', save_data=True, 70 | time_delta=1, asynchronous=True, 71 | serialize_fn=utils.fetch_devicearray, 72 | print_fn=logging.info, 73 | steps_key='evaluator_steps', 74 | hyperparams=config, 75 | exp_name=FLAGS.exp_name) 76 | 77 | learner_logger_function = functools.partial(make_wandb_logger, FLAGS.exp_path, 78 | 'learner', save_data=True, 79 | time_delta=1, asynchronous=True, 80 | serialize_fn=utils.fetch_devicearray, 81 | print_fn=logging.info, 82 | steps_key='learner_steps', 83 | hyperparams=config, 84 | exp_name=FLAGS.exp_name) 85 | 86 | program = DistributedDRLearnerFromConfig( 87 | seed=FLAGS.seed, 88 | environment_factory=env_factory, 89 | network_factory=net_factory, 90 | config=config, 91 | workdir=FLAGS.exp_path, 92 | num_actors_per_mixture=FLAGS.num_actors_per_mixture, 93 | max_episodes=FLAGS.num_episodes, 94 | max_steps=FLAGS.num_steps, 95 | environment_spec=env_spec, 96 | actor_observers=observers, 97 | learner_logger_fn=learner_logger_function, 98 | evaluator_observers=observers, 99 | evaluator_factories=[ 100 | evaluator_factory_logger_choice( 101 | environment_factory=make_ll_environment, 102 | network_factory=net_factory, 103 | policy_factory=lambda networks: make_policy_networks( 104 | networks, config, evaluation=True), 105 | logger_fn=evaluator_logger_fn, 106 | observers=observers 107 | ) 108 | ], 109 | multithreading_colocate_learner_and_reverb=FLAGS.colocate_learner_and_reverb 110 | 111 | ).build(name=FLAGS.exp_path.strip('/').split('/')[-1]) 112 | 113 | return program 114 | 115 | 116 | def main(_): 117 | program = make_program() 118 | 119 | if FLAGS.run_on_vertex: 120 | resources = get_toy_env_vertex_resources() 121 | lp.launch( 122 | program, 123 | launch_type=lp.LaunchType.VERTEX_AI, 124 | xm_resources=resources, 125 | terminal='tmux_session') 126 | else: 127 | resources = get_local_resources() 128 | lp.launch( 129 | program, 130 | lp.LaunchType.LOCAL_MULTI_PROCESSING, 131 | local_resources=resources, 132 | terminal='tmux_session' 133 | ) 134 | 135 | 136 | if __name__ == '__main__': 137 | app.run(main) 138 | -------------------------------------------------------------------------------- /examples/play_atari.py: -------------------------------------------------------------------------------- 1 | import acme 2 | 3 | from absl import flags 4 | from absl import app 5 | 6 | from drlearner.drlearner import DRLearner, networks_zoo 7 | from drlearner.core.environment_loop import EnvironmentLoop 8 | from drlearner.environments.atari import make_environment 9 | from drlearner.configs.config_atari import AtariDRLearnerConfig 10 | from drlearner.utils.utils import make_wandb_logger 11 | from drlearner.core.observers import StorageVideoObserver 12 | 13 | flags.DEFINE_string('level', 'ALE/MontezumaRevenge-v5', 'Which game to play.') 14 | flags.DEFINE_integer('seed', 11, 'Random seed.') 15 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to train for.') 16 | flags.DEFINE_string('exp_path', 'experiments/play1', 'Run name.') 17 | flags.DEFINE_string('exp_name', 'atari play', 'Run name.') 18 | flags.DEFINE_string( 19 | 'checkpoint_path', 'experiments/mon_24cores1', 'Path to checkpoints/ dir') 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | # TODo: add possibility to freeze mixture index for final evaluation 24 | 25 | 26 | def load_and_evaluate(_): 27 | config = AtariDRLearnerConfig 28 | config.batch_size = 1 29 | config.num_mixtures = 32 30 | config.beta_max = 0. # if num_mixtures == 1 beta == beta_max 31 | config.n_arms = 32 32 | config.logs_dir = FLAGS.exp_path 33 | config.video_log_period = 1 34 | config.env_library = 'gym' 35 | config.actor_epsilon = 0.01 36 | config.epsilon = 0.01 37 | config.mc_epsilon = 0.01 38 | 39 | env = make_environment(FLAGS.level, oar_wrapper=True) 40 | env_spec = acme.make_environment_spec(env) 41 | 42 | agent = DRLearner( 43 | env_spec, 44 | networks=networks_zoo.make_atari_nets(config, env_spec), 45 | config=config, 46 | seed=FLAGS.seed, 47 | workdir=FLAGS.checkpoint_path 48 | ) 49 | 50 | observers = [StorageVideoObserver(config)] 51 | logger = make_wandb_logger( 52 | FLAGS.exp_path, label='evaluator', hyperparams=config, exp_name=FLAGS.exp_name) 53 | 54 | loop = EnvironmentLoop(env, agent, logger=logger, 55 | observers=observers, should_update=False) 56 | loop.run(FLAGS.num_episodes) 57 | 58 | 59 | if __name__ == '__main__': 60 | app.run(load_and_evaluate) 61 | -------------------------------------------------------------------------------- /examples/run_atari.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example running Local Layout DRLearner agent, on Atari-like environments. 3 | 4 | This module contains the main function to run the Atari environment using a Deep Reinforcement Learning (DRL) agent. 5 | 6 | Imports: 7 | os: Provides a way of using operating system dependent functionality. 8 | flags: Command line flag module. 9 | AtariDRLearnerConfig: Configuration for the DRL agent. 10 | make_environment: Function to create an Atari environment. 11 | acme: DeepMind's library of reinforcement learning components. 12 | networks_zoo: Contains the network architectures for the DRL agent. 13 | DRLearner: The DRL agent. 14 | make_wandb_logger: Function to create a Weights & Biases logger. 15 | EnvironmentLoop: Acme's main loop for running environments. 16 | 17 | Functions: 18 | main(_): 19 | The main function to run the Atari environment. 20 | 21 | It sets up the environment, the DRL agent, and the logger, and then runs the environment loop for a specified number of episodes. 22 | """ 23 | import os 24 | 25 | import acme 26 | from absl import app 27 | from absl import flags 28 | 29 | from drlearner.drlearner import networks_zoo, DRLearner 30 | from drlearner.configs.config_atari import AtariDRLearnerConfig 31 | from drlearner.core.environment_loop import EnvironmentLoop 32 | from drlearner.environments.atari import make_environment 33 | from drlearner.core.observers import IntrinsicRewardObserver, DistillationCoefObserver 34 | from drlearner.utils.utils import make_wandb_logger 35 | 36 | # Command line flags 37 | flags.DEFINE_string('level', 'PongNoFrameskip-v4', 'Which game to play.') 38 | flags.DEFINE_integer('num_episodes', 7, 'Number of episodes to train for.') 39 | flags.DEFINE_string('exp_path', 'experiments/default', 40 | 'Experiment data storage.') 41 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.') 42 | flags.DEFINE_integer('seed', 0, 'Random seed.') 43 | 44 | flags.DEFINE_bool('force_sync_run', False, 'Skip deadlock warning.') 45 | 46 | FLAGS = flags.FLAGS 47 | 48 | 49 | def main(_): 50 | # Configuration for the DRL agent hyperparameters 51 | config = AtariDRLearnerConfig 52 | # To avoid the deadlock when running reverb in the synchronous set-up, 53 | # this setting ensures rate limiter won't be called. 54 | # @see https://github.com/google-deepmind/acme/issues/207 for additional information. 55 | if config.samples_per_insert != 0: 56 | if not FLAGS.force_sync_run: 57 | while True: 58 | user_answer = input("\nThe simulation may deadlock if run in the synchronous set-up with samples_per_rate != 0. " 59 | "Do you want to continue? (yes/no): ") 60 | 61 | if user_answer.lower() in ["yes", "y"]: 62 | print("Proceeding...") 63 | break 64 | elif user_answer.lower() in ["no", "n"]: 65 | print("Exiting...") 66 | return 67 | else: 68 | print("Invalid input. Please enter yes/no.") 69 | 70 | print(config) 71 | if not os.path.exists(FLAGS.exp_path): 72 | os.makedirs(FLAGS.exp_path) 73 | with open(os.path.join(FLAGS.exp_path, 'config.txt'), 'w') as f: 74 | f.write(str(config)) 75 | 76 | # Create the Atari environment 77 | env = make_environment(FLAGS.level, oar_wrapper=True) 78 | # Create the environment specification 79 | env_spec = acme.make_environment_spec(env) 80 | 81 | # Create the networks for the DRL agent learning algorithm 82 | networks = networks_zoo.make_atari_nets(config, env_spec) 83 | 84 | # Create a Weights & Biases loggers for the environment and the actor 85 | logger_env = make_wandb_logger( 86 | FLAGS.exp_path, label='enviroment', hyperparams=config, exp_name=FLAGS.exp_name) 87 | logger_actor = make_wandb_logger( 88 | FLAGS.exp_path, label='actor', hyperparams=config, exp_name=FLAGS.exp_name) 89 | 90 | # Create the DRL agent 91 | agent = DRLearner( 92 | spec=env_spec, 93 | networks=networks, 94 | config=config, 95 | seed=FLAGS.seed, 96 | workdir=FLAGS.exp_path, 97 | logger=logger_actor 98 | ) 99 | # Create the observers for the DRL agent 100 | # TODO: Add StorageVideoObserver 101 | observers = [IntrinsicRewardObserver(), DistillationCoefObserver()] 102 | 103 | # Create the environment loop 104 | loop = EnvironmentLoop( 105 | environment=env, 106 | actor=agent, 107 | logger=logger_env, 108 | observers=observers 109 | ) 110 | # Run the environment loop for a specified number of episodes 111 | loop.run(FLAGS.num_episodes) 112 | 113 | 114 | if __name__ == '__main__': 115 | app.run(main) 116 | -------------------------------------------------------------------------------- /examples/run_discomaze.py: -------------------------------------------------------------------------------- 1 | """Example running Local Layout DRLearner on DiscoMaze environment.""" 2 | 3 | import os 4 | 5 | import acme 6 | from absl import app 7 | from absl import flags 8 | 9 | from drlearner.drlearner import networks_zoo, DRLearner 10 | from drlearner.configs.config_discomaze import DiscomazeDRLearnerConfig 11 | from drlearner.core.environment_loop import EnvironmentLoop 12 | from drlearner.core.observers import UniqueStatesDiscoMazeObserver, IntrinsicRewardObserver, ActionProbObserver 13 | from drlearner.environments.disco_maze import make_discomaze_environment 14 | from drlearner.utils.utils import make_tf_logger 15 | 16 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to train for.') 17 | flags.DEFINE_string('exp_path', 'experiments/default', 18 | 'Experiment data storage.') 19 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.') 20 | flags.DEFINE_integer('seed', 0, 'Random seed.') 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | 25 | def main(_): 26 | config = DiscomazeDRLearnerConfig 27 | print(config) 28 | if not os.path.exists(FLAGS.exp_path): 29 | os.makedirs(FLAGS.exp_path) 30 | with open(os.path.join(FLAGS.exp_path, 'config.txt'), 'w') as f: 31 | f.write(str(config)) 32 | 33 | env = make_discomaze_environment(FLAGS.seed) 34 | env_spec = acme.make_environment_spec(env) 35 | 36 | networks = networks_zoo.make_discomaze_nets(config, env_spec) 37 | 38 | agent = DRLearner( 39 | env_spec, 40 | networks=networks, 41 | config=config, 42 | seed=FLAGS.seed) 43 | 44 | logger = make_tf_logger(FLAGS.exp_path) 45 | 46 | observers = [ 47 | UniqueStatesDiscoMazeObserver(), 48 | IntrinsicRewardObserver(), 49 | ActionProbObserver(num_actions=env_spec.actions.num_values), 50 | ] 51 | loop = EnvironmentLoop(env, agent, logger=logger, observers=observers) 52 | loop.run(FLAGS.num_episodes) 53 | 54 | 55 | if __name__ == '__main__': 56 | app.run(main) 57 | -------------------------------------------------------------------------------- /examples/run_lunar_lander.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example running Local Layout DRLearner on Lunar Lander environment. 3 | 4 | This module contains the main function to run the Lunar Lander environment using a Deep Reinforcement Learning (DRL) agent. 5 | 6 | Imports: 7 | os: Provides a way of using operating system dependent functionality. 8 | flags: Command line flag module. 9 | LunarLanderDRLearnerConfig: Configuration for the DRL agent. 10 | make_ll_environment: Function to create a Lunar Lander environment. 11 | acme: DeepMind's library of reinforcement learning components. 12 | networks_zoo: Contains the network architectures for the DRL agent. 13 | DRLearner: The DRL agent. 14 | IntrinsicRewardObserver, DistillationCoefObserver: Observers for the DRL agent. 15 | make_wandb_logger: Function to create a Weights & Biases logger. 16 | EnvironmentLoop: Acme's main loop for running environments. 17 | 18 | Functions: 19 | main(_): 20 | The main function to run the Lunar Lander environment. 21 | 22 | It sets up the environment, the DRL agent, the observers, and the logger, and then runs the environment loop for a specified number of episodes. 23 | """ 24 | import os 25 | 26 | import acme 27 | from absl import app 28 | from absl import flags 29 | 30 | from drlearner.drlearner import networks_zoo, DRLearner 31 | from drlearner.configs.config_lunar_lander import LunarLanderDRLearnerConfig 32 | from drlearner.core.environment_loop import EnvironmentLoop 33 | from drlearner.environments.lunar_lander import make_ll_environment 34 | from drlearner.core.observers import IntrinsicRewardObserver, DistillationCoefObserver, StorageVideoObserver 35 | from drlearner.utils.utils import make_wandb_logger 36 | 37 | 38 | # Command line flags 39 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to train for.') 40 | flags.DEFINE_string('exp_path', 'experiments/default', 41 | 'Experiment data storage.') 42 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.') 43 | flags.DEFINE_integer('seed', 42, 'Random seed.') 44 | 45 | FLAGS = flags.FLAGS 46 | 47 | 48 | def main(_): 49 | # Configuration for the DRL agent hyperparameters 50 | config = LunarLanderDRLearnerConfig 51 | 52 | print(config) 53 | if not os.path.exists(FLAGS.exp_path): 54 | os.makedirs(FLAGS.exp_path) 55 | with open(os.path.join(FLAGS.exp_path, 'config.txt'), 'w') as f: 56 | f.write(str(config)) 57 | 58 | # Create a Weights & Biases loggers for the environment and the actor 59 | logger_env = make_wandb_logger( 60 | FLAGS.exp_path, label='enviroment', hyperparams=config, exp_name=FLAGS.exp_name) 61 | logger_actor = make_wandb_logger( 62 | FLAGS.exp_path, label='actor', hyperparams=config, exp_name=FLAGS.exp_name) 63 | 64 | # Create the Lunar Lander environment 65 | env = make_ll_environment(FLAGS.seed) 66 | # Create the environment specification 67 | env_spec = acme.make_environment_spec(env) 68 | 69 | # Create the networks for the DRL agent learning algorithm 70 | networks = networks_zoo.make_lunar_lander_nets(config, env_spec) 71 | 72 | # Create the DRL agent 73 | agent = DRLearner( 74 | spec=env_spec, 75 | networks=networks, 76 | config=config, 77 | seed=FLAGS.seed, 78 | workdir=FLAGS.exp_path, 79 | logger=logger_actor 80 | ) 81 | # Create the observers for the DRL agent 82 | observers = [IntrinsicRewardObserver(), DistillationCoefObserver(), 83 | StorageVideoObserver(config)] 84 | 85 | # Create the environment loop 86 | loop = EnvironmentLoop( 87 | environment=env, 88 | actor=agent, 89 | logger=logger_env, 90 | observers=observers 91 | ) 92 | # Run the environment loop for a specified number of episodes 93 | loop.run(FLAGS.num_episodes) 94 | 95 | 96 | if __name__ == '__main__': 97 | app.run(main) 98 | -------------------------------------------------------------------------------- /external/xm_docker.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 DeepMind Technologies Limited. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities to run PyNodes in Docker containers using XManager.""" 17 | 18 | import atexit 19 | import copy 20 | import dataclasses 21 | from distutils import dir_util 22 | import functools 23 | import os 24 | import pathlib 25 | import shutil 26 | import sys 27 | import tempfile 28 | from typing import Any, List, Optional, Sequence, Tuple 29 | 30 | import cloudpickle 31 | from launchpad.launch import serialization 32 | 33 | try: 34 | from xmanager import xm 35 | except ModuleNotFoundError: 36 | raise Exception('Launchpad requires `xmanager` for XM-based runtimes.' 37 | 'Please run `pip install xmanager`.') 38 | 39 | 40 | _DATA_FILE_NAME = 'job.pkl' 41 | _INIT_FILE_NAME = 'init.pkl' 42 | 43 | 44 | @dataclasses.dataclass 45 | class DockerConfig: 46 | """Local docker launch configuration. 47 | 48 | Attributes: 49 | code_directory: Path to directory containing any user code that may be 50 | required inside the Docker image. The user code from this directory is 51 | copied over into the Docker containers, as the user code may be needed 52 | during program execution. If needed, modify docker_instructions in 53 | xm.PythonContainer construction below if user code needs installation. 54 | docker_requirements: Path to requirements.txt specifying Python packages to 55 | install inside the Docker image. 56 | hw_requirements: Hardware requirements. 57 | python_path: Additional paths to be added to PYTHONPATH prior to executing 58 | an entry point. 59 | """ 60 | code_directory: Optional[str] = None 61 | docker_requirements: Optional[str] = None 62 | hw_requirements: Optional[xm.JobRequirements] = None 63 | python_path: Optional[List[str]] = None 64 | 65 | 66 | def initializer(python_path): 67 | sys.path = python_path + sys.path 68 | 69 | 70 | def to_docker_executables( 71 | nodes: Sequence[Any], 72 | label: str, 73 | docker_config: DockerConfig, 74 | ) -> List[Tuple[xm.PythonContainer, xm.JobRequirements]]: 75 | 76 | """Returns a list of `PythonContainer`s objects for the given `PyNode`s.""" 77 | 78 | if docker_config.code_directory is None or docker_config.docker_requirements is None: 79 | raise ValueError( 80 | 'code_directory and docker_requirements must be specified through' 81 | 'DockerConfig via local_resources when using "xm_docker" launch type.') 82 | 83 | # Generate tmp dir without '_' in the name, Vertex AI fails otherwise. 84 | tmp_dir = '_' 85 | while '_' in tmp_dir: 86 | tmp_dir = tempfile.mkdtemp() 87 | atexit.register(shutil.rmtree, tmp_dir, ignore_errors=True) 88 | 89 | command_line = f'python -m my_process_entry --data_file={_DATA_FILE_NAME}' 90 | 91 | # Add common initialization function for all nodes which sets up PYTHONPATH. 92 | if docker_config.python_path: 93 | command_line += f' --init_file={_INIT_FILE_NAME}' 94 | # Local 'path' is copied under 'tmp_dir' (no /tmp prefix) inside Docker. 95 | python_path = [ 96 | '/' + os.path.basename(tmp_dir) + os.path.abspath(path) 97 | for path in docker_config.python_path 98 | ] 99 | initializer_file_path = pathlib.Path(tmp_dir, _INIT_FILE_NAME) 100 | with open(initializer_file_path, 'wb') as f: 101 | cloudpickle.dump(functools.partial(initializer, python_path), f) 102 | 103 | data_file_path = str(pathlib.Path(tmp_dir, _DATA_FILE_NAME)) 104 | serialization.serialize_functions(data_file_path, label, 105 | [n.function for n in nodes]) 106 | 107 | file_path = pathlib.Path(__file__).absolute() 108 | 109 | # shutil.copy(pathlib.Path(file_path.parent, 'process_entry.py'), tmp_dir) 110 | dir_util.copy_tree(docker_config.code_directory, tmp_dir) 111 | shutil.copy(docker_config.docker_requirements, 112 | pathlib.Path(tmp_dir, 'requirements.txt')) 113 | 114 | workdir_path = pathlib.Path(tmp_dir).name 115 | 116 | if not os.path.exists(docker_config.docker_requirements): 117 | raise FileNotFoundError('Please specify a path to a file with Python' 118 | 'package requirements through' 119 | 'docker_config.docker_requirements.') 120 | job_requirements = docker_config.hw_requirements 121 | if not job_requirements: 122 | job_requirements = xm.JobRequirements() 123 | 124 | # Make a copy of requirements since they are being mutated below. 125 | job_requirements = copy.deepcopy(job_requirements) 126 | 127 | if job_requirements.replicas != 1: 128 | raise ValueError( 129 | 'Number of replicas is computed by the runtime. ' 130 | 'Please do not set it explicitly in the requirements.' 131 | ) 132 | 133 | job_requirements.replicas = len(nodes) 134 | # python_version = f'{sys.version_info.major}.{sys.version_info.minor}' 135 | python_version = 3.7 136 | 137 | # if label == 'workerpool2': 138 | command_lst = [ 139 | # 'python -m pip install --upgrade jax==0.3.7 jaxlib==0.3.7+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html', 140 | 'python -m pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', 141 | 'python -m pip install --upgrade flax==0.4.1', 142 | command_line 143 | ] 144 | base_image = 'gcr.io/deeplearning-platform-release/tf-gpu.2-8' 145 | return [(xm.PythonContainer( 146 | path=tmp_dir, 147 | base_image=base_image, 148 | entrypoint=xm.CommandList(command_lst), 149 | docker_instructions=[ 150 | 'ENV XLA_PYTHON_CLIENT_MEM_FRACTION=0.8', 151 | 'ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/lib:/usr/lib:/usr/local/lib:/opt/conda/lib', 152 | 'ENV PYTHONPATH=$PYTHONPATH:$(pwd)', 153 | 'RUN apt-get install -y git', 154 | f'RUN apt-get -y install libpython{python_version}', 155 | f'COPY {workdir_path}/ {workdir_path}', 156 | 157 | f'COPY {workdir_path}/requirements.txt requirements.txt', 158 | 'RUN python -m pip install xmanager', 159 | 'RUN python -m pip install --no-cache-dir -r requirements.txt ', 160 | 'RUN python -m pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable', 161 | 162 | f'RUN ale-import-roms {workdir_path}/roms/', 163 | f'WORKDIR {workdir_path}', 164 | ]), job_requirements)] 165 | 166 | # else: 167 | # base_image = f'python:{python_version}' 168 | # 169 | # 170 | # return [(xm.PythonContainer( 171 | # path=tmp_dir, 172 | # base_image=base_image, 173 | # entrypoint=xm.CommandList([command_line]), 174 | # docker_instructions=[ 175 | # 'ENV LD_LIBRARY_PATH=/lib:/usr/lib:/usr/local/lib:/usr/local/nvidia/lib64:/usr/local/cuda-11.0/targets/x86_64-linux/lib:/opt/conda/lib:/usr/local/cuda-11.0/targets/x86_64-linux/lib/stubs/', 176 | # 'RUN apt-get install -y git', 177 | # f'RUN apt-get -y install libpython{python_version}', 178 | # f'COPY {workdir_path}/requirements.txt requirements.txt', 179 | # 'RUN python -m pip install xmanager', 180 | # 'RUN python -m pip install --no-cache-dir -r requirements.txt ', 181 | # f'COPY {workdir_path}/ {workdir_path}', 182 | # f'RUN ale-import-roms {workdir_path}/roms/', 183 | # f'WORKDIR {workdir_path}', 184 | # ]), job_requirements)] 185 | 186 | -------------------------------------------------------------------------------- /my_process_entry.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2020 DeepMind Technologies Limited. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Entry of a PythonNode worker.""" 17 | 18 | 19 | import contextlib 20 | import json 21 | import os 22 | import sys 23 | 24 | from absl import app 25 | from absl import flags 26 | from absl import logging 27 | import cloudpickle 28 | from launchpad.launch import worker_manager 29 | import six 30 | 31 | import tensorflow as tf 32 | 33 | tf.config.set_visible_devices([], 'GPU') 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | flags.DEFINE_integer( 39 | 'lp_task_id', None, 'a list index deciding which ' 40 | 'worker to run. given a list of workers (obtained from the' 41 | ' data_file)') 42 | flags.DEFINE_string('data_file', '', 43 | 'Pickle file location with entry points for all nodes') 44 | flags.DEFINE_string( 45 | 'lp_job_name', '', 46 | 'The name of the job, used to access the correct pickle file resource when ' 47 | 'using the new launch API') 48 | flags.DEFINE_string( 49 | 'init_file', '', 'Pickle file location containing initialization module ' 50 | 'executed for each node prior to an entry point') 51 | flags.DEFINE_string('flags_to_populate', '{}', '') 52 | 53 | _FLAG_TYPE_MAPPING = { 54 | str: flags.DEFINE_string, 55 | six.text_type: flags.DEFINE_string, 56 | float: flags.DEFINE_float, 57 | int: flags.DEFINE_integer, 58 | bool: flags.DEFINE_boolean, 59 | list: flags.DEFINE_list, 60 | } 61 | 62 | 63 | def _populate_flags(): 64 | """Populate flags that cannot be passed directly to this script.""" 65 | FLAGS(sys.argv, known_only=True) 66 | 67 | flags_to_populate = json.loads(FLAGS.flags_to_populate) 68 | for name, value in flags_to_populate.items(): 69 | value_type = type(value) 70 | if value_type in _FLAG_TYPE_MAPPING: 71 | flag_ctr = _FLAG_TYPE_MAPPING[value_type] 72 | logging.info('Defining flag %s with default value %s', name, value) 73 | flag_ctr( 74 | name, 75 | value, 76 | 'This flag has been auto-generated.', 77 | allow_override=True) 78 | 79 | # JAX doesn't use absl flags and so we need to forward absl flags to JAX 80 | # explicitly. Here's a heuristic to detect JAX flags and forward them. 81 | for arg in sys.argv: 82 | if arg.startswith('--jax_'): 83 | try: 84 | # pytype:disable=import-error 85 | import jax 86 | # pytype:enable=import-error 87 | jax.config.parse_flags_with_absl() 88 | break 89 | except ImportError: 90 | pass 91 | 92 | 93 | def _get_task_id(): 94 | """Returns current task's id.""" 95 | if FLAGS.lp_task_id is None: 96 | # Running under Vertex AI... 97 | cluster_spec = os.environ.get('CLUSTER_SPEC', None) 98 | return json.loads(cluster_spec).get('task').get('index') 99 | 100 | return FLAGS.lp_task_id 101 | 102 | 103 | def main(_): 104 | # Allow for importing modules from the current directory. 105 | sys.path.append(os.getcwd()) 106 | data_file = FLAGS.data_file 107 | init_file = FLAGS.init_file 108 | 109 | if os.environ.get('TF_CONFIG', None): 110 | # For GCP runtime log to STDOUT so that logs are not reported as errors. 111 | logging.get_absl_handler().python_handler.stream = sys.stdout 112 | 113 | if init_file: 114 | init_function = cloudpickle.load(open(init_file, 'rb')) 115 | init_function() 116 | functions = cloudpickle.load(open(data_file, 'rb')) 117 | task_id = _get_task_id() 118 | 119 | # Worker manager is used here to handle termination signals and provide 120 | # preemption support. 121 | worker_manager.WorkerManager( 122 | register_in_thread=True) 123 | 124 | with contextlib.suppress(): # no-op context manager 125 | functions[task_id]() 126 | 127 | 128 | if __name__ == '__main__': 129 | _populate_flags() 130 | app.run(main) 131 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | numpy==1.22.4 3 | cloudpickle==2.0.0 4 | six==1.16.0 5 | libpython==0.2 6 | chex==0.1.5 7 | Cython==0.29.28 8 | flax==0.4.1 9 | optax==0.1.2 10 | rlax==0.1.4 11 | pyglet==1.5.24 12 | xmanager==0.1.5 13 | pyvirtualdisplay==3.0 14 | sk-video==1.1.10 15 | ffmpeg-python==0.2.0 16 | wandb==0.16.2 17 | tensorrt 18 | tensorflow-gpu==2.8.0 19 | tensorflow_probability==0.15.0 20 | tensorflow_datasets==4.6.0 21 | dm-reverb==0.7.2 22 | dm-launchpad==0.5.2 23 | jax==0.4.3 24 | jaxlib==0.4.3+cuda11.cudnn86 25 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 26 | dm-haiku==0.0.10 27 | dm-sonnet 28 | trfl 29 | atari-py 30 | bsuite 31 | dm-control 32 | gym==0.25.0 33 | gym[accept-rom-license, atari, Box2D] 34 | pygame==2.1.0 35 | rlds 36 | git+https://github.com/google-deepmind/acme.git@4c6351ef8ff3f4045a9a24bee6a994667d89c69c 37 | scipy==1.12.0 -------------------------------------------------------------------------------- /scripts/update_tb.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from google.cloud import storage 3 | 4 | BUCKET_NAME = os.environ['GOOGLE_CLOUD_BUCKET_NAME'] 5 | 6 | 7 | def gcs_download(prefix, save_to, fname=None): 8 | """ prefix - experiment name, i.e. test_pong/ 9 | save_to - path to save the downloaded files to 10 | """ 11 | if not os.path.isdir(os.path.join(save_to, prefix)): 12 | os.mkdir(os.path.join(save_to, prefix)) 13 | storage_client = storage.Client() 14 | bucket = storage_client.bucket(BUCKET_NAME) 15 | 16 | blobs = storage_client.list_blobs(BUCKET_NAME, prefix=prefix, delimiter='/') 17 | for b in blobs: 18 | if not fname or fname in b.name: 19 | blob = bucket.blob(b.name) 20 | print(os.path.join(save_to, b.name)) 21 | blob.download_to_filename(os.path.join(save_to, b.name)) 22 | 23 | 24 | if __name__ == '__main__': 25 | exp_path = sys.argv[1] 26 | save_to = sys.argv[2] 27 | gcs_download(exp_path, save_to, 'tfevents') 28 | --------------------------------------------------------------------------------