├── LICENSE ├── README.md ├── requirements.txt ├── revisiting_rainbow ├── Agents │ ├── dqn_agent_new.py │ ├── implicit_quantile_agent_new.py │ ├── quantile_agent_new.py │ └── rainbow_agent_new.py ├── Configs │ ├── dqn_acrobot.gin │ ├── dqn_asterix.gin │ ├── dqn_breakout.gin │ ├── dqn_cartpole.gin │ ├── dqn_freeway.gin │ ├── dqn_invaders.gin │ ├── dqn_lunarlander.gin │ ├── dqn_mountaincar.gin │ ├── dqn_seaquest.gin │ ├── implicit_acrobot.gin │ ├── implicit_asterix.gin │ ├── implicit_breakout.gin │ ├── implicit_cartpole.gin │ ├── implicit_freeway.gin │ ├── implicit_invaders.gin │ ├── implicit_lunarlander.gin │ ├── implicit_mountaincar.gin │ ├── implicit_seaquest.gin │ ├── quantile_acrobot.gin │ ├── quantile_asterix.gin │ ├── quantile_breakout.gin │ ├── quantile_cartpole.gin │ ├── quantile_freeway.gin │ ├── quantile_invaders.gin │ ├── quantile_lunarlander.gin │ ├── quantile_mountaincar.gin │ ├── quantile_seaquest.gin │ ├── rainbow_acrobot.gin │ ├── rainbow_asterix.gin │ ├── rainbow_breakout.gin │ ├── rainbow_cartpole.gin │ ├── rainbow_freeway.gin │ ├── rainbow_invaders.gin │ ├── rainbow_lunarlander.gin │ ├── rainbow_mountaincar.gin │ └── rainbow_seaquest.gin ├── external_configurations.py ├── minatar_env.py ├── networks_new.py └── test_main.ipynb └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisiting Rainbow: Promoting more insightful and inclusive deep reinforcement learning research 2 | 3 | In this work we argue that, despite the community’s emphasis on large-scale environments, the traditional small-scale environments can 4 | still yield valuable scientific insights and can help reduce the barriers to entry for underprivileged communities. To substantiate our claims, we empirically revisit the paper which introduced the Rainbow algorithm [Hessel et al., 2018][fortunato] and present some new insights into the algorithms used by Rainbow. 5 | 6 | Our rainbow agent implements three addittional components to the rainbow agent proposed by Dopamine. ([Pablo Samuel Castro et al., 2018][castro]) 7 | 8 | * Noisy nets ([Meire Fortunato et al., 2018][fortunato]) 9 | * Dueling networks ([Hado van Hasselt et al., 2016][wang]) 10 | * Double Q-learning ([Ziyu Wang et al., 2016][hasselt]) 11 | * Munchausen Reinforcement Learning ([Nino Vieillard et al., 2020][Vieillard]) 12 | 13 | if you are interested to know more about Revisiting Rainbow, considering check the following resources: 14 | 15 | * **Paper:** [arxiv.org/abs/2011.14826][arXiv_rev] 16 | * **Blog:** [https://psc-g.github.io/posts/...][blog] 17 | * **Deep RL Workshop talk, NeurIPS 2020:** [https://slideslive.com/38941329/...][video] 18 | 19 | 20 | ## Quick Start 21 | To use the algorithms proposed in the Revisiting Rainbow paper, you need python3 installed, make sure pip is also up to date. If you want to run the MinAtar experiments you should install it. To install MinAtar, please check the following paper ([Young et al., 2019][young]) and repositore ([github][young_repo]): 22 | 23 | 1. Clone the repo: 24 | ```bash 25 | https://github.com/JohanSamir/revisiting_rainbow 26 | ``` 27 | If you prefer running the algorithms in a virtualenv, you can do the following before step 2: 28 | 29 | ```bash 30 | python3 -m venv venv 31 | source venv/bin/activate 32 | # Upgrade Pip 33 | pip install --upgrade pip 34 | ``` 35 | 36 | 2. Finally setup the environment and install Revisiting Rainbow's dependencies 37 | ```bash 38 | pip install -U pip 39 | pip install -r revisiting_rainbow/requirements.txt 40 | ``` 41 | 42 | ## Running tests 43 | 44 | Check the following colab file [`revisiting_rainbow/test_main.ipynb`](https://github.com/JohanSamir/revisiting_rainbow/blob/main/revisiting_rainbow/test_main.ipynb) to run the basic DQN agent. 45 | 46 | ## References 47 | 48 | [Hado van Hasselt, Arthur Guez, and David Silver. *Deep reinforcement learning with double q-learning*. 49 | In Proceedings of the Thirthieth AAAI Conference On Artificial Intelligence (AAAI), 2016.][hasselt] 50 | 51 | [Matteo Hessel, Joseph Modayil, Hado van Hasselt, Tom Schaul, Georg Ostrovski, Will Dabney, Dan 52 | Horgan, Bilal Piot, Mohammad Azar, and David Silver. *Rainbow: Combining Improvements in Deep Reinforcement learning*. 53 | In Proceedings of the AAAI Conference on Artificial Intelligence, 2018.][Hessel] 54 | 55 | [Meire Fortunato, Mohammad Gheshlaghi Azar, Bilal Piot, Jacob Menick, Ian Osband, Alexander 56 | Graves, Vlad Mnih, Remi Munos, Demis Hassabis, Olivier Pietquin, Charles Blundell, and 57 | Shane Legg. *Noisy networks for exploration*. In Proceedings of the International Conference on 58 | Representation Learning (ICLR 2018), Vancouver (Canada), 2018.][fortunato] 59 | 60 | [Pablo Samuel Castro, Subhodeep Moitra, Carles Gelada, Saurabh Kumar, and Marc G. Bellemare. 61 | *Dopamine: A Research Framework for Deep Reinforcement Learning*, 2018.][castro] 62 | 63 | [Kenny Young and Tian Tian. *Minatar: An atari-inspired testbed for thorough and reproducible reinforcement learning experiments*, 2019.][young] 64 | 65 | [Ziyu Wang, Tom Schaul, Matteo Hessel, Hado Hasselt, Marc Lanctot, and Nando Freitas. *Dueling network architectures for deep reinforcement learning*. In Proceedings of the 33rd International 66 | Conference on Machine Learning, volume 48, pages 1995–2003, 2016.][wang] 67 | 68 | [Vieillard, N., Pietquin, O., and Geist, M. Munchausen Reinforcement Learning. In Advances in Neural Information Processing Systems (NeurIPS), 2020.][Vieillard] 69 | 70 | [fortunato]: https://arxiv.org/abs/1706.10295 71 | [hasselt]: https://arxiv.org/abs/1509.06461 72 | [wang]: https://arxiv.org/abs/1511.06581 73 | [castro]: https://arxiv.org/abs/1812.06110 74 | [Hessel]: https://arxiv.org/abs/1710.02298 75 | [young]: https://arxiv.org/abs/1903.03176 76 | [Vieillard]: https://arxiv.org/abs/2007.14430 77 | [young_repo]: https://github.com/kenjyoung/MinAtar 78 | [arXiv_rev]: https://arxiv.org/abs/2011.14826 79 | [blog]: https://psc-g.github.io/posts/research/rl/revisiting_rainbow/ 80 | [video]: https://slideslive.com/38941329/revisiting-rainbow-promoting-more-insightful-and-inclusive-deep-reinforcement-learning-research 81 | 82 | ## Giving credit 83 | If you use Revisiting Rainbow in your research please cite the following: 84 | 85 | Johan S Obando-Ceron, & Pablo Samuel Castro (2020). Revisiting Rainbow: Promoting more insightful and inclusive deep reinforcement learning research. Proceedings of the 38th International Conference on Machine Learning, ICML 2021. [*arXiv preprint:* ][arXiv_rev] 86 | 87 | In BibTeX format: 88 | 89 | ``` 90 | @inproceedings{obando2020revisiting, 91 | title={Revisiting Rainbow: Promoting more insightful and inclusive deep reinforcement learning research}, 92 | author={Obando-Ceron, Johan S and Castro, Pablo Samuel}, 93 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 94 | year = {2021}, 95 | series = {Proceedings of Machine Learning Research}, 96 | publisher = {PMLR}, 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dopamine-rl==3.1.10 2 | minatar -> https://github.com/kenjyoung/MinAtar 3 | -------------------------------------------------------------------------------- /revisiting_rainbow/Agents/dqn_agent_new.py: -------------------------------------------------------------------------------- 1 | """Compact implementation of a DQN agent 2 | 3 | Specifically, we implement the following components: 4 | 5 | * prioritized replay 6 | * huber_loss 7 | * mse_loss 8 | * double_dqn 9 | * noisy 10 | * dueling 11 | * Munchausen 12 | 13 | Details in: 14 | "Human-level control through deep reinforcement learning" by Mnih et al. (2015). 15 | "Noisy Networks for Exploration" by Fortunato et al. (2017). 16 | "Deep Reinforcement Learning with Double Q-learning" by Hasselt et al. (2015). 17 | "Dueling Network Architectures for Deep Reinforcement Learning" by Wang et al. (2015). 18 | "Munchausen Reinforcement Learning" by Vieillard et al. (2020). 19 | 20 | """ 21 | import time 22 | import copy 23 | import functools 24 | from dopamine.jax import networks 25 | from dopamine.jax.agents.dqn import dqn_agent 26 | from dopamine.replay_memory import prioritized_replay_buffer 27 | import gin 28 | import jax 29 | import jax.numpy as jnp 30 | import numpy as onp 31 | import tensorflow as tf 32 | import jax.scipy.special as scp 33 | from flax import linen as nn 34 | 35 | 36 | def mse_loss(targets, predictions): 37 | return jnp.mean(jnp.power((targets - (predictions)),2)) 38 | 39 | 40 | @functools.partial(jax.jit, static_argnums=(0, 9,10,11,12,13, 14)) 41 | def train(network_def, target_params, optimizer, states, actions, next_states, rewards, 42 | terminals, loss_weights, cumulative_gamma, target_opt, mse_inf,tau,alpha,clip_value_min, rng): 43 | 44 | 45 | """Run the training step.""" 46 | online_params = optimizer.target 47 | def loss_fn(params, rng_input, target, loss_multipliers): 48 | def q_online(state): 49 | return network_def.apply(params, state, rng=rng_input) 50 | 51 | q_values = jax.vmap(q_online)(states).q_values 52 | q_values = jnp.squeeze(q_values) 53 | replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions) 54 | 55 | if mse_inf: 56 | loss = jax.vmap(mse_loss)(target, replay_chosen_q) 57 | else: 58 | loss = jax.vmap(dqn_agent.huber_loss)(target, replay_chosen_q) 59 | 60 | mean_loss = jnp.mean(loss_multipliers * loss) 61 | return mean_loss, loss 62 | 63 | rng, rng2, rng3, rng4 = jax.random.split(rng, 4) 64 | 65 | def q_target(state): 66 | return network_def.apply(target_params, state, rng=rng2) 67 | 68 | def q_target_online(state): 69 | return network_def.apply(online_params, state, rng=rng4) 70 | 71 | if target_opt == 0: 72 | target = dqn_agent.target_q(q_target, next_states, rewards, terminals, cumulative_gamma) 73 | elif target_opt == 1: 74 | #Double DQN 75 | target = target_DDQN(q_target_online, q_target, next_states, rewards, terminals, cumulative_gamma) 76 | 77 | elif target_opt == 2: 78 | #Munchausen 79 | target = target_m_dqn(q_target_online, q_target, states,next_states,actions,rewards,terminals, 80 | cumulative_gamma,tau,alpha,clip_value_min) 81 | else: 82 | print('error') 83 | 84 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 85 | (mean_loss, loss), grad = grad_fn(online_params, rng3, target, loss_weights) 86 | optimizer = optimizer.apply_gradient(grad) 87 | return optimizer, loss, mean_loss 88 | 89 | def target_DDQN(model, target_network, next_states, rewards, terminals, cumulative_gamma): 90 | """Compute the target Q-value. Double DQN""" 91 | next_q_values = jax.vmap(model, in_axes=(0))(next_states).q_values 92 | next_q_values = jnp.squeeze(next_q_values) 93 | replay_next_qt_max = jnp.argmax(next_q_values, axis=1) 94 | next_q_state_values = jax.vmap(target_network, in_axes=(0))(next_states).q_values 95 | 96 | q_values = jnp.squeeze(next_q_state_values) 97 | replay_chosen_q = jax.vmap(lambda t, u: t[u])(q_values, replay_next_qt_max) 98 | 99 | return jax.lax.stop_gradient(rewards + cumulative_gamma * replay_chosen_q * 100 | (1. - terminals)) 101 | 102 | def stable_scaled_log_softmax(x, tau, axis=-1): 103 | max_x = jnp.amax(x, axis=axis, keepdims=True) 104 | y = x - max_x 105 | tau_lse = max_x + tau * jnp.log(jnp.sum(jnp.exp(y / tau), axis=axis, keepdims=True)) 106 | return x - tau_lse 107 | 108 | def stable_softmax(x, tau, axis=-1): 109 | max_x = jnp.amax(x, axis=axis, keepdims=True) 110 | y = x - max_x 111 | return nn.softmax(y/tau, axis=axis) 112 | 113 | def target_m_dqn(model, target_network, states, next_states, actions,rewards, terminals, 114 | cumulative_gamma,tau,alpha,clip_value_min): 115 | """Compute the target Q-value. Munchausen DQN""" 116 | 117 | #---------------------------------------- 118 | q_state_values = jax.vmap(target_network, in_axes=(0))(states).q_values 119 | q_state_values = jnp.squeeze(q_state_values) 120 | 121 | next_q_values = jax.vmap(target_network, in_axes=(0))(next_states).q_values 122 | next_q_values = jnp.squeeze(next_q_values) 123 | #---------------------------------------- 124 | 125 | tau_log_pi_next = stable_scaled_log_softmax(next_q_values, tau, axis=1) 126 | pi_target = stable_softmax(next_q_values,tau, axis=1) 127 | replay_log_policy = stable_scaled_log_softmax(q_state_values, tau, axis=1) 128 | 129 | #---------------------------------------- 130 | 131 | replay_next_qt_softmax = jnp.sum((next_q_values-tau_log_pi_next)*pi_target,axis=1) 132 | 133 | replay_action_one_hot = nn.one_hot(actions, q_state_values.shape[-1]) 134 | tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=1) 135 | 136 | #a_max=1 137 | tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min,a_max=1) 138 | 139 | munchausen_term = alpha * tau_log_pi_a 140 | modified_bellman = (rewards + munchausen_term +cumulative_gamma * replay_next_qt_softmax * 141 | (1. - jnp.float32(terminals))) 142 | 143 | return jax.lax.stop_gradient(modified_bellman) 144 | 145 | 146 | @functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11)) 147 | def select_action(network_def, params, state, rng, num_actions, eval_mode, 148 | epsilon_eval, epsilon_train, epsilon_decay_period, 149 | training_steps, min_replay_history, epsilon_fn): 150 | 151 | epsilon = jnp.where(eval_mode, 152 | epsilon_eval, 153 | epsilon_fn(epsilon_decay_period, 154 | training_steps, 155 | min_replay_history, 156 | epsilon_train)) 157 | 158 | rng, rng1, rng2, rng3 = jax.random.split(rng, num=4) 159 | selected_action = jnp.argmax(network_def.apply(params, state, rng=rng3).q_values) 160 | p = jax.random.uniform(rng1) 161 | return rng, jnp.where(p <= epsilon, 162 | jax.random.randint(rng2, (), 0, num_actions), 163 | selected_action) 164 | 165 | @gin.configurable 166 | class JaxDQNAgentNew(dqn_agent.JaxDQNAgent): 167 | """A compact implementation of a simplified Rainbow agent.""" 168 | 169 | def __init__(self, 170 | num_actions, 171 | 172 | tau, 173 | alpha=1, 174 | clip_value_min=-10, 175 | 176 | net_conf = None, 177 | env = "CartPole", 178 | normalize_obs = True, 179 | hidden_layer=2, 180 | neurons=512, 181 | replay_scheme='prioritized', 182 | noisy = False, 183 | dueling = False, 184 | initzer = 'xavier_uniform', 185 | target_opt=0, 186 | mse_inf=False, 187 | network=networks.NatureDQNNetwork, 188 | optimizer='adam', 189 | epsilon_fn=dqn_agent.linearly_decaying_epsilon, 190 | seed=None): 191 | """Initializes the agent and constructs the necessary components. 192 | 193 | Args: 194 | num_actions: int, number of actions the agent can take at any state. 195 | observation_shape: tuple of ints or an int. If single int, the observation 196 | is assumed to be a 2D square. 197 | observation_dtype: DType, specifies the type of the observations. Note 198 | that if your inputs are continuous, you should set this to jnp.float32. 199 | stack_size: int, number of frames to use in state stack. 200 | network: flax.nn Module that is initialized by shape in _create_network 201 | below. See dopamine.jax.networks.RainbowNetwork as an example. 202 | num_atoms: int, the number of buckets of the value function distribution. 203 | vmax: float, the value distribution support is [-vmax, vmax]. 204 | gamma: float, discount factor with the usual RL meaning. 205 | update_horizon: int, horizon at which updates are performed, the 'n' in 206 | n-step update. 207 | min_replay_history: int, number of transitions that should be experienced 208 | before the agent begins training its value function. 209 | update_period: int, period between DQN updates. 210 | target_update_period: int, update period for the target network. 211 | epsilon_fn: function expecting 4 parameters: 212 | (decay_period, step, warmup_steps, epsilon). This function should return 213 | the epsilon value used for exploration during training. 214 | epsilon_train: float, the value to which the agent's epsilon is eventually 215 | decayed during training. 216 | epsilon_eval: float, epsilon used when evaluating the agent. 217 | epsilon_decay_period: int, length of the epsilon decay schedule. 218 | replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the 219 | replay memory. 220 | optimizer: str, name of optimizer to use. 221 | summary_writer: SummaryWriter object for outputting training statistics. 222 | Summary writing disabled if set to None. 223 | summary_writing_frequency: int, frequency with which summaries will be 224 | written. Lower values will result in slower training. 225 | allow_partial_reload: bool, whether we allow reloading a partial agent 226 | (for instance, only the network parameters). 227 | """ 228 | # We need this because some tools convert round floats into ints. 229 | seed = int(time.time() * 1e6) if seed is None else seed 230 | self._net_conf = net_conf 231 | self._env = env 232 | self._normalize_obs = normalize_obs 233 | self._hidden_layer = hidden_layer 234 | self._neurons=neurons 235 | self._noisy = noisy 236 | self._dueling = dueling 237 | self._initzer = initzer 238 | self._target_opt = target_opt 239 | self._mse_inf = mse_inf 240 | self._tau = tau 241 | self._alpha = alpha 242 | self._clip_value_min = clip_value_min 243 | self._rng = jax.random.PRNGKey(seed) 244 | 245 | super(JaxDQNAgentNew, self).__init__( 246 | num_actions= num_actions, 247 | network= functools.partial(network, 248 | num_actions=num_actions, 249 | net_conf=self._net_conf, 250 | env=self._env, 251 | normalize_obs=self._normalize_obs, 252 | hidden_layer=self._hidden_layer, 253 | neurons=self._neurons, 254 | noisy=self._noisy, 255 | dueling=self._dueling, 256 | initzer=self._initzer), 257 | optimizer=optimizer, 258 | epsilon_fn=dqn_agent.identity_epsilon if self._noisy == True else epsilon_fn) 259 | 260 | 261 | self._replay_scheme = replay_scheme 262 | 263 | def _build_networks_and_optimizer(self): 264 | self._rng, rng = jax.random.split(self._rng) 265 | online_network_params = self.network_def.init( 266 | rng, x=self.state, rng=self._rng) 267 | optimizer_def = dqn_agent.create_optimizer(self._optimizer_name) 268 | self.optimizer = optimizer_def.create(online_network_params) 269 | self.target_network_params = copy.deepcopy(online_network_params) 270 | 271 | 272 | def _build_replay_buffer(self): 273 | """Creates the prioritized replay buffer used by the agent.""" 274 | return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( 275 | observation_shape=self.observation_shape, 276 | stack_size=self.stack_size, 277 | update_horizon=self.update_horizon, 278 | gamma=self.gamma, 279 | observation_dtype=self.observation_dtype) 280 | 281 | def _train_step(self): 282 | """Runs a single training step. 283 | 284 | Runs training if both: 285 | (1) A minimum number of frames have been added to the replay buffer. 286 | (2) `training_steps` is a multiple of `update_period`. 287 | 288 | Also, syncs weights from online_network to target_network if training steps 289 | is a multiple of target update period. 290 | """ 291 | # Run a train op at the rate of self.update_period if enough training steps 292 | # have been run. This matches the Nature DQN behaviour. 293 | if self._replay.add_count > self.min_replay_history: 294 | if self.training_steps % self.update_period == 0: 295 | self._sample_from_replay_buffer() 296 | 297 | if self._replay_scheme == 'prioritized': 298 | # The original prioritized experience replay uses a linear exponent 299 | # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 300 | # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) 301 | # suggested a fixed exponent actually performs better, except on Pong. 302 | probs = self.replay_elements['sampling_probabilities'] 303 | # Weight the loss by the inverse priorities. 304 | loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) 305 | loss_weights /= jnp.max(loss_weights) 306 | else: 307 | loss_weights = jnp.ones(self.replay_elements['state'].shape[0]) 308 | 309 | 310 | self.optimizer, loss, mean_loss = train(self.network_def, 311 | self.target_network_params, 312 | self.optimizer, 313 | self.replay_elements['state'], 314 | self.replay_elements['action'], 315 | self.replay_elements['next_state'], 316 | self.replay_elements['reward'], 317 | self.replay_elements['terminal'], 318 | loss_weights, 319 | self.cumulative_gamma, 320 | self._target_opt, 321 | self._mse_inf, 322 | self._tau, 323 | self._alpha, 324 | self._clip_value_min, 325 | self._rng) 326 | 327 | if self._replay_scheme == 'prioritized': 328 | # Rainbow and prioritized replay are parametrized by an exponent 329 | # alpha, but in both cases it is set to 0.5 - for simplicity's sake we 330 | # leave it as is here, using the more direct sqrt(). Taking the square 331 | # root "makes sense", as we are dealing with a squared loss. Add a 332 | # small nonzero value to the loss to avoid 0 priority items. While 333 | # technically this may be okay, setting all items to 0 priority will 334 | # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. 335 | self._replay.set_priority(self.replay_elements['indices'], 336 | jnp.sqrt(loss + 1e-10)) 337 | 338 | if (self.summary_writer is not None and 339 | self.training_steps > 0 and 340 | self.training_steps % self.summary_writing_frequency == 0): 341 | summary = tf.compat.v1.Summary(value=[ 342 | tf.compat.v1.Summary.Value(tag='HuberLoss', simple_value=mean_loss)]) 343 | self.summary_writer.add_summary(summary, self.training_steps) 344 | if self.training_steps % self.target_update_period == 0: 345 | self._sync_weights() 346 | 347 | self.training_steps += 1 348 | 349 | def _store_transition(self, 350 | last_observation, 351 | action, 352 | reward, 353 | is_terminal, 354 | priority=None): 355 | """Stores a transition when in training mode. 356 | Stores the following tuple in the replay buffer (last_observation, action, 357 | reward, is_terminal, priority). 358 | Args: 359 | last_observation: Last observation, type determined via observation_type 360 | parameter in the replay_memory constructor. 361 | action: An integer, the action taken. 362 | reward: A float, the reward. 363 | is_terminal: Boolean indicating if the current state is a terminal state. 364 | priority: Float. Priority of sampling the transition. If None, the default 365 | priority will be used. If replay scheme is uniform, the default priority 366 | is 1. If the replay scheme is prioritized, the default priority is the 367 | maximum ever seen [Schaul et al., 2015]. 368 | """ 369 | if priority is None: 370 | if self._replay_scheme == 'uniform': 371 | priority = 1. 372 | else: 373 | priority = self._replay.sum_tree.max_recorded_priority 374 | 375 | if not self.eval_mode: 376 | self._replay.add(last_observation, action, reward, is_terminal, priority) 377 | 378 | def begin_episode(self, observation): 379 | """Returns the agent's first action for this episode. 380 | Args: 381 | observation: numpy array, the environment's initial observation. 382 | Returns: 383 | int, the selected action. 384 | """ 385 | self._reset_state() 386 | self._record_observation(observation) 387 | 388 | if not self.eval_mode: 389 | self._train_step() 390 | 391 | self._rng, self.action = select_action(self.network_def, 392 | self.online_params, 393 | self.state, 394 | self._rng, 395 | self.num_actions, 396 | self.eval_mode, 397 | self.epsilon_eval, 398 | self.epsilon_train, 399 | self.epsilon_decay_period, 400 | self.training_steps, 401 | self.min_replay_history, 402 | self.epsilon_fn) 403 | self.action = onp.asarray(self.action) 404 | return self.action 405 | 406 | def step(self, reward, observation): 407 | """Records the most recent transition and returns the agent's next action. 408 | We store the observation of the last time step since we want to store it 409 | with the reward. 410 | Args: 411 | reward: float, the reward received from the agent's most recent action. 412 | observation: numpy array, the most recent observation. 413 | Returns: 414 | int, the selected action. 415 | """ 416 | self._last_observation = self._observation 417 | self._record_observation(observation) 418 | 419 | if not self.eval_mode: 420 | self._store_transition(self._last_observation, self.action, reward, False) 421 | self._train_step() 422 | 423 | self._rng, self.action = select_action(self.network_def, 424 | self.online_params, 425 | self.state, 426 | self._rng, 427 | self.num_actions, 428 | self.eval_mode, 429 | self.epsilon_eval, 430 | self.epsilon_train, 431 | self.epsilon_decay_period, 432 | self.training_steps, 433 | self.min_replay_history, 434 | self.epsilon_fn) 435 | self.action = onp.asarray(self.action) 436 | return self.action -------------------------------------------------------------------------------- /revisiting_rainbow/Agents/implicit_quantile_agent_new.py: -------------------------------------------------------------------------------- 1 | """The implicit quantile networks (IQN) agent. 2 | 3 | The agent follows the description given in "Implicit Quantile Networks for 4 | Distributional RL" (Dabney et. al, 2018). 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import copy 12 | import functools 13 | import time 14 | from dopamine.jax import networks 15 | from dopamine.jax.agents.dqn import dqn_agent 16 | from dopamine.replay_memory import prioritized_replay_buffer 17 | from flax import linen as nn 18 | import gin 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as onp 22 | import tensorflow as tf 23 | import jax.scipy.special as scp 24 | import jax.lax 25 | 26 | 27 | 28 | @functools.partial( 29 | jax.vmap, 30 | in_axes=(None, None, None, 0, 0, 0, None, None, None, None, None), 31 | out_axes=(None, 0)) 32 | def target_quantile_values_fun(network_def, online_params, target_params, 33 | next_states, rewards, terminals, 34 | num_tau_prime_samples, num_quantile_samples, 35 | cumulative_gamma, double_dqn, rng): 36 | 37 | rewards = jnp.tile(rewards, [num_tau_prime_samples]) 38 | is_terminal_multiplier = 1. - terminals.astype(jnp.float32) 39 | # Incorporate terminal state to discount factor. 40 | gamma_with_terminal = cumulative_gamma * is_terminal_multiplier 41 | gamma_with_terminal = jnp.tile(gamma_with_terminal, [num_tau_prime_samples]) 42 | rng, rng1, rng2 = jax.random.split(rng, num=3) 43 | # Compute Q-values which are used for action selection for the next states 44 | # in the replay buffer. Compute the argmax over the Q-values. 45 | if double_dqn: 46 | outputs_action = network_def.apply(online_params, 47 | next_states, 48 | num_quantiles=num_quantile_samples, 49 | rng=rng1) 50 | else: 51 | outputs_action = network_def.apply(target_params, 52 | next_states, 53 | num_quantiles=num_quantile_samples, 54 | rng=rng1) 55 | target_quantile_values_action = outputs_action.quantile_values 56 | target_q_values = jnp.squeeze( 57 | jnp.mean(target_quantile_values_action, axis=0)) 58 | # Shape: batch_size. 59 | next_qt_argmax = jnp.argmax(target_q_values) 60 | # Get the indices of the maximium Q-value across the action dimension. 61 | # Shape of next_qt_argmax: (num_tau_prime_samples x batch_size). 62 | next_state_target_outputs = network_def.apply( 63 | target_params, 64 | next_states, 65 | num_quantiles=num_tau_prime_samples, 66 | rng=rng2) 67 | next_qt_argmax = jnp.tile(next_qt_argmax, [num_tau_prime_samples]) 68 | 69 | target_quantile_vals = ( 70 | jax.vmap(lambda x, y: x[y])(next_state_target_outputs.quantile_values, 71 | next_qt_argmax)) 72 | 73 | target_quantile_vals = rewards + gamma_with_terminal * target_quantile_vals 74 | # We return with an extra dimension, which is expected by train. 75 | 76 | return rng, jax.lax.stop_gradient(target_quantile_vals[:, None]) 77 | 78 | 79 | def stable_scaled_log_softmax(x, tau, axis=-1): 80 | max_x = jnp.amax(x, axis=axis, keepdims=True) 81 | y = x - max_x 82 | tau_lse = max_x + tau * jnp.log(jnp.sum(jnp.exp(y / tau), axis=axis, keepdims=True)) 83 | return x - tau_lse 84 | 85 | def stable_softmax(x, tau, axis=-1): 86 | max_x = jnp.amax(x, axis=axis, keepdims=True) 87 | y = x - max_x 88 | return jax.nn.softmax(y/tau, axis=axis) 89 | 90 | 91 | @functools.partial( 92 | jax.vmap, 93 | in_axes=(None, None, None, 0, 0, 0, 0, 0, None, None, None, None, None,None, None, None,None), 94 | out_axes=(None, 0)) 95 | 96 | def munchau_target_quantile_values_fun(network_def, online_network, target_params, 97 | states,actions,next_states, rewards, terminals, 98 | num_tau_prime_samples, num_quantile_samples, 99 | cumulative_gamma, double_dqn, rng,tau,alpha,clip_value_min,num_actions): 100 | #Build the munchausen target for return values at given quantiles. 101 | del double_dqn 102 | 103 | is_terminal_multiplier = 1. - terminals.astype(jnp.float32) 104 | # Incorporate terminal state to discount factor. 105 | gamma_with_terminal = cumulative_gamma * is_terminal_multiplier 106 | gamma_with_terminal = jnp.tile(gamma_with_terminal, [num_tau_prime_samples]) 107 | 108 | rng, rng1, rng2 = jax.random.split(rng, num=3) 109 | #------------------------------------------------------------------------ 110 | replay_net_target_outputs = network_def.apply(target_params, next_states,num_quantiles=num_tau_prime_samples, rng=rng) 111 | replay_net_target_quantile_values = replay_net_target_outputs.quantile_values 112 | 113 | target_next_action = network_def.apply(target_params, next_states,num_quantiles=num_quantile_samples, rng=rng1) 114 | target_next_quantile_values_action = target_next_action.quantile_values 115 | _replay_next_target_q_values = jnp.squeeze(jnp.mean(target_next_quantile_values_action, axis=0)) 116 | 117 | outputs_action = network_def.apply(target_params, states,num_quantiles=num_quantile_samples, rng=rng2) 118 | q_state_values = outputs_action.quantile_values 119 | _replay_target_q_values = jnp.squeeze(jnp.mean(q_state_values, axis=0)) 120 | #------------------------------------------------------------------------ 121 | 122 | replay_action_one_hot = jax.nn.one_hot(actions,num_actions) 123 | replay_next_log_policy = stable_scaled_log_softmax(_replay_next_target_q_values, tau, axis=0) 124 | replay_next_policy = stable_softmax(_replay_next_target_q_values,tau, axis=0) 125 | replay_log_policy = stable_scaled_log_softmax(_replay_target_q_values, tau, axis=0) 126 | 127 | #------------------------------------------------------------------------ 128 | 129 | tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=0) 130 | tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min,a_max=0) 131 | munchausen_term = alpha * tau_log_pi_a 132 | 133 | rewards = rewards + munchausen_term 134 | rewards = jnp.tile(rewards, [num_tau_prime_samples]) 135 | 136 | weighted_logits = (replay_next_policy * (replay_net_target_quantile_values-replay_next_log_policy)) 137 | 138 | target_quantile_values = jnp.sum(weighted_logits, axis=1) 139 | target_quantile_values = rewards + gamma_with_terminal * target_quantile_values 140 | 141 | return rng, jax.lax.stop_gradient(target_quantile_values[:, None]) 142 | 143 | 144 | @functools.partial(jax.jit, static_argnums=(0, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) 145 | def train(network_def, target_params, optimizer, states, actions, next_states, rewards, 146 | terminals, loss_weights, target_opt, num_tau_samples, num_tau_prime_samples, 147 | num_quantile_samples, cumulative_gamma, double_dqn, kappa, tau,alpha,clip_value_min, num_actions,rng): 148 | """Run a training step.""" 149 | online_params = optimizer.target 150 | def loss_fn(params, rng_input, target_quantile_vals, loss_multipliers): 151 | def online(state): 152 | return network_def.apply(params, state, num_quantiles=num_tau_samples, rng=rng_input) 153 | 154 | model_output = jax.vmap(online)(states) 155 | quantile_values = model_output.quantile_values 156 | quantiles = model_output.quantiles 157 | chosen_action_quantile_values = jax.vmap(lambda x, y: x[:, y][:, None])( 158 | quantile_values, actions) 159 | # Shape of bellman_erors and huber_loss: 160 | # batch_size x num_tau_prime_samples x num_tau_samples x 1. 161 | bellman_errors = (target_quantile_vals[:, :, None, :] - 162 | chosen_action_quantile_values[:, None, :, :]) 163 | # The huber loss (see Section 2.3 of the paper) is defined via two cases: 164 | # case_one: |bellman_errors| <= kappa 165 | # case_two: |bellman_errors| > kappa 166 | huber_loss_case_one = ( 167 | (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 168 | 0.5 * bellman_errors ** 2) 169 | huber_loss_case_two = ( 170 | (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * 171 | kappa * (jnp.abs(bellman_errors) - 0.5 * kappa)) 172 | huber_loss = huber_loss_case_one + huber_loss_case_two 173 | # Tile by num_tau_prime_samples along a new dimension. Shape is now 174 | # batch_size x num_tau_prime_samples x num_tau_samples x 1. 175 | # These quantiles will be used for computation of the quantile huber loss 176 | # below (see section 2.3 of the paper). 177 | quantiles = jnp.tile(quantiles[:, None, :, :], 178 | [1, num_tau_prime_samples, 1, 1]).astype(jnp.float32) 179 | # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. 180 | quantile_huber_loss = (jnp.abs(quantiles - jax.lax.stop_gradient( 181 | (bellman_errors < 0).astype(jnp.float32))) * huber_loss) / kappa 182 | # Sum over current quantile value (num_tau_samples) dimension, 183 | # average over target quantile value (num_tau_prime_samples) dimension. 184 | # Shape: batch_size x num_tau_prime_samples x 1. 185 | loss = jnp.sum(quantile_huber_loss, axis=2) 186 | loss = jnp.squeeze(jnp.mean(loss, axis=1), axis=-1) 187 | 188 | mean_loss = jnp.mean(loss_multipliers * loss) 189 | 190 | return mean_loss, loss 191 | 192 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 193 | 194 | if target_opt == 0: 195 | rng, target_quantile_vals = target_quantile_values_fun( 196 | network_def, 197 | online_params, 198 | target_params, 199 | next_states, 200 | rewards, 201 | terminals, 202 | num_tau_prime_samples, 203 | num_quantile_samples, 204 | cumulative_gamma, 205 | double_dqn, 206 | rng) 207 | 208 | elif target_opt == 1: 209 | rng, target_quantile_vals = munchau_target_quantile_values_fun( 210 | network_def, 211 | online_params, 212 | target_params, 213 | states, 214 | actions, 215 | next_states, 216 | rewards, 217 | terminals, 218 | num_tau_prime_samples, 219 | num_quantile_samples, 220 | cumulative_gamma, 221 | double_dqn, 222 | rng, 223 | tau, 224 | alpha, 225 | clip_value_min, 226 | num_actions 227 | ) 228 | 229 | else: 230 | print('error') 231 | 232 | rng, rng_input = jax.random.split(rng) 233 | (mean_loss, loss), grad = grad_fn(online_params, rng_input, target_quantile_vals, loss_weights) 234 | optimizer = optimizer.apply_gradient(grad) 235 | return rng, optimizer, loss, mean_loss 236 | 237 | 238 | @functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 9, 11, 12, 13)) 239 | def select_action(network_def, params, state, rng, num_quantile_samples, num_actions, 240 | eval_mode, epsilon_eval, epsilon_train, epsilon_decay_period, 241 | training_steps, min_replay_history, epsilon_fn, tau, model): 242 | 243 | epsilon = jnp.where(eval_mode, 244 | epsilon_eval, 245 | epsilon_fn(epsilon_decay_period, 246 | training_steps, 247 | min_replay_history, 248 | epsilon_train)) 249 | 250 | rng, rng1, rng2 = jax.random.split(rng, num=3) 251 | 252 | selected_action = jnp.argmax(jnp.mean( 253 | network_def.apply(params, state, 254 | num_quantiles=num_quantile_samples, 255 | rng=rng2).quantile_values, axis=0), 256 | axis=0) 257 | 258 | p = jax.random.uniform(rng1) 259 | return rng, jnp.where(p <= epsilon, 260 | jax.random.randint(rng2, (), 0, num_actions), 261 | selected_action) 262 | 263 | 264 | @gin.configurable 265 | class JaxImplicitQuantileAgentNew(dqn_agent.JaxDQNAgent): 266 | """An extension of Rainbow to perform implicit quantile regression.""" 267 | 268 | def __init__(self, 269 | num_actions, 270 | 271 | tau, 272 | alpha=1, 273 | clip_value_min=-10, 274 | target_opt=0, 275 | 276 | net_conf = None, 277 | env = "CartPole", 278 | hidden_layer=2, 279 | neurons=512, 280 | noisy = False, 281 | dueling = False, 282 | initzer = 'variance_scaling', 283 | 284 | observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE, 285 | observation_dtype=dqn_agent.NATURE_DQN_DTYPE, 286 | stack_size=dqn_agent.NATURE_DQN_STACK_SIZE, 287 | network=networks.ImplicitQuantileNetwork, 288 | kappa=1.0, 289 | num_tau_samples=32, 290 | num_tau_prime_samples=32, 291 | num_quantile_samples=32, 292 | quantile_embedding_dim=64, 293 | double_dqn=False, 294 | gamma=0.99, 295 | update_horizon=1, 296 | min_replay_history=20000, 297 | update_period=4, 298 | target_update_period=8000, 299 | epsilon_fn=dqn_agent.linearly_decaying_epsilon, 300 | epsilon_train=0.01, 301 | epsilon_eval=0.001, 302 | epsilon_decay_period=250000, 303 | replay_scheme='prioritized', 304 | optimizer='adam', 305 | summary_writer=None, 306 | summary_writing_frequency=500, 307 | seed=None): 308 | """Initializes the agent and constructs the necessary components. 309 | 310 | Most of this constructor's parameters are IQN-specific hyperparameters whose 311 | values are taken from Dabney et al. (2018). 312 | 313 | Args: 314 | num_actions: int, number of actions the agent can take at any state. 315 | observation_shape: tuple of ints or an int. If single int, the observation 316 | is assumed to be a 2D square. 317 | observation_dtype: DType, specifies the type of the observations. Note 318 | that if your inputs are continuous, you should set this to jnp.float32. 319 | stack_size: int, number of frames to use in state stack. 320 | network: flax.nn Module that is initialized by shape in _create_network 321 | below. See dopamine.jax.networks.JaxImplicitQuantileNetwork as an 322 | example. 323 | kappa: float, Huber loss cutoff. 324 | num_tau_samples: int, number of online quantile samples for loss 325 | estimation. 326 | num_tau_prime_samples: int, number of target quantile samples for loss 327 | estimation. 328 | num_quantile_samples: int, number of quantile samples for computing 329 | Q-values. 330 | quantile_embedding_dim: int, embedding dimension for the quantile input. 331 | double_dqn: boolean, whether to perform double DQN style learning 332 | as described in Van Hasselt et al.: https://arxiv.org/abs/1509.06461. 333 | gamma: float, discount factor with the usual RL meaning. 334 | update_horizon: int, horizon at which updates are performed, the 'n' in 335 | n-step update. 336 | min_replay_history: int, number of transitions that should be experienced 337 | before the agent begins training its value function. 338 | update_period: int, period between DQN updates. 339 | target_update_period: int, update period for the target network. 340 | epsilon_fn: function expecting 4 parameters: 341 | (decay_period, step, warmup_steps, epsilon). This function should return 342 | the epsilon value used for exploration during training. 343 | epsilon_train: float, the value to which the agent's epsilon is eventually 344 | decayed during training. 345 | epsilon_eval: float, epsilon used when evaluating the agent. 346 | epsilon_decay_period: int, length of the epsilon decay schedule. 347 | replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the 348 | replay memory. 349 | optimizer: str, name of optimizer to use. 350 | summary_writer: SummaryWriter object for outputting training statistics. 351 | Summary writing disabled if set to None. 352 | summary_writing_frequency: int, frequency with which summaries will be 353 | written. Lower values will result in slower training. 354 | """ 355 | 356 | seed = int(time.time() * 1e6) if seed is None else seed 357 | self._net_conf = net_conf 358 | self._env = env 359 | self._hidden_layer = hidden_layer 360 | self._neurons=neurons 361 | self._noisy = noisy 362 | self._dueling = dueling 363 | self._initzer = initzer 364 | 365 | self._tau = tau 366 | self._alpha = alpha 367 | self._clip_value_min = clip_value_min 368 | self._target_opt = target_opt 369 | self._rng = jax.random.PRNGKey(seed) 370 | 371 | self.kappa = kappa 372 | self._replay_scheme = replay_scheme 373 | 374 | # num_tau_samples = N below equation (3) in the paper. 375 | self.num_tau_samples = num_tau_samples 376 | # num_tau_prime_samples = N' below equation (3) in the paper. 377 | self.num_tau_prime_samples = num_tau_prime_samples 378 | # num_quantile_samples = k below equation (3) in the paper. 379 | self.num_quantile_samples = num_quantile_samples 380 | # quantile_embedding_dim = n above equation (4) in the paper. 381 | self.quantile_embedding_dim = quantile_embedding_dim 382 | # option to perform double dqn. 383 | self.double_dqn = double_dqn 384 | 385 | 386 | super(JaxImplicitQuantileAgentNew, self).__init__( 387 | num_actions=num_actions, 388 | observation_shape=observation_shape, 389 | observation_dtype=observation_dtype, 390 | stack_size=stack_size, 391 | network=functools.partial(network, 392 | num_actions=num_actions, 393 | net_conf=self._net_conf, 394 | env=self._env, 395 | hidden_layer=self._hidden_layer, 396 | neurons=self._neurons, 397 | noisy=self._noisy, 398 | dueling=self._dueling, 399 | initzer=self._initzer, 400 | quantile_embedding_dim=quantile_embedding_dim), 401 | gamma=gamma, 402 | update_horizon=update_horizon, 403 | min_replay_history=min_replay_history, 404 | update_period=update_period, 405 | target_update_period=target_update_period, 406 | epsilon_fn=epsilon_fn, 407 | epsilon_train=epsilon_train, 408 | epsilon_eval=epsilon_eval, 409 | epsilon_decay_period=epsilon_decay_period, 410 | optimizer=optimizer, 411 | summary_writer=summary_writer, 412 | summary_writing_frequency=summary_writing_frequency) 413 | 414 | self._num_actions=num_actions 415 | self._replay = self._build_replay_buffer() 416 | 417 | def _build_networks_and_optimizer(self): 418 | self._rng, rng = jax.random.split(self._rng) 419 | online_network_params = self.network_def.init( 420 | rng, x=self.state, num_quantiles=self.num_tau_samples, rng=self._rng) 421 | optimizer_def = dqn_agent.create_optimizer(self._optimizer_name) 422 | self.optimizer = optimizer_def.create(online_network_params) 423 | self.target_network_params = copy.deepcopy(online_network_params) 424 | 425 | def begin_episode(self, observation): 426 | """Returns the agent's first action for this episode. 427 | 428 | Args: 429 | observation: numpy array, the environment's initial observation. 430 | 431 | Returns: 432 | int, the selected action. 433 | """ 434 | self._reset_state() 435 | self._record_observation(observation) 436 | 437 | if not self.eval_mode: 438 | self._train_step() 439 | 440 | self._rng, self.action = select_action(self.network_def, 441 | self.online_params, 442 | self.state, 443 | self._rng, 444 | self.num_quantile_samples, 445 | self.num_actions, 446 | self.eval_mode, 447 | self.epsilon_eval, 448 | self.epsilon_train, 449 | self.epsilon_decay_period, 450 | self.training_steps, 451 | self.min_replay_history, 452 | self.epsilon_fn, 453 | self._tau, 454 | self.optimizer) 455 | self.action = onp.asarray(self.action) 456 | return self.action 457 | 458 | def step(self, reward, observation): 459 | """Records the most recent transition and returns the agent's next action. 460 | 461 | We store the observation of the last time step since we want to store it 462 | with the reward. 463 | 464 | Args: 465 | reward: float, the reward received from the agent's most recent action. 466 | observation: numpy array, the most recent observation. 467 | 468 | Returns: 469 | int, the selected action. 470 | """ 471 | self._last_observation = self._observation 472 | self._record_observation(observation) 473 | 474 | if not self.eval_mode: 475 | self._store_transition(self._last_observation, self.action, reward, False) 476 | self._train_step() 477 | 478 | self._rng, self.action = select_action(self.network_def, 479 | self.online_params, 480 | self.state, 481 | self._rng, 482 | self.num_quantile_samples, 483 | self.num_actions, 484 | self.eval_mode, 485 | self.epsilon_eval, 486 | self.epsilon_train, 487 | self.epsilon_decay_period, 488 | self.training_steps, 489 | self.min_replay_history, 490 | self.epsilon_fn, 491 | self._tau, 492 | self.optimizer) 493 | self.action = onp.asarray(self.action) 494 | return self.action 495 | 496 | def _build_replay_buffer(self): 497 | """Creates the replay buffer used by the agent.""" 498 | if self._replay_scheme not in ['uniform', 'prioritized']: 499 | raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme)) 500 | # Both replay schemes use the same data structure, but the 'uniform' scheme 501 | # sets all priorities to the same value (which yields uniform sampling). 502 | return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( 503 | observation_shape=self.observation_shape, 504 | stack_size=self.stack_size, 505 | update_horizon=self.update_horizon, 506 | gamma=self.gamma, 507 | observation_dtype=self.observation_dtype) 508 | 509 | def _train_step(self): 510 | """Runs a single training step. 511 | 512 | Runs training if both: 513 | (1) A minimum number of frames have been added to the replay buffer. 514 | (2) `training_steps` is a multiple of `update_period`. 515 | 516 | Also, syncs weights from online_network to target_network if training steps 517 | is a multiple of target update period. 518 | """ 519 | if self._replay.add_count > self.min_replay_history: 520 | if self.training_steps % self.update_period == 0: 521 | self._sample_from_replay_buffer() 522 | 523 | if self._replay_scheme == 'prioritized': 524 | # The original prioritized experience replay uses a linear exponent 525 | # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 526 | # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) 527 | # suggested a fixed exponent actually performs better, except on Pong. 528 | probs = self.replay_elements['sampling_probabilities'] 529 | loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) 530 | loss_weights /= jnp.max(loss_weights) 531 | else: 532 | loss_weights = jnp.ones(self.replay_elements['state'].shape[0]) 533 | 534 | self._rng, self.optimizer, loss, mean_loss= train( 535 | self.network_def, 536 | self.target_network_params, 537 | self.optimizer, 538 | self.replay_elements['state'], 539 | self.replay_elements['action'], 540 | self.replay_elements['next_state'], 541 | self.replay_elements['reward'], 542 | self.replay_elements['terminal'], 543 | loss_weights, 544 | self._target_opt, 545 | self.num_tau_samples, 546 | self.num_tau_prime_samples, 547 | self.num_quantile_samples, 548 | self.cumulative_gamma, 549 | self.double_dqn, 550 | self.kappa, 551 | self._tau, 552 | self._alpha, 553 | self._clip_value_min, 554 | self._num_actions, 555 | self._rng) 556 | 557 | if self._replay_scheme == 'prioritized': 558 | # Rainbow and prioritized replay are parametrized by an exponent 559 | # alpha, but in both cases it is set to 0.5 - for simplicity's sake we 560 | # leave it as is here, using the more direct sqrt(). Taking the square 561 | # root "makes sense", as we are dealing with a squared loss. Add a 562 | # small nonzero value to the loss to avoid 0 priority items. While 563 | # technically this may be okay, setting all items to 0 priority will 564 | # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. 565 | self._replay.set_priority(self.replay_elements['indices'], 566 | jnp.sqrt(loss + 1e-10)) 567 | 568 | 569 | if (self.summary_writer is not None and 570 | self.training_steps > 0 and 571 | self.training_steps % self.summary_writing_frequency == 0): 572 | summary = tf.compat.v1.Summary(value=[ 573 | tf.compat.v1.Summary.Value(tag='ImplicitLoss', 574 | simple_value=mean_loss)]) 575 | self.summary_writer.add_summary(summary, self.training_steps) 576 | if self.training_steps % self.target_update_period == 0: 577 | self._sync_weights() 578 | 579 | self.training_steps += 1 580 | 581 | 582 | 583 | def _store_transition(self, 584 | last_observation, 585 | action, 586 | reward, 587 | is_terminal, 588 | priority=None): 589 | """Stores a transition when in training mode. 590 | Stores the following tuple in the replay buffer (last_observation, action, 591 | reward, is_terminal, priority). 592 | Args: 593 | last_observation: Last observation, type determined via observation_type 594 | parameter in the replay_memory constructor. 595 | action: An integer, the action taken. 596 | reward: A float, the reward. 597 | is_terminal: Boolean indicating if the current state is a terminal state. 598 | priority: Float. Priority of sampling the transition. If None, the default 599 | priority will be used. If replay scheme is uniform, the default priority 600 | is 1. If the replay scheme is prioritized, the default priority is the 601 | maximum ever seen [Schaul et al., 2015]. 602 | """ 603 | if priority is None: 604 | if self._replay_scheme == 'uniform': 605 | priority = 1. 606 | else: 607 | priority = self._replay.sum_tree.max_recorded_priority 608 | 609 | if not self.eval_mode: 610 | self._replay.add(last_observation, action, reward, is_terminal, priority) -------------------------------------------------------------------------------- /revisiting_rainbow/Agents/quantile_agent_new.py: -------------------------------------------------------------------------------- 1 | """An extension of Rainbow to perform quantile regression. 2 | 3 | This loss is computed as in "Distributional Reinforcement Learning with Quantile 4 | Regression" - Dabney et. al, 2017" 5 | 6 | Specifically, we implement the following components: 7 | 8 | * n-step updates 9 | * prioritized replay 10 | * double_dqn 11 | * noisy 12 | * dueling 13 | 14 | """ 15 | import copy 16 | import time 17 | import functools 18 | from dopamine.jax import networks 19 | from dopamine.jax.agents.dqn import dqn_agent 20 | from dopamine.replay_memory import prioritized_replay_buffer #check 21 | import gin 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as onp 25 | import tensorflow as tf 26 | 27 | 28 | @functools.partial(jax.vmap, in_axes=(None, None, 0, 0, 0, None)) 29 | def target_distributionDouble(model,target_network, next_states, rewards, terminals, 30 | cumulative_gamma): 31 | """Builds the Quantile target distribution as per Dabney et al. (2017). 32 | 33 | Args: 34 | target_network: Jax Module used for the target network. 35 | next_states: numpy array of batched next states. 36 | rewards: numpy array of batched rewards. 37 | terminals: numpy array of batched terminals. 38 | cumulative_gamma: float, cumulative gamma to use (static_argnum). 39 | 40 | Returns: 41 | The target distribution from the replay. 42 | """ 43 | is_terminal_multiplier = 1. - terminals.astype(jnp.float32) 44 | # Incorporate terminal state to discount factor. 45 | gamma_with_terminal = cumulative_gamma * is_terminal_multiplier 46 | 47 | next_state_target_outputs = model(next_states) 48 | q_values = jnp.squeeze(next_state_target_outputs.q_values) 49 | next_qt_argmax = jnp.argmax(q_values) 50 | 51 | next_dist = target_network(next_states) 52 | logits = jnp.squeeze(next_dist.logits) 53 | next_logits = logits[next_qt_argmax] 54 | return jax.lax.stop_gradient(rewards + gamma_with_terminal * next_logits) 55 | 56 | 57 | @functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, None)) 58 | def target_distribution(target_network, next_states, rewards, terminals, 59 | cumulative_gamma): 60 | """Builds the Quantile target distribution as per Dabney et al. (2017). 61 | 62 | Args: 63 | target_network: Jax Module used for the target network. 64 | next_states: numpy array of batched next states. 65 | rewards: numpy array of batched rewards. 66 | terminals: numpy array of batched terminals. 67 | cumulative_gamma: float, cumulative gamma to use (static_argnum). 68 | 69 | Returns: 70 | The target distribution from the replay. 71 | """ 72 | is_terminal_multiplier = 1. - terminals.astype(jnp.float32) 73 | # Incorporate terminal state to discount factor. 74 | gamma_with_terminal = cumulative_gamma * is_terminal_multiplier 75 | next_state_target_outputs = target_network(next_states) 76 | q_values = jnp.squeeze(next_state_target_outputs.q_values) 77 | next_qt_argmax = jnp.argmax(q_values) 78 | logits = jnp.squeeze(next_state_target_outputs.logits) 79 | next_logits = logits[next_qt_argmax] 80 | return jax.lax.stop_gradient(rewards + gamma_with_terminal * next_logits) 81 | 82 | 83 | @functools.partial(jax.jit, static_argnums=(0, 9, 10, 11, 12)) 84 | def train(network_def, target_params, optimizer, states, actions, next_states, rewards, 85 | terminals, loss_weights, kappa, num_atoms, cumulative_gamma, double_dqn, rng): 86 | """Run a training step.""" 87 | 88 | online_params = optimizer.target 89 | def loss_fn(params,rng_input, target, loss_multipliers): 90 | def q_online(state): 91 | return network_def.apply(params, state, rng=rng_input) 92 | 93 | logits = jax.vmap(q_online)(states).logits 94 | logits = jnp.squeeze(logits) 95 | # Fetch the logits for its selected action. We use vmap to perform this 96 | # indexing across the batch. 97 | chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions) 98 | bellman_errors = (target[:, None, :] - 99 | chosen_action_logits[:, :, None]) # Input `u' of Eq. 9. 100 | # Eq. 9 of paper. 101 | huber_loss = ( 102 | (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 103 | 0.5 * bellman_errors ** 2 + 104 | (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * 105 | kappa * (jnp.abs(bellman_errors) - 0.5 * kappa)) 106 | 107 | tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) / 108 | num_atoms) # Quantile midpoints. See Lemma 2 of paper. 109 | # Eq. 10 of paper. 110 | tau_bellman_diff = jnp.abs( 111 | tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32)) 112 | quantile_huber_loss = tau_bellman_diff * huber_loss 113 | # Sum over tau dimension, average over target value dimension. 114 | loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1) 115 | 116 | mean_loss = jnp.mean(loss_multipliers * loss) 117 | return mean_loss, loss 118 | 119 | rng, rng2, rng3, rng4 = jax.random.split(rng, 4) 120 | def q_target(state): 121 | return network_def.apply(target_params, state, rng=rng2) 122 | 123 | def q_target_online(state): 124 | return network_def.apply(online_params, state, rng=rng4) 125 | 126 | if double_dqn: 127 | target = target_distributionDouble(q_target_online, q_target, next_states, rewards, terminals, cumulative_gamma) 128 | else: 129 | target = target_distribution(q_target, next_states, rewards, terminals, cumulative_gamma) 130 | 131 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 132 | (mean_loss, loss), grad = grad_fn(online_params, rng3, target, loss_weights) 133 | optimizer = optimizer.apply_gradient(grad) 134 | return optimizer, loss, mean_loss 135 | 136 | @functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11)) 137 | def select_action(network_def, params, state, rng, num_actions, eval_mode, 138 | epsilon_eval, epsilon_train, epsilon_decay_period, 139 | training_steps, min_replay_history, epsilon_fn): 140 | 141 | epsilon = jnp.where(eval_mode, 142 | epsilon_eval, 143 | epsilon_fn(epsilon_decay_period, 144 | training_steps, 145 | min_replay_history, 146 | epsilon_train)) 147 | 148 | rng, rng1, rng2, rng3 = jax.random.split(rng, num=4) 149 | selected_action = jnp.argmax(network_def.apply(params, state, rng=rng3).q_values) 150 | p = jax.random.uniform(rng1) 151 | return rng, jnp.where(p <= epsilon, 152 | jax.random.randint(rng2, (), 0, num_actions), 153 | selected_action) 154 | 155 | @gin.configurable 156 | class JaxQuantileAgentNew(dqn_agent.JaxDQNAgent): 157 | """An implementation of Quantile regression DQN agent.""" 158 | 159 | def __init__(self, 160 | num_actions, 161 | 162 | kappa=1.0, 163 | num_atoms=200, 164 | noisy = False, 165 | dueling = False, 166 | initzer = 'variance_scaling', 167 | net_conf = None, 168 | env = "CartPole", 169 | normalize_obs = True, 170 | hidden_layer=2, 171 | neurons=512, 172 | double_dqn=False, 173 | replay_scheme='prioritized', 174 | optimizer='adam', 175 | network=networks.QuantileNetwork, 176 | epsilon_fn=dqn_agent.linearly_decaying_epsilon, 177 | seed=None): 178 | """Initializes the agent and constructs the Graph. 179 | 180 | Args: 181 | num_actions: Int, number of actions the agent can take at any state. 182 | observation_shape: tuple of ints or an int. If single int, the observation 183 | is assumed to be a 2D square. 184 | observation_dtype: DType, specifies the type of the observations. Note 185 | that if your inputs are continuous, you should set this to jnp.float32. 186 | stack_size: int, number of frames to use in state stack. 187 | network: tf.Keras.Model, expects 3 parameters: num_actions, num_atoms, 188 | network_type. A call to this object will return an instantiation of the 189 | network provided. The network returned can be run with different inputs 190 | to create different outputs. See 191 | dopamine.discrete_domains.jax.networks.QuantileNetwork as an example. 192 | kappa: Float, Huber loss cutoff. 193 | num_atoms: Int, the number of buckets for the value function distribution. 194 | gamma: Float, exponential decay factor as commonly used in the RL 195 | literature. 196 | update_horizon: Int, horizon at which updates are performed, the 'n' in 197 | n-step update. 198 | min_replay_history: Int, number of stored transitions for training to 199 | start. 200 | update_period: Int, period between DQN updates. 201 | target_update_period: Int, ppdate period for the target network. 202 | epsilon_fn: Function expecting 4 parameters: (decay_period, step, 203 | warmup_steps, epsilon), and which returns the epsilon value used for 204 | exploration during training. 205 | epsilon_train: Float, final epsilon for training. 206 | epsilon_eval: Float, epsilon during evaluation. 207 | epsilon_decay_period: Int, number of steps for epsilon to decay. 208 | replay_scheme: String, replay memory scheme to be used. Choices are: 209 | uniform - Standard (DQN) replay buffer (Mnih et al., 2015) 210 | prioritized - Prioritized replay buffer (Schaul et al., 2015) 211 | optimizer: str, name of optimizer to use. 212 | summary_writer: SummaryWriter object for outputting training statistics. 213 | Summary writing disabled if set to None. 214 | summary_writing_frequency: int, frequency with which summaries will be 215 | written. Lower values will result in slower training. 216 | allow_partial_reload: bool, whether we allow reloading a partial agent 217 | (for instance, only the network parameters). 218 | """ 219 | seed = int(time.time() * 1e6) if seed is None else seed 220 | self._num_atoms = num_atoms 221 | self._kappa = kappa 222 | self._replay_scheme = replay_scheme 223 | self._double_dqn = double_dqn 224 | self._net_conf = net_conf 225 | self._env = env 226 | self._normalize_obs = normalize_obs 227 | self._hidden_layer= hidden_layer 228 | self._neurons=neurons 229 | self._noisy = noisy 230 | self._dueling = dueling 231 | self._initzer = initzer 232 | self._rng = jax.random.PRNGKey(seed) 233 | 234 | super(JaxQuantileAgentNew, self).__init__( 235 | num_actions=num_actions, 236 | optimizer=optimizer, 237 | epsilon_fn = dqn_agent.identity_epsilon if self._noisy == True else epsilon_fn, 238 | network=functools.partial(network, num_atoms=self._num_atoms , net_conf=self._net_conf, 239 | env=self._env, 240 | normalize_obs=self._normalize_obs, 241 | hidden_layer=self._hidden_layer, 242 | neurons=self._neurons, 243 | noisy=self._noisy, 244 | dueling=self._dueling, 245 | initzer=self._initzer)) 246 | 247 | def _build_networks_and_optimizer(self): 248 | self._rng, rng = jax.random.split(self._rng) 249 | online_network_params = self.network_def.init(rng, x=self.state, rng=self._rng) 250 | optimizer_def = dqn_agent.create_optimizer(self._optimizer_name) 251 | self.optimizer = optimizer_def.create(online_network_params) 252 | self.target_network_params = copy.deepcopy(online_network_params) 253 | 254 | 255 | def _build_replay_buffer(self): 256 | """Creates the replay buffer used by the agent.""" 257 | if self._replay_scheme not in ['uniform', 'prioritized']: 258 | raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme)) 259 | # Both replay schemes use the same data structure, but the 'uniform' scheme 260 | # sets all priorities to the same value (which yields uniform sampling). 261 | return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( 262 | observation_shape=self.observation_shape, 263 | stack_size=self.stack_size, 264 | update_horizon=self.update_horizon, 265 | gamma=self.gamma, 266 | observation_dtype=self.observation_dtype) 267 | 268 | def begin_episode(self, observation): 269 | self._reset_state() 270 | self._record_observation(observation) 271 | 272 | if not self.eval_mode: 273 | self._train_step() 274 | 275 | self._rng, self.action = select_action(self.network_def, 276 | self.online_params, 277 | self.state, 278 | self._rng, 279 | self.num_actions, 280 | self.eval_mode, 281 | self.epsilon_eval, 282 | self.epsilon_train, 283 | self.epsilon_decay_period, 284 | self.training_steps, 285 | self.min_replay_history, 286 | self.epsilon_fn) 287 | self.action = onp.asarray(self.action) 288 | return self.action 289 | 290 | def step(self, reward, observation): 291 | self._last_observation = self._observation 292 | self._record_observation(observation) 293 | 294 | if not self.eval_mode: 295 | self._store_transition(self._last_observation, self.action, reward, False) 296 | self._train_step() 297 | 298 | self._rng, self.action = select_action(self.network_def, 299 | self.online_params, 300 | self.state, 301 | self._rng, 302 | self.num_actions, 303 | self.eval_mode, 304 | self.epsilon_eval, 305 | self.epsilon_train, 306 | self.epsilon_decay_period, 307 | self.training_steps, 308 | self.min_replay_history, 309 | self.epsilon_fn) 310 | self.action = onp.asarray(self.action) 311 | return self.action 312 | 313 | 314 | def _train_step(self): 315 | """Runs a single training step. 316 | 317 | Runs training if both: 318 | (1) A minimum number of frames have been added to the replay buffer. 319 | (2) `training_steps` is a multiple of `update_period`. 320 | 321 | Also, syncs weights from online_network to target_network if training steps 322 | is a multiple of target update period. 323 | """ 324 | if self._replay.add_count > self.min_replay_history: 325 | if self.training_steps % self.update_period == 0: 326 | self._sample_from_replay_buffer() 327 | 328 | if self._replay_scheme == 'prioritized': 329 | # The original prioritized experience replay uses a linear exponent 330 | # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 331 | # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) 332 | # suggested a fixed exponent actually performs better, except on Pong. 333 | probs = self.replay_elements['sampling_probabilities'] 334 | loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) 335 | loss_weights /= jnp.max(loss_weights) 336 | else: 337 | loss_weights = jnp.ones(self.replay_elements['state'].shape[0]) 338 | 339 | self.optimizer, loss, mean_loss = train( 340 | self.network_def, 341 | self.target_network_params, 342 | self.optimizer, 343 | self.replay_elements['state'], 344 | self.replay_elements['action'], 345 | self.replay_elements['next_state'], 346 | self.replay_elements['reward'], 347 | self.replay_elements['terminal'], 348 | loss_weights, 349 | self._kappa, 350 | self._num_atoms, 351 | self.cumulative_gamma, 352 | self._double_dqn, 353 | self._rng) 354 | 355 | if self._replay_scheme == 'prioritized': 356 | # Rainbow and prioritized replay are parametrized by an exponent 357 | # alpha, but in both cases it is set to 0.5 - for simplicity's sake we 358 | # leave it as is here, using the more direct sqrt(). Taking the square 359 | # root "makes sense", as we are dealing with a squared loss. Add a 360 | # small nonzero value to the loss to avoid 0 priority items. While 361 | # technically this may be okay, setting all items to 0 priority will 362 | # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. 363 | self._replay.set_priority(self.replay_elements['indices'], 364 | jnp.sqrt(loss + 1e-10)) 365 | 366 | if self.summary_writer is not None: 367 | summary = tf.compat.v1.Summary(value=[ 368 | tf.compat.v1.Summary.Value(tag='QuantileLoss', 369 | simple_value=mean_loss)]) 370 | self.summary_writer.add_summary(summary, self.training_steps) 371 | if self.training_steps % self.target_update_period == 0: 372 | self._sync_weights() 373 | 374 | self.training_steps += 1 375 | 376 | def _store_transition(self, 377 | last_observation, 378 | action, 379 | reward, 380 | is_terminal, 381 | priority=None): 382 | """Stores a transition when in training mode. 383 | 384 | Stores the following tuple in the replay buffer (last_observation, action, 385 | reward, is_terminal, priority). 386 | 387 | Args: 388 | last_observation: Last observation, type determined via observation_type 389 | parameter in the replay_memory constructor. 390 | action: An integer, the action taken. 391 | reward: A float, the reward. 392 | is_terminal: Boolean indicating if the current state is a terminal state. 393 | priority: Float. Priority of sampling the transition. If None, the default 394 | priority will be used. If replay scheme is uniform, the default priority 395 | is 1. If the replay scheme is prioritized, the default priority is the 396 | maximum ever seen [Schaul et al., 2015]. 397 | """ 398 | if priority is None: 399 | if self._replay_scheme == 'uniform': 400 | priority = 1. 401 | else: 402 | priority = self._replay.sum_tree.max_recorded_priority 403 | 404 | if not self.eval_mode: 405 | self._replay.add(last_observation, action, reward, is_terminal, priority) -------------------------------------------------------------------------------- /revisiting_rainbow/Agents/rainbow_agent_new.py: -------------------------------------------------------------------------------- 1 | """Compact implementation of a simplified Rainbow agent in Jax. 2 | 3 | Specifically, we implement the following components from Rainbow: 4 | 5 | * n-step updates 6 | * prioritized replay 7 | * distributional RL 8 | * double_dqn 9 | * noisy 10 | * dueling 11 | 12 | Details in "Rainbow: Combining Improvements in Deep Reinforcement Learning" by 13 | Hessel et al. (2018). 14 | """ 15 | import copy 16 | import time 17 | import functools 18 | from dopamine.jax import networks 19 | from dopamine.jax.agents.dqn import dqn_agent 20 | from dopamine.replay_memory import prioritized_replay_buffer 21 | from flax import nn 22 | import gin 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as onp 26 | import tensorflow as tf 27 | 28 | 29 | @functools.partial(jax.jit, static_argnums=(0, 10, 11)) 30 | def train(network_def, target_params, optimizer, states, actions, next_states, rewards, 31 | terminals, loss_weights, support, cumulative_gamma, double_dqn, rng): 32 | """Run a training step.""" 33 | online_params = optimizer.target 34 | def loss_fn(params, rng_input, target, loss_multipliers): 35 | def q_online(state): 36 | return network_def.apply(params, state, support, rng=rng_input) 37 | 38 | logits = jax.vmap(q_online)(states).logits 39 | # Fetch the logits for its selected action. We use vmap to perform this 40 | # indexing across the batch. 41 | chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions) 42 | loss = jax.vmap(networks.softmax_cross_entropy_loss_with_logits)(target, chosen_action_logits) 43 | 44 | mean_loss = jnp.mean(loss_multipliers * loss) 45 | return mean_loss, loss 46 | 47 | rng, rng2, rng3, rng4 = jax.random.split(rng, 4) 48 | # Use the weighted mean loss for gradient computation 49 | def q_target(state): 50 | return network_def.apply(target_params, state, support, rng=rng2) 51 | 52 | def q_target_online(state): 53 | return network_def.apply(online_params, state, support, rng=rng4) 54 | 55 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 56 | 57 | if double_dqn: 58 | target = target_distributionDouble(q_target_online, q_target, next_states, rewards, terminals, support, cumulative_gamma) 59 | else: 60 | target = target_distribution(q_target, next_states, rewards, terminals, support, cumulative_gamma) 61 | 62 | # Get the unweighted loss without taking its mean for updating priorities. 63 | (mean_loss, loss), grad = grad_fn(online_params, rng3, target, loss_weights) 64 | optimizer = optimizer.apply_gradient(grad) 65 | return optimizer, loss, mean_loss 66 | 67 | 68 | @functools.partial(jax.vmap, in_axes=(None, None, 0, 0, 0, None, None)) 69 | def target_distributionDouble(model, target_network, next_states, rewards, terminals, 70 | support, cumulative_gamma): 71 | is_terminal_multiplier = 1. - terminals.astype(jnp.float32) 72 | # Incorporate terminal state to discount factor. 73 | gamma_with_terminal = cumulative_gamma * is_terminal_multiplier 74 | target_support = rewards + gamma_with_terminal * support 75 | 76 | next_state_target_outputs = model(next_states) 77 | q_values = jnp.squeeze(next_state_target_outputs.q_values) 78 | next_qt_argmax = jnp.argmax(q_values) 79 | 80 | next_dist = target_network(next_states) 81 | probabilities = jnp.squeeze(next_dist.probabilities) 82 | next_probabilities = probabilities[next_qt_argmax] 83 | 84 | return jax.lax.stop_gradient(project_distribution(target_support, next_probabilities, support)) 85 | 86 | @functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, None, None)) 87 | def target_distribution(target_network, next_states, rewards, terminals, 88 | support, cumulative_gamma): 89 | is_terminal_multiplier = 1. - terminals.astype(jnp.float32) 90 | # Incorporate terminal state to discount factor. 91 | gamma_with_terminal = cumulative_gamma * is_terminal_multiplier 92 | target_support = rewards + gamma_with_terminal * support 93 | 94 | next_state_target_outputs = target_network(next_states) 95 | q_values = jnp.squeeze(next_state_target_outputs.q_values) 96 | next_qt_argmax = jnp.argmax(q_values) 97 | 98 | probabilities = jnp.squeeze(next_state_target_outputs.probabilities) 99 | next_probabilities = probabilities[next_qt_argmax] 100 | 101 | return jax.lax.stop_gradient(project_distribution(target_support, next_probabilities, support)) 102 | 103 | @functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11)) 104 | def select_action(network_def, params, state, rng, num_actions, eval_mode, 105 | epsilon_eval, epsilon_train, epsilon_decay_period, 106 | training_steps, min_replay_history, epsilon_fn, support): 107 | epsilon = jnp.where(eval_mode, 108 | epsilon_eval, 109 | epsilon_fn(epsilon_decay_period, 110 | training_steps, 111 | min_replay_history, 112 | epsilon_train)) 113 | 114 | rng, rng1, rng2, rng3 = jax.random.split(rng, num=4) 115 | p = jax.random.uniform(rng1) 116 | return rng, jnp.where( 117 | p <= epsilon, 118 | jax.random.randint(rng2, (), 0, num_actions), 119 | jnp.argmax(network_def.apply(params, state, support, rng=rng3).q_values)) 120 | 121 | @gin.configurable 122 | class JaxRainbowAgentNew(dqn_agent.JaxDQNAgent): 123 | """A compact implementation of a simplified Rainbow agent.""" 124 | 125 | def __init__(self, 126 | num_actions, 127 | 128 | noisy = False, 129 | dueling = False, 130 | initzer = 'variance_scaling', 131 | net_conf = None, 132 | env = "CartPole", 133 | normalize_obs = True, 134 | hidden_layer=2, 135 | neurons=512, 136 | num_atoms=51, 137 | vmax=10., 138 | double_dqn=False, 139 | replay_scheme='prioritized', 140 | optimizer='adam', 141 | network=networks.RainbowNetwork, 142 | epsilon_fn=dqn_agent.linearly_decaying_epsilon, 143 | seed=None): 144 | """Initializes the agent and constructs the necessary components. 145 | 146 | Args: 147 | num_actions: int, number of actions the agent can take at any state. 148 | observation_shape: tuple of ints or an int. If single int, the observation 149 | is assumed to be a 2D square. 150 | observation_dtype: DType, specifies the type of the observations. Note 151 | that if your inputs are continuous, you should set this to jnp.float32. 152 | stack_size: int, number of frames to use in state stack. 153 | network: flax.nn Module that is initialized by shape in _create_network 154 | below. See dopamine.jax.networks.RainbowNetwork as an example. 155 | num_atoms: int, the number of buckets of the value function distribution. 156 | vmax: float, the value distribution support is [-vmax, vmax]. 157 | gamma: float, discount factor with the usual RL meaning. 158 | update_horizon: int, horizon at which updates are performed, the 'n' in 159 | n-step update. 160 | min_replay_history: int, number of transitions that should be experienced 161 | before the agent begins training its value function. 162 | update_period: int, period between DQN updates. 163 | target_update_period: int, update period for the target network. 164 | epsilon_fn: function expecting 4 parameters: 165 | (decay_period, step, warmup_steps, epsilon). This function should return 166 | the epsilon value used for exploration during training. 167 | epsilon_train: float, the value to which the agent's epsilon is eventually 168 | decayed during training. 169 | epsilon_eval: float, epsilon used when evaluating the agent. 170 | epsilon_decay_period: int, length of the epsilon decay schedule. 171 | replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the 172 | replay memory. 173 | optimizer: str, name of optimizer to use. 174 | summary_writer: SummaryWriter object for outputting training statistics. 175 | Summary writing disabled if set to None. 176 | summary_writing_frequency: int, frequency with which summaries will be 177 | written. Lower values will result in slower training. 178 | allow_partial_reload: bool, whether we allow reloading a partial agent 179 | (for instance, only the network parameters). 180 | """ 181 | # We need this because some tools convert round floats into ints. 182 | 183 | vmax = float(vmax) 184 | seed = int(time.time() * 1e6) if seed is None else seed 185 | self._num_atoms = num_atoms 186 | self._support = jnp.linspace(-vmax, vmax, num_atoms) 187 | self._replay_scheme = replay_scheme 188 | self._double_dqn = double_dqn 189 | self._net_conf = net_conf 190 | self._env = env 191 | self._normalize_obs = normalize_obs 192 | self._hidden_layer= hidden_layer 193 | self._neurons=neurons 194 | self._noisy = noisy 195 | self._dueling = dueling 196 | self._initzer = initzer 197 | self._rng = jax.random.PRNGKey(seed) 198 | 199 | super(JaxRainbowAgentNew, self).__init__( 200 | num_actions=num_actions, 201 | network=functools.partial(network, 202 | num_atoms=num_atoms, 203 | net_conf=self._net_conf, 204 | env=self._env, 205 | normalize_obs=self._normalize_obs, 206 | hidden_layer=self._hidden_layer, 207 | neurons=self._neurons, 208 | noisy=self._noisy, 209 | dueling=self._dueling, 210 | initzer=self._initzer), 211 | 212 | epsilon_fn = dqn_agent.identity_epsilon if self._noisy == True else epsilon_fn, 213 | optimizer=optimizer) 214 | 215 | def _build_networks_and_optimizer(self): 216 | self._rng, rng = jax.random.split(self._rng) 217 | online_network_params = self.network_def.init(rng, x=self.state, support=self._support, rng=self._rng) 218 | optimizer_def = dqn_agent.create_optimizer(self._optimizer_name) 219 | self.optimizer = optimizer_def.create(online_network_params) 220 | self.target_network_params = copy.deepcopy(online_network_params) 221 | 222 | 223 | def _build_replay_buffer(self): 224 | """Creates the replay buffer used by the agent.""" 225 | if self._replay_scheme not in ['uniform', 'prioritized']: 226 | raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme)) 227 | # Both replay schemes use the same data structure, but the 'uniform' scheme 228 | # sets all priorities to the same value (which yields uniform sampling). 229 | return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( 230 | observation_shape=self.observation_shape, 231 | stack_size=self.stack_size, 232 | update_horizon=self.update_horizon, 233 | gamma=self.gamma, 234 | observation_dtype=self.observation_dtype) 235 | 236 | def begin_episode(self, observation): 237 | self._reset_state() 238 | self._record_observation(observation) 239 | 240 | if not self.eval_mode: 241 | self._train_step() 242 | 243 | self._rng, self.action = select_action(self.network_def, 244 | self.online_params, 245 | self.state, 246 | self._rng, 247 | self.num_actions, 248 | self.eval_mode, 249 | self.epsilon_eval, 250 | self.epsilon_train, 251 | self.epsilon_decay_period, 252 | self.training_steps, 253 | self.min_replay_history, 254 | self.epsilon_fn, 255 | self._support) 256 | # TODO(psc): Why a numpy array? Why not an int? 257 | self.action = onp.asarray(self.action) 258 | return self.action 259 | 260 | def step(self, reward, observation): 261 | self._last_observation = self._observation 262 | self._record_observation(observation) 263 | 264 | if not self.eval_mode: 265 | self._store_transition(self._last_observation, self.action, reward, False) 266 | self._train_step() 267 | 268 | self._rng, self.action = select_action(self.network_def, 269 | self.online_params, 270 | self.state, 271 | self._rng, 272 | self.num_actions, 273 | self.eval_mode, 274 | self.epsilon_eval, 275 | self.epsilon_train, 276 | self.epsilon_decay_period, 277 | self.training_steps, 278 | self.min_replay_history, 279 | self.epsilon_fn, 280 | self._support) 281 | self.action = onp.asarray(self.action) 282 | return self.action 283 | 284 | def _train_step(self): 285 | """Runs a single training step. 286 | 287 | Runs training if both: 288 | (1) A minimum number of frames have been added to the replay buffer. 289 | (2) `training_steps` is a multiple of `update_period`. 290 | 291 | Also, syncs weights from online_network to target_network if training steps 292 | is a multiple of target update period. 293 | """ 294 | if self._replay.add_count > self.min_replay_history: 295 | if self.training_steps % self.update_period == 0: 296 | self._sample_from_replay_buffer() 297 | 298 | if self._replay_scheme == 'prioritized': 299 | # The original prioritized experience replay uses a linear exponent 300 | # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 301 | # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) 302 | # suggested a fixed exponent actually performs better, except on Pong. 303 | probs = self.replay_elements['sampling_probabilities'] 304 | # Weight the loss by the inverse priorities. 305 | loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) 306 | loss_weights /= jnp.max(loss_weights) 307 | else: 308 | loss_weights = jnp.ones(self.replay_elements['state'].shape[0]) 309 | 310 | self.optimizer, loss, mean_loss = train( 311 | self.network_def, 312 | self.target_network_params, 313 | self.optimizer, 314 | self.replay_elements['state'], 315 | self.replay_elements['action'], 316 | self.replay_elements['next_state'], 317 | self.replay_elements['reward'], 318 | self.replay_elements['terminal'], 319 | loss_weights, 320 | self._support, 321 | self.cumulative_gamma, 322 | self._double_dqn, 323 | self._rng) 324 | 325 | if self._replay_scheme == 'prioritized': 326 | # Rainbow and prioritized replay are parametrized by an exponent 327 | # alpha, but in both cases it is set to 0.5 - for simplicity's sake we 328 | # leave it as is here, using the more direct sqrt(). Taking the square 329 | # root "makes sense", as we are dealing with a squared loss. Add a 330 | # small nonzero value to the loss to avoid 0 priority items. While 331 | # technically this may be okay, setting all items to 0 priority will 332 | # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. 333 | self._replay.set_priority(self.replay_elements['indices'], 334 | jnp.sqrt(loss + 1e-10)) 335 | 336 | if self.summary_writer is not None: 337 | summary = tf.compat.v1.Summary(value=[ 338 | tf.compat.v1.Summary.Value(tag='CrossEntropyLoss', 339 | simple_value=mean_loss)]) 340 | self.summary_writer.add_summary(summary, self.training_steps) 341 | if self.training_steps % self.target_update_period == 0: 342 | self._sync_weights() 343 | 344 | self.training_steps += 1 345 | 346 | def _store_transition(self, 347 | last_observation, 348 | action, 349 | reward, 350 | is_terminal, 351 | priority=None): 352 | """Stores a transition when in training mode. 353 | 354 | Stores the following tuple in the replay buffer (last_observation, action, 355 | reward, is_terminal, priority). 356 | 357 | Args: 358 | last_observation: Last observation, type determined via observation_type 359 | parameter in the replay_memory constructor. 360 | action: An integer, the action taken. 361 | reward: A float, the reward. 362 | is_terminal: Boolean indicating if the current state is a terminal state. 363 | priority: Float. Priority of sampling the transition. If None, the default 364 | priority will be used. If replay scheme is uniform, the default priority 365 | is 1. If the replay scheme is prioritized, the default priority is the 366 | maximum ever seen [Schaul et al., 2015]. 367 | """ 368 | if priority is None: 369 | if self._replay_scheme == 'uniform': 370 | priority = 1. 371 | else: 372 | priority = self._replay.sum_tree.max_recorded_priority 373 | 374 | if not self.eval_mode: 375 | self._replay.add(last_observation, action, reward, is_terminal, priority) 376 | 377 | 378 | def project_distribution(supports, weights, target_support): 379 | """Projects a batch of (support, weights) onto target_support. 380 | 381 | Based on equation (7) in (Bellemare et al., 2017): 382 | https://arxiv.org/abs/1707.06887 383 | In the rest of the comments we will refer to this equation simply as Eq7. 384 | 385 | Args: 386 | supports: Jax array of shape (num_dims) defining supports for 387 | the distribution. 388 | weights: Jax array of shape (num_dims) defining weights on the 389 | original support points. Although for the CategoricalDQN agent these 390 | weights are probabilities, it is not required that they are. 391 | target_support: Jax array of shape (num_dims) defining support of the 392 | projected distribution. The values must be monotonically increasing. Vmin 393 | and Vmax will be inferred from the first and last elements of this Jax 394 | array, respectively. The values in this Jax array must be equally spaced. 395 | 396 | Returns: 397 | A Jax array of shape (num_dims) with the projection of a batch 398 | of (support, weights) onto target_support. 399 | 400 | Raises: 401 | ValueError: If target_support has no dimensions, or if shapes of supports, 402 | weights, and target_support are incompatible. 403 | """ 404 | v_min, v_max = target_support[0], target_support[-1] 405 | # `N` in Eq7. 406 | num_dims = target_support.shape[0] 407 | # delta_z = `\Delta z` in Eq7. 408 | delta_z = (v_max - v_min) / (num_dims - 1) 409 | # clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7. 410 | clipped_support = jnp.clip(supports, v_min, v_max) 411 | # numerator = `|clipped_support - z_i|` in Eq7. 412 | numerator = jnp.abs(clipped_support - target_support[:, None]) 413 | quotient = 1 - (numerator / delta_z) 414 | # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7. 415 | clipped_quotient = jnp.clip(quotient, 0, 1) 416 | # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))` in Eq7. 417 | inner_prod = clipped_quotient * weights 418 | return jnp.squeeze(jnp.sum(inner_prod, -1)) -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_acrobot.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.ACROBOT_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.ACROBOT_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.ACROBOT_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 100 20 | 21 | JaxDQNAgentNew.optimizer = 'adam' 22 | JaxDQNAgentNew.net_conf = 'classic' 23 | JaxDQNAgentNew.env = 'Acrobot' 24 | JaxDQNAgentNew.normalize_obs = True 25 | JaxDQNAgentNew.hidden_layer = 2 26 | JaxDQNAgentNew.neurons = 512 27 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 28 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 29 | JaxDQNAgentNew.mse_inf = True 30 | JaxDQNAgentNew.noisy = False 31 | JaxDQNAgentNew.dueling = False 32 | JaxDQNAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 38 | JaxDQNAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 39 | JaxDQNAgentNew.tau = 100 40 | JaxDQNAgentNew.alpha = 1 41 | JaxDQNAgentNew.clip_value_min = -1e3 42 | 43 | create_optimizer = @dqn_agent.create_optimizer 44 | create_optimizer.learning_rate = 0.001 45 | create_optimizer.eps = 3.125e-4 46 | 47 | create_gym_environment.environment_name = 'Acrobot' 48 | create_gym_environment.version = 'v1' 49 | 50 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 51 | Runner.num_iterations = 30 52 | Runner.training_steps = 1000 53 | Runner.max_steps_per_episode = 500 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_asterix.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.ASTERIX_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 1#3 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxDQNAgentNew.optimizer = 'adam' 23 | JaxDQNAgentNew.net_conf = 'minatar' 24 | JaxDQNAgentNew.env = None 25 | JaxDQNAgentNew.normalize_obs = False 26 | JaxDQNAgentNew.hidden_layer = 0 27 | JaxDQNAgentNew.neurons = None 28 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 29 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 30 | JaxDQNAgentNew.mse_inf = False 31 | JaxDQNAgentNew.noisy = False 32 | JaxDQNAgentNew.dueling = False 33 | JaxDQNAgentNew.initzer = @variance_scaling() 34 | variance_scaling.scale=1 35 | variance_scaling.mode='fan_avg' 36 | variance_scaling.distribution='uniform' 37 | 38 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 39 | JaxDQNAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | JaxDQNAgentNew.tau = 0.03 41 | JaxDQNAgentNew.alpha = 0.9 42 | JaxDQNAgentNew.clip_value_min = -1 43 | 44 | create_optimizer = @dqn_agent.create_optimizer 45 | create_optimizer.learning_rate = 0.00025#0.001 46 | create_optimizer.eps = 3.125e-4#0.01 47 | 48 | create_minatar_env.game_name = 'asterix' 49 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 50 | 51 | Runner.num_iterations = 10 52 | Runner.training_steps = 1000000 53 | Runner.max_steps_per_episode = 100000000 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_breakout.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.BREAKOUT_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 1#3 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxDQNAgentNew.optimizer = 'adam' 23 | JaxDQNAgentNew.net_conf = 'minatar' 24 | JaxDQNAgentNew.env = None 25 | JaxDQNAgentNew.normalize_obs = False 26 | JaxDQNAgentNew.hidden_layer = 0 27 | JaxDQNAgentNew.neurons = None 28 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 29 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 30 | JaxDQNAgentNew.mse_inf = True 31 | JaxDQNAgentNew.noisy = False 32 | JaxDQNAgentNew.dueling = False 33 | JaxDQNAgentNew.initzer = @variance_scaling() 34 | variance_scaling.scale=1 35 | variance_scaling.mode='fan_avg' 36 | variance_scaling.distribution='uniform' 37 | 38 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 39 | JaxDQNAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | JaxDQNAgentNew.tau = 0.03 41 | JaxDQNAgentNew.alpha = 0.9 42 | JaxDQNAgentNew.clip_value_min = -1 43 | 44 | create_optimizer = @dqn_agent.create_optimizer 45 | create_optimizer.learning_rate = 0.00025#0.001 46 | create_optimizer.eps = 3.125e-4#0.01 47 | 48 | create_minatar_env.game_name = 'breakout' 49 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 50 | 51 | Runner.num_iterations = 10 52 | Runner.training_steps = 1000000 53 | Runner.max_steps_per_episode = 100000000 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_cartpole.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.CARTPOLE_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.CARTPOLE_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.CARTPOLE_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 100 20 | 21 | JaxDQNAgentNew.optimizer = 'adam' 22 | JaxDQNAgentNew.net_conf = 'classic' 23 | JaxDQNAgentNew.env = 'CartPole' 24 | JaxDQNAgentNew.normalize_obs = True 25 | JaxDQNAgentNew.hidden_layer = 2 26 | JaxDQNAgentNew.neurons = 512 27 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 28 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 29 | JaxDQNAgentNew.mse_inf = True 30 | JaxDQNAgentNew.noisy = False 31 | JaxDQNAgentNew.dueling = False 32 | JaxDQNAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxDQNAgentNew.tau = 100 38 | JaxDQNAgentNew.alpha = 1 39 | JaxDQNAgentNew.clip_value_min = -1e3 40 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 41 | JaxDQNAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 42 | 43 | create_optimizer = @dqn_agent.create_optimizer 44 | create_optimizer.learning_rate = 0.001 45 | create_optimizer.eps = 3.125e-4 46 | 47 | create_gym_environment.environment_name = 'CartPole' 48 | create_gym_environment.version = 'v0' 49 | 50 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 51 | Runner.num_iterations = 30 52 | Runner.training_steps = 1000 53 | Runner.max_steps_per_episode = 200 # Default max episode length. 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_freeway.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.FREEWAY_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 1 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxDQNAgentNew.optimizer = 'adam' 23 | JaxDQNAgentNew.net_conf = 'minatar' 24 | JaxDQNAgentNew.env = None 25 | JaxDQNAgentNew.normalize_obs = False 26 | JaxDQNAgentNew.hidden_layer = 0 27 | JaxDQNAgentNew.neurons = None 28 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 29 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 30 | JaxDQNAgentNew.mse_inf = True 31 | JaxDQNAgentNew.noisy = False 32 | JaxDQNAgentNew.dueling = False 33 | JaxDQNAgentNew.initzer = @variance_scaling() 34 | variance_scaling.scale=1 35 | variance_scaling.mode='fan_avg' 36 | variance_scaling.distribution='uniform' 37 | 38 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 39 | JaxDQNAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | JaxDQNAgentNew.tau = 0.03 41 | JaxDQNAgentNew.alpha = 0.9 42 | JaxDQNAgentNew.clip_value_min = -1 43 | 44 | create_optimizer = @dqn_agent.create_optimizer 45 | create_optimizer.learning_rate = 0.00025 46 | create_optimizer.eps = 3.125e-4 47 | 48 | create_minatar_env.game_name = 'freeway' 49 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 50 | 51 | Runner.num_iterations = 10 52 | Runner.training_steps = 1000000 53 | Runner.max_steps_per_episode = 100000000 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_invaders.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.SPACE_INVADERS_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 1 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxDQNAgentNew.optimizer = 'adam' 23 | JaxDQNAgentNew.net_conf = 'minatar' 24 | JaxDQNAgentNew.env = None 25 | JaxDQNAgentNew.normalize_obs = False 26 | JaxDQNAgentNew.hidden_layer = 0 27 | JaxDQNAgentNew.neurons = None 28 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 29 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 30 | JaxDQNAgentNew.mse_inf = True 31 | JaxDQNAgentNew.noisy = False 32 | JaxDQNAgentNew.dueling = False 33 | JaxDQNAgentNew.initzer = @variance_scaling() 34 | variance_scaling.scale=1 35 | variance_scaling.mode='fan_avg' 36 | variance_scaling.distribution='uniform' 37 | 38 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 39 | JaxDQNAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | JaxDQNAgentNew.tau = 0.03 41 | JaxDQNAgentNew.alpha = 0.9 42 | JaxDQNAgentNew.clip_value_min = -1 43 | 44 | create_optimizer = @dqn_agent.create_optimizer 45 | create_optimizer.learning_rate = 0.00025 46 | create_optimizer.eps = 3.125e-4 47 | 48 | create_minatar_env.game_name = 'space_invaders' 49 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 50 | 51 | Runner.num_iterations = 10 52 | Runner.training_steps = 1000000 53 | Runner.max_steps_per_episode = 100000000 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_lunarlander.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.LUNAR_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.LUNAR_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.LUNAR_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 300 20 | 21 | JaxDQNAgentNew.optimizer = 'adam' 22 | JaxDQNAgentNew.net_conf = 'classic' 23 | JaxDQNAgentNew.env = 'LunarLander' 24 | JaxDQNAgentNew.normalize_obs = False 25 | JaxDQNAgentNew.hidden_layer = 2 26 | JaxDQNAgentNew.neurons = 512 27 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 28 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 29 | JaxDQNAgentNew.mse_inf = True 30 | JaxDQNAgentNew.noisy = False 31 | JaxDQNAgentNew.dueling = False 32 | JaxDQNAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 38 | JaxDQNAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 39 | JaxDQNAgentNew.tau = 100 40 | JaxDQNAgentNew.alpha = 1 41 | JaxDQNAgentNew.clip_value_min = -1e3 42 | 43 | create_optimizer = @dqn_agent.create_optimizer 44 | create_optimizer.learning_rate = 0.001 45 | create_optimizer.eps = 3.125e-4 46 | 47 | create_gym_environment.environment_name = 'LunarLander' 48 | create_gym_environment.version = 'v2' 49 | 50 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 51 | Runner.num_iterations = 30 52 | Runner.training_steps = 4000 53 | Runner.max_steps_per_episode = 1000 # Default max episode length. 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_mountaincar.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.MOUNTAINCAR_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.MOUNTAINCAR_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 100 20 | 21 | JaxDQNAgentNew.optimizer = 'adam' 22 | JaxDQNAgentNew.net_conf = 'classic' 23 | JaxDQNAgentNew.env = 'MountainCar' 24 | JaxDQNAgentNew.normalize_obs = True 25 | JaxDQNAgentNew.hidden_layer = 2 26 | JaxDQNAgentNew.neurons = 512 27 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 28 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 29 | JaxDQNAgentNew.mse_inf = True 30 | JaxDQNAgentNew.noisy = False 31 | JaxDQNAgentNew.dueling = False 32 | JaxDQNAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 38 | JaxDQNAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 39 | JaxDQNAgentNew.tau = 100 40 | JaxDQNAgentNew.alpha = 1 41 | JaxDQNAgentNew.clip_value_min = -1e3 42 | 43 | create_optimizer = @dqn_agent.create_optimizer 44 | create_optimizer.learning_rate = 0.01 45 | create_optimizer.eps = 3.125e-4 46 | 47 | create_gym_environment.environment_name = 'MountainCar' 48 | create_gym_environment.version = 'v0' 49 | 50 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 51 | Runner.num_iterations = 30 52 | Runner.training_steps = 1000 53 | Runner.max_steps_per_episode = 600 # Default max episode length. 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/dqn_seaquest.gin: -------------------------------------------------------------------------------- 1 | import dopamine.discrete_domains.gym_lib 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.run_experiment 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import dqn_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.SEAQUEST_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 1 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxDQNAgentNew.optimizer = 'adam' 23 | JaxDQNAgentNew.net_conf = 'minatar' 24 | JaxDQNAgentNew.env = None 25 | JaxDQNAgentNew.normalize_obs = False 26 | JaxDQNAgentNew.hidden_layer = 0 27 | JaxDQNAgentNew.neurons = None 28 | JaxDQNAgentNew.replay_scheme = 'uniform' #'prioritized' or 'uniform' 29 | JaxDQNAgentNew.target_opt = 0 #0:DQN 1:Double DQN 2:Munchausen DQN 30 | JaxDQNAgentNew.mse_inf = True 31 | JaxDQNAgentNew.noisy = False 32 | JaxDQNAgentNew.dueling = False 33 | JaxDQNAgentNew.initzer = @variance_scaling() 34 | variance_scaling.scale=1 35 | variance_scaling.mode='fan_avg' 36 | variance_scaling.distribution='uniform' 37 | 38 | JaxDQNAgentNew.network = @networks_new.DQNNetwork 39 | JaxDQNAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | JaxDQNAgentNew.tau = 0.03 41 | JaxDQNAgentNew.alpha = 0.9 42 | JaxDQNAgentNew.clip_value_min = -1 43 | 44 | create_optimizer = @dqn_agent.create_optimizer 45 | create_optimizer.learning_rate = 0.00025 46 | create_optimizer.eps = 3.125e-4 47 | 48 | create_minatar_env.game_name = 'seaquest' 49 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 50 | 51 | Runner.num_iterations = 10 52 | Runner.training_steps = 1000000 53 | Runner.max_steps_per_episode = 100000000 54 | 55 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 56 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 57 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_acrobot.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import external_configurations 12 | 13 | JaxImplicitQuantileAgentNew.observation_shape = %gym_lib.ACROBOT_OBSERVATION_SHAPE 14 | JaxImplicitQuantileAgentNew.observation_dtype = %jax_networks.ACROBOT_OBSERVATION_DTYPE 15 | JaxImplicitQuantileAgentNew.stack_size = %gym_lib.ACROBOT_STACK_SIZE 16 | JaxImplicitQuantileAgentNew.gamma = 0.99 17 | JaxImplicitQuantileAgentNew.update_horizon = 1#3 18 | JaxImplicitQuantileAgentNew.min_replay_history = 500 # agent steps 19 | JaxImplicitQuantileAgentNew.update_period = 2 20 | JaxImplicitQuantileAgentNew.target_update_period = 100 # agent step 21 | 22 | 23 | JaxImplicitQuantileAgentNew.net_conf = 'classic' 24 | JaxImplicitQuantileAgentNew.env = 'Acrobot' 25 | JaxImplicitQuantileAgentNew.hidden_layer = 2 26 | JaxImplicitQuantileAgentNew.neurons = 512 27 | JaxImplicitQuantileAgentNew.noisy = False 28 | JaxImplicitQuantileAgentNew.double_dqn = False 29 | JaxImplicitQuantileAgentNew.dueling = False 30 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 31 | variance_scaling.scale=1 32 | variance_scaling.mode='fan_avg' 33 | variance_scaling.distribution='uniform' 34 | 35 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'#'prioritized' 36 | JaxImplicitQuantileAgentNew.kappa = 1.0 37 | 38 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 39 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 40 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 41 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 42 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 43 | 44 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 45 | JaxImplicitQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon #@dqn_agent.linearly_decaying_epsilon 46 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 47 | 48 | JaxImplicitQuantileAgentNew.tau = 0.03 49 | JaxImplicitQuantileAgentNew.alpha = 1 50 | JaxImplicitQuantileAgentNew.clip_value_min = -1 51 | 52 | create_optimizer.learning_rate = 0.0001 53 | create_optimizer.eps = 3.125e-4 54 | 55 | create_gym_environment.environment_name = 'Acrobot' 56 | create_gym_environment.version = 'v1' 57 | 58 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 59 | Runner.num_iterations = 30 60 | Runner.training_steps = 1000 61 | Runner.max_steps_per_episode = 500 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_asterix.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import minatar_env 12 | import external_configurations 13 | 14 | JaxImplicitQuantileAgentNew.observation_shape = %minatar_env.ASTERIX_SHAPE 15 | JaxImplicitQuantileAgentNew.observation_dtype = %minatar_env.DTYPE 16 | JaxImplicitQuantileAgentNew.stack_size = 1 17 | JaxImplicitQuantileAgentNew.gamma = 0.99 18 | JaxImplicitQuantileAgentNew.update_horizon = 1 19 | JaxImplicitQuantileAgentNew.min_replay_history = 1000 # agent steps 20 | JaxImplicitQuantileAgentNew.update_period = 4 21 | JaxImplicitQuantileAgentNew.target_update_period = 1000 # agent step 22 | 23 | 24 | JaxImplicitQuantileAgentNew.net_conf = 'minatar' 25 | JaxImplicitQuantileAgentNew.env = None 26 | JaxImplicitQuantileAgentNew.hidden_layer = 0 27 | JaxImplicitQuantileAgentNew.neurons = None 28 | JaxImplicitQuantileAgentNew.noisy = False 29 | JaxImplicitQuantileAgentNew.double_dqn = False 30 | JaxImplicitQuantileAgentNew.dueling = False 31 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 32 | variance_scaling.scale=1 33 | variance_scaling.mode='fan_avg' 34 | variance_scaling.distribution='uniform' 35 | 36 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'# prioritized 37 | JaxImplicitQuantileAgentNew.kappa = 1.0 38 | 39 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 40 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 41 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 42 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 43 | 44 | JaxImplicitQuantileAgentNew.tau = 0.03 45 | JaxImplicitQuantileAgentNew.alpha = 0.9 46 | JaxImplicitQuantileAgentNew.clip_value_min = -1 47 | 48 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 49 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 50 | JaxImplicitQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 51 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 52 | 53 | create_optimizer.learning_rate = 0.0001 54 | create_optimizer.eps =0.0003125 55 | 56 | create_minatar_env.game_name ='asterix' 57 | 58 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 59 | Runner.num_iterations = 10 60 | Runner.training_steps = 1000000 61 | Runner.max_steps_per_episode = 100000000 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_breakout.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import minatar_env 12 | import external_configurations 13 | 14 | JaxImplicitQuantileAgentNew.observation_shape = %minatar_env.BREAKOUT_SHAPE 15 | JaxImplicitQuantileAgentNew.observation_dtype = %minatar_env.DTYPE 16 | JaxImplicitQuantileAgentNew.stack_size = 1 17 | JaxImplicitQuantileAgentNew.gamma = 0.99 18 | JaxImplicitQuantileAgentNew.update_horizon = 1 19 | JaxImplicitQuantileAgentNew.min_replay_history = 1000 # agent steps 20 | JaxImplicitQuantileAgentNew.update_period = 4 21 | JaxImplicitQuantileAgentNew.target_update_period = 1000 # agent step 22 | 23 | 24 | JaxImplicitQuantileAgentNew.net_conf = 'minatar' 25 | JaxImplicitQuantileAgentNew.env = None 26 | JaxImplicitQuantileAgentNew.hidden_layer = 0 27 | JaxImplicitQuantileAgentNew.neurons = None 28 | JaxImplicitQuantileAgentNew.noisy = False 29 | JaxImplicitQuantileAgentNew.double_dqn = False 30 | JaxImplicitQuantileAgentNew.dueling = False 31 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 32 | variance_scaling.scale=1 33 | variance_scaling.mode='fan_avg' 34 | variance_scaling.distribution='uniform' 35 | 36 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'# prioritized 37 | JaxImplicitQuantileAgentNew.kappa = 1.0 38 | 39 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 40 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 41 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 42 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 43 | 44 | JaxImplicitQuantileAgentNew.tau = 0.03 45 | JaxImplicitQuantileAgentNew.alpha = 0.9 46 | JaxImplicitQuantileAgentNew.clip_value_min = -1 47 | 48 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 49 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 50 | JaxImplicitQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 51 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 52 | 53 | create_optimizer.learning_rate = 0.0001 54 | create_optimizer.eps =0.0003125 55 | 56 | create_minatar_env.game_name ='breakout' 57 | 58 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 59 | Runner.num_iterations = 10 60 | Runner.training_steps = 1000000 61 | Runner.max_steps_per_episode = 100000000 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_cartpole.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import external_configurations 12 | 13 | JaxImplicitQuantileAgentNew.observation_shape = %gym_lib.CARTPOLE_OBSERVATION_SHAPE 14 | JaxImplicitQuantileAgentNew.observation_dtype = %jax_networks.CARTPOLE_OBSERVATION_DTYPE 15 | JaxImplicitQuantileAgentNew.stack_size = %gym_lib.CARTPOLE_STACK_SIZE 16 | JaxImplicitQuantileAgentNew.gamma = 0.99 17 | JaxImplicitQuantileAgentNew.update_horizon = 1#3 18 | JaxImplicitQuantileAgentNew.min_replay_history = 500 # agent steps 19 | JaxImplicitQuantileAgentNew.update_period = 2 20 | JaxImplicitQuantileAgentNew.target_update_period = 100 # agent step 21 | 22 | 23 | JaxImplicitQuantileAgentNew.net_conf = 'classic' 24 | JaxImplicitQuantileAgentNew.env = 'CartPole' 25 | JaxImplicitQuantileAgentNew.hidden_layer = 2 26 | JaxImplicitQuantileAgentNew.neurons = 512 27 | JaxImplicitQuantileAgentNew.noisy = False 28 | JaxImplicitQuantileAgentNew.double_dqn = False 29 | JaxImplicitQuantileAgentNew.dueling = False 30 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 31 | variance_scaling.scale=1 32 | variance_scaling.mode='fan_avg' 33 | variance_scaling.distribution='uniform' 34 | 35 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'# prioritized or uniform 36 | JaxImplicitQuantileAgentNew.kappa = 1.0 37 | 38 | JaxImplicitQuantileAgentNew.num_tau_samples = 32#64 39 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32#64 40 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 41 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 42 | 43 | JaxImplicitQuantileAgentNew.tau = 0.03#100 44 | JaxImplicitQuantileAgentNew.alpha = 1 45 | JaxImplicitQuantileAgentNew.clip_value_min = -1#-1e3 46 | 47 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 48 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 49 | JaxImplicitQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon #@dqn_agent.linearly_decaying_epsilon 50 | JaxImplicitQuantileAgentNew.target_opt = 1 # 0 target_quantile and 1 munchau_target_quantile 51 | 52 | create_optimizer.learning_rate = 0.0001 53 | create_optimizer.eps =0.0003125 54 | 55 | create_gym_environment.environment_name = 'CartPole' 56 | create_gym_environment.version = 'v0' 57 | 58 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 59 | Runner.num_iterations = 30 60 | Runner.training_steps = 1000 61 | Runner.max_steps_per_episode = 200 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_freeway.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import minatar_env 12 | import external_configurations 13 | 14 | JaxImplicitQuantileAgentNew.observation_shape = %minatar_env.FREEWAY_SHAPE 15 | JaxImplicitQuantileAgentNew.observation_dtype = %minatar_env.DTYPE 16 | JaxImplicitQuantileAgentNew.stack_size = 1 17 | JaxImplicitQuantileAgentNew.gamma = 0.99 18 | JaxImplicitQuantileAgentNew.update_horizon = 1 19 | JaxImplicitQuantileAgentNew.min_replay_history = 1000 # agent steps 20 | JaxImplicitQuantileAgentNew.update_period = 4 21 | JaxImplicitQuantileAgentNew.target_update_period = 1000 # agent step 22 | 23 | 24 | JaxImplicitQuantileAgentNew.net_conf = 'minatar' 25 | JaxImplicitQuantileAgentNew.env = None 26 | JaxImplicitQuantileAgentNew.hidden_layer = 0 27 | JaxImplicitQuantileAgentNew.neurons = None 28 | JaxImplicitQuantileAgentNew.noisy = False 29 | JaxImplicitQuantileAgentNew.double_dqn = False 30 | JaxImplicitQuantileAgentNew.dueling = False 31 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 32 | variance_scaling.scale=1 33 | variance_scaling.mode='fan_avg' 34 | variance_scaling.distribution='uniform' 35 | 36 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'# prioritized 37 | JaxImplicitQuantileAgentNew.kappa = 1.0 38 | 39 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 40 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 41 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 42 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 43 | 44 | JaxImplicitQuantileAgentNew.tau = 0.03 45 | JaxImplicitQuantileAgentNew.alpha = 0.9 46 | JaxImplicitQuantileAgentNew.clip_value_min = -1 47 | 48 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 49 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 50 | JaxImplicitQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 51 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 52 | 53 | create_optimizer.learning_rate = 0.0001 54 | create_optimizer.eps =0.0003125 55 | 56 | create_minatar_env.game_name ='freeway' 57 | 58 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 59 | Runner.num_iterations = 10 60 | Runner.training_steps = 1000000 61 | Runner.max_steps_per_episode = 100000000 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_invaders.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import minatar_env 12 | import external_configurations 13 | 14 | JaxImplicitQuantileAgentNew.observation_shape = %minatar_env.SPACE_INVADERS_SHAPE 15 | JaxImplicitQuantileAgentNew.observation_dtype = %minatar_env.DTYPE 16 | JaxImplicitQuantileAgentNew.stack_size = 1 17 | JaxImplicitQuantileAgentNew.gamma = 0.99 18 | JaxImplicitQuantileAgentNew.update_horizon = 1 19 | JaxImplicitQuantileAgentNew.min_replay_history = 1000 # agent steps 20 | JaxImplicitQuantileAgentNew.update_period = 4 21 | JaxImplicitQuantileAgentNew.target_update_period = 1000 # agent step 22 | 23 | 24 | JaxImplicitQuantileAgentNew.net_conf = 'minatar' 25 | JaxImplicitQuantileAgentNew.env = None 26 | JaxImplicitQuantileAgentNew.hidden_layer = 0 27 | JaxImplicitQuantileAgentNew.neurons = None 28 | JaxImplicitQuantileAgentNew.noisy = False 29 | JaxImplicitQuantileAgentNew.double_dqn = False 30 | JaxImplicitQuantileAgentNew.dueling = False 31 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 32 | variance_scaling.scale=1 33 | variance_scaling.mode='fan_avg' 34 | variance_scaling.distribution='uniform' 35 | 36 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'# prioritized 37 | JaxImplicitQuantileAgentNew.kappa = 1.0 38 | 39 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 40 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 41 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 42 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 43 | 44 | JaxImplicitQuantileAgentNew.tau = 0.03 45 | JaxImplicitQuantileAgentNew.alpha = 0.9 46 | JaxImplicitQuantileAgentNew.clip_value_min = -1 47 | 48 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 49 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 50 | JaxImplicitQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 51 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 52 | 53 | create_optimizer.learning_rate = 0.0001 54 | create_optimizer.eps =0.0003125 55 | 56 | create_minatar_env.game_name ='space_invaders' 57 | 58 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 59 | Runner.num_iterations = 10 60 | Runner.training_steps = 1000000 61 | Runner.max_steps_per_episode = 100000000 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_lunarlander.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import external_configurations 12 | 13 | JaxImplicitQuantileAgentNew.observation_shape = %gym_lib.LUNAR_OBSERVATION_SHAPE 14 | JaxImplicitQuantileAgentNew.observation_dtype = %jax_networks.LUNAR_OBSERVATION_DTYPE 15 | JaxImplicitQuantileAgentNew.stack_size = %gym_lib.LUNAR_STACK_SIZE 16 | 17 | JaxImplicitQuantileAgentNew.gamma = 0.99 18 | JaxImplicitQuantileAgentNew.update_horizon = 1#3 19 | JaxImplicitQuantileAgentNew.min_replay_history = 500 # agent steps 20 | JaxImplicitQuantileAgentNew.update_period = 2 21 | JaxImplicitQuantileAgentNew.target_update_period = 100 # agent step 22 | 23 | JaxImplicitQuantileAgentNew.net_conf = 'classic' 24 | JaxImplicitQuantileAgentNew.env = 'LunarLander' 25 | JaxImplicitQuantileAgentNew.hidden_layer = 2 26 | JaxImplicitQuantileAgentNew.neurons = 512 27 | JaxImplicitQuantileAgentNew.noisy = False 28 | JaxImplicitQuantileAgentNew.double_dqn = False 29 | JaxImplicitQuantileAgentNew.dueling = False 30 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 31 | variance_scaling.scale=1 32 | variance_scaling.mode='fan_avg' 33 | variance_scaling.distribution='uniform' 34 | 35 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 36 | JaxImplicitQuantileAgentNew.kappa = 1.0 37 | 38 | JaxImplicitQuantileAgentNew.tau = 0.03 39 | JaxImplicitQuantileAgentNew.alpha = 1 40 | JaxImplicitQuantileAgentNew.clip_value_min = -1 41 | 42 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 43 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 44 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 45 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 46 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 47 | 48 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 49 | JaxImplicitQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon #@dqn_agent.linearly_decaying_epsilon 50 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 51 | 52 | create_optimizer.learning_rate = 0.001 53 | create_optimizer.eps = 3.125e-4 54 | 55 | create_gym_environment.environment_name = 'LunarLander' 56 | create_gym_environment.version = 'v2' 57 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 58 | 59 | Runner.num_iterations = 30 60 | Runner.training_steps = 4000 61 | Runner.max_steps_per_episode = 1000 62 | 63 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 64 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 65 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_mountaincar.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import external_configurations 12 | 13 | JaxImplicitQuantileAgentNew.observation_shape = %gym_lib.MOUNTAINCAR_OBSERVATION_SHAPE 14 | JaxImplicitQuantileAgentNew.observation_dtype = %jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE 15 | JaxImplicitQuantileAgentNew.stack_size = %gym_lib.MOUNTAINCAR_STACK_SIZE 16 | JaxImplicitQuantileAgentNew.gamma = 0.99 17 | JaxImplicitQuantileAgentNew.update_horizon = 1#3 18 | JaxImplicitQuantileAgentNew.min_replay_history = 500 # agent steps 19 | JaxImplicitQuantileAgentNew.update_period = 2 20 | JaxImplicitQuantileAgentNew.target_update_period = 100 # agent step 21 | 22 | JaxImplicitQuantileAgentNew.net_conf = 'classic' 23 | JaxImplicitQuantileAgentNew.env = 'MountainCar' 24 | JaxImplicitQuantileAgentNew.hidden_layer = 2 25 | JaxImplicitQuantileAgentNew.neurons = 512 26 | JaxImplicitQuantileAgentNew.noisy = False 27 | JaxImplicitQuantileAgentNew.double_dqn = False 28 | JaxImplicitQuantileAgentNew.dueling = False 29 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 30 | variance_scaling.scale=1 31 | variance_scaling.mode='fan_avg' 32 | variance_scaling.distribution='uniform' 33 | 34 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 35 | JaxImplicitQuantileAgentNew.kappa = 1.0 36 | 37 | JaxImplicitQuantileAgentNew.tau = 0.03 38 | JaxImplicitQuantileAgentNew.alpha = 1 39 | JaxImplicitQuantileAgentNew.clip_value_min = -1 40 | 41 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 42 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 43 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 44 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 45 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 46 | 47 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 48 | JaxImplicitQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon #@dqn_agent.linearly_decaying_epsilon 49 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 50 | 51 | create_optimizer.learning_rate = 0.0001 52 | create_optimizer.eps = 3.125e-4 53 | 54 | create_gym_environment.environment_name = 'MountainCar' 55 | create_gym_environment.version = 'v0' 56 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 57 | 58 | Runner.num_iterations = 30 59 | Runner.training_steps = 1000 60 | Runner.max_steps_per_episode = 600 61 | 62 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 63 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 64 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/implicit_seaquest.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.implicit_quantile.implicit_quantile_agent 2 | import dopamine.discrete_domains.run_experiment 3 | import dopamine.replay_memory.prioritized_replay_buffer 4 | 5 | import dopamine.jax.agents.dqn.dqn_agent 6 | import dopamine.jax.networks 7 | import dopamine.discrete_domains.gym_lib 8 | 9 | import networks_new 10 | import implicit_quantile_agent_new 11 | import minatar_env 12 | import external_configurations 13 | 14 | JaxImplicitQuantileAgentNew.observation_shape = %minatar_env.SEAQUEST_SHAPE 15 | JaxImplicitQuantileAgentNew.observation_dtype = %minatar_env.DTYPE 16 | JaxImplicitQuantileAgentNew.stack_size = 1 17 | JaxImplicitQuantileAgentNew.gamma = 0.99 18 | JaxImplicitQuantileAgentNew.update_horizon = 1 19 | JaxImplicitQuantileAgentNew.min_replay_history = 1000 # agent steps 20 | JaxImplicitQuantileAgentNew.update_period = 4 21 | JaxImplicitQuantileAgentNew.target_update_period = 1000 # agent step 22 | 23 | JaxImplicitQuantileAgentNew.net_conf = 'minatar' 24 | JaxImplicitQuantileAgentNew.env = None 25 | JaxImplicitQuantileAgentNew.hidden_layer = 0 26 | JaxImplicitQuantileAgentNew.neurons = None 27 | JaxImplicitQuantileAgentNew.noisy = False 28 | JaxImplicitQuantileAgentNew.double_dqn = False 29 | JaxImplicitQuantileAgentNew.dueling = False 30 | JaxImplicitQuantileAgentNew.initzer = @variance_scaling() 31 | variance_scaling.scale=1 32 | variance_scaling.mode='fan_avg' 33 | variance_scaling.distribution='uniform' 34 | 35 | JaxImplicitQuantileAgentNew.replay_scheme = 'uniform'# prioritized 36 | JaxImplicitQuantileAgentNew.kappa = 1.0 37 | 38 | JaxImplicitQuantileAgentNew.num_tau_samples = 32 39 | JaxImplicitQuantileAgentNew.num_tau_prime_samples = 32 40 | JaxImplicitQuantileAgentNew.num_quantile_samples = 32 41 | JaxImplicitQuantileAgentNew.quantile_embedding_dim = 64 42 | 43 | JaxImplicitQuantileAgentNew.tau = 0.03 44 | JaxImplicitQuantileAgentNew.alpha = 0.9 45 | JaxImplicitQuantileAgentNew.clip_value_min = -1 46 | 47 | JaxImplicitQuantileAgentNew.optimizer = 'adam' 48 | JaxImplicitQuantileAgentNew.network = @networks_new.ImplicitQuantileNetwork 49 | JaxImplicitQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 50 | JaxImplicitQuantileAgentNew.target_opt = 0 # 0 target_quantile and 1 munchau_target_quantile 51 | 52 | create_optimizer.learning_rate = 0.0001 53 | create_optimizer.eps =0.0003125 54 | 55 | create_minatar_env.game_name ='seaquest' 56 | 57 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 58 | Runner.num_iterations = 10 59 | Runner.training_steps = 1000000 60 | Runner.max_steps_per_episode = 100000000 61 | 62 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 63 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 64 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_acrobot.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.ACROBOT_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.ACROBOT_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.ACROBOT_STACK_SIZE 14 | JaxDQNAgent.gamma = 0.99 15 | JaxDQNAgent.update_horizon = 1 16 | JaxDQNAgent.min_replay_history = 500 # agent steps 17 | JaxDQNAgent.update_period = 2 18 | JaxDQNAgent.target_update_period = 100 # agent steps 19 | 20 | JaxQuantileAgentNew.optimizer = 'adam' 21 | JaxQuantileAgentNew.kappa = 1.0 22 | JaxQuantileAgentNew.num_atoms = 51 23 | JaxQuantileAgentNew.net_conf = 'classic' 24 | JaxQuantileAgentNew.env = 'Acrobot' 25 | JaxQuantileAgentNew.normalize_obs = True 26 | JaxQuantileAgentNew.hidden_layer = 2 27 | JaxQuantileAgentNew.neurons = 512 28 | JaxQuantileAgentNew.double_dqn = False 29 | JaxQuantileAgentNew.noisy = False 30 | JaxQuantileAgentNew.dueling = False 31 | JaxQuantileAgentNew.replay_scheme = 'uniform'#'prioritized' 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 38 | JaxQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 39 | 40 | create_optimizer.learning_rate = 0.0001 41 | create_optimizer.eps = 0.0003125 42 | 43 | create_gym_environment.environment_name = 'Acrobot' 44 | create_gym_environment.version = 'v1' 45 | 46 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 1000 49 | Runner.max_steps_per_episode = 500 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_asterix.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.ASTERIX_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 1000 # agent steps 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 1000 # agent steps 20 | 21 | JaxQuantileAgentNew.optimizer = 'adam' 22 | JaxQuantileAgentNew.kappa = 1.0 23 | JaxQuantileAgentNew.num_atoms = 51 24 | JaxQuantileAgentNew.net_conf = 'minatar' 25 | JaxQuantileAgentNew.env = None 26 | JaxQuantileAgentNew.normalize_obs = False 27 | JaxQuantileAgentNew.hidden_layer = 0 28 | JaxQuantileAgentNew.neurons = None 29 | JaxQuantileAgentNew.double_dqn = False 30 | JaxQuantileAgentNew.noisy = False 31 | JaxQuantileAgentNew.dueling = False 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 38 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 39 | JaxQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | 44 | create_minatar_env.game_name ='asterix' 45 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 46 | 47 | Runner.num_iterations = 10 48 | Runner.training_steps = 1000000 49 | Runner.max_steps_per_episode = 100000000 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_breakout.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.BREAKOUT_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 1000 # agent steps 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 1000 # agent steps 20 | 21 | JaxQuantileAgentNew.optimizer = 'adam' 22 | JaxQuantileAgentNew.kappa = 1.0 23 | JaxQuantileAgentNew.num_atoms = 51 24 | JaxQuantileAgentNew.net_conf = 'minatar' 25 | JaxQuantileAgentNew.env = None 26 | JaxQuantileAgentNew.normalize_obs = False 27 | JaxQuantileAgentNew.hidden_layer = 0 28 | JaxQuantileAgentNew.neurons = None 29 | JaxQuantileAgentNew.double_dqn = False 30 | JaxQuantileAgentNew.noisy = False 31 | JaxQuantileAgentNew.dueling = False 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 38 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 39 | JaxQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | 44 | create_minatar_env.game_name ='breakout' 45 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 46 | 47 | Runner.num_iterations = 10 48 | Runner.training_steps = 1000000 49 | Runner.max_steps_per_episode = 100000000 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_cartpole.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.CARTPOLE_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.CARTPOLE_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.CARTPOLE_STACK_SIZE 14 | JaxDQNAgent.gamma = 0.99 15 | JaxDQNAgent.update_horizon = 1 16 | JaxDQNAgent.min_replay_history = 500 # agent steps 17 | JaxDQNAgent.update_period = 2 18 | JaxDQNAgent.target_update_period = 100 # agent steps 19 | 20 | JaxQuantileAgentNew.optimizer = 'adam' 21 | JaxQuantileAgentNew.kappa = 1.0 22 | JaxQuantileAgentNew.num_atoms = 51 23 | JaxQuantileAgentNew.net_conf = 'classic' 24 | JaxQuantileAgentNew.env = 'CartPole' 25 | JaxQuantileAgentNew.normalize_obs = True 26 | JaxQuantileAgentNew.hidden_layer = 2 27 | JaxQuantileAgentNew.neurons = 512 28 | JaxQuantileAgentNew.double_dqn = False 29 | JaxQuantileAgentNew.noisy = False 30 | JaxQuantileAgentNew.dueling = False 31 | JaxQuantileAgentNew.initzer = @variance_scaling() 32 | variance_scaling.scale=1 33 | variance_scaling.mode='fan_avg' 34 | variance_scaling.distribution='uniform' 35 | 36 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 37 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 38 | JaxQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 39 | 40 | create_optimizer.learning_rate = 0.0001 41 | create_optimizer.eps = 0.0003125 42 | 43 | create_gym_environment.environment_name = 'CartPole' 44 | create_gym_environment.version = 'v0' 45 | 46 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 1000 49 | Runner.max_steps_per_episode = 200 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_freeway.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.FREEWAY_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 1000 # agent steps 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 1000 # agent steps 20 | 21 | JaxQuantileAgentNew.optimizer = 'adam' 22 | JaxQuantileAgentNew.kappa = 1.0 23 | JaxQuantileAgentNew.num_atoms = 51 24 | JaxQuantileAgentNew.net_conf = 'minatar' 25 | JaxQuantileAgentNew.env = None 26 | JaxQuantileAgentNew.normalize_obs = False 27 | JaxQuantileAgentNew.hidden_layer = 0 28 | JaxQuantileAgentNew.neurons = None 29 | JaxQuantileAgentNew.double_dqn = False 30 | JaxQuantileAgentNew.noisy = False 31 | JaxQuantileAgentNew.dueling = False 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 38 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 39 | JaxQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | 44 | create_minatar_env.game_name ='freeway' 45 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 46 | 47 | Runner.num_iterations = 10 48 | Runner.training_steps = 1000000 49 | Runner.max_steps_per_episode = 100000000 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_invaders.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.SPACE_INVADERS_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 1000 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 1000 20 | 21 | JaxQuantileAgentNew.optimizer = 'adam' 22 | JaxQuantileAgentNew.kappa = 1.0 23 | JaxQuantileAgentNew.num_atoms = 51 24 | JaxQuantileAgentNew.net_conf = 'minatar' 25 | JaxQuantileAgentNew.env = None 26 | JaxQuantileAgentNew.normalize_obs = False 27 | JaxQuantileAgentNew.hidden_layer = 0 28 | JaxQuantileAgentNew.neurons = None 29 | JaxQuantileAgentNew.double_dqn = False 30 | JaxQuantileAgentNew.noisy = False 31 | JaxQuantileAgentNew.dueling = False 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 38 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 39 | JaxQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | 44 | create_minatar_env.game_name ='space_invaders' 45 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 46 | 47 | Runner.num_iterations = 10 48 | Runner.training_steps = 1000000 49 | Runner.max_steps_per_episode = 100000000 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_lunarlander.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.LUNAR_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.LUNAR_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.LUNAR_STACK_SIZE 14 | JaxDQNAgent.gamma = 0.99 15 | JaxDQNAgent.update_horizon = 1 16 | JaxDQNAgent.min_replay_history = 500 # agent steps 17 | JaxDQNAgent.update_period = 4 18 | JaxDQNAgent.target_update_period = 300 # agent steps 19 | 20 | JaxQuantileAgentNew.optimizer = 'adam' 21 | JaxQuantileAgentNew.kappa = 1.0 22 | JaxQuantileAgentNew.num_atoms = 51 23 | JaxQuantileAgentNew.net_conf = 'classic' 24 | JaxQuantileAgentNew.env = 'LunarLander' 25 | JaxQuantileAgentNew.normalize_obs = False 26 | JaxQuantileAgentNew.hidden_layer = 2 27 | JaxQuantileAgentNew.neurons = 512 28 | JaxQuantileAgentNew.double_dqn = False 29 | JaxQuantileAgentNew.noisy = False 30 | JaxQuantileAgentNew.dueling = False 31 | JaxQuantileAgentNew.initzer = @variance_scaling() 32 | variance_scaling.scale=1 33 | variance_scaling.mode='fan_avg' 34 | variance_scaling.distribution='uniform' 35 | 36 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 37 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 38 | JaxQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 39 | 40 | create_optimizer.learning_rate = 1e-3 41 | create_optimizer.eps = 3.125e-4 42 | 43 | create_gym_environment.environment_name = 'LunarLander' 44 | create_gym_environment.version = 'v2' 45 | 46 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 4000 49 | Runner.max_steps_per_episode = 1000 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_mountaincar.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.discrete_domains.run_experiment 6 | import dopamine.replay_memory.prioritized_replay_buffer 7 | 8 | import networks_new 9 | import quantile_agent_new 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %gym_lib.MOUNTAINCAR_OBSERVATION_SHAPE 13 | JaxDQNAgent.observation_dtype = %jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE 14 | JaxDQNAgent.stack_size = %gym_lib.MOUNTAINCAR_STACK_SIZE 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 500 # agent steps 18 | JaxDQNAgent.update_period = 2 19 | JaxDQNAgent.target_update_period = 100 # agent steps 20 | 21 | JaxQuantileAgentNew.optimizer = 'adam' 22 | JaxQuantileAgentNew.kappa = 1.0 23 | JaxQuantileAgentNew.num_atoms = 51 24 | JaxQuantileAgentNew.net_conf = 'classic' 25 | JaxQuantileAgentNew.env = 'MountainCar' 26 | JaxQuantileAgentNew.normalize_obs = True 27 | JaxQuantileAgentNew.hidden_layer = 2 28 | JaxQuantileAgentNew.neurons = 512 29 | JaxQuantileAgentNew.double_dqn = False 30 | JaxQuantileAgentNew.noisy = False 31 | JaxQuantileAgentNew.dueling = False 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 38 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 39 | JaxQuantileAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | 44 | create_gym_environment.environment_name = 'MountainCar' 45 | create_gym_environment.version = 'v0' 46 | 47 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 48 | Runner.num_iterations = 30 49 | Runner.training_steps = 1000 50 | Runner.max_steps_per_episode = 600 51 | 52 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 53 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 54 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/quantile_seaquest.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import quantile_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.SEAQUEST_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 1 17 | JaxDQNAgent.min_replay_history = 1000 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 1000 20 | 21 | JaxQuantileAgentNew.optimizer = 'adam' 22 | JaxQuantileAgentNew.kappa = 1.0 23 | JaxQuantileAgentNew.num_atoms = 51 24 | JaxQuantileAgentNew.net_conf = 'minatar' 25 | JaxQuantileAgentNew.env = None 26 | JaxQuantileAgentNew.normalize_obs = False 27 | JaxQuantileAgentNew.hidden_layer = 0 28 | JaxQuantileAgentNew.neurons = None 29 | JaxQuantileAgentNew.double_dqn = False 30 | JaxQuantileAgentNew.noisy = False 31 | JaxQuantileAgentNew.dueling = False 32 | JaxQuantileAgentNew.initzer = @variance_scaling() 33 | variance_scaling.scale=1 34 | variance_scaling.mode='fan_avg' 35 | variance_scaling.distribution='uniform' 36 | 37 | JaxQuantileAgentNew.replay_scheme = 'uniform' #'prioritized' 38 | JaxQuantileAgentNew.network = @networks_new.QuantileNetwork 39 | JaxQuantileAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | 44 | create_minatar_env.game_name ='seaquest' 45 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 46 | 47 | Runner.num_iterations = 10 48 | Runner.training_steps = 1000000 49 | Runner.max_steps_per_episode = 100000000 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_acrobot.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.ACROBOT_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.ACROBOT_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.ACROBOT_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 3 # Rainbow 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 2 19 | JaxDQNAgent.target_update_period = 100 20 | 21 | JaxRainbowAgentNew.optimizer = 'adam' 22 | JaxRainbowAgentNew.noisy = True 23 | JaxRainbowAgentNew.dueling = True 24 | JaxRainbowAgentNew.initzer = @variance_scaling() 25 | variance_scaling.scale=1 26 | variance_scaling.mode='fan_avg' 27 | variance_scaling.distribution='uniform' 28 | 29 | JaxRainbowAgentNew.double_dqn = True 30 | JaxRainbowAgentNew.net_conf = 'classic' 31 | JaxRainbowAgentNew.env = 'Acrobot' 32 | JaxRainbowAgentNew.normalize_obs = True 33 | JaxRainbowAgentNew.hidden_layer = 2 34 | JaxRainbowAgentNew.neurons = 512 35 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 36 | JaxRainbowAgentNew.vmax = 200. 37 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 38 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 39 | JaxRainbowAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | create_gym_environment.environment_name = 'Acrobot' 44 | create_gym_environment.version = 'v1' 45 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 46 | 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 1000 49 | Runner.max_steps_per_episode = 500 # Default max episode length. 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_asterix.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.ASTERIX_SHAPE 13 | JaxDQNAgent.observation_dtype =%minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 3 # Rainbow 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxRainbowAgentNew.optimizer = 'adam' 23 | JaxRainbowAgentNew.noisy = True 24 | JaxRainbowAgentNew.dueling = True 25 | JaxRainbowAgentNew.initzer = @variance_scaling() 26 | variance_scaling.scale=1 27 | variance_scaling.mode='fan_avg' 28 | variance_scaling.distribution='uniform' 29 | 30 | JaxRainbowAgentNew.double_dqn = True 31 | JaxRainbowAgentNew.net_conf = 'minatar' 32 | JaxRainbowAgentNew.env = None 33 | JaxRainbowAgentNew.normalize_obs = False 34 | JaxRainbowAgentNew.hidden_layer = 0 35 | JaxRainbowAgentNew.neurons = None 36 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 37 | JaxRainbowAgentNew.vmax = 100. 38 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 39 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 40 | JaxRainbowAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 41 | 42 | create_optimizer.learning_rate = 0.0001 43 | create_optimizer.eps = 0.0003125 44 | 45 | create_minatar_env.game_name = 'asterix' 46 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 47 | 48 | Runner.num_iterations = 10 49 | Runner.training_steps = 1000000 50 | Runner.max_steps_per_episode = 100000000 51 | 52 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 53 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 54 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_breakout.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.BREAKOUT_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 3 # Rainbow 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxRainbowAgentNew.optimizer = 'adam' 23 | JaxRainbowAgentNew.noisy = True 24 | JaxRainbowAgentNew.dueling = True 25 | JaxRainbowAgentNew.initzer = @variance_scaling() 26 | variance_scaling.scale=1 27 | variance_scaling.mode='fan_avg' 28 | variance_scaling.distribution='uniform' 29 | 30 | JaxRainbowAgentNew.double_dqn = True 31 | JaxRainbowAgentNew.net_conf = 'minatar' 32 | JaxRainbowAgentNew.env = None 33 | JaxRainbowAgentNew.normalize_obs = False 34 | JaxRainbowAgentNew.hidden_layer = 0 35 | JaxRainbowAgentNew.neurons = None 36 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 37 | JaxRainbowAgentNew.vmax = 100. 38 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 39 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 40 | JaxRainbowAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 41 | 42 | create_optimizer.learning_rate = 0.0001 43 | create_optimizer.eps = 0.0003125 44 | 45 | create_minatar_env.game_name = 'breakout' 46 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 47 | 48 | Runner.num_iterations = 10 49 | Runner.training_steps = 1000000 50 | Runner.max_steps_per_episode = 100000000 51 | 52 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 53 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 54 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_cartpole.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.CARTPOLE_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.CARTPOLE_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.CARTPOLE_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 3 # Rainbow 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 2 19 | JaxDQNAgent.target_update_period = 100 20 | 21 | JaxRainbowAgentNew.optimizer = 'adam' 22 | JaxRainbowAgentNew.noisy = True 23 | JaxRainbowAgentNew.dueling = True 24 | JaxRainbowAgentNew.initzer = @variance_scaling() 25 | variance_scaling.scale=1 26 | variance_scaling.mode='fan_avg' 27 | variance_scaling.distribution='uniform' 28 | 29 | JaxRainbowAgentNew.double_dqn = True 30 | JaxRainbowAgentNew.net_conf = 'classic' 31 | JaxRainbowAgentNew.env = 'CartPole' 32 | JaxRainbowAgentNew.normalize_obs = True 33 | JaxRainbowAgentNew.hidden_layer = 2 34 | JaxRainbowAgentNew.neurons = 512 35 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 36 | JaxRainbowAgentNew.vmax = 200. 37 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 38 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 39 | JaxRainbowAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | create_gym_environment.environment_name = 'CartPole' 44 | create_gym_environment.version = 'v0' 45 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 46 | 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 1000 49 | Runner.max_steps_per_episode = 200 # Default max episode length. 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_freeway.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.FREEWAY_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 3 # Rainbow 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxRainbowAgentNew.optimizer = 'adam' 23 | JaxRainbowAgentNew.noisy = True 24 | JaxRainbowAgentNew.dueling = True 25 | JaxRainbowAgentNew.initzer = @variance_scaling() 26 | variance_scaling.scale=1 27 | variance_scaling.mode='fan_avg' 28 | variance_scaling.distribution='uniform' 29 | 30 | JaxRainbowAgentNew.double_dqn = True 31 | JaxRainbowAgentNew.net_conf = 'minatar' 32 | JaxRainbowAgentNew.env = None 33 | JaxRainbowAgentNew.normalize_obs = False 34 | JaxRainbowAgentNew.hidden_layer = 0 35 | JaxRainbowAgentNew.neurons = None 36 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 37 | JaxRainbowAgentNew.vmax = 100. 38 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 39 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 40 | JaxRainbowAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 41 | 42 | create_optimizer.learning_rate = 0.0001 43 | create_optimizer.eps = 0.0003125 44 | 45 | create_minatar_env.game_name = 'freeway' 46 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 47 | 48 | Runner.num_iterations = 10 49 | Runner.training_steps = 1000000 50 | Runner.max_steps_per_episode = 100000000 51 | 52 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 53 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 54 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_invaders.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.SPACE_INVADERS_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 3 # Rainbow 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxRainbowAgentNew.optimizer = 'adam' 23 | JaxRainbowAgentNew.noisy = True 24 | JaxRainbowAgentNew.dueling = True 25 | JaxRainbowAgentNew.initzer = @variance_scaling() 26 | variance_scaling.scale=1 27 | variance_scaling.mode='fan_avg' 28 | variance_scaling.distribution='uniform' 29 | 30 | JaxRainbowAgentNew.double_dqn = True 31 | JaxRainbowAgentNew.net_conf = 'minatar' 32 | JaxRainbowAgentNew.env = None 33 | JaxRainbowAgentNew.normalize_obs = False 34 | JaxRainbowAgentNew.hidden_layer = 0 35 | JaxRainbowAgentNew.neurons = None 36 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 37 | JaxRainbowAgentNew.vmax = 100. 38 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 39 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 40 | JaxRainbowAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 41 | 42 | create_optimizer.learning_rate = 0.0001 43 | create_optimizer.eps = 0.0003125 44 | 45 | create_minatar_env.game_name = 'space_invaders' 46 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 47 | 48 | Runner.num_iterations = 10 49 | Runner.training_steps = 1000000 50 | Runner.max_steps_per_episode = 100000000 51 | 52 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 53 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 54 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_lunarlander.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.LUNAR_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.LUNAR_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.LUNAR_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 3 # Rainbow 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 4 19 | JaxDQNAgent.target_update_period = 300 20 | 21 | JaxRainbowAgentNew.optimizer = 'adam' 22 | JaxRainbowAgentNew.noisy = True 23 | JaxRainbowAgentNew.dueling = True 24 | JaxRainbowAgentNew.initzer = @variance_scaling() 25 | variance_scaling.scale=1 26 | variance_scaling.mode='fan_avg' 27 | variance_scaling.distribution='uniform' 28 | 29 | JaxRainbowAgentNew.double_dqn = True 30 | JaxRainbowAgentNew.net_conf = 'classic' 31 | JaxRainbowAgentNew.env = 'LunarLander' 32 | JaxRainbowAgentNew.normalize_obs = False 33 | JaxRainbowAgentNew.hidden_layer = 2 34 | JaxRainbowAgentNew.neurons = 512 35 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 36 | JaxRainbowAgentNew.vmax = 300#200. 37 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 38 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 39 | JaxRainbowAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 40 | 41 | create_optimizer.learning_rate = 1e-3 42 | create_optimizer.eps = 3.125e-4 43 | create_gym_environment.environment_name = 'LunarLander' 44 | create_gym_environment.version = 'v2' 45 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 46 | 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 4000 49 | Runner.max_steps_per_episode = 1000 # Default max episode length. 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_mountaincar.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import external_configurations 10 | 11 | JaxDQNAgent.observation_shape = %gym_lib.MOUNTAINCAR_OBSERVATION_SHAPE 12 | JaxDQNAgent.observation_dtype = %jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE 13 | JaxDQNAgent.stack_size = %gym_lib.MOUNTAINCAR_STACK_SIZE 14 | 15 | JaxDQNAgent.gamma = 0.99 16 | JaxDQNAgent.update_horizon = 3 # Rainbow 17 | JaxDQNAgent.min_replay_history = 500 18 | JaxDQNAgent.update_period = 2 19 | JaxDQNAgent.target_update_period = 100 20 | 21 | JaxRainbowAgentNew.optimizer = 'adam' 22 | JaxRainbowAgentNew.noisy = True 23 | JaxRainbowAgentNew.dueling = True 24 | JaxRainbowAgentNew.initzer = @variance_scaling() 25 | variance_scaling.scale=1 26 | variance_scaling.mode='fan_avg' 27 | variance_scaling.distribution='uniform' 28 | 29 | JaxRainbowAgentNew.double_dqn = True 30 | JaxRainbowAgentNew.net_conf = 'classic' 31 | JaxRainbowAgentNew.env = 'MountainCar' 32 | JaxRainbowAgentNew.normalize_obs = True 33 | JaxRainbowAgentNew.hidden_layer = 2 34 | JaxRainbowAgentNew.neurons = 512 35 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 36 | JaxRainbowAgentNew.vmax = 100. 37 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 38 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 39 | JaxRainbowAgentNew.epsilon_fn = @dqn_agent.identity_epsilon 40 | 41 | create_optimizer.learning_rate = 0.0001 42 | create_optimizer.eps = 0.0003125 43 | create_gym_environment.environment_name = 'MountainCar' 44 | create_gym_environment.version = 'v0' 45 | TrainRunner.create_environment_fn = @gym_lib.create_gym_environment 46 | 47 | Runner.num_iterations = 30 48 | Runner.training_steps = 1000 49 | Runner.max_steps_per_episode = 600 # Default max episode length. 50 | 51 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 50000 52 | OutOfGraphPrioritizedReplayBuffer.batch_size = 128 53 | -------------------------------------------------------------------------------- /revisiting_rainbow/Configs/rainbow_seaquest.gin: -------------------------------------------------------------------------------- 1 | import dopamine.jax.agents.dqn.dqn_agent 2 | import dopamine.jax.networks 3 | import dopamine.discrete_domains.gym_lib 4 | import dopamine.discrete_domains.run_experiment 5 | import dopamine.replay_memory.prioritized_replay_buffer 6 | 7 | import networks_new 8 | import rainbow_agent_new 9 | import minatar_env 10 | import external_configurations 11 | 12 | JaxDQNAgent.observation_shape = %minatar_env.SEAQUEST_SHAPE 13 | JaxDQNAgent.observation_dtype = %minatar_env.DTYPE 14 | JaxDQNAgent.stack_size = 1 15 | 16 | JaxDQNAgent.gamma = 0.99 17 | JaxDQNAgent.update_horizon = 3 # Rainbow 18 | JaxDQNAgent.min_replay_history = 1000 19 | JaxDQNAgent.update_period = 4 20 | JaxDQNAgent.target_update_period = 1000 21 | 22 | JaxRainbowAgentNew.optimizer = 'adam' 23 | JaxRainbowAgentNew.noisy = True 24 | JaxRainbowAgentNew.dueling = True 25 | JaxRainbowAgentNew.initzer = @variance_scaling() 26 | variance_scaling.scale=1 27 | variance_scaling.mode='fan_avg' 28 | variance_scaling.distribution='uniform' 29 | 30 | JaxRainbowAgentNew.double_dqn = True 31 | JaxRainbowAgentNew.net_conf = 'minatar' 32 | JaxRainbowAgentNew.env = None 33 | JaxRainbowAgentNew.normalize_obs = False 34 | JaxRainbowAgentNew.hidden_layer = 0 35 | JaxRainbowAgentNew.neurons = None 36 | JaxRainbowAgentNew.num_atoms = 51 # Original 51 37 | JaxRainbowAgentNew.vmax = 100. 38 | JaxRainbowAgentNew.replay_scheme = 'prioritized' 39 | JaxRainbowAgentNew.network = @networks_new.RainbowDQN 40 | JaxRainbowAgentNew.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon 41 | 42 | create_optimizer.learning_rate = 0.0001 43 | create_optimizer.eps = 0.0003125 44 | 45 | create_minatar_env.game_name = 'seaquest' 46 | TrainRunner.create_environment_fn = @minatar_env.create_minatar_env 47 | 48 | Runner.num_iterations = 10 49 | Runner.training_steps = 1000000 50 | Runner.max_steps_per_episode = 100000000 51 | 52 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 100000 53 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 54 | -------------------------------------------------------------------------------- /revisiting_rainbow/external_configurations.py: -------------------------------------------------------------------------------- 1 | """External configuration .gin""" 2 | 3 | from gin import config 4 | from flax import linen as nn 5 | 6 | config.external_configurable(nn.initializers.zeros, 'nn.initializers.zeros') 7 | config.external_configurable(nn.initializers.ones, 'nn.initializers.ones') 8 | config.external_configurable(nn.initializers.variance_scaling, 'nn.initializers.variance_scaling') -------------------------------------------------------------------------------- /revisiting_rainbow/minatar_env.py: -------------------------------------------------------------------------------- 1 | """MinAtar environment made compatible for Dopamine.""" 2 | 3 | from dopamine.discrete_domains import atari_lib 4 | from flax import nn 5 | import gin 6 | import jax 7 | import jax.numpy as jnp 8 | import minatar 9 | 10 | 11 | gin.constant('minatar_env.ASTERIX_SHAPE', (10, 10, 4)) 12 | gin.constant('minatar_env.BREAKOUT_SHAPE', (10, 10, 4)) 13 | gin.constant('minatar_env.FREEWAY_SHAPE', (10, 10, 7)) 14 | gin.constant('minatar_env.SEAQUEST_SHAPE', (10, 10, 10)) 15 | gin.constant('minatar_env.SPACE_INVADERS_SHAPE', (10, 10, 6)) 16 | gin.constant('minatar_env.DTYPE', jnp.float64) 17 | 18 | 19 | class MinAtarEnv(object): 20 | def __init__(self, game_name): 21 | self.env = minatar.Environment(env_name=game_name) 22 | self.env.n = self.env.num_actions() 23 | self.game_over = False 24 | 25 | @property 26 | def observation_space(self): 27 | return self.env.state_shape() 28 | 29 | @property 30 | def action_space(self): 31 | return self.env # Only used for the `n` parameter. 32 | 33 | @property 34 | def reward_range(self): 35 | pass # Unused 36 | 37 | @property 38 | def metadata(self): 39 | pass # Unused 40 | 41 | def reset(self): 42 | self.game_over = False 43 | self.env.reset() 44 | return self.env.state() 45 | 46 | def step(self, action): 47 | r, terminal = self.env.act(action) 48 | self.game_over = terminal 49 | return self.env.state(), r, terminal, None 50 | 51 | 52 | @gin.configurable 53 | def create_minatar_env(game_name): 54 | return MinAtarEnv(game_name) 55 | -------------------------------------------------------------------------------- /revisiting_rainbow/networks_new.py: -------------------------------------------------------------------------------- 1 | """Various networks for Jax Dopamine agents.""" 2 | 3 | from dopamine.discrete_domains import atari_lib 4 | from dopamine.discrete_domains import gym_lib 5 | from flax import linen as nn 6 | import gin 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as onp 10 | from jax import random 11 | import math 12 | 13 | from jax.tree_util import tree_flatten, tree_map 14 | 15 | #--------------------------------------------------------------------------------------------------------------------- 16 | 17 | env_inf = {"CartPole":{"MIN_VALS": jnp.array([-2.4, -5., -math.pi/12., -math.pi*2.]),"MAX_VALS": jnp.array([2.4, 5., math.pi/12., math.pi*2.])}, 18 | "Acrobot":{"MIN_VALS": jnp.array([-1., -1., -1., -1., -5., -5.]),"MAX_VALS": jnp.array([1., 1., 1., 1., 5., 5.])}, 19 | "MountainCar":{"MIN_VALS":jnp.array([-1.2, -0.07]),"MAX_VALS": jnp.array([0.6, 0.07])} 20 | } 21 | 22 | prn_inf = {"count":0, "rng2_":None, "rng3_":None} 23 | 24 | #--------------------------------------------------------------------------------------------------------------------- 25 | class NoisyNetwork(nn.Module): 26 | features: int 27 | rng: int 28 | bias_in: bool 29 | 30 | @nn.compact 31 | def __call__(self, x): 32 | 33 | def sample_noise(rng_input, shape): 34 | noise = jax.random.normal(rng_input,shape) 35 | return noise 36 | 37 | def f(x): 38 | return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5)) 39 | # Initializer of \mu and \sigma 40 | 41 | def mu_init(key, shape, rng): 42 | low = -1*1/jnp.power(x.shape[-1], 0.5) 43 | high = 1*1/jnp.power(x.shape[-1], 0.5) 44 | return random.uniform(rng, shape=shape, dtype=jnp.float32, minval=low, maxval=high) 45 | 46 | def sigma_init(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)*(0.1 / jnp.sqrt(x.shape[-1])) 47 | 48 | rng, rng2, rng3, rng4, rng5 = jax.random.split(self.rng, 5) 49 | 50 | if prn_inf["count"] == 0: 51 | prn_inf["rng2_"] = rng2 52 | prn_inf["rng3_"] = rng3 53 | prn_inf["count"] = prn_inf["count"]+1 54 | 55 | # Sample noise from gaussian 56 | p = sample_noise(prn_inf["rng2_"], [x.shape[-1], 1]) 57 | q = sample_noise(prn_inf["rng3_"], [1, self.features]) 58 | f_p = f(p); f_q = f(q) 59 | 60 | w_epsilon = f_p*f_q; b_epsilon = jnp.squeeze(f_q) 61 | w_mu = self.param('kernel', mu_init, (x.shape[-1], self.features), rng4) 62 | w_sigma = self.param('kernell', sigma_init, (x.shape[-1], self.features)) 63 | w = w_mu + jnp.multiply(w_sigma, w_epsilon) 64 | ret = jnp.matmul(x, w) 65 | 66 | b_mu = self.param('bias', mu_init, (self.features,), rng5) 67 | b_sigma = self.param('biass',sigma_init, (self.features,)) 68 | b = b_mu + jnp.multiply(b_sigma, b_epsilon) 69 | 70 | return jnp.where(self.bias_in, ret + b, ret) 71 | 72 | #---------------------------------------------< DQNNetwork >---------------------------------------------------------- 73 | 74 | @gin.configurable 75 | class DQNNetwork(nn.Module): 76 | 77 | num_actions:int 78 | net_conf: str 79 | env: str 80 | normalize_obs:bool 81 | noisy: bool 82 | dueling: bool 83 | initzer:str 84 | hidden_layer: int 85 | neurons: int 86 | 87 | @nn.compact 88 | def __call__(self, x , rng): 89 | 90 | if self.net_conf == 'minatar': 91 | x = x.squeeze(3) 92 | x = x.astype(jnp.float32) 93 | x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) 94 | x = jax.nn.relu(x) 95 | x = x.reshape((-1)) 96 | 97 | elif self.net_conf == 'atari': 98 | # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will 99 | # have removed the true batch dimension. 100 | x = x.astype(jnp.float32) / 255. 101 | x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), 102 | kernel_init=self.initzer)(x) 103 | x = jax.nn.relu(x) 104 | x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), 105 | kernel_init=self.initzer)(x) 106 | x = jax.nn.relu(x) 107 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), 108 | kernel_init=self.initzer)(x) 109 | x = jax.nn.relu(x) 110 | x = x.reshape((-1)) # flatten 111 | 112 | elif self.net_conf == 'classic': 113 | #classic environments 114 | x = x.astype(jnp.float32) 115 | x = x.reshape((-1)) 116 | 117 | if self.env is not None and self.env in env_inf: 118 | x = x - env_inf[self.env]['MIN_VALS'] 119 | x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] 120 | x = 2.0 * x - 1.0 121 | 122 | if self.noisy: 123 | def net(x, features, rng): 124 | return NoisyNetwork(features, rng=rng, bias_in=True)(x) 125 | else: 126 | def net(x, features, rng): 127 | return nn.Dense(features, kernel_init=self.initzer)(x) 128 | 129 | for _ in range(self.hidden_layer): 130 | x = net(x, features=self.neurons, rng=rng) 131 | x = jax.nn.relu(x) 132 | 133 | adv = net(x, features=self.num_actions, rng=rng) 134 | val = net(x, features=1, rng=rng) 135 | 136 | dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True))) 137 | non_dueling_q = net(x, features=self.num_actions, rng=rng) 138 | 139 | q_values = jnp.where(self.dueling, dueling_q, non_dueling_q) 140 | return atari_lib.DQNNetworkType(q_values) 141 | #---------------------------------------------< RainbowDQN >---------------------------------------------------------- 142 | 143 | @gin.configurable 144 | class RainbowDQN(nn.Module): 145 | 146 | num_actions:int 147 | net_conf:str 148 | env:str 149 | normalize_obs:bool 150 | noisy:bool 151 | dueling:bool 152 | initzer:str 153 | num_atoms:int 154 | hidden_layer:int 155 | neurons:int 156 | 157 | @nn.compact 158 | def __call__(self, x, support, rng): 159 | 160 | if self.net_conf == 'minatar': 161 | x = x.squeeze(3) 162 | x = x.astype(jnp.float32) 163 | x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) 164 | x = jax.nn.relu(x) 165 | x = x.reshape((-1)) 166 | 167 | elif self.net_conf == 'atari': 168 | # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will 169 | # have removed the true batch dimension. 170 | x = x.astype(jnp.float32) / 255. 171 | x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), 172 | kernel_init=self.initzer)(x) 173 | x = jax.nn.relu(x) 174 | x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), 175 | kernel_init=self.initzer)(x) 176 | x = jax.nn.relu(x) 177 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), 178 | kernel_init=self.initzer)(x) 179 | x = jax.nn.relu(x) 180 | x = x.reshape((-1)) # flatten 181 | 182 | elif self.net_conf == 'classic': 183 | x = x.astype(jnp.float32) 184 | x = x.reshape((-1)) 185 | 186 | if self.env is not None and self.env in env_inf: 187 | x = x - env_inf[self.env]['MIN_VALS'] 188 | x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] 189 | x = 2.0 * x - 1.0 190 | 191 | if self.noisy: 192 | def net(x, features, rng): 193 | return NoisyNetwork(features, rng=rng, bias_in=True)(x) 194 | else: 195 | def net(x, features, rng): 196 | return nn.Dense(features, kernel_init=self.initzer)(x) 197 | 198 | for _ in range(self.hidden_layer): 199 | x = net(x, features=self.neurons, rng=rng) 200 | x = jax.nn.relu(x) 201 | 202 | if self.dueling: 203 | adv = net(x,features=self.num_actions * self.num_atoms, rng=rng) 204 | value = net(x, features=self.num_atoms, rng=rng) 205 | 206 | adv = adv.reshape((self.num_actions, self.num_atoms)) 207 | value = value.reshape((1, self.num_atoms)) 208 | 209 | logits = value + (adv - (jnp.mean(adv, -2, keepdims=True))) 210 | probabilities = nn.softmax(logits) 211 | q_values = jnp.sum(support * probabilities, axis=1) 212 | 213 | else: 214 | x = net(x, features=self.num_actions * self.num_atoms, rng=rng) 215 | logits = x.reshape((self.num_actions, self.num_atoms)) 216 | probabilities = nn.softmax(logits) 217 | q_values = jnp.sum(support * probabilities, axis=1) 218 | 219 | return atari_lib.RainbowNetworkType(q_values, logits, probabilities) 220 | 221 | #---------------------------------------------< QuantileNetwork >---------------------------------------------------------- 222 | 223 | @gin.configurable 224 | class QuantileNetwork(nn.Module): 225 | 226 | num_actions:int 227 | net_conf:str 228 | env:str 229 | normalize_obs:bool 230 | noisy:bool 231 | dueling:bool 232 | initzer:str 233 | num_atoms:int 234 | hidden_layer:int 235 | neurons:int 236 | 237 | @nn.compact 238 | def __call__(self, x, rng): 239 | 240 | if self.net_conf == 'minatar': 241 | x = x.squeeze(3) 242 | x = x.astype(jnp.float32) 243 | x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) 244 | x = jax.nn.relu(x) 245 | x = x.reshape((-1)) 246 | 247 | elif self.net_conf == 'atari': 248 | # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will 249 | # have removed the true batch dimension. 250 | x = x.astype(jnp.float32) / 255. 251 | x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), 252 | kernel_init=self.initzer)(x) 253 | x = jax.nn.relu(x) 254 | x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), 255 | kernel_init=self.initzer)(x) 256 | x = jax.nn.relu(x) 257 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), 258 | kernel_init=self.initzer)(x) 259 | x = jax.nn.relu(x) 260 | x = x.reshape((-1)) # flatten 261 | 262 | elif self.net_conf == 'classic': 263 | #classic environments 264 | x = x.astype(jnp.float32) 265 | x = x.reshape((-1)) 266 | 267 | if self.env is not None and self.env in env_inf: 268 | x = x - env_inf[self.env]['MIN_VALS'] 269 | x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] 270 | x = 2.0 * x - 1.0 271 | 272 | if self.noisy: 273 | def net(x, features, rng): 274 | return NoisyNetwork(features, rng=rng, bias_in=True)(x) 275 | else: 276 | def net(x, features, rng): 277 | return nn.Dense(features, kernel_init=self.initzer)(x) 278 | 279 | for _ in range(self.hidden_layer): 280 | x = net(x, features=self.neurons, rng=rng) 281 | x = jax.nn.relu(x) 282 | 283 | if self.dueling: 284 | adv = net(x,features=self.num_actions * self.num_atoms, rng=rng) 285 | value = net(x, features=self.num_atoms, rng=rng) 286 | adv = adv.reshape((self.num_actions, self.num_atoms)) 287 | value = value.reshape((1, self.num_atoms)) 288 | 289 | logits = value + (adv - (jnp.mean(adv, -2, keepdims=True))) 290 | probabilities = nn.softmax(logits) 291 | q_values = jnp.mean(logits, axis=1) 292 | 293 | else: 294 | x = net(x, features=self.num_actions * self.num_atoms, rng=rng) 295 | logits = x.reshape((self.num_actions, self.num_atoms)) 296 | probabilities = nn.softmax(logits) 297 | q_values = jnp.mean(logits, axis=1) 298 | 299 | return atari_lib.RainbowNetworkType(q_values, logits, probabilities) 300 | 301 | #---------------------------------------------< IQ-Network >---------------------------------------------------------- 302 | @gin.configurable 303 | class ImplicitQuantileNetwork(nn.Module): 304 | 305 | num_actions:int 306 | net_conf:str 307 | env:str 308 | noisy:bool 309 | dueling:bool 310 | initzer:str 311 | quantile_embedding_dim:int 312 | hidden_layer:int 313 | neurons:int 314 | 315 | @nn.compact 316 | def __call__(self, x, num_quantiles, rng): 317 | 318 | if self.net_conf == 'minatar': 319 | x = x.squeeze(3) 320 | x = x.astype(jnp.float32) 321 | x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) 322 | x = jax.nn.relu(x) 323 | x = x.reshape((-1)) 324 | 325 | elif self.net_conf == 'atari': 326 | # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will 327 | # have removed the true batch dimension. 328 | x = x.astype(jnp.float32) / 255. 329 | x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), 330 | kernel_init=self.initzer)(x) 331 | x = jax.nn.relu(x) 332 | x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), 333 | kernel_init=self.initzer)(x) 334 | x = jax.nn.relu(x) 335 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), 336 | kernel_init=self.initzer)(x) 337 | x = jax.nn.relu(x) 338 | x = x.reshape((-1)) # flatten 339 | 340 | elif self.net_conf == 'classic': 341 | #classic environments 342 | x = x.astype(jnp.float32) 343 | x = x.reshape((-1)) 344 | 345 | if self.env is not None and self.env in env_inf: 346 | x = x - env_inf[self.env]['MIN_VALS'] 347 | x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] 348 | x = 2.0 * x - 1.0 349 | 350 | if self.noisy: 351 | def net(x, features, rng): 352 | return NoisyNetwork(features, rng=rng, bias_in=True)(x) 353 | else: 354 | def net(x, features, rng): 355 | return nn.Dense(features, kernel_init=self.initzer)(x) 356 | 357 | for _ in range(self.hidden_layer): 358 | x = net(x, features=self.neurons, rng=rng) 359 | x = jax.nn.relu(x) 360 | 361 | state_vector_length = x.shape[-1] 362 | state_net_tiled = jnp.tile(x, [num_quantiles, 1]) 363 | quantiles_shape = [num_quantiles, 1] 364 | quantiles = jax.random.uniform(rng, shape=quantiles_shape) 365 | quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim]) 366 | quantile_net = ( 367 | jnp.arange(1, self.quantile_embedding_dim + 1, 1).astype(jnp.float32) 368 | * onp.pi 369 | * quantile_net) 370 | quantile_net = jnp.cos(quantile_net) 371 | quantile_net = nn.Dense(features=state_vector_length, 372 | kernel_init=self.initzer)(quantile_net) 373 | quantile_net = jax.nn.relu(quantile_net) 374 | x = state_net_tiled * quantile_net 375 | 376 | adv = net(x,features=self.num_actions, rng=rng) 377 | val = net(x, features=1, rng=rng) 378 | dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True))) 379 | non_dueling_q = net(x, features=self.num_actions, rng=rng) 380 | quantile_values = jnp.where(self.dueling, dueling_q, non_dueling_q) 381 | 382 | return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for revisiting_rainbow. 2 | 3 | This script will install the algorithms presented in revisiting_rainbow paper as a Python module. 4 | 5 | See: https://github.com/JohanSamir/revisiting_rainbow 6 | """ 7 | 8 | import pathlib 9 | from setuptools import find_packages 10 | from setuptools import setup 11 | 12 | here = pathlib.Path(__file__).parent.resolve() 13 | 14 | long_description = (here / 'README.md').read_text(encoding='utf-8') 15 | 16 | install_requires = [ 17 | 'dopamine-rl>= 3.1.10 ', 18 | ] 19 | 20 | rev_rainbow_description = ( 21 | 'Revisiting Rainbow: Promoting more insightful and inclusive deep reinforcement learning research') 22 | 23 | setup( 24 | name='revisiting_rainbow', 25 | version='1.0.0', 26 | description=rev_rainbow_description, 27 | long_description=long_description, 28 | long_description_content_type='text/markdown', 29 | url='https://github.com/JohanSamir/revisiting_rainbow', 30 | author='Johan S Obando-Ceron and Pablo Samuel Castro', 31 | classifiers=[ 32 | 'Development Status :: 4 - Beta', 33 | 34 | 'Intended Audience :: Developers', 35 | 'Intended Audience :: Education', 36 | 'Intended Audience :: Science/Research', 37 | 38 | 'License :: OSI Approved :: Apache Software License', 39 | 40 | 'Programming Language :: Python :: 3', 41 | 'Programming Language :: Python :: 3.5', 42 | 'Programming Language :: Python :: 3.6', 43 | 'Programming Language :: Python :: 3.7', 44 | 'Programming Language :: Python :: 3.8', 45 | 'Programming Language :: Python :: 3 :: Only', 46 | 47 | 'Topic :: Scientific/Engineering', 48 | 'Topic :: Scientific/Engineering :: Mathematics', 49 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 50 | 'Topic :: Software Development', 51 | 'Topic :: Software Development :: Libraries', 52 | 'Topic :: Software Development :: Libraries :: Python Modules', 53 | 54 | ], 55 | keywords='dopamine, reinforcement, machine, learning, research', 56 | install_requires=install_requires, 57 | project_urls={ # Optional 58 | 'Bug Reports': 'https://github.com/JohanSamir/revisiting_rainbow/issues', 59 | 'Source': 'https://github.com/JohanSamir/revisiting_rainbow', 60 | }, 61 | license='Apache 2.0', 62 | ) 63 | 64 | --------------------------------------------------------------------------------