├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── artifacts
└── montezuma_base
│ ├── 20240425-115850
│ └── logs
│ │ └── evaluator
│ │ └── logs.csv
│ ├── 20240425-115853
│ └── logs
│ │ └── learner
│ │ └── logs.csv
│ ├── 20240506-133311
│ └── logs
│ │ └── evaluator
│ │ └── logs.csv
│ ├── 20240506-133314
│ └── logs
│ │ └── learner
│ │ └── logs.csv
│ ├── 20240506-135934
│ └── logs
│ │ └── evaluator
│ │ └── logs.csv
│ ├── 20240506-135937
│ └── logs
│ │ └── learner
│ │ └── logs.csv
│ └── config.txt
├── compose.yaml
├── docker-configurations
├── local-docker-images.md
├── python3.10
│ └── Dockerfile
└── python3.7
│ ├── Dockerfile
│ ├── compose.yaml
│ ├── post-install.sh
│ └── requirements.txt
├── docs
├── DRLearner_notes.md
├── atari_pong.md
├── aws-setup.md
├── debug_and_monitor.md
├── docker.md
├── img
│ ├── lunar_lander.png
│ ├── notebook-instance.png
│ ├── tensorboard.png
│ └── wandb.png
├── unity.md
└── vertexai.md
├── drlearner
├── __init__.py
├── configs
│ ├── config_atari.py
│ ├── config_discomaze.py
│ ├── config_lunar_lander.py
│ └── resources
│ │ ├── __init__.py
│ │ ├── atari.py
│ │ ├── local_resources.py
│ │ └── toy_env.py
├── core
│ ├── __init__.py
│ ├── distributed_layout.py
│ ├── environment_loop.py
│ ├── local_layout.py
│ ├── loggers
│ │ ├── __init__.py
│ │ └── image.py
│ └── observers
│ │ ├── __init__.py
│ │ ├── action_dist.py
│ │ ├── actions.py
│ │ ├── discomaze_unique_states.py
│ │ ├── distillation_coef.py
│ │ ├── intrinsic_reward.py
│ │ ├── lazy_dict.py
│ │ ├── meta_controller.py
│ │ └── video.py
├── drlearner
│ ├── __init__.py
│ ├── actor.py
│ ├── actor_core.py
│ ├── agent.py
│ ├── builder.py
│ ├── config.py
│ ├── distributed_agent.py
│ ├── drlearner_types.py
│ ├── learning.py
│ ├── lifelong_curiosity.py
│ ├── networks
│ │ ├── __init__.py
│ │ ├── distillation_network.py
│ │ ├── embedding_network.py
│ │ ├── networks.py
│ │ ├── networks_zoo
│ │ │ ├── __init__.py
│ │ │ ├── atari.py
│ │ │ ├── discomaze.py
│ │ │ └── lunar_lander.py
│ │ ├── policy_networks.py
│ │ ├── uvfa_network.py
│ │ └── uvfa_torso.py
│ └── utils.py
├── environments
│ ├── __init__.py
│ ├── atari.py
│ ├── disco_maze.py
│ └── lunar_lander.py
└── utils
│ ├── __init__.py
│ ├── stats.py
│ └── utils.py
├── examples
├── distrun_atari.py
├── distrun_discomaze.py
├── distrun_lunar_lander.py
├── play_atari.py
├── run_atari.py
├── run_discomaze.py
└── run_lunar_lander.py
├── external
├── vertex.py
└── xm_docker.py
├── my_process_entry.py
├── requirements.txt
└── scripts
└── update_tb.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | **/__pycache__/
3 |
4 | venv/
5 |
6 | checkpoints/
7 | scratch/
8 | experiments/
9 |
10 | *.json
11 | roms/
12 | wandb/
13 | .env
14 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 | ## Basic dependencies.
3 | ADD . /app
4 |
5 | WORKDIR /app
6 |
7 | ### Installing dependencies.
8 | RUN apt-get update \
9 | && apt-get install -y --no-install-recommends \
10 | build-essential \
11 | curl \
12 | wget \
13 | xvfb \
14 | ffmpeg \
15 | xorg-dev \
16 | libsdl2-dev \
17 | swig \
18 | cmake \
19 | unar \
20 | libpython3.10 \
21 | tmux
22 |
23 | # Conda environment
24 |
25 | ENV CONDA_DIR /opt/conda
26 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
27 | /bin/bash ~/miniconda.sh -b -p /opt/conda
28 |
29 | # # Put conda in path so we can use conda activate
30 | ENV PATH=$CONDA_DIR/bin:$PATH
31 |
32 | # SHELL ["/bin/bash", "-c"]
33 | RUN conda create --name drlearner python=3.10 -y
34 | RUN python --version
35 | RUN echo "source activate drlearner" > ~/.bashrc
36 | ENV PATH /opt/conda/envs/drlearner/bin:$PATH
37 |
38 | # RUN conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib
39 | # RUN conda env config vars set PYTHONPATH=$PYTHONPATH:$(pwd)
40 |
41 | # Install dependencies (some of them are old + maybe there is need to check support of cuda)
42 | RUN python3.10 -m pip install --upgrade pip
43 | RUN python3.10 -m pip install --no-cache-dir -r requirements.txt
44 | RUN python3.10 -m pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable
45 | RUN conda install conda-forge::ffmpeg
46 |
47 | # RUN pip install git+https://github.com/google-deepmind/acme.git@4c6351ef8ff3f4045a9a24bee6a994667d89c69c
48 |
49 |
50 | # RUN conda install -c conda-forge cudatoolkit=11.2.2 cudnn=8.1.0
51 |
52 | # Get binaries for Atari games
53 | RUN wget http://www.atarimania.com/roms/Roms.rar
54 | RUN unar Roms.rar
55 | RUN mv Roms roms
56 | RUN ale-import-roms roms/
57 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib
58 | ENV PYTHONPATH=$PYTHONPATH:$(pwd)
59 | ENV XLA_PYTHON_CLIENT_PREALLOCATE='0'
60 |
61 | # ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib
62 | RUN chmod -R 777 ./
63 |
64 |
65 | # CMD ["python3" ,"examples/run_lunar_lander.py"]
66 | CMD ["/bin/bash"]
67 |
68 | # CMD ["/bin/bash python3", "examples/run_atari.py --level PongNoFrameskip-v4 --num_episodes 1000 --exp_path experiments/test_pong/"]
69 |
70 | # sudo docker build -t rdlearner:latest .
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # DRLearner
3 | Open Source Deep Reinforcement Learning (DRL) library, based on Agent 57 (Badia et al, 2020).
4 | We recommend reading this documentation [page](docs/DRLearner_notes.md) to get the essence of DRLearner.
5 |
6 | # Table of contents
7 | - [DRLearner](#drlearner)
8 | - [Table of content](#table-of-content)
9 | - [System Requirements](#system-requirements)
10 | - [Installation](#installation)
11 | - [Running DRLearner Agent](#running-drlearner-agent)
12 | - [Documentation](#documentation)
13 | - [Ongoing Support](#ongoing-support)
14 |
15 |
16 | ## System Requirements
17 |
18 | Hardware and cloud infrastructure used for DRLearner testing are listed below. For more information on specific configurations for running experiments, see GCP Hardware Specs and Running Experiments at the bottom of this document.
19 |
20 | | Google Cloud Configuration | Local Configuration |
21 | | --- | --- |
22 | | (GCP) | (Local) |
23 | | Tested on Ubuntu 20.4 with Python3.7 | Tested on Ubuntu 22.04 with python3.10 |
24 | | Hardware: NVIDIA Tesla, 500 Gb drive | Hardware: 8-core i7 |
25 |
26 | Depending on exact OS and hardware, packages such as git, Python3.7, Anaconda/Miniconda or gcc.
27 |
28 | ## Installation
29 |
30 | We recommend [Docker-based](docs/docker.md) installation, however for installation from scratch follow the instructions:
31 |
32 |
33 | Clone the repo
34 | ```
35 | git clone https://github.com/PatternsandPredictions/DRLearner_beta.git
36 | cd DRLearner_beta/
37 | ```
38 |
39 | Install xvfb for virtual display
40 | ```
41 | sudo apt-get update
42 | sudo apt-get install xvfb
43 | ```
44 |
45 | ### Creating environment
46 |
47 | #### Conda
48 |
49 | Restarting enviroment after creating and activating it is recommended to make sure that enviromental variables got updated.
50 | ```
51 | sudo apt-get update
52 | sudo apt-get install libpython3.10 ffmpeg swig
53 | conda create --name drlearner python=3.10
54 | conda activate drlearner
55 |
56 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib
57 | export PYTHONPATH=$PYTHONPATH:$(pwd)
58 | conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib
59 | conda env config vars set PYTHONPATH=$PYTHONPATH:$(pwd)
60 | ```
61 |
62 | Install packages
63 | ```
64 | pip install --no-cache-dir -r requirements.txt
65 | pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable
66 | ```
67 |
68 | #### Venv
69 | ```
70 | sudo apt-get update
71 | sudo apt-get install libpython3.10 swig ffmpeg -y
72 | python3.10 -m venv venv
73 | source venv/bin/activate
74 |
75 | export PYTHONPATH=$PYTHONPATH:$(pwd)
76 | ```
77 |
78 | Install packages
79 | ```
80 | pip install --no-cache-dir -r requirements.txt
81 | pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable
82 | ```
83 |
84 | ### Binary files for Atari games
85 | ```
86 | sudo apt-get install unrar
87 | wget http://www.atarimania.com/roms/Roms.rar
88 | unrar e Roms.rar roms/
89 | ale-import-roms roms/
90 |
91 | ```
92 |
93 | ## Running DRLearner Agent
94 |
95 | DRLearner comes with the following available environments:
96 | - Lunar Lander:
97 | - [Config](drlearner/configs/config_lunar_lander.py)
98 | - [Synchronous Agent](examples/run_lunar_lander.py)
99 | - [Asynchronos Agent](examples/distrun_lunar_lander.py)
100 | - Atari:
101 | - [Config](drlearner/configs/config_atari.py)
102 | - [Synchronous Agent](examples/run_atari.py)
103 | - [Asynchronos Agent](examples/distrun_atari.py)
104 | - [Example](docs/atari_pong.md)
105 | - Disco Maze
106 | - [Config](drlearner/configs/config_discomaze.py)
107 | - [Synchronous Agent](examples/run_discomaze.py)
108 | - [Asynchronos Agent](examples/distrun_discomaze.py)
109 |
110 | ### Lunar Lander example
111 |
112 | #### Training
113 | ```
114 | python ./examples/run_lunar_lander.py --num_episodes 1000 --exp_path experiments/test_pong/ --exp_name my_first_experiment
115 | ```
116 | Correct terminal output like this means that the training has been launched successfully:
117 |
118 | `[Enviroment] Mean Distillation Alpha = 1.000 | Action Mean Time = 0.027 | Env Step Mean Time = 0.000 | Episode Length = 63 | Episode Return = -453.10748291015625 | Episodes = 1 | Intrinsic Rewards Mean = 2.422 | Intrinsic Rewards Sum = 155.000 | Observe Mean Time = 0.014 | Steps = 63 | Steps Per Second = 15.544
119 | [Actor] Idm Accuracy = 0.12812499701976776 | Idm Loss = 1.4282478094100952 | Rnd Loss = 0.07360860705375671 | Extrinsic Uvfa Loss = 36.87723159790039 | Intrinsic Uvfa Loss = 19.602252960205078 | Steps = 1 | Time Elapsed = 65.282
120 | `
121 |
122 | To specify which directory to save changes in please specify exp_path. If model already exists in exp_path it will be loaded and training will resume.
123 | To name experiment in W&B please specify exp_name flag.
124 |
125 | #### Observing Lunar Lander in action
126 | To visualize any enviroment all you have to do is pass an instance of StorageVideoObserver to the enviroment. You pass and instance of DRLearnerConfig to the observer. In the config you can define
127 |
128 | ```
129 | observers = [IntrinsicRewardObserver(), DistillationCoefObserver(),StorageVideoObserver(config)]
130 | loop = EnvironmentLoop(env, agent, logger=logger_env, observers=observers)
131 | loop.run(FLAGS.num_episodes)
132 | ```
133 | 
134 |
135 |
136 | #### Training with checkpoints (Montezuma)
137 |
138 | Model will pick up from the moment it stopped in the previous training. Montezuma is the most difficult game so make sure you have enough computational power. Total number of actors is defined as number_of_actors_per_mixture*num_mixtures. If you will try to run too many actors your setup might break. If you have 16 cores of CPU we advice aroud 12 actors total.
139 |
140 | ```
141 | python ./examples/distrun_atari.py --exp_path artifacts/montezuma_base --exp_name montezuma_training
142 | ```
143 |
144 | More examples of synchronous and distributed agents training within the environments can be found in `examples/` .
145 |
146 | ## Documentation
147 | - [Debugging and monitoring](docs/debug_and_monitor.md)
148 | - [Docker installation](docs/docker.md)
149 | - [Apptainer on Unity cluster](docs/unity.md)
150 | - [Running on Vertex AI](docs/vertexai.md)
151 | - [Running on AWS](docs/aws-setup.md)
152 |
153 | ## Ongoing Support
154 |
155 | Join the [DRLearner Developers List](https://groups.google.com/g/drlearner?pli=10).
156 |
157 |
--------------------------------------------------------------------------------
/artifacts/montezuma_base/20240506-133311/logs/evaluator/logs.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/artifacts/montezuma_base/20240506-133311/logs/evaluator/logs.csv
--------------------------------------------------------------------------------
/artifacts/montezuma_base/20240506-133314/logs/learner/logs.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/artifacts/montezuma_base/20240506-133314/logs/learner/logs.csv
--------------------------------------------------------------------------------
/artifacts/montezuma_base/config.txt:
--------------------------------------------------------------------------------
1 | DRLearnerConfig(gamma_min=0.99, gamma_max=0.997, num_mixtures=32, target_update_period=400, evaluation_epsilon=0.01, epsilon=0.01, actor_epsilon=0.4, target_epsilon=0.01, variable_update_period=800, retrace_lambda=0.95, burn_in_length=0, trace_length=80, sequence_period=40, num_sgd_steps_per_step=1, uvfa_learning_rate=0.0001, idm_learning_rate=0.0005, distillation_learning_rate=0.0005, idm_weight_decay=1e-05, distillation_weight_decay=1e-05, idm_clip_steps=5, distillation_clip_steps=5, clip_rewards=True, max_absolute_reward=1.0, tx_pair=TxPair(apply=, apply_inv=), distillation_moving_average_coef=0.001, beta_min=0.0, beta_max=0.3, observation_embed_dim=32, episodic_memory_num_neighbors=10, episodic_memory_max_size=1500, episodic_memory_max_similarity=8.0, episodic_memory_cluster_distance=0.008, episodic_memory_pseudo_counts=0.001, episodic_memory_epsilon=0.0001, distillation_embed_dim=128, max_lifelong_modulation=5.0, samples_per_insert_tolerance_rate=0.5, samples_per_insert=2.0, min_replay_size=6250, max_replay_size=100000, batch_size=64, prefetch_size=1, num_parallel_calls=16, replay_table_name='priority_table', importance_sampling_exponent=0.6, priority_exponent=0.9, max_priority_weight=0.9, window=160, actor_window=160, evaluation_window=3600, n_arms=32, mc_epsilon=0.5, actor_mc_epsilon=0.3, evaluation_mc_epsilon=0.01, mc_beta=1.0, env_library='gym', video_log_period=50, actions_log_period=1, logs_dir='experiments/ll/', num_episodes=50)
--------------------------------------------------------------------------------
/compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | drlearner:
3 | build: .
4 | volumes:
5 | - .:/app
6 | stdin_open: true # docker run -i
7 | tty: true # docker run -t
8 | env_file:
9 | - .env
10 |
--------------------------------------------------------------------------------
/docker-configurations/local-docker-images.md:
--------------------------------------------------------------------------------
1 | # Docker image with tested GPU support.
2 |
3 | Present docker image contains dockerfile, updated docker-compose and post-installation script for docker-based setup.
4 |
5 | ## Prerequisites
6 | *****This setup was tested on Linix host machine only.** - most of the cloud setups will have variations of linux host.
7 |
8 | In order to operate successfully with docker image, user should have local NVIDIA drivers installed as well as NVIDIA CUDA toolkit.
9 | Both installations from ubuntu repo and according to NVIDIA instructions.
10 |
11 | ### Useful references:
12 | 1. [Download Nvidia drivers](https://www.nvidia.com/download/index.aspx)
13 | 2. [Installation guide for CUDA on linux host](https://www.cherryservers.com/blog/install-cuda-ubuntu)
14 | 3. [Official CUDA toolkit installation guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) - tested both installations on Ubuntu, AWS Linux AMI and installation with conda.
15 |
16 | ### Usage
17 | 1. Once Nvidia drivers, CUDA toolkit and nvcc are installed, installation can be validatied running the next commands.
18 | ```nvcc --version```
19 | ```nvidia-smi```
20 |
21 | 2. Copy Dockerfile, compose.yaml and post-install from current directory to root project directory.
22 | 3. Run ```docker-compose build --no-cache && docker-compose up -d```
23 | 4. Once system is fully built, enter the container, and check state of installed nvidia libs, using commands above.
24 | 5. Change permissions to post-install.sh files to be executable if needed.
25 | 6. Inside of container, run `./post-install.sh` to install the post-installation requirements.
26 | 7. Run any of examples.
27 |
28 |
--------------------------------------------------------------------------------
/docker-configurations/python3.10/Dockerfile:
--------------------------------------------------------------------------------
1 | # Start from a CUDA development image
2 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
3 |
4 | # Make non-interactive environment.
5 | ENV DEBIAN_FRONTEND noninteractive
6 |
7 | ## Installing dependencies.
8 | RUN apt-get update -y \
9 | && apt-get install -y --no-install-recommends \
10 | build-essential \
11 | python3.10 \
12 | python3.10-dev \
13 | python3-pip \
14 | curl \
15 | wget \
16 | xvfb \
17 | ffmpeg \
18 | xorg-dev \
19 | libsdl2-dev \
20 | swig \
21 | cmake \
22 | git \
23 | unar \
24 | libpython3.10 \
25 | zlib1g-dev \
26 | tmux \
27 | && rm -rf /var/lib/apt/lists/*
28 |
29 | ## Workdir
30 | ADD . /app
31 | WORKDIR /app
32 | # Library paths.
33 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib/:$(pwd)
34 | ENV PYTHONPATH=$PYTHONPATH:$(pwd)
35 |
36 | # Update pip to the latest version & install packages.
37 | RUN python3.10 -m pip install --upgrade pip
38 | RUN python3.10 -m pip install jax==0.4.3
39 | RUN python3.10 -m pip install jaxlib==0.4.3+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
40 | RUN python3.10 -m pip install --no-cache-dir -r requirements.txt
41 | RUN python3.10 -m pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable
42 |
43 | # Atari games.
44 | RUN wget http://www.atarimania.com/roms/Roms.rar
45 | RUN unar Roms.rar
46 | RUN mv Roms roms
47 | RUN ale-import-roms roms/
48 |
49 | RUN chmod +x ./
50 |
51 | CMD ["/bin/bash"]
52 |
--------------------------------------------------------------------------------
/docker-configurations/python3.7/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.1.1-base-ubuntu20.04
2 | # FROM python:3.7
3 | ## Basic dependencies.
4 | ADD . /app
5 |
6 | WORKDIR /app
7 |
8 | RUN ln -snf /usr/share/zoneinfo/$CONTAINER_TIMEZONE /etc/localtime && echo $CONTAINER_TIMEZONE > /etc/timezone
9 |
10 | ### Installing dependencies.
11 | RUN \
12 | --mount=type=cache,target=/var/cache/apt \
13 | apt-get update \
14 | && apt-get install -y --no-install-recommends \
15 | build-essential \
16 | curl \
17 | wget \
18 | xvfb \
19 | ffmpeg \
20 | xorg-dev \
21 | libsdl2-dev \
22 | swig \
23 | cmake \
24 | git \
25 | unar \
26 | libpython3.7
27 |
28 | # Conda environment
29 | ENV CONDA_DIR /opt/conda
30 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
31 | /bin/bash ~/miniconda.sh -b -p /opt/conda
32 |
33 | # # Put conda in path so we can use conda activate
34 | ENV PATH=$CONDA_DIR/bin:$PATH
35 |
36 | # SHELL ["/bin/bash", "-c"]
37 | RUN conda create --name drlearner python=3.7 -y
38 | RUN python --version
39 | RUN echo "source activate drlearner" > ~/.bashrc
40 | ENV PATH /opt/conda/envs/env/bin:$PATH
41 |
42 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib
43 |
44 | ENV PYTHONPATH=$PYTHONPATH:$(pwd)
45 | RUN conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:lib:/usr/lib:/usr/local/lib:~/anaconda3/envs/drlearner/lib:/opt/conda/envs/drlearner/lib:/opt/conda/lib
46 | RUN conda env config vars set PYTHONPATH=$PYTHONPATH:$(pwd)
47 | RUN conda install nvidia/label/cuda-11.3.1::cuda-nvcc -y
48 | RUN conda install -c conda-forge cudatoolkit=11.3.1 cudnn=8.2 -y
49 | RUN conda install -c anaconda git
50 |
51 | RUN chmod +x ./
52 |
53 | CMD ["/bin/bash"]
54 |
55 |
56 |
--------------------------------------------------------------------------------
/docker-configurations/python3.7/compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | drlearner:
3 | build: .
4 | volumes:
5 | - .:/app
6 | stdin_open: true # docker run -i
7 | tty: true # docker run -t
8 | deploy:
9 | resources:
10 | reservations:
11 | devices:
12 | - driver: nvidia
13 | count: 1
14 | capabilities: [gpu]
15 |
--------------------------------------------------------------------------------
/docker-configurations/python3.7/post-install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "This is post-installation script"
4 |
5 | pip install pip==21.3
6 | pip install jax==0.3.7 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7 | pip install jaxlib==0.3.7+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
8 |
9 | pip install setuptools==65.5.0
10 | pip install wheel==0.38.0
11 | pip install git+https://github.com/horus95/lazydict.git
12 | pip install --no-cache-dir -r requirements.txt
13 | pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable
14 |
15 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
16 | echo $LD_LIBRARY_PATH
17 |
18 | wget http://www.atarimania.com/roms/Roms.rar
19 | unar Roms.rar
20 | mv Roms roms
21 | ale-import-roms roms/
22 |
--------------------------------------------------------------------------------
/docker-configurations/python3.7/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | ale_py==0.7.5
3 | numpy==1.21.5
4 | cloudpickle==2.0.0
5 | six==1.16.0
6 | dm-acme==0.4.0
7 | libpython==0.2
8 | dm-acme[tf]
9 | chex==0.1.3
10 | Cython==0.29.28
11 | flax==0.4.1
12 | optax==0.1.2
13 | rlax==0.1.2
14 | pyglet==1.5.24
15 | #jax==0.3.15 # 0.4.10 # 0.4.13 # 0.3.7
16 | #jaxlib==0.3.15 # 0.4.10 # 0.4.13 # 0.3.7 pip install --upgrade jaxlib==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html
17 | dm-haiku==0.0.5
18 | dm-acme[reverb]
19 | #gym[accept-rom-license, atari, Box2D] # ==0.21.0
20 | xmanager==0.1.5
21 | pyvirtualdisplay==3.0
22 | #lazydict==1.0.0b2 # pip install git+https://github.com/horus95/lazydict
23 | sk-video==1.1.10
24 | ffmpeg-python==0.2.0
25 | wandb==0.16.2
26 | ##pip install --upgrade ml_dtypes==0.2.0
27 |
--------------------------------------------------------------------------------
/docs/atari_pong.md:
--------------------------------------------------------------------------------
1 | # Playing Pong on Atari: Your best Pong score ever
2 |
3 | ```
4 | python ./examples/run_atari.py --level PongNoFrameskip-v4 --num_episodes 1000 --exp_path experiments/test_pong/ --exp_name test_pong
5 | ```
6 | Correct terminal output like this means that the training has been launched successfully:
7 |
8 | `[Learner] Action Mean Time = 0.015 | Env Step Mean Time = 0.005 | Episode Length = 825 | Episode Return = -21.0 | Episodes = 1 | Observe Mean Time = 0.016 | Steps = 825 | Steps Per Second = 24.269`
9 |
10 | Training the model may take up to several hours to run, depending on configuration.
11 |
--------------------------------------------------------------------------------
/docs/aws-setup.md:
--------------------------------------------------------------------------------
1 | # AWS installation
2 |
3 | Current reposiory contains pre-built docker images to run in any cloud / on-premise platform.
4 | This is the recommended way to run in in destineed containers, as they are compatible and tested in GPU and CPU setups,
5 | and they are a basis for containerized distributed scheme.
6 |
7 | In the given file you will find installation instructions to run in [Amazon SageMaker](https://aws.amazon.com/sagemaker/), but they are applicable to according EC2 instances.
8 |
9 | ## Pre-requisites
10 |
11 | 1. Familiriality with AWS cloud is assumed.
12 | 2. Root or IAM account is configured.
13 | 3. *****Disclaimer: AWS is a paid service, and any computations imply costs.**
14 | 4. Navigate to the [console](https://us-east-1.console.aws.amazon.com/console/home?region=us-east-1#) for your selected region.
15 | 5. Create or run your [SageMaker](https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/notebook-instances) instance
16 | 6. Open jupyter lab
17 | 7. Upload your files.
18 | 8. Click Terminal among available options. Validate `nvidia-smi` to make sure that drivers are successfully installed.
19 | 9. Run docker compose build && docker compose up -d according to instructions.
20 |
21 |
22 | Alternatively one may try to setup the appropriate image to EC2 together with drivers, and install application as per guide.
23 |
24 | # CUDA ON EC2 FROM SCRATCH
25 | This instruction helps to set up Pytorch with CUDA on an EC2 instance with plain, Ubuntu AMI.
26 |
27 | ## Pre-installation actions
28 | 1) Verify the instance has the CUDA-capable GPU
29 | ```
30 | lspci | grep -i nvidia
31 | ```
32 |
33 | 2) Install kernel headers and development packages
34 | ```
35 | sudo apt-get install linux-headers-$(uname -r)
36 | ```
37 |
38 | ## NVIDIA drivers installation
39 | 1) Download a CUDA keyring for your distribution $distro and architecture $arch
40 | ```
41 | wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.1-1_all.deb
42 | ```
43 | i.e. for Ubuntu 22.04 with x86_64 the command would look as follows:
44 | ```
45 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
46 | ```
47 | 2) Add the downloaded keyring package
48 | ```
49 | sudo dpkg -i cuda-keyring_1.1-1_all.deb
50 | ```
51 | 3) Update the APT repository cache
52 | ```
53 | sudo apt-get update
54 | ```
55 | 4) Install the drivers
56 | ```
57 | sudo apt-get -y install cuda-drivers
58 | ```
59 | 5) Reboot the instance
60 | ```
61 | sudo reboot
62 | ```
63 | 6) Verify the installation
64 | ```
65 | nvidia-smi
66 | ```
67 | It is important to keep in mind CUDA Version is displayed in the upper-right corner, as PyTorch needs to be compatible with it.
68 |
69 | **NOTE:** At this stage NVIDIA recommends following [Post-installation actions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#environment-setup). I didn't and it worked but some unexpected errors might occur.
70 | ## PyTorch installation
71 |
72 | ### Install package manager
73 | I used conda but pip+venv *should* also work
74 | 1) Install conda
75 | ```
76 | mkdir -p ~/miniconda3
77 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
78 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
79 | rm -rf ~/miniconda3/miniconda.sh
80 | ```
81 | 2) Initialize conda
82 | ```
83 | ~/miniconda3/bin/conda init bash
84 | ```
85 | 3) Reload bash
86 | ```
87 | source ~/.bashrc
88 | ```
89 | 4) Create a new conda environment
90 | ```
91 | conda create -n env
92 | ```
93 | 5) Activate the newly created environment
94 | ```
95 | conda activate env
96 | ```
97 |
--------------------------------------------------------------------------------
/docs/debug_and_monitor.md:
--------------------------------------------------------------------------------
1 | # Debuging and monitoring
2 | Our framwork supports several loggers to help monitor and debug trained agents. All checkpoints will be saved in the workdir directory passed to the agent.
3 | Currently we log 18 different parameters. 7 for the actor and 11 for the enviroment but those can be easily extended by the user.
4 | Actor's parameters:
5 |
6 | ## Terminal logger
7 | Terminal logger logs progress to standard output. To use it please pass an instance of TerminalLogger to agent and enviroment. Logs aren't saved anywhere.
8 |
9 | ## Tensorboard logger
10 | This framework supports standard tensorboard logger. To use it please pass an instance of TFSummaryLogger to agent and enviroment. Logs will be saved in the workdir directory passed to the TFSummaryLogger logger. To visualize the logs please run code snipet below.
11 | ```
12 | tensorboard --logdir
13 |
14 | ```
15 | 
16 |
17 | ## CSV Logger
18 | This is standar csv logger. To use it please pass an instance of CSVLogger (or CloudCSVLogger when running on vertex) to agent and enviroment. Logs will be saved in the workdir directory passed to the logger.
19 |
20 | ## Weights and biases
21 | Our framework support [W&B](https://wandb.ai/) logger. It is installed along with the python requirements.txt. To use it please use code snipet below.
22 | ```
23 | from drlearner.utils.utils import make_wandb_logger
24 | ```
25 | or directly
26 | ```
27 | from drlearner.utils.utils import WandbLogger
28 | ```
29 | Set WANDB_API_KEY enviromental variable to your personal api key. If you are using doker/compose you can set your enviromental variable in the ".env" file. You can find api key on your W&B profile > user settings > Danger zone > reveal api key.
30 | Logs will be saved locally to /wandb directory and on your W&B account in the cloud.
31 | 
32 |
33 | ### Combining loggers
34 | To use more than one logger please see code snipet below
35 | ```
36 | wandb_logger=WandbLogger(logdir=tb_workdir, label=label,hyperparams=hyperparams)
37 | tensorboard_logger=TFSummaryLogger(logdir=tb_workdir, label=label)
38 | terminal_logger=loggers.terminal.TerminalLogger(label=label, print_fn=print_fn)
39 |
40 | all_loggers = [wandb_logger,tensorboard_logger,terminal_logger]
41 |
42 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn)
43 | logger = loggers.filters.NoneFilter(logger)
44 | ```
45 |
--------------------------------------------------------------------------------
/docs/docker.md:
--------------------------------------------------------------------------------
1 | # Running in Docker
2 |
3 | Clone the repo
4 | ```
5 | git clone https://github.com/PatternsandPredictions/DRLearner_beta.git
6 | cd DRLearner_beta/
7 | ```
8 |
9 | Install Docker (if not already installed) and Docker Compose (optional)
10 | ```
11 | https://docs.docker.com/desktop/install/linux-install/
12 | https://docs.docker.com/compose/install/linux/
13 | ```
14 |
15 | 1. Use Dockerfile directly
16 | ```
17 | docker build -t drlearner:latest .
18 | docker run -it --name drlearner -d drlearner:latest
19 | ```
20 | 2. Use Docker compose
21 | ```
22 | docker compose up
23 | ```
24 |
25 | Now you can attach yourself to the docker container to play with it.
26 | ```
27 | docker exec -it drlearner bash
28 | ```
29 | ## Dockerfile
30 | Using python image, setting "/app" as workdir and running essential linux dependecies
31 | ```
32 | FROM python:3.10
33 | ADD . /app
34 | WORKDIR /app
35 |
36 | RUN apt-get update \
37 | && apt-get install -y --no-install-recommends \
38 | build-essential \
39 | curl \
40 | wget \
41 | xvfb \
42 | ffmpeg \
43 | xorg-dev \
44 | libsdl2-dev \
45 | swig \
46 | cmake \
47 | unar \
48 | libpython3.10 \
49 | tmux
50 | ```
51 | Downloading conda and creating enviroment.
52 | ```
53 | ENV CONDA_DIR /opt/conda
54 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
55 | /bin/bash ~/miniconda.sh -b -p /opt/conda
56 |
57 | RUN conda create --name drlearner python=3.10 -y
58 | RUN python --version
59 | RUN echo "source activate drlearner" > ~/.bashrc
60 | ENV PATH /opt/conda/envs/drlearner/bin:$PATH
61 | ```
62 |
63 | Installing requirements for python and downloading game roms.
64 | ```
65 | RUN pip install --upgrade pip
66 | RUN pip install --no-cache-dir -r requirements.txt
67 | RUN pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable
68 |
69 | # Get binaries for Atari games
70 | RUN wget http://www.atarimania.com/roms/Roms.rar
71 | RUN unar Roms.rar
72 | RUN mv Roms roms
73 | RUN ale-import-roms roms/
74 | ```
75 |
76 | Setting up enviromental variables and changing acces mode for all files.
77 | ```
78 | ENV PYTHONPATH=$PYTHONPATH:$(pwd)
79 | RUN chmod -R 777 ./
80 | ```
81 |
82 | Default command for running container. Here you can modify it or simply run with "CMD ["/bin/bash"]" and attach yourself to container to run commands directly.
83 | ```
84 | CMD ["python3" ,"examples/run_atari.py", "--level","PongNoFrameskip-v4", "--num_episodes", "1000", "--exp_path", "experiments/test_pong/", "--exp_name", "my_first_experiment"]
85 | or
86 | CMD ["/bin/bash"]
87 | ```
88 |
89 | ## Docker compose
90 | Compose run one service called drlearner that is built using Dockerfile present in the main directory. Thanks to setting volumes as .:/app we don't have to rebuilt container each time we change codebase. Setting flags stdin_open and tty allows interactive mode of docker container. Thanks to that option user can attach themselves to the container and use it interactivly.
91 |
92 | ```
93 | services:
94 | drlearner:
95 | build: .
96 | volumes:
97 | - .:/app
98 | stdin_open: true # docker run -i
99 | tty: true # docker run -t
100 | env_file:
101 | - .env
102 | ```
103 | All the enviromental variable will be read from .env file.
104 |
--------------------------------------------------------------------------------
/docs/img/lunar_lander.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/lunar_lander.png
--------------------------------------------------------------------------------
/docs/img/notebook-instance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/notebook-instance.png
--------------------------------------------------------------------------------
/docs/img/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/tensorboard.png
--------------------------------------------------------------------------------
/docs/img/wandb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/docs/img/wandb.png
--------------------------------------------------------------------------------
/docs/unity.md:
--------------------------------------------------------------------------------
1 | # Running an Apptainer container on Unity
2 |
3 | Install `spython` utility to make an Apptainer definition file from Dockerfile:
4 | ```
5 | pip install spython
6 | spython recipe Dockerfile1 > Apptainer1.def
7 | ```
8 | Modify the definition file with the following environment settings:
9 | ```
10 | # Make non-interactive environment.
11 | export TZ='America/New_York'
12 | export DEBIAN_FRONTEND=noninteractive
13 | ```
14 | Build the Apptainer image (sif):
15 | ```
16 | module load apptainer/latest
17 | unset APPTAINER_BINDPATH
18 | apptainer build --fakeroot sifs/drlearner1.sif Apptainer1.def
19 | ```
20 | Allocate a computational node and load the required modules:
21 | ```
22 | salloc -N 1 -n 1 -p gpu-preempt -G 1 -t 2:00:00 --constraint=a100
23 | module load cuda/11.8.0
24 | module load cudnn/8.7.0.84-11.8
25 | ```
26 | Run the container:
27 | ```
28 | apptainer exec --nv sifs/drlearner1.sif bash
29 | ```
30 | Export a user's WANDB key for logging the job (for illustrative purposes only!):
31 | ```
32 | export WANDB_API_KEY=c5180d032d5325b08df49b65f9574c8cd59af6b1
33 | ```
34 | Run the Atari example:
35 | ```
36 | python3.10 examples/distrun_atari.py --exp_path experiments/apptainer_test_distrun_atari --exp_name apptainer_test_distrun_atari
37 | ```
--------------------------------------------------------------------------------
/docs/vertexai.md:
--------------------------------------------------------------------------------
1 | # Running on Vertex AI
2 |
3 | ## Installation and set-up
4 |
5 | 1. (Local) Install `gcloud`.
6 | ```
7 | sudo apt-get install apt-transport-https ca-certificates gnupg curl
8 |
9 | echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
10 |
11 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
12 | sudo apt-get update && sudo apt-get install google-cloud-sdk
13 | ```
14 |
15 | 2. (Local) Set up GCP project.
16 | ```
17 | gcloud init # choose the existing project or create a new one
18 | export GCP_PROJECT=
19 | echo $GCP_PROJECT # make sure it's the DRLearner project
20 | conda env config vars set GCP_PROJECT= # optional
21 | ```
22 | 3. (Local) Authorise the use of GCP services by DRLearner.
23 | ```
24 | gcloud auth application-default login # get credentials to allow DRLearner code calls to GC APIs
25 | export GOOGLE_APPLICATION_CREDENTIALS=/home//.config/gcloud/application_default_credentials.json
26 | conda env config vars set GOOGLE_APPLICATION_CREDENTIALS=/home//.config/gcloud/application_default_credentials.json # optional
27 | ```
28 | 4. (Local) Install and configure Docker.
29 | ```
30 | sudo apt-get remove docker docker-engine docker.io containerd runc
31 | sudo apt-get update && sudo apt-get install lsb-release
32 | sudo mkdir -p /etc/apt/keyrings
33 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
34 | echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
35 | sudo apt-get update
36 | sudo apt-get install docker-ce docker-ce-cli containerd.io
37 |
38 | sudo groupadd docker
39 | sudo usermod -aG docker
40 |
41 | gcloud auth configure-docker
42 | ```
43 |
44 | 5. (GCP console) Enable IAM, Enable Vertex AI, Enable Container Registry in ``.
45 |
46 |
47 | 6. (GCP console) Set up a xmanager service account.
48 | - Create xmanager service account in `IAM & Admin/Service accounts` .
49 | - Add 'Storage Admin', 'Vertex AI Administrator', 'Vertex AI User' , 'Service Account User' roles.
50 |
51 | 7. Set up a Cloud storage bucket.
52 | - (GCP console) Create a Cloud storage bucket in Cloud Storage in `us-central1` region.
53 | - (Local) `export GOOGLE_CLOUD_BUCKET_NAME=`
54 | - (Local, optional) `conda env config vars set GOOGLE_CLOUD_BUCKET_NAME=`
55 |
56 | 8. (Local) Replace `envs/drlearner/lib/python3.10/site-packages/launchpad/nodes/python/xm_docker.py` with `./external/xm_docker.py` (to get the correct Docker instructions)*
57 |
58 | *Can't rebuild launchpad package with those changes because the of complicated build process (requires Bazel...)
59 |
60 |
61 | 9. (Local) Replace `envs/drlearner/lib/python3.10/site-packages/xmanager/cloud/vertex.py` with `./external/vertex.py` (to add new machine types, allow web access to nodes from GCP console).
62 |
63 |
64 | 10. (Local) Tensorboard instructions:
65 | - Use scripts/update_tb.py to download current tfevents file which is saved in ``
66 | ```
67 | python update_tb.py /
68 | ```
69 | ! We recommend syncing tf files regularly and keeping older versions as well,
70 | since Vertex AI silently restarts the workers which are down,
71 | and they start writing logs in tf file from scratch !
72 |
73 | ## GCP Hardware Specs
74 | The hardware requirements for running DRLearner on Vertex AI are specified in `drlearner/configs/resources/` - there are two setups: for easy environment (i.e. Atari Boxing) and a more complex one (i.e. Atari Montezuma Revenge). See the table below.
75 |
76 |
77 | | | Simple env | Complex env |
78 | |---------------|:------------------------------------------:|---------------------------------------------:|
79 | | Actor | e2-standard-4 (4 CPU, 16 RAM) | e2-standard-4 (4 CPU, 16 RAM) |
80 | | Learner | n1-standard-4 (4 CPU, 16 RAM + TESLA P100) | n1-highmem-16 (16 CPU, 104 RAM + TESLA P100) |
81 | | Replay Buffer | e2-highmem-8 (8 CPU, 64 RAM) | e2-highmem-16 (16 CPU, 128 RAM) |
82 |
83 | New configurations can be added using the same xm_docker.DockerConfig and xm.JobRequirements classes. Available for use on Vertex AI machine types are listed here https://cloud.google.com/vertex-ai/pricing.
84 | But it might require adding the new machine names to `external/vertex.py` i.e. `'n2-standard-64': (64, 256 * xm.GiB),`.
85 |
86 | ## GCP Troubleshooting
87 | In case of any 'Permission denied' issues, go to `IAM & Admin/` in GCP console and try adding 'Service Account User' role to your User, and
88 | 'Compute Storage Admin' role to 'AI Platform Custom Code Service Agent' Service Account.
89 |
90 | ## Running experiments
91 | ```
92 | python ./examples/distrun_atari.py --run_on_vertex --exp_path /gcs/$GOOGLE_CLOUD_BUCKET_NAME/test_pong/ --level PongNoFrameskip-v4 --num_actors_per_mixture 3
93 | ```
94 | - add `--noxm_build_image_locally` to build Docker images with Cloud Build, otherwise it will be built locally.
95 | - number of nodes running Actor code is `--num_actors_per_mixture` x `num_mixtures` - default number of mixtures for Atari is 32 - so be careful and don't launch the full-scale experiment before testing that everything works correctly.
96 |
--------------------------------------------------------------------------------
/drlearner/__init__.py:
--------------------------------------------------------------------------------
1 | from .drlearner import *
2 | from .core import *
3 | from .configs import *
4 |
--------------------------------------------------------------------------------
/drlearner/configs/config_atari.py:
--------------------------------------------------------------------------------
1 | from drlearner.drlearner.config import DRLearnerConfig
2 | import rlax
3 | from acme.adders import reverb as adders_reverb
4 |
5 | AtariDRLearnerConfig = DRLearnerConfig(
6 | gamma_min=0.99,
7 | gamma_max=0.997,
8 | num_mixtures=32,
9 | target_update_period=400,
10 | evaluation_epsilon=0.01,
11 | actor_epsilon=0.4,
12 | target_epsilon=0.01,
13 | variable_update_period=800,
14 |
15 | # Learner options
16 | retrace_lambda=0.95,
17 | burn_in_length=0,
18 | trace_length=80,
19 | sequence_period=40,
20 | num_sgd_steps_per_step=1,
21 | uvfa_learning_rate=1e-4,
22 | idm_learning_rate=5e-4,
23 | distillation_learning_rate=5e-4,
24 | idm_weight_decay=1e-5,
25 | distillation_weight_decay=1e-5,
26 | idm_clip_steps=5,
27 | distillation_clip_steps=5,
28 | clip_rewards=True,
29 | max_absolute_reward=1.0,
30 | tx_pair=rlax.SIGNED_HYPERBOLIC_PAIR,
31 | distillation_moving_average_coef=1e-3,
32 |
33 | # Intrinsic reward multipliers
34 | beta_min=0.,
35 | beta_max=0.3,
36 |
37 | # Embedding network options
38 | observation_embed_dim=32,
39 | episodic_memory_num_neighbors=10,
40 | episodic_memory_max_size=1500,
41 | episodic_memory_max_similarity=8.,
42 | episodic_memory_cluster_distance=8e-3,
43 | episodic_memory_pseudo_counts=1e-3,
44 | episodic_memory_epsilon=1e-4,
45 |
46 | # Distillation network
47 | distillation_embed_dim=128,
48 | max_lifelong_modulation=5.0,
49 |
50 | # Replay options
51 | samples_per_insert_tolerance_rate=0.5,
52 | samples_per_insert=2.,
53 | min_replay_size=6250,
54 | max_replay_size=100_000,
55 | batch_size=64,
56 | prefetch_size=1,
57 | num_parallel_calls=16,
58 | replay_table_name=adders_reverb.DEFAULT_PRIORITY_TABLE,
59 |
60 | # Priority options
61 | importance_sampling_exponent=0.6,
62 | priority_exponent=0.9,
63 | max_priority_weight=0.9,
64 |
65 | # Meta Controller options
66 | actor_window=160,
67 | evaluation_window=3600,
68 | n_arms=32,
69 | actor_mc_epsilon=0.3,
70 | evaluation_mc_epsilon=0.01,
71 | mc_beta=1.,
72 |
73 | # Agent video logging options
74 | env_library='gym',
75 | video_log_period=50,
76 | actions_log_period=1,
77 | logs_dir='experiments/videos/',
78 | num_episodes=50,
79 |
80 | )
--------------------------------------------------------------------------------
/drlearner/configs/config_discomaze.py:
--------------------------------------------------------------------------------
1 | import rlax
2 | from acme.adders import reverb as adders_reverb
3 |
4 | from drlearner.drlearner.config import DRLearnerConfig
5 |
6 | DiscomazeDRLearnerConfig = DRLearnerConfig(
7 | gamma_min=0.99,
8 | gamma_max=0.99,
9 | num_mixtures=3,
10 | target_update_period=100,
11 | evaluation_epsilon=0.,
12 | actor_epsilon=0.05,
13 | target_epsilon=0.0,
14 | variable_update_period=1000,
15 |
16 | # Learner options
17 | retrace_lambda=0.97,
18 | burn_in_length=0,
19 | trace_length=30,
20 | sequence_period=30,
21 | num_sgd_steps_per_step=1,
22 | uvfa_learning_rate=1e-3,
23 | idm_learning_rate=1e-3,
24 | distillation_learning_rate=1e-3,
25 | idm_weight_decay=1e-5,
26 | distillation_weight_decay=1e-5,
27 | idm_clip_steps=5,
28 | distillation_clip_steps=5,
29 | clip_rewards=False,
30 | max_absolute_reward=1.0,
31 | tx_pair=rlax.SIGNED_HYPERBOLIC_PAIR,
32 |
33 | # Intrinsic reward multipliers
34 | beta_min=0.,
35 | beta_max=0.5,
36 |
37 | # Embedding network options
38 | observation_embed_dim=16,
39 | episodic_memory_num_neighbors=10,
40 | episodic_memory_max_size=5_000,
41 | episodic_memory_max_similarity=8.,
42 | episodic_memory_cluster_distance=8e-3,
43 | episodic_memory_pseudo_counts=1e-3,
44 | episodic_memory_epsilon=1e-2,
45 |
46 | # Distillation network
47 | distillation_embed_dim=32,
48 | max_lifelong_modulation=5.0,
49 |
50 | # Replay options
51 | samples_per_insert_tolerance_rate=1.0,
52 | samples_per_insert=0.0,
53 | min_replay_size=1,
54 | max_replay_size=100_000,
55 | batch_size=64,
56 | prefetch_size=1,
57 | num_parallel_calls=16,
58 | replay_table_name=adders_reverb.DEFAULT_PRIORITY_TABLE,
59 |
60 | # Priority options
61 | importance_sampling_exponent=0.6,
62 | priority_exponent=0.9,
63 | max_priority_weight=0.9,
64 |
65 | # Meta Controller options
66 | actor_window=160,
67 | evaluation_window=1000,
68 | n_arms=3,
69 | actor_mc_epsilon=0.3,
70 | evaluation_mc_epsilon=0.01,
71 | mc_beta=1.,
72 |
73 | # Agent video logging options
74 | env_library='discomaze',
75 | video_log_period=10,
76 | actions_log_period=1,
77 | logs_dir='experiments/emb_size_4_nn1_less_learning_steps',
78 | num_episodes=50,
79 | )
80 |
--------------------------------------------------------------------------------
/drlearner/configs/config_lunar_lander.py:
--------------------------------------------------------------------------------
1 | import rlax
2 | from acme.adders import reverb as adders_reverb
3 |
4 | from drlearner.drlearner.config import DRLearnerConfig
5 |
6 | LunarLanderDRLearnerConfig = DRLearnerConfig(
7 | gamma_min=0.99,
8 | gamma_max=0.99,
9 | num_mixtures=3,
10 | target_update_period=50,
11 | evaluation_epsilon=0.01,
12 | actor_epsilon=0.05,
13 | target_epsilon=0.01,
14 | variable_update_period=100,
15 |
16 | # Learner options
17 | retrace_lambda=0.95,
18 | burn_in_length=0,
19 | trace_length=30,
20 | sequence_period=30,
21 | num_sgd_steps_per_step=1,
22 | uvfa_learning_rate=5e-4,
23 | idm_learning_rate=5e-4,
24 | distillation_learning_rate=5e-4,
25 | idm_weight_decay=1e-5,
26 | distillation_weight_decay=1e-5,
27 | idm_clip_steps=5,
28 | distillation_clip_steps=5,
29 | clip_rewards=True,
30 | max_absolute_reward=1.0,
31 | tx_pair=rlax.SIGNED_HYPERBOLIC_PAIR,
32 |
33 | # Intrinsic reward multipliers
34 | beta_min=0.,
35 | beta_max=0.,
36 |
37 | # Embedding network options
38 | observation_embed_dim=1,
39 | episodic_memory_num_neighbors=1,
40 | episodic_memory_max_size=1,
41 | episodic_memory_max_similarity=8.,
42 | episodic_memory_cluster_distance=8e-3,
43 | episodic_memory_pseudo_counts=1e-3,
44 | episodic_memory_epsilon=1e-2,
45 |
46 | # Distillation network
47 | distillation_embed_dim=1,
48 | max_lifelong_modulation=1,
49 |
50 | # Replay options
51 | samples_per_insert_tolerance_rate=1.0,
52 | samples_per_insert=0.0,
53 | min_replay_size=1,
54 | max_replay_size=100_000,
55 | batch_size=64,
56 | prefetch_size=1,
57 | num_parallel_calls=16,
58 | replay_table_name=adders_reverb.DEFAULT_PRIORITY_TABLE,
59 |
60 | # Priority options
61 | importance_sampling_exponent=0.6,
62 | priority_exponent=0.9,
63 | max_priority_weight=0.9,
64 |
65 | # Meta Controller options
66 | actor_window=160,
67 | evaluation_window=1000,
68 | n_arms=3,
69 | actor_mc_epsilon=0.3,
70 | evaluation_mc_epsilon=0.01,
71 | mc_beta=1.,
72 |
73 | # Agent video logging options
74 | env_library='gym',
75 | video_log_period=50,
76 | actions_log_period=1,
77 | logs_dir='experiments/ll/',
78 | num_episodes=50,
79 | )
80 |
--------------------------------------------------------------------------------
/drlearner/configs/resources/__init__.py:
--------------------------------------------------------------------------------
1 | from .atari import get_vertex_resources as get_atari_vertex_resources
2 | from .toy_env import get_vertex_resources as get_toy_env_vertex_resources
3 | from .local_resources import get_local_resources
4 |
--------------------------------------------------------------------------------
/drlearner/configs/resources/atari.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from launchpad.nodes.python import xm_docker
4 | from xmanager import xm
5 | import xmanager.cloud.build_image
6 |
7 |
8 | def get_vertex_resources():
9 | resources = dict()
10 |
11 | resources['learner'] = xm_docker.DockerConfig(
12 | os.getcwd() + '/',
13 | os.getcwd() + '/requirements.txt',
14 | xm.JobRequirements(cpu=16, memory=104 * xm.GiB, P100=1)
15 | )
16 |
17 | resources['counter'] = xm_docker.DockerConfig(
18 | os.getcwd() + '/',
19 | os.getcwd() + '/requirements.txt',
20 | xm.JobRequirements(cpu=2, memory=16 * xm.GiB)
21 | )
22 |
23 | for node in ['actor', 'evaluator']:
24 | resources[node] = xm_docker.DockerConfig(
25 | os.getcwd() + '/',
26 | os.getcwd() + '/requirements.txt',
27 | xm.JobRequirements(cpu=4, memory=16 * xm.GiB)
28 | )
29 |
30 | resources['replay'] = xm_docker.DockerConfig(
31 | os.getcwd() + '/',
32 | os.getcwd() + '/requirements.txt',
33 | xm.JobRequirements(cpu=16, memory=128 * xm.GiB)
34 | )
35 |
36 | return resources
37 |
--------------------------------------------------------------------------------
/drlearner/configs/resources/local_resources.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from launchpad.nodes.python import xm_docker
4 | from xmanager import xm
5 | import xmanager.cloud.build_image
6 | from launchpad.nodes.python import local_multi_processing
7 |
8 |
9 | def get_local_resources():
10 | local_resources = dict(
11 | actor=local_multi_processing.PythonProcess(
12 | env=dict(CUDA_VISIBLE_DEVICES='-1')
13 | ),
14 | counter=local_multi_processing.PythonProcess(
15 | env=dict(CUDA_VISIBLE_DEVICES='-1')
16 | ),
17 | evaluator=local_multi_processing.PythonProcess(
18 | env=dict(CUDA_VISIBLE_DEVICES='-1')
19 | ),
20 | replay=local_multi_processing.PythonProcess(
21 | env=dict(CUDA_VISIBLE_DEVICES='-1')
22 | ),
23 | learner=local_multi_processing.PythonProcess(
24 | env=dict(
25 | # XLA_PYTHON_CLIENT_MEM_FRACTION='0.1',
26 | CUDA_VISIBLE_DEVICES='0',
27 | XLA_PYTHON_CLIENT_PREALLOCATE='0',
28 | LD_LIBRARY_PATH=os.environ.get('LD_LIBRARY_PATH', '') + ':/usr/local/cuda/lib64'))
29 | )
30 | return local_resources
--------------------------------------------------------------------------------
/drlearner/configs/resources/toy_env.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from launchpad.nodes.python import xm_docker
4 | from xmanager import xm
5 | import xmanager.cloud.build_image
6 |
7 |
8 | def get_vertex_resources():
9 | resources = dict()
10 |
11 | resources['learner'] = xm_docker.DockerConfig(
12 | os.getcwd() + '/',
13 | os.getcwd() + '/requirements.txt',
14 | xm.JobRequirements(cpu=4, memory=15 * xm.GiB, P100=1)
15 | )
16 |
17 | resources['counter'] = xm_docker.DockerConfig(
18 | os.getcwd() + '/',
19 | os.getcwd() + '/requirements.txt',
20 | xm.JobRequirements(cpu=2, memory=16 * xm.GiB)
21 | )
22 |
23 | for node in ['actor', 'evaluator']:
24 | resources[node] = xm_docker.DockerConfig(
25 | os.getcwd() + '/',
26 | os.getcwd() + '/requirements.txt',
27 | xm.JobRequirements(cpu=4, memory=16 * xm.GiB)
28 | )
29 |
30 | resources['replay'] = xm_docker.DockerConfig(
31 | os.getcwd() + '/',
32 | os.getcwd() + '/requirements.txt',
33 | xm.JobRequirements(cpu=8, memory=64 * xm.GiB)
34 | )
35 |
36 | return resources
37 |
--------------------------------------------------------------------------------
/drlearner/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/drlearner/core/__init__.py
--------------------------------------------------------------------------------
/drlearner/core/environment_loop.py:
--------------------------------------------------------------------------------
1 | """A simple agent-environment training loop."""
2 |
3 | import operator
4 | import time
5 | from typing import Optional, Sequence
6 | import platform
7 |
8 | import numpy as np
9 | from pyvirtualdisplay import Display
10 | import tree
11 |
12 | from acme import core
13 | from acme.utils import counting
14 | from acme.utils import loggers
15 | from acme.utils import observers as observers_lib
16 | from acme.utils import signals
17 |
18 | import dm_env
19 | from dm_env import specs
20 |
21 | from drlearner.core.loggers import disable_view_window
22 | from drlearner.core.observers import VideoObserver
23 |
24 |
25 | class EnvironmentLoop(core.Worker):
26 | """A simple RL environment loop.
27 | This takes `Environment` and `Actor` instances and coordinates their
28 | interaction. Agent is updated if `should_update=True`. This can be used as:
29 | loop = EnvironmentLoop(environment, actor)
30 | loop.run(num_episodes)
31 | A `Counter` instance can optionally be given in order to maintain counts
32 | between different Acme components. If not given a local Counter will be
33 | created to maintain counts between calls to the `run` method.
34 | A `Logger` instance can also be passed in order to control the output of the
35 | loop. If not given a platform-specific default logger will be used as defined
36 | by utils.loggers.make_default_logger. A string `label` can be passed to easily
37 | change the label associated with the default logger; this is ignored if a
38 | `Logger` instance is given.
39 | A list of 'Observer' instances can be specified to generate additional metrics
40 | to be logged by the logger. They have access to the 'Environment' instance,
41 | the current timestep datastruct and the current action.
42 | """
43 |
44 | def __init__(
45 | self,
46 | environment: dm_env.Environment,
47 | actor: core.Actor,
48 | counter: Optional[counting.Counter] = None,
49 | logger: Optional[loggers.Logger] = None,
50 | should_update: bool = True,
51 | label: str = 'environment_loop',
52 | observers: Sequence[observers_lib.EnvLoopObserver] = (),
53 | ):
54 | # Internalize agent and environment.
55 | self._environment = environment
56 | self._actor = actor
57 | self._counter = counter or counting.Counter()
58 | self._logger = logger or loggers.make_default_logger(label)
59 | self._should_update = should_update
60 | self._observers = observers
61 |
62 | self.platform = platform.system().lower()
63 |
64 | if any([isinstance(o, VideoObserver) for o in observers]):
65 | if 'linux' in self.platform:
66 | display = Display(visible=0, size=(1400, 900))
67 | display.start()
68 | else:
69 | disable_view_window()
70 |
71 | def run_episode(self, episode_count: int) -> loggers.LoggingData:
72 | """Run one episode.
73 | Each episode is a loop which interacts first with the environment to get an
74 | observation and then give that observation to the agent in order to retrieve
75 | an action.
76 | Returns:
77 | An instance of `loggers.LoggingData`.
78 | """
79 | # Reset any counts and start the environment.
80 | start_time = time.time()
81 | action_time, env_step_time, observe_time = 0., 0., 0.
82 | episode_steps = 0
83 |
84 | # For evaluation, this keeps track of the total undiscounted reward
85 | # accumulated during the episode.
86 | episode_return = tree.map_structure(_generate_zeros_from_spec,
87 | self._environment.reward_spec())
88 | timestep = self._environment.reset()
89 | # Make the first observation.
90 | self._actor.observe_first(timestep)
91 | actor_extras = self._actor.get_extras()
92 |
93 | for observer in self._observers:
94 | # Initialize the observer with the current state of the env after reset
95 | # and the initial timestep.
96 | if hasattr(observer, 'observe_first'):
97 | observer.observe_first(
98 | self._environment,
99 | timestep,
100 | actor_extras,
101 | episode=episode_count,
102 | step=episode_steps,
103 | )
104 |
105 | # Run an episode.
106 | while not timestep.last():
107 | # Generate an action from the agent's policy and step the environment.
108 | t = time.perf_counter()
109 | action = self._actor.select_action(timestep.observation)
110 | action_time += time.perf_counter() - t
111 |
112 | t = time.perf_counter()
113 | timestep = self._environment.step(action)
114 | env_step_time += time.perf_counter() - t
115 |
116 | # Have the agent observe the timestep and let the actor update itself.
117 | t = time.perf_counter()
118 | self._actor.observe(action, next_timestep=timestep)
119 | observe_time += time.perf_counter() - t
120 |
121 | actor_extras = self._actor.get_extras()
122 | for observer in self._observers:
123 | # One environment step was completed. Observe the current state of the
124 | # environment, the current timestep and the action.
125 | observer.observe(
126 | self._environment,
127 | timestep,
128 | action,
129 | actor_extras,
130 | episode=episode_count,
131 | step=episode_steps,
132 | )
133 | if self._should_update:
134 | self._actor.update()
135 |
136 | # Book-keeping.
137 | episode_steps += 1
138 |
139 | # Equivalent to: episode_return += timestep.reward
140 | # We capture the return value because if timestep.reward is a JAX
141 | # DeviceArray, episode_return will not be mutated in-place. (In all other
142 | # cases, the returned episode_return will be the same object as the
143 | # argument episode_return.)
144 | episode_return = tree.map_structure(operator.iadd,
145 | episode_return,
146 | timestep.reward)
147 |
148 | # Record counts.
149 | counts = self._counter.increment(episodes=1, steps=episode_steps)
150 |
151 | # Collect the results and combine with counts.
152 | steps_per_second = episode_steps / (time.time() - start_time)
153 | result = {
154 | 'episode_length': episode_steps,
155 | 'episode_return': episode_return,
156 | 'steps_per_second': steps_per_second,
157 | 'action_mean_time': action_time / episode_steps,
158 | 'env_step_mean_time': env_step_time / episode_steps,
159 | 'observe_mean_time': observe_time / episode_steps,
160 |
161 | }
162 | result.update(counts)
163 | for observer in self._observers:
164 | if hasattr(observer, 'get_metrics'):
165 | result.update(
166 | observer.get_metrics(timestep=timestep, episode=episode_count),
167 | )
168 |
169 | return result
170 |
171 | def run(self,
172 | num_episodes: Optional[int] = None,
173 | num_steps: Optional[int] = None):
174 | """Perform the run loop.
175 | Run the environment loop either for `num_episodes` episodes or for at
176 | least `num_steps` steps (the last episode is always run until completion,
177 | so the total number of steps may be slightly more than `num_steps`).
178 | At least one of these two arguments has to be None.
179 | Upon termination of an episode a new episode will be started. If the number
180 | of episodes and the number of steps are not given then this will interact
181 | with the environment infinitely.
182 | Args:
183 | num_episodes: number of episodes to run the loop for.
184 | num_steps: minimal number of steps to run the loop for.
185 | Raises:
186 | ValueError: If both 'num_episodes' and 'num_steps' are not None.
187 | """
188 |
189 | if not (num_episodes is None or num_steps is None):
190 | raise ValueError('Either "num_episodes" or "num_steps" should be None.')
191 |
192 | def should_terminate(episode_count: int, step_count: int) -> bool:
193 | return ((num_episodes is not None and episode_count >= num_episodes) or
194 | (num_steps is not None and step_count >= num_steps))
195 |
196 | episode_count, step_count = 0, 0
197 | with signals.runtime_terminator():
198 | while not should_terminate(episode_count, step_count):
199 | result = self.run_episode(episode_count)
200 | episode_count += 1
201 | step_count += result['episode_length']
202 | # Log the given episode results.
203 | self._logger.write(result)
204 |
205 |
206 | # Placeholder for an EnvironmentLoop alias
207 |
208 |
209 | def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray:
210 | return np.zeros(spec.shape, spec.dtype)
211 |
--------------------------------------------------------------------------------
/drlearner/core/local_layout.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Local agent based on builders."""
17 |
18 | from typing import Any, Optional
19 |
20 | from acme import specs
21 | from acme.agents import agent
22 | from acme.agents.jax import builders
23 | from acme.jax import utils
24 | from acme.tf import savers
25 | from acme.utils import counting
26 | import jax
27 | import reverb
28 |
29 |
30 | class LocalLayout(agent.Agent):
31 | """An Agent that runs an algorithm defined by 'builder' on a single machine.
32 | """
33 |
34 | def __init__(
35 | self,
36 | seed: int,
37 | environment_spec: specs.EnvironmentSpec,
38 | builder: builders.GenericActorLearnerBuilder,
39 | networks: Any,
40 | policy_network: Any,
41 | workdir: Optional[str] = '~/acme',
42 | min_replay_size: int = 1000,
43 | samples_per_insert: float = 256.0,
44 | batch_size: int = 256,
45 | num_sgd_steps_per_step: int = 1,
46 | prefetch_size: int = 1,
47 | device_prefetch: bool = True,
48 | counter: Optional[counting.Counter] = None,
49 | checkpoint: bool = True,
50 | ):
51 | """Initialize the agent.
52 |
53 | Args:
54 | seed: A random seed to use for this layout instance.
55 | environment_spec: description of the actions, observations, etc.
56 | builder: builder defining an RL algorithm to train.
57 | networks: network objects to be passed to the learner.
58 | policy_network: function that given an observation returns actions.
59 | workdir: if provided saves the state of the learner and the counter
60 | (if the counter is not None) into workdir.
61 | min_replay_size: minimum replay size before updating.
62 | samples_per_insert: number of samples to take from replay for every insert
63 | that is made.
64 | batch_size: batch size for updates.
65 | num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
66 | For performance reasons (especially to reduce TPU host-device transfer
67 | times) it is performance-beneficial to do multiple sgd updates at once,
68 | provided that it does not hurt the training, which needs to be verified
69 | empirically for each environment.
70 | prefetch_size: whether to prefetch iterator.
71 | device_prefetch: whether prefetching should happen to a device.
72 | counter: counter object used to keep track of steps.
73 | checkpoint: boolean indicating whether to checkpoint the learner
74 | and the counter (if the counter is not None).
75 | """
76 | if prefetch_size < 0:
77 | raise ValueError(f'Prefetch size={prefetch_size} should be non negative')
78 |
79 | key = jax.random.PRNGKey(seed)
80 |
81 | # Create the replay server and grab its address.
82 | replay_tables = builder.make_replay_tables(environment_spec)
83 | replay_server = reverb.Server(replay_tables, port=None)
84 | replay_client = reverb.Client(f'localhost:{replay_server.port}')
85 |
86 | # Create actor, dataset, and learner for generating, storing, and consuming
87 | # data respectively.
88 | adder = builder.make_adder(replay_client)
89 |
90 | def _is_reverb_queue(reverb_table: reverb.Table,
91 | reverb_client: reverb.Client) -> bool:
92 | """Returns True iff the Reverb Table is actually a queue."""
93 | # TODO(sinopalnikov): make it more generic and check for a table that
94 | # needs special handling on update.
95 | info = reverb_client.server_info()
96 | table_info = info[reverb_table.name]
97 | is_queue = (
98 | table_info.max_times_sampled == 1 and
99 | table_info.sampler_options.fifo and
100 | table_info.remover_options.fifo)
101 | return is_queue
102 |
103 | is_reverb_queue = any(_is_reverb_queue(table, replay_client)
104 | for table in replay_tables)
105 |
106 | dataset = builder.make_dataset_iterator(replay_client)
107 | if prefetch_size > 1:
108 | device = jax.devices()[0] if device_prefetch else None
109 | dataset = utils.prefetch(dataset, buffer_size=prefetch_size,
110 | device=device)
111 | learner_key, key = jax.random.split(key)
112 | learner = builder.make_learner(
113 | random_key=learner_key,
114 | networks=networks,
115 | dataset=dataset,
116 | replay_client=replay_client,
117 | counter=counter)
118 | if not checkpoint or workdir is None:
119 | self._checkpointer = None
120 | else:
121 | objects_to_save = {'learner': learner}
122 | if counter is not None:
123 | objects_to_save.update({'counter': counter})
124 | self._checkpointer = savers.Checkpointer(
125 | objects_to_save,
126 | time_delta_minutes=30,
127 | subdirectory='learner',
128 | directory=workdir,
129 | add_uid=(workdir == '~/acme'))
130 |
131 | actor_key, key = jax.random.split(key)
132 | # use actor_id as 0 for local layout
133 | actor = builder.make_actor(
134 | actor_key, policy_network, adder, variable_source=learner)
135 | self._custom_update_fn = None
136 | if is_reverb_queue:
137 | # Reverb queue requires special handling on update: custom logic to
138 | # decide when it is safe to make a learner step. This is only needed for
139 | # the local agent, where the actor and the learner are running
140 | # synchronously and the learner will deadlock if it makes a step with
141 | # no data available.
142 | def custom_update():
143 | should_update_actor = False
144 | # Run a number of learner steps (usually gradient steps).
145 | # TODO(raveman): This is wrong. When running multi-level learners,
146 | # different levels might have different batch sizes. Find a solution.
147 | while all(table.can_sample(batch_size) for table in replay_tables):
148 | learner.step()
149 | should_update_actor = True
150 |
151 | if should_update_actor:
152 | # "wait=True" to make it more onpolicy
153 | actor.update(wait=True)
154 |
155 | self._custom_update_fn = custom_update
156 |
157 | effective_batch_size = batch_size * num_sgd_steps_per_step
158 | super().__init__(
159 | actor=actor,
160 | learner=learner,
161 | min_observations=max(effective_batch_size, min_replay_size),
162 | observations_per_step=float(effective_batch_size) / samples_per_insert)
163 |
164 | # Save the replay so we don't garbage collect it.
165 | self._replay_server = replay_server
166 |
167 | def update(self):
168 | if self._custom_update_fn:
169 | self._custom_update_fn()
170 | else:
171 | super().update()
172 | if self._checkpointer:
173 | self._checkpointer.save()
174 |
175 | def save(self):
176 | """Checkpoint the state of the agent."""
177 | if self._checkpointer:
178 | self._checkpointer.save(force=True)
179 |
--------------------------------------------------------------------------------
/drlearner/core/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | from drlearner.core.loggers.image import ImageLogger, disable_view_window
2 |
--------------------------------------------------------------------------------
/drlearner/core/loggers/image.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from acme.utils.loggers.tf_summary import TFSummaryLogger
3 |
4 |
5 | def disable_view_window():
6 | """
7 | Disables gym view window
8 | """
9 | from gym.envs.classic_control import rendering
10 | org_constructor = rendering.Viewer.__init__
11 |
12 | def constructor(self, *args, **kwargs):
13 | org_constructor(self, *args, **kwargs)
14 | self.window.set_visible(visible=False)
15 |
16 | rendering.Viewer.__init__ = constructor
17 |
18 |
19 | class ImageLogger(TFSummaryLogger):
20 | def __init__(self, *args, **kwargs):
21 | super().__init__(*args, **kwargs)
22 |
23 | def write_image(self, name, image, step=0):
24 | with self.summary.as_default():
25 | tf.summary.image(name, [image], step=step)
--------------------------------------------------------------------------------
/drlearner/core/observers/__init__.py:
--------------------------------------------------------------------------------
1 | from .action_dist import ActionProbObserver
2 | from .discomaze_unique_states import UniqueStatesDiscoMazeObserver
3 | from .intrinsic_reward import IntrinsicRewardObserver
4 | from .meta_controller import MetaControllerObserver
5 | from .distillation_coef import DistillationCoefObserver
6 |
7 | from .video import VideoObserver, StorageVideoObserver
8 | from .actions import ActionsObserver
9 |
--------------------------------------------------------------------------------
/drlearner/core/observers/action_dist.py:
--------------------------------------------------------------------------------
1 | import dm_env
2 | import numpy as np
3 |
4 |
5 | class ActionProbObserver:
6 | def __init__(self, num_actions):
7 | self._num_actions = num_actions
8 | self._action_counter = None
9 |
10 | def observe_first(self, *args, **kwargs) -> None:
11 | # todo: defaultdict
12 | self._action_counter = {i: 0 for i in range(self._num_actions)}
13 |
14 | def observe(self, *args, **kwargs) -> None:
15 | env, timestamp, action, actor_extras = args
16 | self._action_counter[int(action)] += 1
17 |
18 | def get_metrics(self, **kwargs):
19 | total_actions = sum(self._action_counter.values())
20 | return {f'Action: {i}': self._action_counter[i] / total_actions for i in range(self._num_actions)}
21 |
--------------------------------------------------------------------------------
/drlearner/core/observers/actions.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from io import BytesIO
3 | from typing import Dict
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from matplotlib import pyplot as plt
9 | from PIL import Image
10 | import dm_env
11 |
12 | from drlearner.core.loggers import ImageLogger
13 |
14 |
15 | class ActionsObserver:
16 | def __init__(self, config):
17 | self.config = config
18 | self.image_logger = ImageLogger(config.logs_dir)
19 |
20 | self.unique_actions = set()
21 | self.ratios = list()
22 | self.episode_actions = list()
23 |
24 | def _log_episode_action(self, timestamp: dm_env.TimeStep, action: np.array):
25 | if not timestamp.last():
26 | self.episode_actions.append(action)
27 | else:
28 | episode_ratios = self.calculate_actions_ratio()
29 | self.ratios.append(episode_ratios)
30 |
31 | self.episode_actions = list()
32 |
33 | def observe(
34 | self,
35 | *args,
36 | **kwargs,
37 | ) -> None:
38 | env, timestamp, action, actor_extras = args
39 |
40 | action = np.asscalar(action)
41 | self.unique_actions.add(action)
42 |
43 | episode_count = kwargs['episode_count']
44 | log_action = episode_count % self.config.actions_log_period == 0
45 | if log_action:
46 | self._log_episode_action(timestamp, action)
47 |
48 | def calculate_actions_ratio(self) -> Dict:
49 | """
50 | Calculates actions ratio per episode
51 |
52 | Returns
53 | ratios: list of action ratios per action type
54 | unique_actions: set of possible actions
55 | """
56 | counter = Counter(self.episode_actions)
57 | episode_ratios = dict()
58 |
59 | for action in self.unique_actions:
60 | count = counter.get(action, 0)
61 | ratio = count / len(self.episode_actions)
62 | episode_ratios[str(action)] = ratio
63 |
64 | return episode_ratios
65 |
66 | def _plot_actions(self):
67 | """
68 | Creates actions plot and logs it into tensorboard
69 | """
70 | for action in self.unique_actions:
71 | values = [ratio.get(str(action), 0) for ratio in self.ratios]
72 |
73 | n = len(self.ratios)
74 | steps = list(range(n))
75 |
76 | plt.plot(steps, values, c=np.random.rand(3), label=action)
77 | plt.legend()
78 |
79 | plt.title('action ratios per episode')
80 | plt.xlabel('episode')
81 | plt.ylabel('action ratio')
82 |
83 | buffer = BytesIO()
84 | plt.savefig(buffer, format='png')
85 | img = Image.open(buffer)
86 | self.image_logger.write_image('actions_ratio', tf.convert_to_tensor(np.array(img)))
87 |
88 | def get_metrics(self, **kwargs):
89 | episode_count = kwargs['episode_count']
90 | last_episode = episode_count == self.config.num_episodes - 1
91 |
92 | if last_episode:
93 | self._plot_actions()
94 |
95 | return dict()
96 |
--------------------------------------------------------------------------------
/drlearner/core/observers/discomaze_unique_states.py:
--------------------------------------------------------------------------------
1 | import dm_env
2 | import numpy as np
3 |
4 | MAZE_PATH_COLOR = (0., 0., 0.)
5 | AGENT_COLOR = (1., 1., 1.)
6 |
7 |
8 | def mask_color_on_rgb(image, color) -> np.ndarray:
9 | """
10 | Given `image` of shape (H, W, C=3) and `color` pf shape (3,) return
11 | mask of shape (H, W) where pixel on the image have the same color as `color`
12 | """
13 | return np.isclose(image[..., 0], color[0]) & \
14 | np.isclose(image[..., 1], color[1]) & \
15 | np.isclose(image[..., 2], color[2])
16 |
17 |
18 | class UniqueStatesVisitsCounter:
19 | def __init__(self, total):
20 | self.__total = total
21 | self.__visited = set()
22 | self.__reward_first_visit = []
23 | self.__reward_repeated_visit = []
24 |
25 | def add(self, state, intrinsic_reward):
26 | coords = self.get_xy_from_state(state)
27 |
28 | if coords in self.__visited:
29 | self.__reward_repeated_visit.append(float(intrinsic_reward))
30 | else:
31 | self.__visited.add(coords)
32 | self.__reward_first_visit.append(float(intrinsic_reward))
33 |
34 | @staticmethod
35 | def get_xy_from_state(state):
36 | mask = mask_color_on_rgb(state, AGENT_COLOR)
37 | coords = np.where(mask)
38 | x, y = coords[1][0], coords[0][0]
39 | return x, y
40 |
41 | def get_number_of_visited(self):
42 | return len(self.__visited)
43 |
44 | def get_fraction_of_visited(self):
45 | return len(self.__visited) / self.__total
46 |
47 | def get_mean_first_visit_reward(self):
48 | return np.mean(self.__reward_first_visit)
49 |
50 | def get_mean_repeated_visit_reward(self):
51 | return np.mean(self.__reward_repeated_visit)
52 |
53 |
54 | class UniqueStatesDiscoMazeObserver:
55 | def __init__(self):
56 | self._state_visit_counter: UniqueStatesVisitsCounter = None
57 |
58 | def reset(self, states_total):
59 | self._state_visit_counter = UniqueStatesVisitsCounter(states_total)
60 |
61 | def observe_first(self, *args, **kwargs) -> None:
62 | env, timestamp, actor_extras = args
63 |
64 | states_total = mask_color_on_rgb(
65 | timestamp.observation.observation,
66 | color=MAZE_PATH_COLOR
67 | ).sum() + 1 # +1 for current position of agent
68 | # TODO: account for states with targets if needed; currently support only 0 targets case
69 | self.reset(states_total)
70 |
71 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep,
72 | action: np.ndarray, actor_extras, **kwargs) -> None:
73 | self._state_visit_counter.add(
74 | timestep.observation.observation,
75 | actor_extras['intrinsic_reward']
76 | )
77 |
78 | def get_metrics(self, **kwargs):
79 | metrics = {
80 | "unique_fraction": self._state_visit_counter.get_fraction_of_visited(),
81 | "first_visit_mean_reward": self._state_visit_counter.get_mean_first_visit_reward(),
82 | "repeated_visit_mean_reward": self._state_visit_counter.get_mean_repeated_visit_reward()
83 | }
84 | return metrics
85 |
--------------------------------------------------------------------------------
/drlearner/core/observers/distillation_coef.py:
--------------------------------------------------------------------------------
1 | import dm_env
2 | import numpy as np
3 |
4 |
5 | class DistillationCoefObserver:
6 | def __init__(self):
7 | self._alphas = None
8 |
9 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep, actor_extras, **kwargs) -> None:
10 | self._alphas = []
11 |
12 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep,
13 | action: np.ndarray, actor_extras, **kwargs) -> None:
14 | self._alphas.append(float(actor_extras['alpha']))
15 |
16 |
17 | def get_metrics(self, **kwargs):
18 | return {'Mean Distillation Alpha': np.mean(self._alphas)}
--------------------------------------------------------------------------------
/drlearner/core/observers/intrinsic_reward.py:
--------------------------------------------------------------------------------
1 | import dm_env
2 | import numpy as np
3 |
4 |
5 | class IntrinsicRewardObserver:
6 | def __init__(self):
7 | self._intrinsic_rewards = None
8 |
9 | def observe_first(self, *args, **kwargs) -> None:
10 | env, timestep, actor_extras = args
11 |
12 | self._intrinsic_rewards = []
13 | self._intrinsic_rewards.append(float(actor_extras['intrinsic_reward']))
14 |
15 | def observe(self, *args, **kwargs) -> None:
16 | env, timestep, action, actor_extras = args
17 | self._intrinsic_rewards.append(float(actor_extras['intrinsic_reward']))
18 |
19 | def get_metrics(self, **kwargs):
20 | return {
21 | "intrinsic_rewards_sum": np.sum(self._intrinsic_rewards),
22 | "intrinsic_rewards_mean": np.mean(self._intrinsic_rewards)
23 | }
24 |
--------------------------------------------------------------------------------
/drlearner/core/observers/lazy_dict.py:
--------------------------------------------------------------------------------
1 | from collections.abc import MutableMapping
2 | from threading import RLock
3 | from inspect import getfullargspec
4 | from copy import copy
5 |
6 | class LazyDictionaryError(Exception):
7 | pass
8 |
9 | class CircularReferenceError(LazyDictionaryError):
10 | pass
11 |
12 | class ConstantRedefinitionError(LazyDictionaryError):
13 | pass
14 |
15 | class LazyDictionary(MutableMapping):
16 | def __init__(self, values={ }):
17 | self.lock = RLock()
18 | self.values = copy(values)
19 | self.states = {}
20 | for key in self.values:
21 | self.states[key] = 'defined'
22 |
23 | def __len__(self):
24 | return len(self.values)
25 |
26 | def __iter__(self):
27 | return iter(self.values)
28 |
29 | def __getitem__(self, key):
30 | with self.lock:
31 | if key in self.states:
32 | if self.states[key] == 'evaluating':
33 | raise CircularReferenceError('value of "%s" depends on itself' % key)
34 | elif self.states[key] == 'error':
35 | raise self.values[key]
36 | elif self.states[key] == 'defined':
37 | value = self.values[key]
38 | if callable(value):
39 | args= getfullargspec(value).args
40 | if len(args) == 0:
41 | self.states[key] = 'evaluating'
42 | try:
43 | self.values[key] = value()
44 | except Exception as ex:
45 | self.values[key] = ex
46 | self.states[key] = 'error'
47 | raise ex
48 | elif len(args) == 1:
49 | self.states[key] = 'evaluating'
50 | try:
51 | self.values[key] = value(self)
52 | except Exception as ex:
53 | self.values[key] = ex
54 | self.states[key] = 'error'
55 | raise ex
56 | self.states[key] = 'evaluated'
57 | return self.values[key]
58 |
59 | def __contains__(self, key):
60 | return key in self.values
61 |
62 | def __setitem__(self, key, value):
63 | with self.lock:
64 | if key in self.states and self.states[key][0:4] == 'eval':
65 | raise ConstantRedefinitionError('"%s" is immutable' % key)
66 | self.values[key] = value
67 | self.states[key] = 'defined'
68 |
69 | def __delitem__(self, key):
70 | with self.lock:
71 | if key in self.states and self.states[key][0:4] == 'eval':
72 | raise ConstantRedefinitionError('"%s" is immutable' % key)
73 | del self.values[key]
74 | del self.states[key]
75 |
76 | def __str__(self):
77 | return str(self.values)
78 |
79 | def __repr__(self):
80 | return "LazyDictionary({0})".format(repr(self.values))
81 |
--------------------------------------------------------------------------------
/drlearner/core/observers/meta_controller.py:
--------------------------------------------------------------------------------
1 | import dm_env
2 | import numpy as np
3 |
4 | class MetaControllerObserver:
5 | def __init__(self):
6 | self._mixture_indices = None
7 | self._is_eval = None
8 |
9 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep, actor_extras, **kwargs) -> None:
10 | self._mixture_indices = int(actor_extras['mixture_idx'])
11 | self._is_eval = int(actor_extras['is_eval'])
12 |
13 |
14 | def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep,
15 | action: np.ndarray, actor_extras, **kwargs) -> None:
16 | pass
17 |
18 | def get_metrics(self, **kwargs):
19 | return {
20 | 'mixture_idx': self._mixture_indices,
21 | 'is_eval': self._is_eval
22 | }
--------------------------------------------------------------------------------
/drlearner/core/observers/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC, abstractmethod
3 | import platform
4 |
5 | from skvideo import io
6 | import numpy as np
7 | import tensorflow as tf
8 | import dm_env
9 |
10 | from drlearner.core.loggers import ImageLogger
11 | from drlearner.core.observers.lazy_dict import LazyDictionary
12 |
13 |
14 | class VideoObserver(ABC):
15 | def __init__(self, config):
16 | self.config = config
17 | self.env_library = config.env_library
18 | self.log_period = config.video_log_period
19 |
20 | self.platform = platform.system().lower()
21 |
22 | def render(self, env):
23 | """
24 | Renders current frame
25 | """
26 | render_funcs = LazyDictionary(
27 | {
28 | 'dm_control': lambda: env.physics.render(camera_id=0),
29 | 'gym': lambda: env.environment.render(mode='rgb_array'),
30 | 'discomaze': lambda: env.render(mode='state_pixels'),
31 | },
32 | )
33 | env_lib = self.env_library
34 |
35 | if env_lib in render_funcs.keys():
36 | return render_funcs[env_lib]
37 | else:
38 | raise ValueError(
39 | f"Unknown environment library: {env_lib}; choose among {list(render_funcs.keys())}",
40 | )
41 |
42 | def _log_video(self, episode_count):
43 | return True if (episode_count + 1) % self.log_period == 0 else False
44 |
45 | @abstractmethod
46 | def observe(self, env: dm_env.Environment, *args, **kwargs):
47 | pass
48 |
49 |
50 | class StorageVideoObserver(VideoObserver):
51 | def __init__(self, config):
52 | print(f'INIT: {self.__class__.__name__}')
53 | super().__init__(config)
54 |
55 | self.frames = list()
56 | self.videos_dir = self._create_videos_dir()
57 |
58 | def observe(self, env: dm_env.Environment, *args, **kwargs):
59 | frame = self.render(env)
60 | self.frames.append(frame.astype('uint8'))
61 |
62 | def get_metrics(self, **kwargs):
63 | episode = kwargs['episode']
64 |
65 | if self._log_video(episode):
66 | video_dir = os.path.join(self.videos_dir, f'episode_{episode + 1}.mp4')
67 | io.vwrite(video_dir, np.array(self.frames))
68 |
69 | self.frames = list()
70 |
71 | return dict()
72 |
73 | def _create_videos_dir(self):
74 | video_dir = os.path.join(self.config.logs_dir, 'episodes')
75 |
76 | if not os.path.exists(video_dir):
77 | os.makedirs(video_dir, exist_ok=True)
78 |
79 | return video_dir
80 |
81 | observe_first = observe
82 |
83 |
84 | class TBVideoObserver(VideoObserver):
85 | def __init__(self, config):
86 | super().__init__(config)
87 | self.image_logger = ImageLogger(config.logs_dir)
88 |
89 | def log_frame(self, env: dm_env.Environment, episode=None, step=None):
90 | frame = self.render(env)
91 |
92 | self.image_logger.write_image(
93 | f'video_{episode + 1}',
94 | tf.convert_to_tensor(np.array(frame)),
95 | step=step,
96 | )
97 |
98 | def observe(self, env: dm_env.Environment, *args, **kwargs) -> None:
99 | episode = kwargs['episode']
100 | step = kwargs['step']
101 |
102 | if self._log_video(episode):
103 | self.log_frame(env, episode=episode, step=step)
104 |
105 | observe_first = observe
--------------------------------------------------------------------------------
/drlearner/drlearner/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent import DRLearner
2 | from .builder import DRLearnerBuilder
3 | from .config import DRLearnerConfig
4 | from .distributed_agent import DistributedDRLearnerFromConfig
5 | from .learning import DRLearnerLearner
6 | from .networks import DRLearnerNetworks
7 | from .networks import make_policy_networks
8 | from .networks import networks_zoo
9 |
--------------------------------------------------------------------------------
/drlearner/drlearner/actor.py:
--------------------------------------------------------------------------------
1 | """DRLearner JAX actors."""
2 | from typing import Optional
3 |
4 | import dm_env
5 | import jax
6 | import jax.numpy as jnp
7 | from acme import adders
8 | from acme import types
9 | from acme.agents.jax import actors
10 | from acme.jax import networks as network_lib
11 | from acme.jax import utils
12 | from acme.jax import variable_utils
13 |
14 | from .actor_core import DRLearnerActorCore
15 |
16 |
17 | class DRLearnerActor(actors.GenericActor):
18 | """A generic actor implemented on top of ActorCore.
19 |
20 | An actor based on a policy which takes observations and outputs actions. It
21 | also adds experiences to replay and updates the actor weights from the policy
22 | on the learner.
23 | """
24 |
25 | def __init__(
26 | self,
27 | actor: DRLearnerActorCore,
28 | mixture_idx: int,
29 | random_key: network_lib.PRNGKey,
30 | variable_client: Optional[variable_utils.VariableClient],
31 | adder: Optional[adders.Adder] = None,
32 | jit: bool = True,
33 | backend: Optional[str] = 'cpu',
34 | per_episode_update: bool = False
35 | ):
36 | """Initializes a feed forward actor.
37 |
38 | Args:
39 | actor: actor core.
40 | random_key: Random key.
41 | variable_client: The variable client to get policy parameters from.
42 | adder: An adder to add experiences to.
43 | jit: Whether to jit the passed ActorCore's pure functions.
44 | backend: Which backend to use when jitting the policy.
45 | per_episode_update: if True, updates variable client params once at the
46 | beginning of each episode
47 | """
48 | super(DRLearnerActor, self).__init__(actor, random_key, variable_client, adder, jit, backend, per_episode_update)
49 | if jit:
50 | self._observe = jax.jit(actor.observe)
51 | # self._observe_first = jax.jit(actor.observe_first)
52 | else:
53 | self._observe = actor.observe
54 | self._observe_first = actor.observe_first
55 |
56 | self._mixture_idx = jnp.array(mixture_idx, dtype=jnp.int32)
57 |
58 | def select_action(self,
59 | observation: network_lib.Observation) -> types.NestedArray:
60 | action, self._state = self._policy(self._params, observation, self._state)
61 | return utils.to_numpy(action)
62 |
63 | def observe_first(self, timestep: dm_env.TimeStep):
64 | self._random_key, key = jax.random.split(self._random_key)
65 | self._state = self._init(key, self._mixture_idx, self._state)
66 | self._state = self._observe_first(self._params, timestep, self._state)
67 | if self._adder:
68 | self._adder.add_first(timestep)
69 | if self._variable_client and self._per_episode_update:
70 | self._variable_client.update_and_wait()
71 |
72 | def observe(self, action: network_lib.Action, next_timestep: dm_env.TimeStep):
73 | self._state = self._observe(self._params, action, next_timestep, self._state)
74 | super(DRLearnerActor, self).observe(action, next_timestep)
75 |
76 | def get_extras(self):
77 | return self._get_extras(self._state)
--------------------------------------------------------------------------------
/drlearner/drlearner/agent.py:
--------------------------------------------------------------------------------
1 | """Defines local DRLearner agent, using JAX."""
2 |
3 | from typing import Optional
4 |
5 | from acme import specs
6 | from acme.utils import counting
7 |
8 | from ..core import local_layout
9 | from .builder import DRLearnerBuilder
10 | from .config import DRLearnerConfig
11 | from .networks import make_policy_networks, DRLearnerNetworks
12 |
13 |
14 | class DRLearner(local_layout.LocalLayout):
15 | """Local agent for DRLearner.
16 |
17 | This implements a single-process DRLearner agent.
18 | """
19 |
20 | def __init__(
21 | self,
22 | spec: specs.EnvironmentSpec,
23 | networks: DRLearnerNetworks,
24 | config: DRLearnerConfig,
25 | seed: int,
26 | workdir: Optional[str] = '~/acme',
27 | counter: Optional[counting.Counter] = None,
28 | logger=None
29 | ):
30 | ngu_builder = DRLearnerBuilder(networks, config, num_actors_per_mixture=1,logger=logger)
31 | super().__init__(
32 | seed=seed,
33 | environment_spec=spec,
34 | builder=ngu_builder,
35 | networks=networks,
36 | policy_network=make_policy_networks(networks, config),
37 | workdir=workdir,
38 | min_replay_size=config.min_replay_size,
39 | samples_per_insert=config.samples_per_insert if config.samples_per_insert \
40 | else 10 / (config.burn_in_length + config.trace_length),
41 | batch_size=config.batch_size,
42 | num_sgd_steps_per_step=config.num_sgd_steps_per_step,
43 | counter=counter,
44 | )
45 |
46 | def get_extras(self):
47 | return self._actor.get_extras()
48 |
--------------------------------------------------------------------------------
/drlearner/drlearner/builder.py:
--------------------------------------------------------------------------------
1 | """DRLearner Builder."""
2 | from typing import Callable, Iterator, List, Optional
3 | from copy import deepcopy
4 | import functools
5 |
6 | import acme
7 | import jax
8 | import jax.numpy as jnp
9 | import optax
10 | import reverb
11 | import tensorflow as tf
12 | from acme import adders
13 | from acme import core
14 | from acme import specs
15 | from acme.adders import reverb as adders_reverb
16 | from acme.agents.jax import builders
17 | from acme.datasets import reverb as datasets
18 | from acme.jax import networks as networks_lib
19 | from acme.jax import utils
20 | from acme.jax import variable_utils
21 | from acme.utils import counting
22 | from acme.utils import loggers
23 |
24 | from .config import DRLearnerConfig
25 | from .actor import DRLearnerActor
26 | from .actor_core import get_actor_core
27 | from .learning import DRLearnerLearner
28 | from .networks import DRLearnerNetworks
29 |
30 | # run CPU-only tensorflow for data loading
31 | tf.config.set_visible_devices([], "GPU")
32 |
33 |
34 | class DRLearnerBuilder(builders.ActorLearnerBuilder):
35 | """DRLearner Builder.
36 |
37 | """
38 |
39 | def __init__(self,
40 | networks: DRLearnerNetworks,
41 | config: DRLearnerConfig,
42 | num_actors_per_mixture: int,
43 | logger: Callable[[], loggers.Logger] = lambda: None, ):
44 | """Creates DRLearner learner, a behavior policy and an eval actor.
45 |
46 | Args:
47 | networks: DRLearner networks, used to build core state spec.
48 | config: a config with DRLearner hps
49 | logger: a logger for the learner
50 | """
51 | self._networks = networks
52 | self._config = config
53 | self._num_actors_per_mixture = num_actors_per_mixture
54 | self._logger_fn = logger
55 |
56 | # Sequence length for dataset iterator.
57 | self._sequence_length = (
58 | self._config.burn_in_length + self._config.trace_length + 1)
59 |
60 | # Construct the core state spec.
61 | dummy_key = jax.random.PRNGKey(0)
62 | intrinsic_initial_state_params = networks.uvfa_net.initial_state.init(dummy_key, 1)
63 | intrinsic_initial_state = networks.uvfa_net.initial_state.apply(intrinsic_initial_state_params,
64 | dummy_key, 1)
65 | extrinsic_initial_state_params = networks.uvfa_net.initial_state.init(dummy_key, 1)
66 | extrinsic_initial_state = networks.uvfa_net.initial_state.apply(extrinsic_initial_state_params,
67 | dummy_key, 1)
68 | intrinsic_core_state_spec = utils.squeeze_batch_dim(intrinsic_initial_state)
69 | extrinsic_core_state_spec = utils.squeeze_batch_dim(extrinsic_initial_state)
70 | self._extra_spec = {
71 | 'intrinsic_core_state': intrinsic_core_state_spec,
72 | 'extrinsic_core_state': extrinsic_core_state_spec
73 | }
74 |
75 | def evaluate_logger(self):
76 | if isinstance(self._logger_fn, functools.partial):
77 | self._logger_fn=self._logger_fn()
78 |
79 | def make_learner(
80 | self,
81 | random_key: networks_lib.PRNGKey,
82 | networks: DRLearnerNetworks,
83 | dataset: Iterator[reverb.ReplaySample],
84 | replay_client: Optional[reverb.Client] = None,
85 | counter: Optional[counting.Counter] = None,
86 | ) -> core.Learner:
87 | # The learner updates the parameters (and initializes them).
88 | self.evaluate_logger()
89 | return DRLearnerLearner(
90 | uvfa_unroll=networks.uvfa_net.unroll,
91 | uvfa_initial_state=networks.uvfa_net.initial_state,
92 | idm_action_pred=networks.embedding_net.predict_action,
93 | distillation_embed=networks.distillation_net.embed_sequence,
94 | batch_size=self._config.batch_size,
95 | random_key=random_key,
96 | burn_in_length=self._config.burn_in_length,
97 | beta_min=self._config.beta_min,
98 | beta_max=self._config.beta_max,
99 | gamma_min=self._config.gamma_min,
100 | gamma_max=self._config.gamma_max,
101 | num_mixtures=self._config.num_mixtures,
102 | target_epsilon=self._config.target_epsilon,
103 | importance_sampling_exponent=(
104 | self._config.importance_sampling_exponent),
105 | max_priority_weight=self._config.max_priority_weight,
106 | target_update_period=self._config.target_update_period,
107 | iterator=dataset,
108 | uvfa_optimizer=optax.adam(self._config.uvfa_learning_rate),
109 | idm_optimizer=optax.adamw(self._config.idm_learning_rate,
110 | weight_decay=self._config.idm_weight_decay),
111 | distillation_optimizer=optax.adamw(self._config.distillation_learning_rate,
112 | weight_decay=self._config.distillation_weight_decay),
113 | idm_clip_steps=self._config.idm_clip_steps,
114 | distillation_clip_steps=self._config.distillation_clip_steps,
115 | retrace_lambda=self._config.retrace_lambda,
116 | tx_pair=self._config.tx_pair,
117 | clip_rewards=self._config.clip_rewards,
118 | max_abs_reward=self._config.max_absolute_reward,
119 | replay_client=replay_client,
120 | counter=counter,
121 | logger=self._logger_fn)
122 |
123 | def make_replay_tables(
124 | self,
125 | environment_spec: specs.EnvironmentSpec,
126 | ) -> List[reverb.Table]:
127 | """Create tables to insert data into."""
128 | if self._config.samples_per_insert:
129 | samples_per_insert_tolerance = (
130 | self._config.samples_per_insert_tolerance_rate *
131 | self._config.samples_per_insert)
132 | error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
133 | limiter = reverb.rate_limiters.SampleToInsertRatio(
134 | min_size_to_sample=self._config.min_replay_size,
135 | samples_per_insert=self._config.samples_per_insert,
136 | error_buffer=error_buffer)
137 | else:
138 | limiter = reverb.rate_limiters.MinSize(1)
139 |
140 | # add intrinsic rewards and mixture_idx (intrinsic reward beta) to extra_specs
141 | self._extra_spec['intrinsic_reward'] = specs.Array(
142 | shape=environment_spec.rewards.shape,
143 | dtype=jnp.float32,
144 | name='intrinsic_reward'
145 | )
146 | self._extra_spec['mixture_idx'] = specs.Array(
147 | shape=environment_spec.rewards.shape,
148 | dtype=jnp.int32,
149 | name='mixture_idx'
150 | )
151 | # add probability of action under behavior policy
152 | self._extra_spec['behavior_action_prob'] = specs.Array(
153 | shape=environment_spec.rewards.shape,
154 | dtype=jnp.float32,
155 | name='behavior_action_prob'
156 | )
157 |
158 | # add the mode of evaluator
159 | self._extra_spec['is_eval'] = specs.Array(
160 | shape=environment_spec.rewards.shape,
161 | dtype=jnp.int32,
162 | name='is_eval'
163 | )
164 |
165 | self._extra_spec['alpha'] = specs.Array(
166 | shape=environment_spec.rewards.shape,
167 | dtype=jnp.float32,
168 | name='alpha'
169 | )
170 |
171 |
172 | return [
173 | reverb.Table(
174 | name=self._config.replay_table_name,
175 | sampler=reverb.selectors.Prioritized(
176 | self._config.priority_exponent),
177 | remover=reverb.selectors.Fifo(),
178 | max_size=self._config.max_replay_size,
179 | rate_limiter=limiter,
180 | signature=adders_reverb.SequenceAdder.signature(
181 | environment_spec, self._extra_spec))
182 | ]
183 |
184 | def make_dataset_iterator(
185 | self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
186 | """Create a dataset iterator to use for learning/updating the agent."""
187 | dataset = datasets.make_reverb_dataset(
188 | table=self._config.replay_table_name,
189 | server_address=replay_client.server_address,
190 | batch_size=self._config.batch_size,
191 | prefetch_size=self._config.prefetch_size,
192 | num_parallel_calls=self._config.num_parallel_calls)
193 | return dataset.as_numpy_iterator()
194 |
195 | def make_adder(self,
196 | replay_client: reverb.Client) -> Optional[adders.Adder]:
197 | """Create an adder which records data generated by the actor/environment."""
198 | return adders_reverb.SequenceAdder(
199 | client=replay_client,
200 | period=self._config.sequence_period,
201 | sequence_length=self._sequence_length,
202 | delta_encoded=True)
203 |
204 | def make_actor(
205 | self,
206 | random_key: networks_lib.PRNGKey,
207 | policy_networks,
208 | adder: Optional[adders.Adder] = None,
209 | variable_source: Optional[core.VariableSource] = None,
210 | actor_id: int = 0,
211 | is_evaluator: bool = False,
212 | ) -> acme.Actor:
213 |
214 | # Create variable client.
215 | variable_client = variable_utils.VariableClient(
216 | variable_source,
217 | key='actor_variables',
218 | update_period=self._config.variable_update_period)
219 | variable_client.update_and_wait()
220 |
221 | intrinsic_initial_state_key1, intrinsic_initial_state_key2, \
222 | extrinsic_initial_state_key1, extrinsic_initial_state_key2, random_key = jax.random.split(random_key, 5)
223 | intrinsic_actor_initial_state_params = self._networks.uvfa_net.initial_state.init(
224 | intrinsic_initial_state_key1, 1)
225 | intrinsic_actor_initial_state = self._networks.uvfa_net.initial_state.apply(
226 | intrinsic_actor_initial_state_params, intrinsic_initial_state_key2, 1)
227 | extrinsic_actor_initial_state_params = self._networks.uvfa_net.initial_state.init(
228 | extrinsic_initial_state_key1, 1)
229 | extrinsic_actor_initial_state = self._networks.uvfa_net.initial_state.apply(
230 | extrinsic_actor_initial_state_params, extrinsic_initial_state_key2, 1)
231 |
232 | config = deepcopy(self._config)
233 | if is_evaluator:
234 | config.window = self._config.evaluation_window
235 | config.epsilon = self._config.evaluation_epsilon
236 | config.mc_epsilon = self._config.evaluation_mc_epsilon
237 | else:
238 | config.window = self._config.actor_window
239 | config.epsilon = self._config.actor_epsilon
240 | config.mc_epsilon = self._config.actor_mc_epsilon
241 |
242 |
243 | actor_core = get_actor_core(policy_networks,
244 | intrinsic_actor_initial_state,
245 | extrinsic_actor_initial_state,
246 | actor_id,
247 | self._num_actors_per_mixture,
248 | config,
249 | jit=True)
250 |
251 | mixture_idx = actor_id // self._num_actors_per_mixture
252 |
253 |
254 | return DRLearnerActor(
255 | actor_core, mixture_idx, random_key, variable_client, adder, backend='cpu', jit=True)
256 |
--------------------------------------------------------------------------------
/drlearner/drlearner/config.py:
--------------------------------------------------------------------------------
1 | """DRLearner config."""
2 | import dataclasses
3 |
4 | import rlax
5 | from acme.adders import reverb as adders_reverb
6 |
7 |
8 | @dataclasses.dataclass
9 | class DRLearnerConfig:
10 | """Configuration options for DRLearner agent."""
11 | gamma_min: float = 0.99
12 | gamma_max: float = 0.997
13 | num_mixtures: int = 32
14 | target_update_period: int = 400
15 | evaluation_epsilon: float = 0.01
16 | epsilon: float = 0.01
17 | actor_epsilon: float = 0.01
18 | target_epsilon: float = 0.01
19 | variable_update_period: int = 400
20 |
21 | # Learner options
22 | retrace_lambda: float = 0.95
23 | burn_in_length: int = 40
24 | trace_length: int = 80
25 | sequence_period: int = 40
26 | num_sgd_steps_per_step: int = 1
27 | uvfa_learning_rate: float = 1e-4
28 | idm_learning_rate: float = 5e-4
29 | distillation_learning_rate: float = 5e-4
30 | idm_weight_decay: float = 1e-5
31 | distillation_weight_decay: float = 1e-5
32 | idm_clip_steps: int = 5
33 | distillation_clip_steps: int = 5
34 | clip_rewards: bool = False
35 | max_absolute_reward: float = 1.0
36 | tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR
37 | distillation_moving_average_coef: float = 1e-3
38 |
39 | # Intrinsic reward multipliers
40 | beta_min: float = 0.
41 | beta_max: float = 0.3
42 |
43 | # Embedding network options
44 | observation_embed_dim: int = 128
45 | episodic_memory_num_neighbors: int = 10
46 | episodic_memory_max_size: int = 30_000
47 | episodic_memory_max_similarity: float = 8.
48 | episodic_memory_cluster_distance: float = 8e-3
49 | episodic_memory_pseudo_counts: float = 1e-3
50 | episodic_memory_epsilon: float = 1e-4
51 |
52 | # Distillation network
53 | distillation_embed_dim: int = 128
54 | max_lifelong_modulation: float = 5.0
55 |
56 | # Replay options
57 | samples_per_insert_tolerance_rate: float = 0.1
58 | samples_per_insert: float = 4.0
59 | min_replay_size: int = 50_000
60 | max_replay_size: int = 100_000
61 | batch_size: int = 64
62 | prefetch_size: int = 2
63 | num_parallel_calls: int = 16
64 | replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE
65 |
66 | # Priority options
67 | importance_sampling_exponent: float = 0.6
68 | priority_exponent: float = 0.9
69 | max_priority_weight: float = 0.9
70 |
71 | # Meta Controller options
72 | window: int = 160
73 | actor_window: int = 160
74 | evaluation_window: int = 3600
75 | n_arms: int = 32
76 | mc_epsilon: float = 0.5 # Value is set from actor_mc_espilon or evaluation_mc_epsilon depending on whether the actor acts as evaluator
77 | actor_mc_epsilon: float = 0.5
78 | evaluation_mc_epsilon: float = 0.01
79 | mc_beta: float = 1.
80 |
81 | # Agent's video logging options
82 | env_library: str = None
83 | video_log_period: int = 10
84 | actions_log_period: int = 1
85 | logs_dir: str = 'experiments/default'
86 | num_episodes: int = 50
87 |
88 |
--------------------------------------------------------------------------------
/drlearner/drlearner/distributed_agent.py:
--------------------------------------------------------------------------------
1 | """Defines distributed DRLearner agent, using JAX."""
2 |
3 | import functools
4 | from typing import Callable, Optional, Sequence
5 |
6 | import dm_env
7 | from acme import specs
8 | from acme.jax import utils
9 | from acme.utils import loggers
10 |
11 | from ..core import distributed_layout
12 | from .config import DRLearnerConfig
13 | from .builder import DRLearnerBuilder
14 | from .networks import DRLearnerNetworks, make_policy_networks
15 |
16 | NetworkFactory = Callable[[specs.EnvironmentSpec], DRLearnerNetworks]
17 | EnvironmentFactory = Callable[[int], dm_env.Environment]
18 |
19 |
20 | class DistributedDRLearnerFromConfig(distributed_layout.DistributedLayout):
21 | """Distributed DRLearner agents from config."""
22 |
23 | def __init__(
24 | self,
25 | environment_factory: EnvironmentFactory,
26 | environment_spec: specs.EnvironmentSpec,
27 | network_factory: NetworkFactory,
28 | config: DRLearnerConfig,
29 | seed: int,
30 | num_actors_per_mixture: int,
31 | workdir: str = '~/acme',
32 | device_prefetch: bool = False,
33 | log_to_bigtable: bool = True,
34 | log_every: float = 10.0,
35 | # TODO: Refactor: `max_episodes` and `max_steps`` sould be defined on the experiment level,
36 | # not on the agent level, similarly to other experiment related abstractions
37 | max_episodes: Optional[int] = None,
38 | max_steps: Optional[int] = None,
39 | evaluator_factories: Optional[Sequence[
40 | distributed_layout.EvaluatorFactory]] = None,
41 | actor_observers=(),
42 | evaluator_observers=(),
43 | learner_logger_fn: Optional[Callable[[], loggers.Logger]] = None,
44 | multithreading_colocate_learner_and_reverb: bool = False
45 | ):
46 | learner_logger_fn = learner_logger_fn or functools.partial(loggers.make_default_logger,
47 | 'learner', log_to_bigtable,
48 | time_delta=log_every, asynchronous=True,
49 | serialize_fn=utils.fetch_devicearray,
50 | steps_key='learner_steps')
51 | drlearner_builder = DRLearnerBuilder(
52 | networks=network_factory(environment_spec),
53 | config=config,
54 | num_actors_per_mixture=num_actors_per_mixture,
55 | logger=learner_logger_fn)
56 | policy_network_factory = (
57 | lambda networks: make_policy_networks(networks, config))
58 | if evaluator_factories is None:
59 | evaluator_policy_network_factory = (
60 | lambda networks: make_policy_networks(networks, config, evaluation=True))
61 | evaluator_factories = [
62 | distributed_layout.default_evaluator_factory(
63 | environment_factory=environment_factory,
64 | network_factory=network_factory,
65 | policy_factory=evaluator_policy_network_factory,
66 | log_to_bigtable=log_to_bigtable,
67 | observers=evaluator_observers
68 | )
69 | ]
70 | super().__init__(
71 | seed=seed,
72 | environment_factory=environment_factory,
73 | network_factory=network_factory,
74 | builder=drlearner_builder,
75 | policy_network=policy_network_factory,
76 | evaluator_factories=evaluator_factories,
77 | num_actors=num_actors_per_mixture * config.num_mixtures,
78 | environment_spec=environment_spec,
79 | device_prefetch=device_prefetch,
80 | log_to_bigtable=log_to_bigtable,
81 | max_episodes=max_episodes,
82 | max_steps=max_steps,
83 | actor_logger_fn=distributed_layout.get_default_logger_fn(
84 | log_to_bigtable, log_every),
85 | prefetch_size=config.prefetch_size,
86 | checkpointing_config=distributed_layout.CheckpointingConfig(
87 | directory=workdir, add_uid=(workdir == '~/acme')),
88 | observers=actor_observers,
89 | multithreading_colocate_learner_and_reverb=multithreading_colocate_learner_and_reverb
90 | )
91 |
--------------------------------------------------------------------------------
/drlearner/drlearner/drlearner_types.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Optional
2 |
3 | import chex
4 | import jax.numpy as jnp
5 | import optax
6 | from acme.agents.jax import actor_core as actor_core_lib
7 | from acme.jax import networks as networks_lib
8 | from rlax._src.exploration import IntrinsicRewardState
9 |
10 | from .lifelong_curiosity import LifelongCuriosityModulationState
11 |
12 |
13 | @chex.dataclass(frozen=True, mappable_dataclass=False)
14 | class DRLearnerNetworksParams:
15 | """Collection of all parameters of Neural Networks used by DRLearner Agent"""
16 | intrinsic_uvfa_params: networks_lib.Params # intrinsic Universal Value-Function Approximator
17 | extrinsic_uvfa_params: networks_lib.Params # extrinsic Universal Value-Function Approximator
18 | intrinsic_uvfa_target_params: networks_lib.Params # intrinsic UVFA target network
19 | extrinsic_uvfa_target_params: networks_lib.Params # extrinsic UVFA target network
20 | idm_params: networks_lib.Params # Inverse Dynamics Model
21 | distillation_params: networks_lib.Params # Distillation Network
22 | distillation_random_params: networks_lib.Params # Random Distillation Network
23 |
24 |
25 | @chex.dataclass(frozen=True, mappable_dataclass=False)
26 | class DRLearnerNetworksOptStates:
27 | """Collection of optimizer states for all networks trained by the Learner"""
28 | intrinsic_uvfa_opt_state: optax.OptState
29 | extrinsic_uvfa_opt_state: optax.OptState
30 | idm_opt_state: optax.OptState
31 | distillation_opt_state: optax.OptState
32 |
33 |
34 | class TrainingState(NamedTuple):
35 | """DRLearner Learner training state"""
36 | params: DRLearnerNetworksParams
37 | opt_state: DRLearnerNetworksOptStates
38 | steps: jnp.ndarray
39 | random_key: networks_lib.PRNGKey
40 |
41 |
42 | @chex.dataclass(frozen=True, mappable_dataclass=False)
43 | class MetaControllerState:
44 | episode_returns_history: jnp.ndarray
45 | episode_count: jnp.ndarray
46 | current_episode_return: jnp.ndarray
47 | mixture_idx_history: jnp.ndarray
48 | beta: jnp.ndarray
49 | gamma: jnp.ndarray
50 | is_eval: bool
51 | num_eval_episodes: jnp.ndarray
52 |
53 |
54 | @chex.dataclass(frozen=True, mappable_dataclass=False)
55 | class DRLearnerActorState:
56 | rng: networks_lib.PRNGKey
57 | epsilon: jnp.ndarray
58 | mixture_idx: jnp.ndarray
59 | intrinsic_recurrent_state: actor_core_lib.RecurrentState
60 | extrinsic_recurrent_state: actor_core_lib.RecurrentState
61 | prev_intrinsic_reward: jnp.ndarray
62 | prev_action_prob: jnp.ndarray
63 | prev_alpha: jnp.ndarray
64 | meta_controller_state: MetaControllerState
65 | lifelong_modulation_state: Optional[LifelongCuriosityModulationState] = None
66 | intrinsic_reward_state: Optional[IntrinsicRewardState] = None
67 |
68 |
--------------------------------------------------------------------------------
/drlearner/drlearner/lifelong_curiosity.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import chex
4 | import jax.numpy as jnp
5 |
6 |
7 | @chex.dataclass
8 | class LifelongCuriosityModulationState:
9 | distance_mean: chex.Scalar = 0.
10 | distance_var: chex.Scalar = 1.
11 |
12 |
13 | def lifelong_curiosity_modulation(
14 | learnt_embeddings: chex.Array,
15 | random_embeddings: chex.Array,
16 | max_modulation: float = 5.0,
17 | lifelong_modulation_state: Optional[LifelongCuriosityModulationState] = None,
18 | ma_coef: float = 0.0001):
19 | if not lifelong_modulation_state:
20 | lifelong_modulation_state = LifelongCuriosityModulationState()
21 |
22 | error = jnp.sum((learnt_embeddings - random_embeddings) ** 2, axis=-1)
23 |
24 | distance_mean = lifelong_modulation_state.distance_mean
25 | distance_var = lifelong_modulation_state.distance_var
26 | # exponentially weighted moving average and std
27 | distance_var = (1 - ma_coef) * (distance_var + ma_coef * jnp.mean(error - distance_mean) ** 2)
28 | distance_mean = ma_coef * jnp.mean(error) + (1 - ma_coef) * distance_mean
29 |
30 | alpha = 1 + (error - distance_mean) / jnp.sqrt(distance_var)
31 | alpha = jnp.clip(alpha, 1., max_modulation)
32 |
33 | lifelong_modulation_state = LifelongCuriosityModulationState(
34 | distance_mean=distance_mean,
35 | distance_var=distance_var
36 | )
37 |
38 | return alpha, lifelong_modulation_state
39 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .networks import DRLearnerNetworks
2 | from .policy_networks import DRLearnerPolicyNetworks
3 | from .policy_networks import make_policy_networks
4 | from .uvfa_network import UVFANetworkInput
5 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/distillation_network.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import haiku as hk
4 | import jax
5 | from acme.jax import networks as networks_lib
6 | from acme.jax import utils
7 | from acme.wrappers.observation_action_reward import OAR
8 |
9 |
10 | @dataclasses.dataclass
11 | class DistillationNetwork:
12 | """Pure functions for DRLearner distillation network"""
13 | embed: networks_lib.FeedForwardNetwork
14 | embed_sequence: networks_lib.FeedForwardNetwork
15 |
16 |
17 | def make_distillation_net(
18 | make_distillation_modules,
19 | env_spec) -> DistillationNetwork:
20 | def embed_fn(observation: OAR) -> networks_lib.NetworkOutput:
21 | """
22 | Embed batch of observations
23 | Args:
24 | observation: jnp.array representing a batch of observations [B, ...]
25 |
26 | Returns:
27 | embedding vectors [B, D]
28 | """
29 | embedding_torso = make_distillation_modules()
30 | return embedding_torso(observation.observation)
31 |
32 | # transform functions
33 | embed_hk = hk.transform(embed_fn)
34 |
35 | # create dummy batches for networks initialization
36 | observation_batch = utils.add_batch_dim(
37 | utils.zeros_like(env_spec.observations)
38 | ) # [B=1, ...]
39 |
40 | def embed_init(rng):
41 | return embed_hk.init(rng, observation_batch)
42 |
43 | embed = networks_lib.FeedForwardNetwork(
44 | init=embed_init, apply=embed_hk.apply
45 | )
46 | embed_sequence = networks_lib.FeedForwardNetwork(
47 | init=embed_init,
48 | # vmap over 1-st parameter: apply(params, random_key, data, ...)
49 | apply=jax.vmap(embed_hk.apply, in_axes=(None, None, 0), out_axes=0)
50 | )
51 |
52 | return DistillationNetwork(embed, embed_sequence)
53 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/embedding_network.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import haiku as hk
4 | import jax.numpy as jnp
5 | from acme.jax import networks as networks_lib
6 | from acme.jax import utils
7 | from acme.wrappers.observation_action_reward import OAR
8 |
9 |
10 | @dataclasses.dataclass
11 | class EmbeddingNetwork:
12 | """Pure functions for DRLearner embedding network"""
13 | predict_action: networks_lib.FeedForwardNetwork
14 | embed: networks_lib.FeedForwardNetwork
15 |
16 |
17 | def make_embedding_net(
18 | make_embedding_modules,
19 | env_spec) -> EmbeddingNetwork:
20 | def embed_fn(observation: OAR) -> networks_lib.NetworkOutput:
21 | """
22 | Embed batch of observations
23 | Args:
24 | observation: jnp.array representing a batch of observations [B, ...]
25 |
26 | Returns:
27 | embedding vectors [B, D]
28 | """
29 | embedding_torso, _ = make_embedding_modules()
30 | return embedding_torso(observation.observation)
31 |
32 | def predict_action_fn(observation_tm1: OAR, observation_t: OAR) -> networks_lib.NetworkOutput:
33 | """
34 | Embed batch of sequences two consecutive observations x_{t_1} and x{t} and predict batch of actions a_{t-1}
35 | Args:
36 | observation_tm1: observation x_{t-1} [T, B, ...]
37 | observation_t: observation x_{t} [T, B, ...]
38 |
39 | Returns:
40 | prediction logits for discrete action a_{t_1} [T, B, N]
41 | """
42 | embedding_torso, pred_head = make_embedding_modules()
43 | emb_tm1 = hk.BatchApply(embedding_torso)(observation_tm1.observation)
44 | emb_t = hk.BatchApply(embedding_torso)(observation_t.observation)
45 | return hk.BatchApply(pred_head)(jnp.concatenate([emb_tm1, emb_t], axis=-1))
46 |
47 | # transform functions
48 | embed_hk = hk.transform(embed_fn)
49 | predict_action_hk = hk.transform(predict_action_fn)
50 |
51 | # create dummy batches for networks initialization
52 | observation_sequences = utils.add_batch_dim(
53 | utils.add_batch_dim(
54 | utils.zeros_like(env_spec.observations)
55 | )
56 | ) # [T=1, B=1, ...]
57 |
58 | def predict_action_init(rng):
59 | return predict_action_hk.init(rng, observation_sequences, observation_sequences)
60 |
61 | # create FeedForwardNetworks corresponding to embed and action prediction functions
62 | predict_action = networks_lib.FeedForwardNetwork(
63 | init=predict_action_init, apply=predict_action_hk.apply
64 | )
65 | embed = networks_lib.FeedForwardNetwork(
66 | init=embed_hk.init, apply=embed_hk.apply
67 | )
68 |
69 | return EmbeddingNetwork(predict_action, embed)
70 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/networks.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | from .distillation_network import DistillationNetwork
4 | from .embedding_network import EmbeddingNetwork
5 | from .uvfa_network import UVFANetwork
6 |
7 |
8 | @dataclasses.dataclass
9 | class DRLearnerNetworks:
10 | """Wrapper for all DRLearner learnable networks"""
11 | uvfa_net: UVFANetwork
12 | embedding_net: EmbeddingNetwork
13 | distillation_net: DistillationNetwork
14 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/networks_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .atari import make_atari_nets
2 | from .discomaze import make_discomaze_nets
3 | from .lunar_lander import make_lunar_lander_nets
4 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/networks_zoo/atari.py:
--------------------------------------------------------------------------------
1 | import haiku as hk
2 | import jax.nn
3 | from acme import specs
4 | from acme.jax import networks as networks_lib
5 |
6 | from ..distillation_network import make_distillation_net
7 | from ..embedding_network import make_embedding_net
8 | from ..networks import DRLearnerNetworks
9 | from ..uvfa_network import make_uvfa_net
10 | from ..uvfa_torso import UVFATorso
11 | from ...config import DRLearnerConfig
12 |
13 |
14 | def make_atari_nets(config: DRLearnerConfig, env_spec: specs.EnvironmentSpec) -> DRLearnerNetworks:
15 | uvfa_net = make_atari_uvfa_net(env_spec, num_mixtures=config.num_mixtures, batch_size=config.batch_size)
16 | embedding_network = make_atari_embedding_net(env_spec, config.observation_embed_dim)
17 | distillation_network = make_atari_distillation_net(env_spec, config.distillation_embed_dim)
18 | return DRLearnerNetworks(uvfa_net, embedding_network, distillation_network)
19 |
20 |
21 | def make_atari_embedding_net(env_spec, embedding_dim):
22 | def make_atari_embedding_modules():
23 | embedding_torso = hk.Sequential([
24 | networks_lib.AtariTorso(),
25 | hk.Linear(output_size=embedding_dim),
26 | jax.nn.relu,
27 | ], name='atari_embedding_torso')
28 | pred_head = hk.Sequential([
29 | hk.Linear(128),
30 | jax.nn.relu,
31 | hk.Linear(env_spec.actions.num_values)
32 | ], name='atari_action_pred_head')
33 | return embedding_torso, pred_head
34 |
35 | return make_embedding_net(
36 | make_embedding_modules=make_atari_embedding_modules,
37 | env_spec=env_spec
38 | )
39 |
40 |
41 | def make_atari_distillation_net(env_spec, embedding_dim):
42 | def make_atari_distillation_modules():
43 | embedding_torso = hk.Sequential([
44 | networks_lib.AtariTorso(),
45 | hk.Linear(output_size=embedding_dim),
46 | ], name='atari_distillation_torso')
47 | return embedding_torso
48 |
49 | return make_distillation_net(
50 | make_distillation_modules=make_atari_distillation_modules,
51 | env_spec=env_spec
52 | )
53 |
54 |
55 | def make_atari_uvfa_net(env_spec, num_mixtures: int, batch_size: int):
56 | def make_atari_uvfa_modules():
57 | embedding_torso = make_uvfa_atari_torso(env_spec.actions.num_values, num_mixtures)
58 | recurrent_core = hk.LSTM(512)
59 | head = networks_lib.DuellingMLP(
60 | num_actions=env_spec.actions.num_values,
61 | hidden_sizes=[512]
62 | )
63 | return embedding_torso, recurrent_core, head
64 |
65 | return make_uvfa_net(
66 | make_uvfa_modules=make_atari_uvfa_modules,
67 | batch_size=batch_size,
68 | env_spec=env_spec
69 | )
70 |
71 |
72 | def make_uvfa_atari_torso(num_actions: int, num_mixtures: int):
73 | observation_embedding_torso = hk.Sequential([
74 | networks_lib.AtariTorso(),
75 | hk.Linear(512),
76 | jax.nn.relu
77 | ])
78 | return UVFATorso(
79 | observation_embedding_torso,
80 | num_actions, num_mixtures,
81 | name='atari_uvfa_torso'
82 | )
83 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/networks_zoo/discomaze.py:
--------------------------------------------------------------------------------
1 | import haiku as hk
2 | import jax.nn
3 | from acme import specs
4 | from acme.jax import networks as networks_lib
5 |
6 | from ..distillation_network import make_distillation_net
7 | from ..embedding_network import make_embedding_net
8 | from ..networks import DRLearnerNetworks
9 | from ..uvfa_network import make_uvfa_net
10 | from ..uvfa_torso import UVFATorso
11 | from ...config import DRLearnerConfig
12 |
13 |
14 | def make_discomaze_nets(config: DRLearnerConfig, env_spec: specs.EnvironmentSpec) -> DRLearnerNetworks:
15 | uvfa_net = make_discomaze_uvfa_net(env_spec, num_mixtures=config.num_mixtures, batch_size=config.batch_size)
16 | embedding_network = make_discomaze_embedding_net(env_spec, config.observation_embed_dim)
17 | distillation_network = make_discomaze_distillation_net(env_spec, config.distillation_embed_dim)
18 | return DRLearnerNetworks(uvfa_net, embedding_network, distillation_network)
19 |
20 |
21 | def make_discomaze_embedding_net(env_spec, embedding_dim):
22 | def make_discomaze_embedding_modules():
23 | embedding_torso = hk.Sequential([
24 | hk.Conv2D(16, kernel_shape=3, stride=1),
25 | jax.nn.relu,
26 | hk.Conv2D(32, kernel_shape=3, stride=1),
27 | jax.nn.relu,
28 | hk.Flatten(preserve_dims=-3),
29 | hk.Linear(embedding_dim),
30 | jax.nn.relu
31 | ])
32 | pred_head = hk.Sequential([
33 | hk.Linear(32),
34 | jax.nn.relu,
35 | hk.Linear(env_spec.actions.num_values)
36 | ], name='action_pred_head')
37 | return embedding_torso, pred_head
38 |
39 | return make_embedding_net(
40 | make_embedding_modules=make_discomaze_embedding_modules,
41 | env_spec=env_spec
42 | )
43 |
44 |
45 | def make_discomaze_distillation_net(env_spec, embedding_dim):
46 | def make_discomaze_distillation_modules():
47 | embedding_torso = hk.Sequential([
48 | hk.Conv2D(16, kernel_shape=3, stride=1),
49 | jax.nn.relu,
50 | hk.Conv2D(32, kernel_shape=3, stride=1),
51 | jax.nn.relu,
52 | hk.Flatten(preserve_dims=-3),
53 | hk.Linear(embedding_dim),
54 | jax.nn.relu
55 | ])
56 | return embedding_torso
57 |
58 | return make_distillation_net(
59 | make_distillation_modules=make_discomaze_distillation_modules,
60 | env_spec=env_spec
61 | )
62 |
63 |
64 | def make_discomaze_uvfa_net(env_spec, num_mixtures: int, batch_size: int):
65 | def make_discomaze_uvfa_modules():
66 | embedding_torso = make_uvfa_discomaze_torso(env_spec.actions.num_values, num_mixtures)
67 | recurrent_core = hk.LSTM(256)
68 | head = networks_lib.DuellingMLP(
69 | num_actions=env_spec.actions.num_values,
70 | hidden_sizes=[256]
71 | )
72 | return embedding_torso, recurrent_core, head
73 |
74 | return make_uvfa_net(
75 | make_uvfa_modules=make_discomaze_uvfa_modules,
76 | batch_size=batch_size,
77 | env_spec=env_spec
78 | )
79 |
80 |
81 | def make_uvfa_discomaze_torso(num_actions: int, num_mixtures: int):
82 | observation_embedding_torso = hk.Sequential([
83 | hk.Conv2D(16, kernel_shape=3, stride=1),
84 | jax.nn.relu,
85 | hk.Conv2D(32, kernel_shape=3, stride=1),
86 | jax.nn.relu,
87 | hk.Flatten(preserve_dims=-3),
88 | hk.Linear(256),
89 | jax.nn.relu
90 | ])
91 | return UVFATorso(
92 | observation_embedding_torso,
93 | num_actions, num_mixtures,
94 | name='discomaze_uvfa_torso'
95 | )
96 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/networks_zoo/lunar_lander.py:
--------------------------------------------------------------------------------
1 | import haiku as hk
2 | from acme import specs
3 | from acme.jax import networks as networks_lib
4 |
5 | from ..distillation_network import make_distillation_net
6 | from ..embedding_network import make_embedding_net
7 | from ..networks import DRLearnerNetworks
8 | from ..uvfa_network import make_uvfa_net
9 | from ..uvfa_torso import UVFATorso
10 | from ...config import DRLearnerConfig
11 |
12 |
13 | def make_lunar_lander_nets(config: DRLearnerConfig, env_spec: specs.EnvironmentSpec) -> DRLearnerNetworks:
14 | uvfa_net = make_lunar_lander_uvfa_net(env_spec, num_mixtures=config.num_mixtures, batch_size=config.batch_size)
15 | embedding_network = make_lunar_lander_embedding_net(env_spec, config.observation_embed_dim)
16 | distillation_network = make_lunar_lander_distillation_net(env_spec, config.distillation_embed_dim)
17 | return DRLearnerNetworks(uvfa_net, embedding_network, distillation_network)
18 |
19 |
20 | def make_lunar_lander_embedding_net(env_spec, embedding_dim):
21 | def make_mlp_embedding_modules():
22 | embedding_torso = hk.nets.MLP([16, 32, embedding_dim], name='mlp_embedding_torso')
23 | pred_head = hk.Linear(env_spec.actions.num_values, name='action_pred_head')
24 | return embedding_torso, pred_head
25 |
26 | return make_embedding_net(
27 | make_embedding_modules=make_mlp_embedding_modules,
28 | env_spec=env_spec
29 | )
30 |
31 |
32 | def make_lunar_lander_distillation_net(env_spec, embedding_dim):
33 | def make_mlp_distillation_modules():
34 | embedding_torso = hk.nets.MLP([16, 32, embedding_dim], name='mlp_embedding_torso')
35 | return embedding_torso
36 |
37 | return make_distillation_net(
38 | make_distillation_modules=make_mlp_distillation_modules,
39 | env_spec=env_spec
40 | )
41 |
42 |
43 | def make_lunar_lander_uvfa_net(env_spec, num_mixtures: int, batch_size: int):
44 | def make_mlp_uvfa_modules():
45 | embedding_torso = make_uvfa_lunar_lander_torso(env_spec.actions.num_values, num_mixtures)
46 | recurrent_core = hk.LSTM(32)
47 | head = networks_lib.DuellingMLP(
48 | num_actions=env_spec.actions.num_values,
49 | hidden_sizes=[32]
50 | )
51 | return embedding_torso, recurrent_core, head
52 |
53 | return make_uvfa_net(
54 | make_uvfa_modules=make_mlp_uvfa_modules,
55 | batch_size=batch_size,
56 | env_spec=env_spec
57 | )
58 |
59 |
60 | def make_uvfa_lunar_lander_torso(num_actions: int, num_mixtures: int):
61 | observation_embedding_torso = hk.nets.MLP([16, 32, 16])
62 | return UVFATorso(
63 | observation_embedding_torso,
64 | num_actions, num_mixtures,
65 | name='mlp_uvfa_torso'
66 | )
67 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/policy_networks.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Callable
3 |
4 | import jax.nn
5 | import jax.numpy as jnp
6 | import jax.random
7 | import rlax
8 | from acme import types
9 | from acme.jax import networks as networks_lib
10 |
11 | from .networks import DRLearnerNetworks
12 | from ..config import DRLearnerConfig
13 |
14 |
15 | @dataclasses.dataclass
16 | class DRLearnerPolicyNetworks:
17 | """Pure functions used by DRLearner actors"""
18 | select_action: Callable
19 | embed_observation: Callable
20 | distillation_embed_observation: Callable
21 |
22 |
23 | def make_policy_networks(
24 | networks: DRLearnerNetworks,
25 | config: DRLearnerConfig,
26 | evaluation: bool = False):
27 | def select_action(intrinsic_params: networks_lib.Params,
28 | extrinsic_params: networks_lib.Params,
29 | key: networks_lib.PRNGKey,
30 | observation: types.NestedArray,
31 | intrinsic_core_state: types.NestedArray,
32 | extrinsic_core_state: types.NestedArray,
33 | epsilon, beta):
34 | intrinsic_key_qnet, extrinsic_key_qnet, key_sample = jax.random.split(key, 3)
35 | intrinsic_q_values, intrinsic_core_state = networks.uvfa_net.forward.apply(
36 | intrinsic_params, intrinsic_key_qnet, observation, intrinsic_core_state)
37 | extrinsic_q_values, extrinsic_core_state = networks.uvfa_net.forward.apply(
38 | extrinsic_params, extrinsic_key_qnet, observation, extrinsic_core_state)
39 |
40 | q_values = config.tx_pair.apply(beta * config.tx_pair.apply_inv(intrinsic_q_values) +
41 | config.tx_pair.apply_inv(extrinsic_q_values))
42 | epsilon = config.evaluation_epsilon if evaluation else epsilon
43 | action_dist = rlax.epsilon_greedy(epsilon)
44 | action = action_dist.sample(key_sample, q_values)
45 | action_prob = action_dist.probs(
46 | jax.nn.one_hot(jnp.argmax(q_values, axis=-1), num_classes=q_values.shape[-1])
47 | )
48 | action_prob = jnp.squeeze(action_prob[:, action], axis=-1)
49 | return action, action_prob, intrinsic_core_state, extrinsic_core_state
50 |
51 | return DRLearnerPolicyNetworks(
52 | select_action=select_action,
53 | embed_observation=networks.embedding_net.embed.apply,
54 | distillation_embed_observation=networks.distillation_net.embed.apply
55 | )
56 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/uvfa_network.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Tuple, NamedTuple
3 |
4 | import haiku as hk
5 | import jax.numpy as jnp
6 | from acme.jax import networks as networks_lib
7 | from acme.jax import utils
8 | from acme.wrappers.observation_action_reward import OAR
9 |
10 |
11 | @dataclasses.dataclass
12 | class UVFANetwork:
13 | """Pure functions for DRLearner Universal Value Function Approximator"""
14 | initial_state: networks_lib.FeedForwardNetwork
15 | forward: networks_lib.FeedForwardNetwork
16 | unroll: networks_lib.FeedForwardNetwork
17 |
18 |
19 | class UVFANetworkInput(NamedTuple):
20 | """Wrap input specific to DRLearner Recurrent Q-network"""
21 | oar: OAR # observation_t, action_tm1, reward_tm1
22 | intrinsic_reward: jnp.ndarray # ri_tm1
23 | mixture_idx: jnp.ndarray # beta_idx_tm1
24 |
25 |
26 | def make_uvfa_net(
27 | make_uvfa_modules,
28 | batch_size: int,
29 | env_spec) -> UVFANetwork:
30 | def initial_state(batch_size: int):
31 | _, recurrent_core, _ = make_uvfa_modules()
32 | return recurrent_core.initial_state(batch_size)
33 |
34 | def forward(input: UVFANetworkInput,
35 | state: hk.LSTMState) -> Tuple[networks_lib.NetworkOutput, hk.LSTMState]:
36 | """
37 | Estimate action values for batch of inputs
38 | Args:
39 | input: batch of observations, actions, rewards, intrinsic rewards
40 | and mixture indices (beta param labels)
41 | state: recurrent state
42 | Returns:
43 | q_values: predicted action values
44 | new_state: new recurrent state after prediction
45 | """
46 | embedding_torso, recurrent_core, head = make_uvfa_modules()
47 |
48 | embeddings = embedding_torso(input)
49 | embeddings, new_state = recurrent_core(embeddings, state)
50 | q_values = head(embeddings)
51 | return q_values, new_state
52 |
53 | def unroll(input: UVFANetworkInput,
54 | state: hk.LSTMState) -> Tuple[networks_lib.NetworkOutput, hk.LSTMState]:
55 | """
56 | Estimate action values for batch of input sequences
57 | Args:
58 | input: batch of observations, actions, rewards, intrinsic rewards
59 | and mixture indices (beta param labels) sequences
60 | state: recurrent state
61 | Returns:
62 | q_values: predicted action values
63 | new_state: new recurrent state after prediction
64 | """
65 | embedding_torso, recurrent_core, head = make_uvfa_modules()
66 |
67 | embeddings = hk.BatchApply(embedding_torso)(input)
68 | embeddings, new_states = hk.static_unroll(recurrent_core, embeddings, state)
69 | q_values = hk.BatchApply(head)(embeddings)
70 | return q_values, new_states
71 |
72 | # transform functions
73 | initial_state_hk = hk.transform(initial_state)
74 | forward_hk = hk.transform(forward)
75 | unroll_hk = hk.transform(unroll)
76 |
77 | # create dummy batches for networks initialization
78 | observation = utils.zeros_like(env_spec.observations)
79 | intrinsic_reward = utils.zeros_like(env_spec.rewards)
80 | mixture_idxs = utils.zeros_like(env_spec.rewards, dtype=jnp.int32)
81 | uvfa_input_sequences = utils.add_batch_dim(
82 | utils.tile_nested(
83 | UVFANetworkInput(observation, intrinsic_reward, mixture_idxs),
84 | batch_size
85 | )
86 | )
87 |
88 | def initial_state_init(rng, batch_size: int):
89 | return initial_state_hk.init(rng, batch_size)
90 |
91 | def unroll_init(rng, initial_state):
92 | return unroll_hk.init(rng, uvfa_input_sequences, initial_state)
93 |
94 | # create FeedForwardNetworks corresponding to UVFA pure functions
95 | initial_state = networks_lib.FeedForwardNetwork(
96 | init=initial_state_init, apply=initial_state_hk.apply
97 | )
98 | forward = networks_lib.FeedForwardNetwork(
99 | init=forward_hk.init, apply=forward_hk.apply
100 | )
101 | unroll = networks_lib.FeedForwardNetwork(
102 | init=unroll_init, apply=unroll_hk.apply
103 | )
104 |
105 | return UVFANetwork(initial_state, forward, unroll)
106 |
--------------------------------------------------------------------------------
/drlearner/drlearner/networks/uvfa_torso.py:
--------------------------------------------------------------------------------
1 | import haiku as hk
2 | import jax
3 | import jax.numpy as jnp
4 |
5 |
6 | class UVFATorso(hk.Module):
7 | def __init__(self,
8 | observation_embedding_torso: hk.Module,
9 | num_actions: int,
10 | num_mixtures: int,
11 | name: str):
12 | super().__init__(name=name)
13 | self._embed = observation_embedding_torso
14 |
15 | self._num_actions = num_actions
16 | self._num_mixtures = num_mixtures
17 |
18 | def __call__(self, input):
19 | oar_t, intrinsic_reward_tm1, mixture_idx_tm1 = input.oar, input.intrinsic_reward, input.mixture_idx,
20 | observation_t, action_tm1, reward_tm1 = oar_t.observation, oar_t.action, oar_t.reward
21 |
22 | features_t = self._embed(observation_t) # [T?, B, D]
23 | action_tm1 = jax.nn.one_hot(
24 | action_tm1,
25 | num_classes=self._num_actions
26 | ) # [T?, B, A]
27 | mixture_idx_tm1 = jax.nn.one_hot(
28 | mixture_idx_tm1,
29 | num_classes=self._num_mixtures
30 | ) # [T?, B, M]
31 |
32 | reward_tm1 = jnp.tanh(reward_tm1)
33 | intrinsic_reward_tm1 = jnp.tanh(intrinsic_reward_tm1)
34 | # Add dummy trailing dimensions to the rewards if necessary.
35 | while reward_tm1.ndim < action_tm1.ndim:
36 | reward_tm1 = jnp.expand_dims(reward_tm1, axis=-1)
37 |
38 | while intrinsic_reward_tm1.ndim < action_tm1.ndim:
39 | intrinsic_reward_tm1 = jnp.expand_dims(intrinsic_reward_tm1, axis=-1)
40 |
41 | embedding = jnp.concatenate(
42 | [features_t, action_tm1, reward_tm1, intrinsic_reward_tm1, mixture_idx_tm1],
43 | axis=-1
44 | ) # [T?, B, D+A+M+2]
45 | return embedding
46 |
--------------------------------------------------------------------------------
/drlearner/drlearner/utils.py:
--------------------------------------------------------------------------------
1 | import jax.nn
2 | import jax.numpy as jnp
3 |
4 |
5 | def epsilon_greedy_prob(q_values, epsilon):
6 | """Get probability of actions under epsilon-greedy policy provided estimated q_values"""
7 | num_actions = q_values.shape[0]
8 | max_action = jnp.argmax(q_values)
9 | probs = jnp.full_like(q_values, fill_value=epsilon / num_actions)
10 | probs = probs.at[max_action].set(1 - epsilon * (num_actions - 1) / num_actions)
11 | return probs
12 |
13 |
14 | def get_beta(mixture_idx: jnp.ndarray, beta_min: float, beta_max: float, num_mixtures: int):
15 | beta = jnp.linspace(beta_min, beta_max, num_mixtures)[mixture_idx]
16 | return beta
17 |
18 |
19 | def get_gamma(mixture_idx: jnp.ndarray, gamma_min: float, gamma_max: float, num_mixtures: int):
20 | gamma = jnp.linspace(gamma_min, gamma_max, num_mixtures)[mixture_idx]
21 | return gamma
22 |
23 |
24 | def get_epsilon(actor_id: int, epsilon_base: float, num_actors: int, alpha: float = 8.0):
25 | """Get epsilon parameter for given actor"""
26 | epsilon = epsilon_base ** (1 + alpha * actor_id / ((num_actors - 1) + 0.0001))
27 | return epsilon
28 |
29 |
30 | def get_beta_ngu(mixture_idx: jnp.ndarray, beta_min: float, beta_max: float, num_mixtures: int):
31 | """Get beta parameter for given number of mixtures and mixture_idx"""
32 | beta = jnp.where(
33 | mixture_idx == num_mixtures - 1,
34 | beta_max,
35 | beta_min + beta_max * jax.nn.sigmoid(10 * (2 * mixture_idx - (num_mixtures - 2)) / (num_mixtures - 2))
36 | )
37 | return beta
38 |
39 |
40 | def get_gamma_ngu(mixture_idx: jnp.ndarray, gamma_min: float, gamma_max: float, num_mixtures: int):
41 | """Get gamma parameters for given number of mixtures in descending order"""
42 | gamma = 1 - jnp.exp(
43 | ((num_mixtures - 1 - mixture_idx) * jnp.log(1 - gamma_max) +
44 | mixture_idx * jnp.log(1 - gamma_min)) / (num_mixtures - 1))
45 | return gamma
46 |
47 |
--------------------------------------------------------------------------------
/drlearner/environments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/drlearner/environments/__init__.py
--------------------------------------------------------------------------------
/drlearner/environments/atari.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Shared helpers for different experiment flavours."""
17 |
18 | import functools
19 |
20 | import dm_env
21 | import gym
22 | from acme import wrappers
23 |
24 |
25 | def make_environment(level: str = 'PongNoFrameskip-v4',
26 | oar_wrapper: bool = False) -> dm_env.Environment:
27 | """Loads the Atari environment."""
28 | env = gym.make(level, full_action_space=True)
29 |
30 | # Always use episodes of 108k steps as this is standard, matching the paper.
31 | max_episode_len = 108_000
32 | wrapper_list = [
33 | wrappers.GymAtariAdapter,
34 | functools.partial(
35 | wrappers.AtariWrapper,
36 | action_repeats=4,
37 | pooled_frames=4,
38 | zero_discount_on_life_loss=False,
39 | expose_lives_observation=False,
40 | num_stacked_frames=1,
41 | max_episode_len=max_episode_len,
42 | to_float=True,
43 | grayscaling=True
44 | ),
45 | ]
46 | if oar_wrapper:
47 | # E.g. IMPALA and R2D2 use this particular variant.
48 | wrapper_list.append(wrappers.ObservationActionRewardWrapper)
49 | wrapper_list.append(wrappers.SinglePrecisionWrapper)
50 |
51 | return wrappers.wrap_all(env, wrapper_list)
52 |
--------------------------------------------------------------------------------
/drlearner/environments/disco_maze.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Union, Optional
2 |
3 | import dm_env
4 | import gym_discomaze
5 | import numpy as np
6 | from acme import wrappers
7 | from acme.wrappers import base
8 | from dm_env import specs
9 |
10 |
11 | class DiscoMazeWrapper(base.EnvironmentWrapper):
12 | def __init__(self, environment: dm_env.Environment, *, to_float: bool = False,
13 | max_episode_len: Optional[int] = None):
14 | """
15 | The wrapper performs the following actions:
16 | 1. Converts observations to float (if applied)
17 | 2. Truncates episodes to maximum number of steps (if applied).
18 | 3. Remove action that allows no movement.
19 | """
20 | super(DiscoMazeWrapper, self).__init__(environment)
21 | self._to_float = to_float
22 |
23 | if not max_episode_len:
24 | max_episode_len = np.inf
25 | self._episode_len = 0
26 | self._max_episode_len = max_episode_len
27 |
28 | self._observation_spec = self._init_observation_spec()
29 | self._action_spec = self._init_action_spec()
30 |
31 | def _init_observation_spec(self):
32 | observation_spec = self.environment.observation_spec()
33 | if self._to_float:
34 | observation_shape = observation_spec.shape
35 | dtype = 'float64'
36 | observation_spec = observation_spec.replace(
37 | dtype=dtype,
38 | minimum=(observation_spec.minimum.astype(dtype) / 255.),
39 | maximum=(observation_spec.maximum.astype(dtype) / 255.)
40 | )
41 | return observation_spec
42 |
43 | def _init_action_spec(self):
44 | action_spec = self.environment.action_spec()
45 |
46 | action_spec = action_spec.replace(num_values=action_spec.num_values - 1)
47 | return action_spec
48 |
49 | def step(self, action) -> dm_env.TimeStep:
50 | action = action + 1
51 | timestep = self.environment.step(action)
52 |
53 | if self._to_float:
54 | observation = timestep.observation.astype(float) / 255.
55 | timestep = timestep._replace(observation=observation)
56 |
57 | self._episode_len += 1
58 | if self._episode_len == self._max_episode_len:
59 | timestep = timestep._replace(step_type=dm_env.StepType.LAST)
60 |
61 | return timestep
62 |
63 | def reset(self) -> dm_env.TimeStep:
64 | timestep = self.environment.reset()
65 |
66 | if self._to_float:
67 | observation = timestep.observation.astype(float) / 255.
68 | timestep = timestep._replace(observation=observation)
69 |
70 | self._episode_len = 0
71 | return timestep
72 |
73 | def observation_spec(self) -> Union[specs.Array, Sequence[specs.Array]]:
74 | return self._observation_spec
75 |
76 | def action_spec(self) -> Union[specs.Array, Sequence[specs.Array]]:
77 | return self._action_spec
78 |
79 |
80 | def make_discomaze_environment(seed: int) -> dm_env.Environment:
81 | """Create 21x21 disco maze environment with 5 random colors and no target"""
82 | env = gym_discomaze.RandomDiscoMaze(n_row=10, n_col=10, n_colors=5, n_targets=0, generator=seed)
83 | env = wrappers.GymWrapper(env)
84 | env = DiscoMazeWrapper(env, to_float=True, max_episode_len=5000)
85 | env = wrappers.SinglePrecisionWrapper(env)
86 | env = wrappers.ObservationActionRewardWrapper(env)
87 | return env
88 |
89 |
90 | if __name__ == '__main__':
91 | env = make_discomaze_environment(0)
92 | print(env.action_spec().replace(num_values=4))
93 |
--------------------------------------------------------------------------------
/drlearner/environments/lunar_lander.py:
--------------------------------------------------------------------------------
1 | import dm_env
2 | import gym
3 | from acme import wrappers
4 |
5 |
6 | def make_ll_environment(seed: int) -> dm_env.Environment:
7 | env_name = "LunarLander-v2"
8 |
9 | env = gym.make(env_name)
10 | env = wrappers.GymWrapper(env)
11 | env = wrappers.SinglePrecisionWrapper(env)
12 | env = wrappers.ObservationActionRewardWrapper(env)
13 |
14 | return env
15 |
--------------------------------------------------------------------------------
/drlearner/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PatternsandPredictions/DRLearner_beta/c8a94428c62f1544fd09d9170c45c379f17dc55c/drlearner/utils/__init__.py
--------------------------------------------------------------------------------
/drlearner/utils/stats.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | from statistics import mean, median
4 |
5 |
6 | def read_config(path):
7 | with open(path, 'r') as file:
8 | config = json.load(file)
9 |
10 | return config
11 |
12 |
13 | class StatsCheckpointer:
14 | def __init__(self):
15 | self.values = defaultdict(list)
16 | self.statistics = defaultdict(dict)
17 | self._config = read_config('./configs/config.json')
18 |
19 | self._target_statistics = {
20 | 'min': min,
21 | 'max': max,
22 | 'mean': mean,
23 | 'median': median,
24 | }
25 | self._target_metrics = ['episode_return']
26 |
27 | def update(self, result, log=False):
28 | for metric in self._target_metrics:
29 | value = float(result[metric])
30 | self.values[metric].append(value)
31 |
32 | self.evaluate()
33 | self.save()
34 |
35 | if log:
36 | self.log()
37 |
38 | def evaluate(self):
39 | for metric in self._target_metrics:
40 | values = self.values[metric]
41 |
42 | for statistic, function in self._target_statistics.items():
43 | self.statistics[metric][statistic] = round(function(values), 5)
44 |
45 | def save(self):
46 | path = self._config['statistics_path']
47 |
48 | with open(path, 'w') as file:
49 | json.dump(self.statistics, file)
50 |
51 | def log(self):
52 | print('=' * 30)
53 | for metric in self._target_metrics:
54 | print(f"{metric.replace('_', ' ').upper()}")
55 |
56 | for statistic, value in self.statistics[metric].items():
57 | print(f'{statistic}: {value}')
58 |
59 | print('=' * 30)
60 |
61 | def __repr__(self):
62 | return str(self.statistics)
63 |
--------------------------------------------------------------------------------
/drlearner/utils/utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import logging
3 | import time
4 | import os
5 | from dataclasses import asdict
6 | import jax
7 |
8 | from acme import specs
9 | from acme import core
10 | from acme.utils import counting
11 | from acme.utils import observers as observers_lib
12 | from acme.jax import networks as networks_lib
13 | from acme.jax import utils
14 | from acme.utils import loggers
15 | from acme.utils.loggers.tf_summary import TFSummaryLogger
16 | from acme.utils.loggers import base
17 | from drlearner.drlearner.config import DRLearnerConfig
18 | import wandb
19 |
20 | from ..core import distributed_layout
21 | from ..core.environment_loop import EnvironmentLoop
22 |
23 | from typing import Optional, Callable, Any, Mapping, Sequence, TextIO, Union
24 |
25 |
26 | def from_dict_to_dataclass(cls, data):
27 | return cls(
28 | **{
29 | key: (data[key] if val.default ==
30 | val.empty else data.get(key, val.default))
31 | for key, val in inspect.signature(cls).parameters.items()
32 | }
33 | )
34 |
35 |
36 | def _format_key(key: str) -> str:
37 | """Internal function for formatting keys in Tensorboard format."""
38 | return key.title().replace("_", "")
39 |
40 |
41 | class WandbLogger(base.Logger):
42 | """Logs to wandb instance.
43 |
44 | If multiple WandbLoggers are created with the same credentials, results will be
45 | categorized by labels.
46 | """
47 |
48 | def __init__(self, logdir: str, label: str = "Logs", hyperparams: Optional[Union[None, DRLearnerConfig]] = None, exp_name: Optional[Union[None, str]] = None):
49 | """Initializes the logger.
50 |
51 | Args:
52 | logdir: name of the wandb project
53 | label: label string to use when logging. Default to 'Logs'.
54 | hyperparams: hyperparams config to be saved by wandb
55 | """
56 | self._time = time.time()
57 | self.label = label
58 | self._iter = 0
59 |
60 | if wandb.run is None:
61 | wandb.init(project=logdir, group="DDP", name=exp_name)
62 |
63 | if hyperparams is not None:
64 | self.save_hyperparams(hyperparams=hyperparams)
65 |
66 | def save_hyperparams(self, hyperparams: Optional[Union[None, DRLearnerConfig]]):
67 | wandb.config = asdict(hyperparams)
68 |
69 | def write(self, values: base.LoggingData):
70 | for key in values.keys():
71 | wandb.log({f"{self.label}/{_format_key(key)}": values[key]})
72 |
73 | def close(self):
74 | wandb.finish()
75 |
76 |
77 | class CloudCSVLogger:
78 | def __init__(
79 | self,
80 | directory_or_file: Union[str, TextIO] = "~/acme",
81 | label: str = "",
82 | time_delta: float = 0.0,
83 | add_uid: bool = True,
84 | flush_every: int = 30,
85 | ):
86 | """Instantiates the logger.
87 |
88 | Args:
89 | directory_or_file: Either a directory path as a string, or a file TextIO
90 | object.
91 | label: Extra label to add to logger. This is added as a suffix to the
92 | directory.
93 | time_delta: Interval in seconds between which writes are dropped to
94 | throttle throughput.
95 | add_uid: Whether to add a UID to the file path. See `paths.process_path`
96 | for details.
97 | flush_every: Interval (in writes) between flushes.
98 | """
99 |
100 | if flush_every <= 0:
101 | raise ValueError(
102 | f"`flush_every` must be a positive integer (got {flush_every})."
103 | )
104 |
105 | self._last_log_time = time.time() - time_delta
106 | self._time_delta = time_delta
107 | self._flush_every = flush_every
108 | self._add_uid = add_uid
109 | self._writes = 0
110 | self.file_path = os.path.join(directory_or_file, f"{label}_logs.csv")
111 | self._keys = []
112 | logging.info("Logging to %s", self.file_path)
113 |
114 | def write(self, data: base.LoggingData):
115 | """Writes a `data` into a row of comma-separated values."""
116 | # Only log if `time_delta` seconds have passed since last logging event.
117 | now = time.time()
118 |
119 | # TODO(b/192227744): Remove this in favour of filters.TimeFilter.
120 | elapsed = now - self._last_log_time
121 | if elapsed < self._time_delta:
122 | logging.debug(
123 | "Not due to log for another %.2f seconds, dropping data.",
124 | self._time_delta - elapsed,
125 | )
126 | return
127 | self._last_log_time = now
128 |
129 | # Append row to CSV.
130 | data = base.to_numpy(data)
131 | if self._writes == 0:
132 | self._keys = data.keys()
133 | with open(self.file_path, "w") as f:
134 | f.write(",".join(self._keys))
135 | f.write("\n")
136 | f.write(
137 | ",".join(list(map(str, [data[k] for k in self._keys]))))
138 | f.write("\n")
139 | else:
140 | with open(self.file_path, "a") as f:
141 | f.write(
142 | ",".join(list(map(str, [data[k] for k in self._keys]))))
143 | f.write("\n")
144 | self._writes += 1
145 |
146 |
147 | def make_tf_logger(
148 | workdir: str = "~/acme/",
149 | label: str = "learner",
150 | save_data: bool = True,
151 | time_delta: float = 0.0,
152 | asynchronous: bool = False,
153 | print_fn: Optional[Callable[[str], None]] = print,
154 | serialize_fn: Optional[Callable[[Mapping[str, Any]],
155 | str]] = loggers.base.to_numpy,
156 | steps_key: str = "steps",
157 | ) -> loggers.base.Logger:
158 | del steps_key
159 | if not print_fn:
160 | print_fn = logging.info
161 |
162 | terminal_logger = loggers.terminal.TerminalLogger(
163 | label=label, print_fn=print_fn)
164 |
165 | all_loggers = [terminal_logger]
166 |
167 | if save_data:
168 | if "/gcs/" in workdir:
169 | all_loggers.append(
170 | CloudCSVLogger(
171 | directory_or_file=workdir, label=label, time_delta=time_delta
172 | )
173 | )
174 | else:
175 | all_loggers.append(
176 | loggers.csv.CSVLogger(
177 | directory_or_file=workdir, label=label, time_delta=time_delta
178 | )
179 | )
180 |
181 | tb_workdir = workdir
182 | if "/gcs/" in tb_workdir:
183 | tb_workdir = tb_workdir.replace("/gcs/", "gs://")
184 | all_loggers.append(TFSummaryLogger(logdir=tb_workdir, label=label))
185 |
186 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn)
187 | logger = loggers.filters.NoneFilter(logger)
188 |
189 | logger = loggers.filters.TimeFilter(logger, time_delta)
190 | return logger
191 |
192 |
193 | def make_wandb_logger(
194 | workdir: str = "~/acme/",
195 | label: str = "learner",
196 | save_data: bool = True,
197 | time_delta: float = 0.0,
198 | asynchronous: bool = False,
199 | print_fn: Optional[Callable[[str], None]] = print,
200 | serialize_fn: Optional[Callable[[Mapping[str, Any]],
201 | str]] = loggers.base.to_numpy,
202 | steps_key: str = "steps",
203 | hyperparams: Optional[Union[None, DRLearnerConfig]] = None,
204 | exp_name: str = None
205 | ) -> loggers.base.Logger:
206 | del steps_key
207 | if not print_fn:
208 | print_fn = logging.info
209 |
210 | terminal_logger = loggers.terminal.TerminalLogger(
211 | label=label, print_fn=print_fn)
212 |
213 | all_loggers = [terminal_logger]
214 |
215 | if save_data:
216 | if "/gcs/" in workdir:
217 | all_loggers.append(
218 | CloudCSVLogger(
219 | directory_or_file=workdir, label=label, time_delta=time_delta
220 | )
221 | )
222 | else:
223 | all_loggers.append(
224 | loggers.csv.CSVLogger(
225 | directory_or_file=workdir, label=label, time_delta=time_delta
226 | )
227 | )
228 |
229 | tb_workdir = workdir
230 | if "/gcs/" in tb_workdir:
231 | tb_workdir = tb_workdir.replace("/gcs/", "gs://")
232 | all_loggers.append(WandbLogger(logdir=tb_workdir.split(
233 | "/")[-2], label=label, hyperparams=hyperparams, exp_name=exp_name))
234 |
235 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn)
236 | logger = loggers.filters.NoneFilter(logger)
237 |
238 | logger = loggers.filters.TimeFilter(logger, time_delta)
239 | return logger
240 |
241 |
242 | def all_loggers(
243 | workdir: str = "~/acme/",
244 | label: str = "learner",
245 | save_data: bool = True,
246 | time_delta: float = 0.0,
247 | asynchronous: bool = False,
248 | print_fn: Optional[Callable[[str], None]] = print,
249 | serialize_fn: Optional[Callable[[Mapping[str, Any]],
250 | str]] = loggers.base.to_numpy,
251 | steps_key: str = "steps",
252 | hyperparams: Optional[Union[None, DRLearnerConfig]] = None
253 | ) -> loggers.base.Logger:
254 | del steps_key
255 | if not print_fn:
256 | print_fn = logging.info
257 |
258 | terminal_logger = loggers.terminal.TerminalLogger(
259 | label=label, print_fn=print_fn)
260 |
261 | all_loggers = [terminal_logger]
262 |
263 | if save_data:
264 | if "/gcs/" in workdir:
265 | all_loggers.append(
266 | CloudCSVLogger(
267 | directory_or_file=workdir, label=label, time_delta=time_delta
268 | )
269 | )
270 | else:
271 | all_loggers.append(
272 | loggers.csv.CSVLogger(
273 | directory_or_file=workdir, label=label, time_delta=time_delta
274 | )
275 | )
276 |
277 | tb_workdir = workdir
278 | if "/gcs/" in tb_workdir:
279 | tb_workdir = tb_workdir.replace("/gcs/", "gs://")
280 | all_loggers.append(WandbLogger(logdir=tb_workdir.split(
281 | "/")[-2], label=label, hyperparams=hyperparams))
282 | all_loggers.append(TFSummaryLogger(logdir=tb_workdir, label=label))
283 |
284 | logger = loggers.aggregators.Dispatcher(all_loggers, serialize_fn)
285 | logger = loggers.filters.NoneFilter(logger)
286 |
287 | logger = loggers.filters.TimeFilter(logger, time_delta)
288 | return logger
289 |
290 |
291 | def evaluator_factory_logger_choice(
292 | environment_factory: distributed_layout.EnvironmentFactory,
293 | network_factory: distributed_layout.NetworkFactory,
294 | policy_factory: distributed_layout.PolicyFactory,
295 | logger_fn: Callable,
296 | observers: Sequence[observers_lib.EnvLoopObserver] = (),
297 | actor_id: int = 0,
298 | ) -> distributed_layout.EvaluatorFactory:
299 | """Returns an evaluator process with customizable log function."""
300 |
301 | def evaluator(
302 | random_key: networks_lib.PRNGKey,
303 | variable_source: core.VariableSource,
304 | counter: counting.Counter,
305 | make_actor: distributed_layout.MakeActorFn,
306 | ):
307 | """The evaluation process."""
308 |
309 | # Create environment and evaluator networks
310 | environment_key, actor_key = jax.random.split(random_key)
311 | environment = environment_factory(utils.sample_uint32(environment_key))
312 | networks = network_factory(specs.make_environment_spec(environment))
313 |
314 | actor = make_actor(
315 | random_key, policy_factory(networks), variable_source=variable_source
316 | ) # ToDo: fix actor id for R2D2
317 |
318 | # Create logger and counter.
319 | counter = counting.Counter(counter, "evaluator")
320 |
321 | logger = logger_fn()
322 |
323 | # Create the run loop and return it.
324 | return EnvironmentLoop(environment, actor, counter, logger, observers=observers)
325 |
326 | return evaluator
327 |
--------------------------------------------------------------------------------
/examples/distrun_atari.py:
--------------------------------------------------------------------------------
1 | """Example running Distributed Layout DRLearner, on Atari."""
2 |
3 | import functools
4 | import logging
5 | import os
6 |
7 | import acme
8 | import launchpad as lp
9 | from absl import app
10 | from absl import flags
11 | from acme import specs
12 | from acme.jax import utils
13 |
14 | from drlearner.drlearner import DistributedDRLearnerFromConfig, networks_zoo
15 | from drlearner.configs.config_atari import AtariDRLearnerConfig
16 | from drlearner.configs.resources import get_atari_vertex_resources, get_local_resources
17 | from drlearner.core.observers import IntrinsicRewardObserver, MetaControllerObserver, DistillationCoefObserver, StorageVideoObserver
18 | from drlearner.environments.atari import make_environment
19 | from drlearner.drlearner.networks import make_policy_networks
20 | from drlearner.utils.utils import evaluator_factory_logger_choice, make_wandb_logger
21 |
22 | flags.DEFINE_string('level', 'ALE/MontezumaRevenge-v5', 'Which game to play.')
23 | flags.DEFINE_integer('num_steps', 100000, 'Number of steps to train for.')
24 | flags.DEFINE_integer('num_episodes', 1000,
25 | 'Number of episodes to train for.')
26 | flags.DEFINE_string('exp_path', 'experiments/default',
27 | 'Experiment data storage.')
28 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.')
29 |
30 | flags.DEFINE_integer('seed', 0, 'Random seed.')
31 | flags.DEFINE_integer('num_actors_per_mixture', 2,
32 | 'Number of parallel actors per mixture.')
33 | flags.DEFINE_bool('run_on_vertex', False,
34 | 'Whether to run training in multiple processes or on Vertex AI.')
35 | flags.DEFINE_bool('colocate_learner_and_reverb', True,
36 | 'Flag indicating whether to colocate learner and reverb.')
37 |
38 | FLAGS = flags.FLAGS
39 |
40 |
41 | def make_program():
42 | config = AtariDRLearnerConfig
43 | print(config)
44 |
45 | config_dir = os.path.join(
46 | 'experiments/', FLAGS.exp_path.strip('/').split('/')[-1])
47 | if not os.path.exists(config_dir):
48 | os.makedirs(config_dir)
49 | with open(os.path.join(config_dir, 'config.txt'), 'w') as f:
50 | f.write(str(config))
51 |
52 | env = make_environment(FLAGS.level, oar_wrapper=True)
53 | env_spec = acme.make_environment_spec(env)
54 |
55 | def net_factory(env_spec: specs.EnvironmentSpec):
56 | return networks_zoo.make_atari_nets(config, env_spec)
57 |
58 | level = str(FLAGS.level)
59 |
60 | def env_factory(seed: int):
61 | return make_environment(level, oar_wrapper=True)
62 |
63 | observers = [
64 | IntrinsicRewardObserver(),
65 | MetaControllerObserver(),
66 | DistillationCoefObserver(),
67 | StorageVideoObserver(config)
68 | ]
69 |
70 | evaluator_logger_fn = functools.partial(make_wandb_logger, FLAGS.exp_path,
71 | 'evaluator', save_data=True,
72 | time_delta=1, asynchronous=True,
73 | serialize_fn=utils.fetch_devicearray,
74 | print_fn=logging.info,
75 | steps_key='evaluator_steps',
76 | hyperparams=config,
77 | exp_name=FLAGS.exp_name)
78 |
79 | learner_logger_function = functools.partial(make_wandb_logger, FLAGS.exp_path,
80 | 'learner', save_data=True,
81 | time_delta=1, asynchronous=True,
82 | serialize_fn=utils.fetch_devicearray,
83 | print_fn=logging.info,
84 | steps_key='learner_steps',
85 | hyperparams=config,
86 | exp_name=FLAGS.exp_name)
87 |
88 | program = DistributedDRLearnerFromConfig(
89 | seed=FLAGS.seed,
90 | environment_factory=env_factory,
91 | network_factory=net_factory,
92 | config=config,
93 | workdir=FLAGS.exp_path,
94 | num_actors_per_mixture=FLAGS.num_actors_per_mixture,
95 | max_episodes=FLAGS.num_episodes,
96 | max_steps=FLAGS.num_steps,
97 | environment_spec=env_spec,
98 | actor_observers=observers,
99 | evaluator_observers=observers,
100 | learner_logger_fn=learner_logger_function,
101 | evaluator_factories=[
102 | evaluator_factory_logger_choice(
103 | environment_factory=env_factory,
104 | network_factory=net_factory,
105 | policy_factory=lambda networks: make_policy_networks(
106 | networks, config, evaluation=True),
107 | logger_fn=evaluator_logger_fn,
108 | observers=observers
109 | )
110 | ],
111 |
112 | multithreading_colocate_learner_and_reverb=FLAGS.colocate_learner_and_reverb
113 |
114 | ).build(name=FLAGS.exp_path.strip('/').split('/')[-1])
115 |
116 | return program
117 |
118 |
119 | def main(_):
120 | program = make_program()
121 |
122 | if FLAGS.run_on_vertex:
123 | resources = get_atari_vertex_resources()
124 | lp.launch(
125 | program,
126 | launch_type=lp.LaunchType.VERTEX_AI,
127 | xm_resources=resources,
128 | terminal='current_terminal')
129 |
130 | else:
131 | resources = get_local_resources()
132 | lp.launch(
133 | program,
134 | lp.LaunchType.LOCAL_MULTI_PROCESSING,
135 | local_resources=resources,
136 | terminal='current_terminal'
137 | )
138 |
139 |
140 | if __name__ == '__main__':
141 | app.run(main)
142 |
--------------------------------------------------------------------------------
/examples/distrun_discomaze.py:
--------------------------------------------------------------------------------
1 | """Example running distributed layout DRLearner Agent, on Discomaze environment."""
2 | import functools
3 | import logging
4 | import os
5 |
6 | import acme
7 | from acme.jax import utils
8 | import launchpad as lp
9 | from absl import app
10 | from absl import flags
11 | from acme import specs
12 |
13 | from drlearner.drlearner import DistributedDRLearnerFromConfig, networks_zoo
14 | from drlearner.configs.config_discomaze import DiscomazeDRLearnerConfig
15 | from drlearner.core.observers import UniqueStatesDiscoMazeObserver, IntrinsicRewardObserver, ActionProbObserver, DistillationCoefObserver
16 | from drlearner.environments.disco_maze import make_discomaze_environment
17 | from drlearner.drlearner.networks import make_policy_networks
18 | from drlearner.configs.resources import get_toy_env_vertex_resources, get_local_resources
19 | from drlearner.utils.utils import evaluator_factory_logger_choice, make_tf_logger
20 |
21 | flags.DEFINE_string('level', 'DiscoMaze', 'Which game to play.')
22 | flags.DEFINE_integer('num_episodes', 10000000,
23 | 'Number of episodes to train for.')
24 | flags.DEFINE_string('exp_path', 'experiments/default',
25 | 'Experiment data storage.')
26 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.')
27 | flags.DEFINE_integer('seed', 0, 'Random seed.')
28 | flags.DEFINE_integer('num_actors_per_mixture', 1,
29 | 'Number of parallel actors per mixture.')
30 | flags.DEFINE_bool('run_on_vertex', False,
31 | 'Whether to run training in multiple processes or on Vertex AI.')
32 | flags.DEFINE_bool('colocate_learner_and_reverb', False,
33 | 'Flag indicating whether to colocate learner and reverb.')
34 |
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | def make_program():
39 | config = DiscomazeDRLearnerConfig
40 | print(config)
41 |
42 | config_dir = os.path.join(
43 | 'experiments/', FLAGS.exp_path.strip('/').split('/')[-1])
44 | if not os.path.exists(config_dir):
45 | os.makedirs(config_dir)
46 | with open(os.path.join(config_dir, 'config.txt'), 'w') as f:
47 | f.write(str(config))
48 |
49 | env = make_discomaze_environment(FLAGS.seed)
50 | env_spec = acme.make_environment_spec(env)
51 |
52 | def net_factory(env_spec: specs.EnvironmentSpec):
53 | return networks_zoo.make_discomaze_nets(config, env_spec)
54 |
55 | observers = [
56 | UniqueStatesDiscoMazeObserver(),
57 | IntrinsicRewardObserver(),
58 | DistillationCoefObserver(),
59 | ActionProbObserver(num_actions=env_spec.actions.num_values),
60 | ]
61 |
62 | evaluator_logger_fn = functools.partial(make_tf_logger, FLAGS.exp_path,
63 | 'evaluator', save_data=True,
64 | time_delta=1, asynchronous=True,
65 | serialize_fn=utils.fetch_devicearray,
66 | print_fn=logging.info,
67 | steps_key='evaluator_steps')
68 |
69 | learner_logger_function = functools.partial(make_tf_logger, FLAGS.exp_path,
70 | 'learner', save_data=False,
71 | time_delta=1, asynchronous=True,
72 | serialize_fn=utils.fetch_devicearray,
73 | print_fn=logging.info,
74 | steps_key='learner_steps')
75 |
76 | program = DistributedDRLearnerFromConfig(
77 | seed=FLAGS.seed,
78 | environment_factory=make_discomaze_environment,
79 | network_factory=net_factory,
80 | config=config,
81 | num_actors_per_mixture=FLAGS.num_actors_per_mixture,
82 | environment_spec=env_spec,
83 | actor_observers=observers,
84 | learner_logger_fn=learner_logger_function,
85 | evaluator_observers=observers,
86 | evaluator_factories=[
87 | evaluator_factory_logger_choice(
88 | environment_factory=make_discomaze_environment,
89 | network_factory=net_factory,
90 | policy_factory=lambda networks: make_policy_networks(
91 | networks, config, evaluation=True),
92 | logger_fn=evaluator_logger_fn,
93 | observers=observers
94 | )
95 | ],
96 | multithreading_colocate_learner_and_reverb=FLAGS.colocate_learner_and_reverb
97 | ).build(name=FLAGS.exp_path.strip('/').split('/')[-1])
98 |
99 | return program
100 |
101 |
102 | def main(_):
103 | program = make_program()
104 |
105 | if FLAGS.run_on_vertex:
106 | resources = get_toy_env_vertex_resources()
107 | lp.launch(
108 | program,
109 | launch_type=lp.LaunchType.VERTEX_AI,
110 | xm_resources=resources)
111 | else:
112 | resources = get_local_resources()
113 | lp.launch(
114 | program,
115 | lp.LaunchType.LOCAL_MULTI_PROCESSING,
116 | local_resources=resources
117 | )
118 |
119 |
120 | if __name__ == '__main__':
121 | app.run(main)
122 |
--------------------------------------------------------------------------------
/examples/distrun_lunar_lander.py:
--------------------------------------------------------------------------------
1 | """Example running distributed DRLearner agent, on Lunar Lander."""
2 |
3 | import functools
4 | import logging
5 | import os
6 |
7 | import acme
8 | from acme.jax import utils
9 | import launchpad as lp
10 | from absl import app
11 | from absl import flags
12 | from acme import specs
13 |
14 | from drlearner.drlearner import DistributedDRLearnerFromConfig, networks_zoo
15 | from drlearner.core.observers import IntrinsicRewardObserver, MetaControllerObserver, DistillationCoefObserver
16 | from drlearner.configs.config_lunar_lander import LunarLanderDRLearnerConfig
17 | from drlearner.configs.resources import get_toy_env_vertex_resources, get_local_resources
18 | from drlearner.environments.lunar_lander import make_ll_environment
19 | from drlearner.drlearner.networks import make_policy_networks
20 | from drlearner.utils.utils import evaluator_factory_logger_choice, make_wandb_logger
21 |
22 | flags.DEFINE_string('level', 'LunarLander-v2', 'Which game to play.')
23 | flags.DEFINE_integer('num_episodes', 200, 'Number of episodes to train for.')
24 | flags.DEFINE_integer('num_steps', 10000, 'Number of steps to train for.')
25 | flags.DEFINE_string('exp_path', 'experiments/default',
26 | 'Experiment data storage.')
27 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.')
28 |
29 | flags.DEFINE_integer('num_actors_per_mixture', 1,
30 | 'Number of parallel actors per mixture.')
31 |
32 | flags.DEFINE_integer('seed', 42, 'Random seed.')
33 | flags.DEFINE_bool('run_on_vertex', False,
34 | 'Whether to run training in multiple processes or on Vertex AI.')
35 | flags.DEFINE_bool('colocate_learner_and_reverb', False,
36 | 'Flag indicating whether to colocate learner and reverb.')
37 |
38 | FLAGS = flags.FLAGS
39 |
40 |
41 | def make_program():
42 | config = LunarLanderDRLearnerConfig
43 |
44 | config_dir = os.path.join(
45 | 'experiments/', FLAGS.exp_path.strip('/').split('/')[-1])
46 | if not os.path.exists(config_dir):
47 | os.makedirs(config_dir)
48 | with open(os.path.join(config_dir, 'config.txt'), 'w') as f:
49 | f.write(str(config))
50 |
51 | print(config)
52 |
53 | env = make_ll_environment(0)
54 | env_spec = acme.make_environment_spec(env)
55 |
56 | def net_factory(env_spec: specs.EnvironmentSpec):
57 | return networks_zoo.make_lunar_lander_nets(config, env_spec)
58 |
59 | def env_factory(seed: int):
60 | return make_ll_environment(seed)
61 |
62 | observers = [
63 | IntrinsicRewardObserver(),
64 | MetaControllerObserver(),
65 | DistillationCoefObserver(),
66 | ]
67 |
68 | evaluator_logger_fn = functools.partial(make_wandb_logger, FLAGS.exp_path,
69 | 'evaluator', save_data=True,
70 | time_delta=1, asynchronous=True,
71 | serialize_fn=utils.fetch_devicearray,
72 | print_fn=logging.info,
73 | steps_key='evaluator_steps',
74 | hyperparams=config,
75 | exp_name=FLAGS.exp_name)
76 |
77 | learner_logger_function = functools.partial(make_wandb_logger, FLAGS.exp_path,
78 | 'learner', save_data=True,
79 | time_delta=1, asynchronous=True,
80 | serialize_fn=utils.fetch_devicearray,
81 | print_fn=logging.info,
82 | steps_key='learner_steps',
83 | hyperparams=config,
84 | exp_name=FLAGS.exp_name)
85 |
86 | program = DistributedDRLearnerFromConfig(
87 | seed=FLAGS.seed,
88 | environment_factory=env_factory,
89 | network_factory=net_factory,
90 | config=config,
91 | workdir=FLAGS.exp_path,
92 | num_actors_per_mixture=FLAGS.num_actors_per_mixture,
93 | max_episodes=FLAGS.num_episodes,
94 | max_steps=FLAGS.num_steps,
95 | environment_spec=env_spec,
96 | actor_observers=observers,
97 | learner_logger_fn=learner_logger_function,
98 | evaluator_observers=observers,
99 | evaluator_factories=[
100 | evaluator_factory_logger_choice(
101 | environment_factory=make_ll_environment,
102 | network_factory=net_factory,
103 | policy_factory=lambda networks: make_policy_networks(
104 | networks, config, evaluation=True),
105 | logger_fn=evaluator_logger_fn,
106 | observers=observers
107 | )
108 | ],
109 | multithreading_colocate_learner_and_reverb=FLAGS.colocate_learner_and_reverb
110 |
111 | ).build(name=FLAGS.exp_path.strip('/').split('/')[-1])
112 |
113 | return program
114 |
115 |
116 | def main(_):
117 | program = make_program()
118 |
119 | if FLAGS.run_on_vertex:
120 | resources = get_toy_env_vertex_resources()
121 | lp.launch(
122 | program,
123 | launch_type=lp.LaunchType.VERTEX_AI,
124 | xm_resources=resources,
125 | terminal='tmux_session')
126 | else:
127 | resources = get_local_resources()
128 | lp.launch(
129 | program,
130 | lp.LaunchType.LOCAL_MULTI_PROCESSING,
131 | local_resources=resources,
132 | terminal='tmux_session'
133 | )
134 |
135 |
136 | if __name__ == '__main__':
137 | app.run(main)
138 |
--------------------------------------------------------------------------------
/examples/play_atari.py:
--------------------------------------------------------------------------------
1 | import acme
2 |
3 | from absl import flags
4 | from absl import app
5 |
6 | from drlearner.drlearner import DRLearner, networks_zoo
7 | from drlearner.core.environment_loop import EnvironmentLoop
8 | from drlearner.environments.atari import make_environment
9 | from drlearner.configs.config_atari import AtariDRLearnerConfig
10 | from drlearner.utils.utils import make_wandb_logger
11 | from drlearner.core.observers import StorageVideoObserver
12 |
13 | flags.DEFINE_string('level', 'ALE/MontezumaRevenge-v5', 'Which game to play.')
14 | flags.DEFINE_integer('seed', 11, 'Random seed.')
15 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to train for.')
16 | flags.DEFINE_string('exp_path', 'experiments/play1', 'Run name.')
17 | flags.DEFINE_string('exp_name', 'atari play', 'Run name.')
18 | flags.DEFINE_string(
19 | 'checkpoint_path', 'experiments/mon_24cores1', 'Path to checkpoints/ dir')
20 |
21 | FLAGS = flags.FLAGS
22 |
23 | # TODo: add possibility to freeze mixture index for final evaluation
24 |
25 |
26 | def load_and_evaluate(_):
27 | config = AtariDRLearnerConfig
28 | config.batch_size = 1
29 | config.num_mixtures = 32
30 | config.beta_max = 0. # if num_mixtures == 1 beta == beta_max
31 | config.n_arms = 32
32 | config.logs_dir = FLAGS.exp_path
33 | config.video_log_period = 1
34 | config.env_library = 'gym'
35 | config.actor_epsilon = 0.01
36 | config.epsilon = 0.01
37 | config.mc_epsilon = 0.01
38 |
39 | env = make_environment(FLAGS.level, oar_wrapper=True)
40 | env_spec = acme.make_environment_spec(env)
41 |
42 | agent = DRLearner(
43 | env_spec,
44 | networks=networks_zoo.make_atari_nets(config, env_spec),
45 | config=config,
46 | seed=FLAGS.seed,
47 | workdir=FLAGS.checkpoint_path
48 | )
49 |
50 | observers = [StorageVideoObserver(config)]
51 | logger = make_wandb_logger(
52 | FLAGS.exp_path, label='evaluator', hyperparams=config, exp_name=FLAGS.exp_name)
53 |
54 | loop = EnvironmentLoop(env, agent, logger=logger,
55 | observers=observers, should_update=False)
56 | loop.run(FLAGS.num_episodes)
57 |
58 |
59 | if __name__ == '__main__':
60 | app.run(load_and_evaluate)
61 |
--------------------------------------------------------------------------------
/examples/run_atari.py:
--------------------------------------------------------------------------------
1 | """
2 | Example running Local Layout DRLearner agent, on Atari-like environments.
3 |
4 | This module contains the main function to run the Atari environment using a Deep Reinforcement Learning (DRL) agent.
5 |
6 | Imports:
7 | os: Provides a way of using operating system dependent functionality.
8 | flags: Command line flag module.
9 | AtariDRLearnerConfig: Configuration for the DRL agent.
10 | make_environment: Function to create an Atari environment.
11 | acme: DeepMind's library of reinforcement learning components.
12 | networks_zoo: Contains the network architectures for the DRL agent.
13 | DRLearner: The DRL agent.
14 | make_wandb_logger: Function to create a Weights & Biases logger.
15 | EnvironmentLoop: Acme's main loop for running environments.
16 |
17 | Functions:
18 | main(_):
19 | The main function to run the Atari environment.
20 |
21 | It sets up the environment, the DRL agent, and the logger, and then runs the environment loop for a specified number of episodes.
22 | """
23 | import os
24 |
25 | import acme
26 | from absl import app
27 | from absl import flags
28 |
29 | from drlearner.drlearner import networks_zoo, DRLearner
30 | from drlearner.configs.config_atari import AtariDRLearnerConfig
31 | from drlearner.core.environment_loop import EnvironmentLoop
32 | from drlearner.environments.atari import make_environment
33 | from drlearner.core.observers import IntrinsicRewardObserver, DistillationCoefObserver
34 | from drlearner.utils.utils import make_wandb_logger
35 |
36 | # Command line flags
37 | flags.DEFINE_string('level', 'PongNoFrameskip-v4', 'Which game to play.')
38 | flags.DEFINE_integer('num_episodes', 7, 'Number of episodes to train for.')
39 | flags.DEFINE_string('exp_path', 'experiments/default',
40 | 'Experiment data storage.')
41 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.')
42 | flags.DEFINE_integer('seed', 0, 'Random seed.')
43 |
44 | flags.DEFINE_bool('force_sync_run', False, 'Skip deadlock warning.')
45 |
46 | FLAGS = flags.FLAGS
47 |
48 |
49 | def main(_):
50 | # Configuration for the DRL agent hyperparameters
51 | config = AtariDRLearnerConfig
52 | # To avoid the deadlock when running reverb in the synchronous set-up,
53 | # this setting ensures rate limiter won't be called.
54 | # @see https://github.com/google-deepmind/acme/issues/207 for additional information.
55 | if config.samples_per_insert != 0:
56 | if not FLAGS.force_sync_run:
57 | while True:
58 | user_answer = input("\nThe simulation may deadlock if run in the synchronous set-up with samples_per_rate != 0. "
59 | "Do you want to continue? (yes/no): ")
60 |
61 | if user_answer.lower() in ["yes", "y"]:
62 | print("Proceeding...")
63 | break
64 | elif user_answer.lower() in ["no", "n"]:
65 | print("Exiting...")
66 | return
67 | else:
68 | print("Invalid input. Please enter yes/no.")
69 |
70 | print(config)
71 | if not os.path.exists(FLAGS.exp_path):
72 | os.makedirs(FLAGS.exp_path)
73 | with open(os.path.join(FLAGS.exp_path, 'config.txt'), 'w') as f:
74 | f.write(str(config))
75 |
76 | # Create the Atari environment
77 | env = make_environment(FLAGS.level, oar_wrapper=True)
78 | # Create the environment specification
79 | env_spec = acme.make_environment_spec(env)
80 |
81 | # Create the networks for the DRL agent learning algorithm
82 | networks = networks_zoo.make_atari_nets(config, env_spec)
83 |
84 | # Create a Weights & Biases loggers for the environment and the actor
85 | logger_env = make_wandb_logger(
86 | FLAGS.exp_path, label='enviroment', hyperparams=config, exp_name=FLAGS.exp_name)
87 | logger_actor = make_wandb_logger(
88 | FLAGS.exp_path, label='actor', hyperparams=config, exp_name=FLAGS.exp_name)
89 |
90 | # Create the DRL agent
91 | agent = DRLearner(
92 | spec=env_spec,
93 | networks=networks,
94 | config=config,
95 | seed=FLAGS.seed,
96 | workdir=FLAGS.exp_path,
97 | logger=logger_actor
98 | )
99 | # Create the observers for the DRL agent
100 | # TODO: Add StorageVideoObserver
101 | observers = [IntrinsicRewardObserver(), DistillationCoefObserver()]
102 |
103 | # Create the environment loop
104 | loop = EnvironmentLoop(
105 | environment=env,
106 | actor=agent,
107 | logger=logger_env,
108 | observers=observers
109 | )
110 | # Run the environment loop for a specified number of episodes
111 | loop.run(FLAGS.num_episodes)
112 |
113 |
114 | if __name__ == '__main__':
115 | app.run(main)
116 |
--------------------------------------------------------------------------------
/examples/run_discomaze.py:
--------------------------------------------------------------------------------
1 | """Example running Local Layout DRLearner on DiscoMaze environment."""
2 |
3 | import os
4 |
5 | import acme
6 | from absl import app
7 | from absl import flags
8 |
9 | from drlearner.drlearner import networks_zoo, DRLearner
10 | from drlearner.configs.config_discomaze import DiscomazeDRLearnerConfig
11 | from drlearner.core.environment_loop import EnvironmentLoop
12 | from drlearner.core.observers import UniqueStatesDiscoMazeObserver, IntrinsicRewardObserver, ActionProbObserver
13 | from drlearner.environments.disco_maze import make_discomaze_environment
14 | from drlearner.utils.utils import make_tf_logger
15 |
16 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to train for.')
17 | flags.DEFINE_string('exp_path', 'experiments/default',
18 | 'Experiment data storage.')
19 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.')
20 | flags.DEFINE_integer('seed', 0, 'Random seed.')
21 |
22 | FLAGS = flags.FLAGS
23 |
24 |
25 | def main(_):
26 | config = DiscomazeDRLearnerConfig
27 | print(config)
28 | if not os.path.exists(FLAGS.exp_path):
29 | os.makedirs(FLAGS.exp_path)
30 | with open(os.path.join(FLAGS.exp_path, 'config.txt'), 'w') as f:
31 | f.write(str(config))
32 |
33 | env = make_discomaze_environment(FLAGS.seed)
34 | env_spec = acme.make_environment_spec(env)
35 |
36 | networks = networks_zoo.make_discomaze_nets(config, env_spec)
37 |
38 | agent = DRLearner(
39 | env_spec,
40 | networks=networks,
41 | config=config,
42 | seed=FLAGS.seed)
43 |
44 | logger = make_tf_logger(FLAGS.exp_path)
45 |
46 | observers = [
47 | UniqueStatesDiscoMazeObserver(),
48 | IntrinsicRewardObserver(),
49 | ActionProbObserver(num_actions=env_spec.actions.num_values),
50 | ]
51 | loop = EnvironmentLoop(env, agent, logger=logger, observers=observers)
52 | loop.run(FLAGS.num_episodes)
53 |
54 |
55 | if __name__ == '__main__':
56 | app.run(main)
57 |
--------------------------------------------------------------------------------
/examples/run_lunar_lander.py:
--------------------------------------------------------------------------------
1 | """
2 | Example running Local Layout DRLearner on Lunar Lander environment.
3 |
4 | This module contains the main function to run the Lunar Lander environment using a Deep Reinforcement Learning (DRL) agent.
5 |
6 | Imports:
7 | os: Provides a way of using operating system dependent functionality.
8 | flags: Command line flag module.
9 | LunarLanderDRLearnerConfig: Configuration for the DRL agent.
10 | make_ll_environment: Function to create a Lunar Lander environment.
11 | acme: DeepMind's library of reinforcement learning components.
12 | networks_zoo: Contains the network architectures for the DRL agent.
13 | DRLearner: The DRL agent.
14 | IntrinsicRewardObserver, DistillationCoefObserver: Observers for the DRL agent.
15 | make_wandb_logger: Function to create a Weights & Biases logger.
16 | EnvironmentLoop: Acme's main loop for running environments.
17 |
18 | Functions:
19 | main(_):
20 | The main function to run the Lunar Lander environment.
21 |
22 | It sets up the environment, the DRL agent, the observers, and the logger, and then runs the environment loop for a specified number of episodes.
23 | """
24 | import os
25 |
26 | import acme
27 | from absl import app
28 | from absl import flags
29 |
30 | from drlearner.drlearner import networks_zoo, DRLearner
31 | from drlearner.configs.config_lunar_lander import LunarLanderDRLearnerConfig
32 | from drlearner.core.environment_loop import EnvironmentLoop
33 | from drlearner.environments.lunar_lander import make_ll_environment
34 | from drlearner.core.observers import IntrinsicRewardObserver, DistillationCoefObserver, StorageVideoObserver
35 | from drlearner.utils.utils import make_wandb_logger
36 |
37 |
38 | # Command line flags
39 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to train for.')
40 | flags.DEFINE_string('exp_path', 'experiments/default',
41 | 'Experiment data storage.')
42 | flags.DEFINE_string('exp_name', 'my first run', 'Run name.')
43 | flags.DEFINE_integer('seed', 42, 'Random seed.')
44 |
45 | FLAGS = flags.FLAGS
46 |
47 |
48 | def main(_):
49 | # Configuration for the DRL agent hyperparameters
50 | config = LunarLanderDRLearnerConfig
51 |
52 | print(config)
53 | if not os.path.exists(FLAGS.exp_path):
54 | os.makedirs(FLAGS.exp_path)
55 | with open(os.path.join(FLAGS.exp_path, 'config.txt'), 'w') as f:
56 | f.write(str(config))
57 |
58 | # Create a Weights & Biases loggers for the environment and the actor
59 | logger_env = make_wandb_logger(
60 | FLAGS.exp_path, label='enviroment', hyperparams=config, exp_name=FLAGS.exp_name)
61 | logger_actor = make_wandb_logger(
62 | FLAGS.exp_path, label='actor', hyperparams=config, exp_name=FLAGS.exp_name)
63 |
64 | # Create the Lunar Lander environment
65 | env = make_ll_environment(FLAGS.seed)
66 | # Create the environment specification
67 | env_spec = acme.make_environment_spec(env)
68 |
69 | # Create the networks for the DRL agent learning algorithm
70 | networks = networks_zoo.make_lunar_lander_nets(config, env_spec)
71 |
72 | # Create the DRL agent
73 | agent = DRLearner(
74 | spec=env_spec,
75 | networks=networks,
76 | config=config,
77 | seed=FLAGS.seed,
78 | workdir=FLAGS.exp_path,
79 | logger=logger_actor
80 | )
81 | # Create the observers for the DRL agent
82 | observers = [IntrinsicRewardObserver(), DistillationCoefObserver(),
83 | StorageVideoObserver(config)]
84 |
85 | # Create the environment loop
86 | loop = EnvironmentLoop(
87 | environment=env,
88 | actor=agent,
89 | logger=logger_env,
90 | observers=observers
91 | )
92 | # Run the environment loop for a specified number of episodes
93 | loop.run(FLAGS.num_episodes)
94 |
95 |
96 | if __name__ == '__main__':
97 | app.run(main)
98 |
--------------------------------------------------------------------------------
/external/xm_docker.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | # Copyright 2020 DeepMind Technologies Limited. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities to run PyNodes in Docker containers using XManager."""
17 |
18 | import atexit
19 | import copy
20 | import dataclasses
21 | from distutils import dir_util
22 | import functools
23 | import os
24 | import pathlib
25 | import shutil
26 | import sys
27 | import tempfile
28 | from typing import Any, List, Optional, Sequence, Tuple
29 |
30 | import cloudpickle
31 | from launchpad.launch import serialization
32 |
33 | try:
34 | from xmanager import xm
35 | except ModuleNotFoundError:
36 | raise Exception('Launchpad requires `xmanager` for XM-based runtimes.'
37 | 'Please run `pip install xmanager`.')
38 |
39 |
40 | _DATA_FILE_NAME = 'job.pkl'
41 | _INIT_FILE_NAME = 'init.pkl'
42 |
43 |
44 | @dataclasses.dataclass
45 | class DockerConfig:
46 | """Local docker launch configuration.
47 |
48 | Attributes:
49 | code_directory: Path to directory containing any user code that may be
50 | required inside the Docker image. The user code from this directory is
51 | copied over into the Docker containers, as the user code may be needed
52 | during program execution. If needed, modify docker_instructions in
53 | xm.PythonContainer construction below if user code needs installation.
54 | docker_requirements: Path to requirements.txt specifying Python packages to
55 | install inside the Docker image.
56 | hw_requirements: Hardware requirements.
57 | python_path: Additional paths to be added to PYTHONPATH prior to executing
58 | an entry point.
59 | """
60 | code_directory: Optional[str] = None
61 | docker_requirements: Optional[str] = None
62 | hw_requirements: Optional[xm.JobRequirements] = None
63 | python_path: Optional[List[str]] = None
64 |
65 |
66 | def initializer(python_path):
67 | sys.path = python_path + sys.path
68 |
69 |
70 | def to_docker_executables(
71 | nodes: Sequence[Any],
72 | label: str,
73 | docker_config: DockerConfig,
74 | ) -> List[Tuple[xm.PythonContainer, xm.JobRequirements]]:
75 |
76 | """Returns a list of `PythonContainer`s objects for the given `PyNode`s."""
77 |
78 | if docker_config.code_directory is None or docker_config.docker_requirements is None:
79 | raise ValueError(
80 | 'code_directory and docker_requirements must be specified through'
81 | 'DockerConfig via local_resources when using "xm_docker" launch type.')
82 |
83 | # Generate tmp dir without '_' in the name, Vertex AI fails otherwise.
84 | tmp_dir = '_'
85 | while '_' in tmp_dir:
86 | tmp_dir = tempfile.mkdtemp()
87 | atexit.register(shutil.rmtree, tmp_dir, ignore_errors=True)
88 |
89 | command_line = f'python -m my_process_entry --data_file={_DATA_FILE_NAME}'
90 |
91 | # Add common initialization function for all nodes which sets up PYTHONPATH.
92 | if docker_config.python_path:
93 | command_line += f' --init_file={_INIT_FILE_NAME}'
94 | # Local 'path' is copied under 'tmp_dir' (no /tmp prefix) inside Docker.
95 | python_path = [
96 | '/' + os.path.basename(tmp_dir) + os.path.abspath(path)
97 | for path in docker_config.python_path
98 | ]
99 | initializer_file_path = pathlib.Path(tmp_dir, _INIT_FILE_NAME)
100 | with open(initializer_file_path, 'wb') as f:
101 | cloudpickle.dump(functools.partial(initializer, python_path), f)
102 |
103 | data_file_path = str(pathlib.Path(tmp_dir, _DATA_FILE_NAME))
104 | serialization.serialize_functions(data_file_path, label,
105 | [n.function for n in nodes])
106 |
107 | file_path = pathlib.Path(__file__).absolute()
108 |
109 | # shutil.copy(pathlib.Path(file_path.parent, 'process_entry.py'), tmp_dir)
110 | dir_util.copy_tree(docker_config.code_directory, tmp_dir)
111 | shutil.copy(docker_config.docker_requirements,
112 | pathlib.Path(tmp_dir, 'requirements.txt'))
113 |
114 | workdir_path = pathlib.Path(tmp_dir).name
115 |
116 | if not os.path.exists(docker_config.docker_requirements):
117 | raise FileNotFoundError('Please specify a path to a file with Python'
118 | 'package requirements through'
119 | 'docker_config.docker_requirements.')
120 | job_requirements = docker_config.hw_requirements
121 | if not job_requirements:
122 | job_requirements = xm.JobRequirements()
123 |
124 | # Make a copy of requirements since they are being mutated below.
125 | job_requirements = copy.deepcopy(job_requirements)
126 |
127 | if job_requirements.replicas != 1:
128 | raise ValueError(
129 | 'Number of replicas is computed by the runtime. '
130 | 'Please do not set it explicitly in the requirements.'
131 | )
132 |
133 | job_requirements.replicas = len(nodes)
134 | # python_version = f'{sys.version_info.major}.{sys.version_info.minor}'
135 | python_version = 3.7
136 |
137 | # if label == 'workerpool2':
138 | command_lst = [
139 | # 'python -m pip install --upgrade jax==0.3.7 jaxlib==0.3.7+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html',
140 | 'python -m pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
141 | 'python -m pip install --upgrade flax==0.4.1',
142 | command_line
143 | ]
144 | base_image = 'gcr.io/deeplearning-platform-release/tf-gpu.2-8'
145 | return [(xm.PythonContainer(
146 | path=tmp_dir,
147 | base_image=base_image,
148 | entrypoint=xm.CommandList(command_lst),
149 | docker_instructions=[
150 | 'ENV XLA_PYTHON_CLIENT_MEM_FRACTION=0.8',
151 | 'ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/lib:/usr/lib:/usr/local/lib:/opt/conda/lib',
152 | 'ENV PYTHONPATH=$PYTHONPATH:$(pwd)',
153 | 'RUN apt-get install -y git',
154 | f'RUN apt-get -y install libpython{python_version}',
155 | f'COPY {workdir_path}/ {workdir_path}',
156 |
157 | f'COPY {workdir_path}/requirements.txt requirements.txt',
158 | 'RUN python -m pip install xmanager',
159 | 'RUN python -m pip install --no-cache-dir -r requirements.txt ',
160 | 'RUN python -m pip install git+https://github.com/ivannz/gymDiscoMaze.git@stable',
161 |
162 | f'RUN ale-import-roms {workdir_path}/roms/',
163 | f'WORKDIR {workdir_path}',
164 | ]), job_requirements)]
165 |
166 | # else:
167 | # base_image = f'python:{python_version}'
168 | #
169 | #
170 | # return [(xm.PythonContainer(
171 | # path=tmp_dir,
172 | # base_image=base_image,
173 | # entrypoint=xm.CommandList([command_line]),
174 | # docker_instructions=[
175 | # 'ENV LD_LIBRARY_PATH=/lib:/usr/lib:/usr/local/lib:/usr/local/nvidia/lib64:/usr/local/cuda-11.0/targets/x86_64-linux/lib:/opt/conda/lib:/usr/local/cuda-11.0/targets/x86_64-linux/lib/stubs/',
176 | # 'RUN apt-get install -y git',
177 | # f'RUN apt-get -y install libpython{python_version}',
178 | # f'COPY {workdir_path}/requirements.txt requirements.txt',
179 | # 'RUN python -m pip install xmanager',
180 | # 'RUN python -m pip install --no-cache-dir -r requirements.txt ',
181 | # f'COPY {workdir_path}/ {workdir_path}',
182 | # f'RUN ale-import-roms {workdir_path}/roms/',
183 | # f'WORKDIR {workdir_path}',
184 | # ]), job_requirements)]
185 |
186 |
--------------------------------------------------------------------------------
/my_process_entry.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | # Copyright 2020 DeepMind Technologies Limited. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Entry of a PythonNode worker."""
17 |
18 |
19 | import contextlib
20 | import json
21 | import os
22 | import sys
23 |
24 | from absl import app
25 | from absl import flags
26 | from absl import logging
27 | import cloudpickle
28 | from launchpad.launch import worker_manager
29 | import six
30 |
31 | import tensorflow as tf
32 |
33 | tf.config.set_visible_devices([], 'GPU')
34 |
35 |
36 | FLAGS = flags.FLAGS
37 |
38 | flags.DEFINE_integer(
39 | 'lp_task_id', None, 'a list index deciding which '
40 | 'worker to run. given a list of workers (obtained from the'
41 | ' data_file)')
42 | flags.DEFINE_string('data_file', '',
43 | 'Pickle file location with entry points for all nodes')
44 | flags.DEFINE_string(
45 | 'lp_job_name', '',
46 | 'The name of the job, used to access the correct pickle file resource when '
47 | 'using the new launch API')
48 | flags.DEFINE_string(
49 | 'init_file', '', 'Pickle file location containing initialization module '
50 | 'executed for each node prior to an entry point')
51 | flags.DEFINE_string('flags_to_populate', '{}', '')
52 |
53 | _FLAG_TYPE_MAPPING = {
54 | str: flags.DEFINE_string,
55 | six.text_type: flags.DEFINE_string,
56 | float: flags.DEFINE_float,
57 | int: flags.DEFINE_integer,
58 | bool: flags.DEFINE_boolean,
59 | list: flags.DEFINE_list,
60 | }
61 |
62 |
63 | def _populate_flags():
64 | """Populate flags that cannot be passed directly to this script."""
65 | FLAGS(sys.argv, known_only=True)
66 |
67 | flags_to_populate = json.loads(FLAGS.flags_to_populate)
68 | for name, value in flags_to_populate.items():
69 | value_type = type(value)
70 | if value_type in _FLAG_TYPE_MAPPING:
71 | flag_ctr = _FLAG_TYPE_MAPPING[value_type]
72 | logging.info('Defining flag %s with default value %s', name, value)
73 | flag_ctr(
74 | name,
75 | value,
76 | 'This flag has been auto-generated.',
77 | allow_override=True)
78 |
79 | # JAX doesn't use absl flags and so we need to forward absl flags to JAX
80 | # explicitly. Here's a heuristic to detect JAX flags and forward them.
81 | for arg in sys.argv:
82 | if arg.startswith('--jax_'):
83 | try:
84 | # pytype:disable=import-error
85 | import jax
86 | # pytype:enable=import-error
87 | jax.config.parse_flags_with_absl()
88 | break
89 | except ImportError:
90 | pass
91 |
92 |
93 | def _get_task_id():
94 | """Returns current task's id."""
95 | if FLAGS.lp_task_id is None:
96 | # Running under Vertex AI...
97 | cluster_spec = os.environ.get('CLUSTER_SPEC', None)
98 | return json.loads(cluster_spec).get('task').get('index')
99 |
100 | return FLAGS.lp_task_id
101 |
102 |
103 | def main(_):
104 | # Allow for importing modules from the current directory.
105 | sys.path.append(os.getcwd())
106 | data_file = FLAGS.data_file
107 | init_file = FLAGS.init_file
108 |
109 | if os.environ.get('TF_CONFIG', None):
110 | # For GCP runtime log to STDOUT so that logs are not reported as errors.
111 | logging.get_absl_handler().python_handler.stream = sys.stdout
112 |
113 | if init_file:
114 | init_function = cloudpickle.load(open(init_file, 'rb'))
115 | init_function()
116 | functions = cloudpickle.load(open(data_file, 'rb'))
117 | task_id = _get_task_id()
118 |
119 | # Worker manager is used here to handle termination signals and provide
120 | # preemption support.
121 | worker_manager.WorkerManager(
122 | register_in_thread=True)
123 |
124 | with contextlib.suppress(): # no-op context manager
125 | functions[task_id]()
126 |
127 |
128 | if __name__ == '__main__':
129 | _populate_flags()
130 | app.run(main)
131 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | numpy==1.22.4
3 | cloudpickle==2.0.0
4 | six==1.16.0
5 | libpython==0.2
6 | chex==0.1.5
7 | Cython==0.29.28
8 | flax==0.4.1
9 | optax==0.1.2
10 | rlax==0.1.4
11 | pyglet==1.5.24
12 | xmanager==0.1.5
13 | pyvirtualdisplay==3.0
14 | sk-video==1.1.10
15 | ffmpeg-python==0.2.0
16 | wandb==0.16.2
17 | tensorrt
18 | tensorflow-gpu==2.8.0
19 | tensorflow_probability==0.15.0
20 | tensorflow_datasets==4.6.0
21 | dm-reverb==0.7.2
22 | dm-launchpad==0.5.2
23 | jax==0.4.3
24 | jaxlib==0.4.3+cuda11.cudnn86
25 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
26 | dm-haiku==0.0.10
27 | dm-sonnet
28 | trfl
29 | atari-py
30 | bsuite
31 | dm-control
32 | gym==0.25.0
33 | gym[accept-rom-license, atari, Box2D]
34 | pygame==2.1.0
35 | rlds
36 | git+https://github.com/google-deepmind/acme.git@4c6351ef8ff3f4045a9a24bee6a994667d89c69c
37 | scipy==1.12.0
--------------------------------------------------------------------------------
/scripts/update_tb.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from google.cloud import storage
3 |
4 | BUCKET_NAME = os.environ['GOOGLE_CLOUD_BUCKET_NAME']
5 |
6 |
7 | def gcs_download(prefix, save_to, fname=None):
8 | """ prefix - experiment name, i.e. test_pong/
9 | save_to - path to save the downloaded files to
10 | """
11 | if not os.path.isdir(os.path.join(save_to, prefix)):
12 | os.mkdir(os.path.join(save_to, prefix))
13 | storage_client = storage.Client()
14 | bucket = storage_client.bucket(BUCKET_NAME)
15 |
16 | blobs = storage_client.list_blobs(BUCKET_NAME, prefix=prefix, delimiter='/')
17 | for b in blobs:
18 | if not fname or fname in b.name:
19 | blob = bucket.blob(b.name)
20 | print(os.path.join(save_to, b.name))
21 | blob.download_to_filename(os.path.join(save_to, b.name))
22 |
23 |
24 | if __name__ == '__main__':
25 | exp_path = sys.argv[1]
26 | save_to = sys.argv[2]
27 | gcs_download(exp_path, save_to, 'tfevents')
28 |
--------------------------------------------------------------------------------